import { createRng, type Rng } from "@/lib/playground/prng";

export enum TradingAction {
  Buy = 0,
  Sell = 1,
  Hold = 2,
}

export const TRADING_ACTIONS: Array<{ action: TradingAction; label: string }> = [
  { action: TradingAction.Buy, label: "Buy" },
  { action: TradingAction.Sell, label: "Sell" },
  { action: TradingAction.Hold, label: "Hold" },
];

export type PricePoint = { t: number; close: number };

export type TradingConfig = {
  symbol: string;
  prices: PricePoint[];
  baseInvestment: number; // initial cash
  transactionFeeRate: number; // e.g., 0.001 = 0.1%
  episodeLength: number; // timesteps per episode
  startIndex: number | null; // null => random start
  maxPositionShares: number | null; // null => unlimited (within cash)
  rewardScale: number; // multiply reward by this (default 1)
};

export type TradingState = {
  step: number; // 0..episodeLength
  index: number; // index in prices array
  cash: number;
  shares: number;
  lastPrice: number;
  lastPortfolioValue: number;
};

export type TradingObservation = {
  // Minimal, gym-like numeric observation.
  // All values are normalized for learning stability.
  price: number; // current price
  return1: number; // one-step return
  cashFrac: number; // cash / portfolio
  positionFrac: number; // value in stock / portfolio
};

export type TradingInfo = {
  symbol: string;
  price: number;
  portfolioValue: number;
  pnl: number;
  feePaid: number;
  action: TradingAction;
};

export type TradingStepResult = {
  obs: TradingObservation;
  reward: number;
  done: boolean;
  truncated: boolean;
  info: TradingInfo;
  state: TradingState;
};

function clampNumber(v: number, min: number, max: number): number {
  if (!Number.isFinite(v)) return min;
  return Math.max(min, Math.min(max, v));
}

function safeLogReturn(curr: number, prev: number): number {
  if (!Number.isFinite(curr) || !Number.isFinite(prev) || curr <= 0 || prev <= 0) return 0;
  return Math.log(curr / prev);
}

function portfolioValue(cash: number, shares: number, price: number): number {
  return cash + shares * price;
}

export function defaultTradingConfig(symbol: string, prices: PricePoint[]): TradingConfig {
  return {
    symbol,
    prices,
    baseInvestment: 10_000,
    transactionFeeRate: 0.001,
    episodeLength: 120,
    startIndex: null,
    maxPositionShares: null,
    rewardScale: 1,
  };
}

export function normalizeTradingConfig(cfg: TradingConfig): TradingConfig {
  const baseInvestment = clampNumber(cfg.baseInvestment, 10, 1_000_000_000);
  const fee = clampNumber(cfg.transactionFeeRate, 0, 0.02);
  const episodeLength = Math.max(2, Math.min(2000, Math.trunc(cfg.episodeLength)));
  const rewardScale = clampNumber(cfg.rewardScale, 0.0001, 100);
  const symbol = (cfg.symbol ?? "").toUpperCase().slice(0, 16);
  const prices = (cfg.prices ?? []).filter((p) => Number.isFinite(p.close) && p.close > 0);

  const maxStart = Math.max(0, prices.length - episodeLength - 1);
  let startIndex = cfg.startIndex == null ? null : Math.trunc(cfg.startIndex);
  if (startIndex != null) startIndex = Math.max(0, Math.min(maxStart, startIndex));

  const maxPositionShares =
    cfg.maxPositionShares == null ? null : Math.max(0, Math.min(1_000_000, Math.trunc(cfg.maxPositionShares)));

  return {
    ...cfg,
    symbol,
    prices,
    baseInvestment,
    transactionFeeRate: fee,
    episodeLength,
    startIndex,
    maxPositionShares,
    rewardScale,
  };
}

export class StockTradingEnv {
  private cfg: TradingConfig;
  private rng: Rng;
  private initialSeed: number;
  private _state: TradingState | null = null;

  constructor(cfg: TradingConfig, seed: number = 1234) {
    this.cfg = normalizeTradingConfig(cfg);
    this.initialSeed = seed >>> 0;
    this.rng = createRng(this.initialSeed);
  }

