import { createRng, type Rng } from "@/lib/playground/prng";

export enum TradingAction {
  Buy = 0,
  Sell = 1,
  Hold = 2,
}

export const TRADING_ACTIONS: Array<{ action: TradingAction; label: string }> = [
  { action: TradingAction.Buy, label: "Buy" },
  { action: TradingAction.Sell, label: "Sell" },
  { action: TradingAction.Hold, label: "Hold" },
];

export type PricePoint = {
  t: number; // unix seconds
  open?: number;
  high?: number;
  low?: number;
  close: number;
  volume?: number;
};

export type TradingConfig = {
  symbol: string;
  prices: PricePoint[];
  baseInvestment: number; // initial cash
  transactionFeeRate: number; // e.g., 0.001 = 0.1%
  episodeLength: number; // timesteps per episode
  startIndex: number | null; // null => random start
  maxPositionShares: number | null; // null => unlimited (within cash)
  rewardScale: number; // multiply reward by this (default 1)
  rewardMode: "deltaPV" | "realizedOnSell";
  idlePenalty: number; // penalty when holding flat (no shares)
  invalidActionPenalty: number; // penalty for Sell with 0 shares or Buy with insufficient cash
};

export type TradingState = {
  step: number; // 0..episodeLength
  index: number; // index in prices array
  cash: number;
  shares: number;
  costBasis: number; // total cost spent to acquire current shares (incl. buy fees)
  lastPrice: number;
  lastPortfolioValue: number;
};

export type TradingObservation = {
  // Minimal, gym-like numeric observation.
  // All values are normalized for learning stability.
  price: number; // current price
  return1: number; // one-step return
  cashFrac: number; // cash / portfolio
  positionFrac: number; // value in stock / portfolio
};

export type TradingInfo = {
  symbol: string;
  t: number; // unix seconds for the current observation price
  ohlcv?: { open?: number; high?: number; low?: number; close: number; volume?: number };
  price: number;
  portfolioValue: number;
  pnl: number;
  feePaid: number;
  tradeShares: number; // +buyShares, -sellShares, 0 hold/no-op
  tradeNotional: number; // abs(shares)*price at execution time
  realizedPnl: number; // only when selling (net of fees and cost basis)
  unrealizedPnl: number; // mark-to-market using current price and cost basis (net of buy fees, excl sell fees)
  costBasisAfter: number;
  cashAfter: number;
  sharesAfter: number;
  noop: boolean; // action requested but no trade executed
  action: TradingAction;
};

export type TradingStepResult = {
  obs: TradingObservation;
  reward: number;
  done: boolean;
  truncated: boolean;
  info: TradingInfo;
  state: TradingState;
};

function clampNumber(v: number, min: number, max: number): number {
  if (!Number.isFinite(v)) return min;
  return Math.max(min, Math.min(max, v));
}

function safeLogReturn(curr: number, prev: number): number {
  if (!Number.isFinite(curr) || !Number.isFinite(prev) || curr <= 0 || prev <= 0) return 0;
  return Math.log(curr / prev);
}

function portfolioValue(cash: number, shares: number, price: number): number {
  return cash + shares * price;
}

export function defaultTradingConfig(symbol: string, prices: PricePoint[]): TradingConfig {
  return {
    symbol,
    prices,
    baseInvestment: 10_000,
    transactionFeeRate: 0.001,
    episodeLength: 120,
    startIndex: null,
    maxPositionShares: null,
    rewardScale: 1,
    rewardMode: "deltaPV",
    idlePenalty: 0.00001,
    invalidActionPenalty: 0.00005,
  };
}

