import { createRng } from "@/lib/playground/prng";
import { StockTradingEnv, TradingAction, type TradingObservation, type TradingConfig } from "@/lib/playground/stockTradingEnv";

export type QTable = Map<string, Float32Array>;

export type QLearningParams = {
  episodes: number;
  stepsPerEpisode: number;
  gamma: number;
  alpha: number;
  epsilon: number;
  seed: number;
};

export const NUM_ACTIONS = 3;

function getRow(q: QTable, key: string): Float32Array {
  const row = q.get(key);
  if (row) return row;
  const next = new Float32Array(NUM_ACTIONS);
  q.set(key, next);
  return next;
}

function argMax(row: Float32Array): number {
  let bi = 0;
  let bv = row[0] ?? 0;
  for (let i = 1; i < row.length; i++) {
    const v = row[i] ?? 0;
    if (v > bv) {
      bv = v;
      bi = i;
    }
  }
  return bi;
}

// Simple discretization for tabular Q-learning.
// This is intentionally basic: good for demos, not for serious trading.
export function obsKey(obs: TradingObservation): string {
  const r = obs.return1;
  const rBin = r < -0.01 ? -2 : r < -0.002 ? -1 : r < 0.002 ? 0 : r < 0.01 ? 1 : 2;
  const posBin = obs.positionFrac < 0.01 ? 0 : obs.positionFrac < 0.5 ? 1 : 2;
  const cashBin = obs.cashFrac < 0.01 ? 0 : obs.cashFrac < 0.5 ? 1 : 2;
  return `r${rBin}|p${posBin}|c${cashBin}`;
}

export function selectActionEpsilonGreedy(q: QTable, obs: TradingObservation, epsilon: number, seed: number): TradingAction {
  const rng = createRng(seed);
  if (rng.nextFloat() < epsilon) return rng.nextInt(NUM_ACTIONS) as TradingAction;
  const row = getRow(q, obsKey(obs));
  return argMax(row) as TradingAction;
}

export function actGreedy(q: QTable, obs: TradingObservation): TradingAction {
  const row = getRow(q, obsKey(obs));
  return argMax(row) as TradingAction;
}

export function trainStockQAgentChunked(args: {
  config: TradingConfig;
  params: QLearningParams;
  q: QTable;
  startAt?: { episode: number; stepInEpisode: number; episodeReturn: number; rngSeed: number };
  maxTransitions: number;
}): {
  done: boolean;
  episodeReturnsDelta: number[];
  progress: { episode: number; qSize: number };
  resume: { episode: number; stepInEpisode: number; episodeReturn: number; rngSeed: number };
} {
  const { config, params, q, maxTransitions } = args;
  const gamma = Math.max(0, Math.min(0.999, params.gamma));
  const alpha = Math.max(0.001, Math.min(1, params.alpha));
  const epsilon = Math.max(0, Math.min(1, params.epsilon));

  let episode = args.startAt?.episode ?? 0;
  let stepInEpisode = args.startAt?.stepInEpisode ?? 0;
  let episodeReturn = args.startAt?.episodeReturn ?? 0;
  let rngSeed = args.startAt?.rngSeed ?? (params.seed >>> 0);

  const episodeReturnsDelta: number[] = [];

  // Env is reset each episode; start index is random by default in config.
  let env = new StockTradingEnv(config, rngSeed);
  let obs = env.reset(rngSeed);

  for (let t = 0; t < maxTransitions; t++) {
    if (episode >= params.episodes) break;

    if (stepInEpisode === 0 && t > 0) {
      env = new StockTradingEnv(config, rngSeed);
      obs = env.reset(rngSeed);
    }

    const aSeed = (rngSeed + (episode * 1_000_003 + stepInEpisode * 97)) >>> 0;
    const a = selectActionEpsilonGreedy(q, obs, epsilon, aSeed);
    const row = getRow(q, obsKey(obs));
    const res = env.step(a);
    const r = res.reward;
    const done = res.done || stepInEpisode + 1 >= params.stepsPerEpisode;

    let target = r;
    if (!done) {
      const nextRow = getRow(q, obsKey(res.obs));
      target = r + gamma * (nextRow[argMax(nextRow)] ?? 0);
    }

    row[a] = (row[a] ?? 0) + alpha * (target - (row[a] ?? 0));

    obs = res.obs;
    episodeReturn += r;
    stepInEpisode += 1;
    rngSeed = (rngSeed + 1) >>> 0;

    if (done) {
      episode += 1;
      episodeReturnsDelta.push(episodeReturn);
      episodeReturn = 0;
      stepInEpisode = 0;
    }
  }

  const done = episode >= params.episodes;
  return {
    done,
    episodeReturnsDelta,
    progress: { episode, qSize: q.size },
    resume: { episode, stepInEpisode, episodeReturn, rngSeed },
  };
}