  get config(): TradingConfig {
    return this.cfg;
  }

  get state(): TradingState {
    if (!this._state) throw new Error("Env not reset");
    return this._state;
  }

  reset(seed?: number): TradingObservation {
    if (seed != null) {
      this.initialSeed = seed >>> 0;
      this.rng = createRng(this.initialSeed);
    }
    const cfg = this.cfg;
    if (cfg.prices.length < cfg.episodeLength + 2) {
      throw new Error("Not enough price data for episode length");
    }

    const maxStart = cfg.prices.length - cfg.episodeLength - 1;
    const startIndex = cfg.startIndex == null ? this.rng.nextInt(maxStart + 1) : cfg.startIndex;
    const idx = startIndex;
    const p0 = cfg.prices[idx]?.close ?? 1;
    const pPrev = cfg.prices[Math.max(0, idx - 1)]?.close ?? p0;

    const cash = cfg.baseInvestment;
    const shares = 0;
    const pv = portfolioValue(cash, shares, p0);

    this._state = {
      step: 0,
      index: idx,
      cash,
      shares,
      lastPrice: pPrev,
      lastPortfolioValue: pv,
    };

    return this.observe();
  }

  observe(): TradingObservation {
    const s = this.state;
    const cfg = this.cfg;
    const price = cfg.prices[s.index]?.close ?? s.lastPrice;
    const pv = portfolioValue(s.cash, s.shares, price);
    const cashFrac = pv > 0 ? s.cash / pv : 1;
    const posFrac = pv > 0 ? (s.shares * price) / pv : 0;
    const r1 = safeLogReturn(price, s.lastPrice);

    return {
      price,
      return1: clampNumber(r1, -0.2, 0.2),
      cashFrac: clampNumber(cashFrac, 0, 1),
      positionFrac: clampNumber(posFrac, 0, 1),
    };
  }

  step(action: TradingAction): TradingStepResult {
    const cfg = this.cfg;
    const s0 = this.state;

    const price = cfg.prices[s0.index]?.close ?? s0.lastPrice;
    let cash = s0.cash;
    let shares = s0.shares;
    let feePaid = 0;

    if (action === TradingAction.Buy) {
      // Buy as many shares as possible with cash (all-in), subject to maxPositionShares.
      const maxSharesFromCash = Math.floor(cash / price);
      const maxAllowed = cfg.maxPositionShares == null ? maxSharesFromCash : Math.max(0, cfg.maxPositionShares - shares);
      const buyShares = Math.max(0, Math.min(maxSharesFromCash, maxAllowed));
      if (buyShares > 0) {
        const notional = buyShares * price;
        feePaid = notional * cfg.transactionFeeRate;
        cash = cash - notional - feePaid;
        shares = shares + buyShares;
      }
    } else if (action === TradingAction.Sell) {
      // Sell all shares.
      const sellShares = shares;
      if (sellShares > 0) {
        const notional = sellShares * price;
        feePaid = notional * cfg.transactionFeeRate;
        cash = cash + notional - feePaid;
        shares = 0;
      }
    }

    const pvBefore = s0.lastPortfolioValue;
    const pvNow = portfolioValue(cash, shares, price);
    const reward = ((pvNow - pvBefore) / cfg.baseInvestment) * cfg.rewardScale;

    const nextIndex = s0.index + 1;
    const nextStep = s0.step + 1;
    const done = nextStep >= cfg.episodeLength;
    const truncated = false;

    const nextPrice = cfg.prices[Math.min(nextIndex, cfg.prices.length - 1)]?.close ?? price;

    const nextState: TradingState = {
      step: nextStep,
      index: Math.min(nextIndex, cfg.prices.length - 1),
      cash,
      shares,
      lastPrice: price,
      lastPortfolioValue: pvNow,
    };
    this._state = nextState;

    const obs = this.observe();
    const info: TradingInfo = {
      symbol: cfg.symbol,
      price: nextPrice,
      portfolioValue: pvNow,
      pnl: pvNow - cfg.baseInvestment,
      feePaid,
      action,
    };

    return { obs, reward, done, truncated, info, state: nextState };
  }
}