export function normalizeTradingConfig(cfg: TradingConfig): TradingConfig {
  const baseInvestment = clampNumber(cfg.baseInvestment, 10, 1_000_000_000);
  const fee = clampNumber(cfg.transactionFeeRate, 0, 0.02);
  const episodeLength = Math.max(2, Math.min(2000, Math.trunc(cfg.episodeLength)));
  const rewardScale = clampNumber(cfg.rewardScale, 0.0001, 100);
  const rewardMode = cfg.rewardMode === "realizedOnSell" ? "realizedOnSell" : "deltaPV";
  const idlePenalty = clampNumber(cfg.idlePenalty ?? 0, 0, 1);
  const invalidActionPenalty = clampNumber(cfg.invalidActionPenalty ?? 0, 0, 1);
  const symbol = (cfg.symbol ?? "").toUpperCase().slice(0, 16);
  const prices = (cfg.prices ?? []).filter((p) => Number.isFinite(p.close) && p.close > 0);

  const maxStart = Math.max(0, prices.length - episodeLength - 1);
  let startIndex = cfg.startIndex == null ? null : Math.trunc(cfg.startIndex);
  if (startIndex != null) startIndex = Math.max(0, Math.min(maxStart, startIndex));

  const maxPositionShares =
    cfg.maxPositionShares == null ? null : Math.max(0, Math.min(1_000_000, Math.trunc(cfg.maxPositionShares)));

  return {
    ...cfg,
    symbol,
    prices,
    baseInvestment,
    transactionFeeRate: fee,
    episodeLength,
    startIndex,
    maxPositionShares,
    rewardScale,
    rewardMode,
    idlePenalty,
    invalidActionPenalty,
  };
}

export class StockTradingEnv {
  private cfg: TradingConfig;
  private rng: Rng;
  private initialSeed: number;
  private _state: TradingState | null = null;

  constructor(cfg: TradingConfig, seed: number = 1234) {
    this.cfg = normalizeTradingConfig(cfg);
    this.initialSeed = seed >>> 0;
    this.rng = createRng(this.initialSeed);
  }

  get config(): TradingConfig {
    return this.cfg;
  }

  get state(): TradingState {
    if (!this._state) throw new Error("Env not reset");
    return this._state;
  }

  reset(seed?: number): TradingObservation {
    if (seed != null) {
      this.initialSeed = seed >>> 0;
      this.rng = createRng(this.initialSeed);
    }
    const cfg = this.cfg;
    if (cfg.prices.length < cfg.episodeLength + 2) {
      throw new Error("Not enough price data for episode length");
    }

    const maxStart = cfg.prices.length - cfg.episodeLength - 1;
    const startIndex = cfg.startIndex == null ? this.rng.nextInt(maxStart + 1) : cfg.startIndex;
    const idx = startIndex;
    const p0 = cfg.prices[idx]?.close ?? 1;
    const pPrev = cfg.prices[Math.max(0, idx - 1)]?.close ?? p0;

    const cash = cfg.baseInvestment;
    const shares = 0;
    const pv = portfolioValue(cash, shares, p0);

    this._state = {
      step: 0,
      index: idx,
      cash,
      shares,
      costBasis: 0,
      lastPrice: pPrev,
      lastPortfolioValue: pv,
    };

    return this.observe();
  }

  observe(): TradingObservation {
    const s = this.state;
    const cfg = this.cfg;
    const price = cfg.prices[s.index]?.close ?? s.lastPrice;
    const pv = portfolioValue(s.cash, s.shares, price);
    const cashFrac = pv > 0 ? s.cash / pv : 1;
    const posFrac = pv > 0 ? (s.shares * price) / pv : 0;
    const r1 = safeLogReturn(price, s.lastPrice);

    return {
      price,
      return1: clampNumber(r1, -0.2, 0.2),
      cashFrac: clampNumber(cashFrac, 0, 1),
      positionFrac: clampNumber(posFrac, 0, 1),
    };
  }

  step(action: TradingAction): TradingStepResult {
    const cfg = this.cfg;
    const s0 = this.state;

    const execBar = cfg.prices[s0.index];
    const price = execBar?.close ?? s0.lastPrice;
    let cash = s0.cash;
    let shares = s0.shares;
    let costBasis = s0.costBasis;
    let feePaid = 0;
    let tradeShares = 0;
    let tradeNotional = 0;
    let noop = action === TradingAction.Hold;
    let realizedPnl = 0;

    if (action === TradingAction.Buy) {
      // Buy as many shares as possible with cash (all-in), subject to maxPositionShares.
      const feeMult = 1 + cfg.transactionFeeRate;
      const maxSharesFromCash = price > 0 ? Math.floor(Math.max(0, cash) / (price * feeMult)) : 0;
      const maxAllowed = cfg.maxPositionShares == null ? maxSharesFromCash : Math.max(0, cfg.maxPositionShares - shares);
      const buyShares = Math.max(0, Math.min(maxSharesFromCash, maxAllowed));
      if (buyShares > 0) {
        const notional = buyShares * price;
        feePaid = notional * cfg.transactionFeeRate;
        cash = cash - notional - feePaid;
        if (cash < 0 && cash > -1e-6) cash = 0; // guard tiny float drift
        shares = shares + buyShares;
        costBasis = costBasis + notional + feePaid;
        tradeShares = buyShares;
        tradeNotional = notional;
        noop = false;
      } else {
        noop = true;
      }
    } else if (action === TradingAction.Sell) {
      // Sell all shares.
      const sellShares = shares;
      if (sellShares > 0) {
        const notional = sellShares * price;
        feePaid = notional * cfg.transactionFeeRate;
        cash = cash + notional - feePaid;
        shares = 0;
        realizedPnl = (notional - feePaid) - costBasis;
        costBasis = 0;
        tradeShares = -sellShares;
        tradeNotional = notional;
        noop = false;
      } else {
        noop = true;
      }
    } else {
      noop = true;
    }

    const nextIndex = s0.index + 1;
    const nextStep = s0.step + 1;
    const done = nextStep >= cfg.episodeLength;
    const truncated = false;

    const nextBar = cfg.prices[Math.min(nextIndex, cfg.prices.length - 1)];
    const nextPrice = nextBar?.close ?? price;
    const pvBefore = s0.lastPortfolioValue;
    const pvNext = portfolioValue(cash, shares, nextPrice);
    const unrealizedPnl = shares > 0 ? (shares * price - costBasis) : 0;
    let reward =
      cfg.rewardMode === "realizedOnSell"
        ? ((action === TradingAction.Sell ? realizedPnl : 0) / cfg.baseInvestment) * cfg.rewardScale
        : ((pvNext - pvBefore) / cfg.baseInvestment) * cfg.rewardScale;

    // Shaping penalties for unrealistic / no-op behavior.
    if (action === TradingAction.Hold && s0.shares === 0) reward -= cfg.idlePenalty;
    if (action === TradingAction.Sell && s0.shares === 0) reward -= cfg.invalidActionPenalty;
    if (action === TradingAction.Buy && tradeShares === 0) reward -= cfg.invalidActionPenalty;

    const nextState: TradingState = {
      step: nextStep,
      index: Math.min(nextIndex, cfg.prices.length - 1),
      cash,
      shares,
      costBasis,
      lastPrice: price,
      lastPortfolioValue: pvNext,
    };
    this._state = nextState;

    const obs = this.observe();
    const info: TradingInfo = {
      symbol: cfg.symbol,
      t: nextBar?.t ?? 0,
      ohlcv: nextBar
        ? { open: nextBar.open, high: nextBar.high, low: nextBar.low, close: nextBar.close, volume: nextBar.volume }
        : { close: nextPrice },
      price: nextPrice,
      portfolioValue: pvNext,
      pnl: pvNext - cfg.baseInvestment,
      feePaid,
      tradeShares,
      tradeNotional,
      realizedPnl,
      unrealizedPnl,
      costBasisAfter: costBasis,
      cashAfter: cash,
      sharesAfter: shares,
      noop,
      action,
    };

    return { obs, reward, done, truncated, info, state: nextState };
  }
}
