import { createRng } from "@/lib/playground/prng";
import { stateKey, stepWarehouse, type WarehouseConfig, type WarehouseState, WarehouseAction } from "@/lib/playground/warehouseEnv";

export type QTable = Map<string, Float32Array>;

export type QLearningParams = {
  episodes: number;
  stepsPerEpisode: number;
  gamma: number;
  alpha: number;
  epsilon: number;
  seed: number;
};

export type QLearningProgress = {
  episode: number;
  stepInEpisode: number;
  episodeReturn: number;
  qSize: number;
};

export const WAREHOUSE_NUM_ACTIONS = 6;

function getRow(q: QTable, key: string): Float32Array {
  const existing = q.get(key);
  if (existing) return existing;
  const row = new Float32Array(WAREHOUSE_NUM_ACTIONS);
  q.set(key, row);
  return row;
}

function argMax(row: Float32Array): number {
  let bestI = 0;
  let bestV = row[0] ?? 0;
  for (let i = 1; i < row.length; i++) {
    const v = row[i] ?? 0;
    if (v > bestV) {
      bestV = v;
      bestI = i;
    }
  }
  return bestI;
}

export function selectActionEpsilonGreedy(q: QTable, state: WarehouseState, epsilon: number, seed: number): WarehouseAction {
  const rng = createRng(seed);
  if (rng.nextFloat() < epsilon) {
    return rng.nextInt(WAREHOUSE_NUM_ACTIONS) as WarehouseAction;
  }
  const row = getRow(q, stateKey(state));
  return argMax(row) as WarehouseAction;
}

export function actGreedy(q: QTable, state: WarehouseState): WarehouseAction {
  const row = getRow(q, stateKey(state));
  return argMax(row) as WarehouseAction;
}

export function trainWarehouseQAgentChunked(args: {
  config: WarehouseConfig;
  initialState: WarehouseState;
  q: QTable;
  params: QLearningParams;
  startAt?: { episode: number; stepInEpisode: number; episodeReturn: number; rngSeed: number; state: WarehouseState };
  maxTransitions: number;
}): {
  done: boolean;
  progress: QLearningProgress;
  state: WarehouseState;
  episodeReturnsDelta: number[]; // returns for episodes completed in this chunk
  rngSeed: number;
  episode: number;
  stepInEpisode: number;
  episodeReturn: number;
} {
  const { config, initialState, q, params, maxTransitions } = args;
  const gamma = Math.max(0, Math.min(0.999, params.gamma));
  const alpha = Math.max(0.001, Math.min(1, params.alpha));
  const epsilon = Math.max(0, Math.min(1, params.epsilon));

  let episode = args.startAt?.episode ?? 0;
  let stepInEpisode = args.startAt?.stepInEpisode ?? 0;
  let episodeReturn = args.startAt?.episodeReturn ?? 0;
  let rngSeed = args.startAt?.rngSeed ?? (params.seed >>> 0);
  let state = args.startAt?.state ?? initialState;

  const episodeReturnsDelta: number[] = [];

  for (let t = 0; t < maxTransitions; t++) {
    if (episode >= params.episodes) break;

    if (stepInEpisode === 0 && t === 0 && (args.startAt?.episode ?? 0) === 0) {
      state = initialState;
      episodeReturn = 0;
    }

    const actionSeed = (rngSeed + (episode * 1_000_003 + stepInEpisode * 97)) >>> 0;
    const a = selectActionEpsilonGreedy(q, state, epsilon, actionSeed);

    const sKey = stateKey(state);
    const row = getRow(q, sKey);

    const result = stepWarehouse(config, state, a);
    const r = result.reward;
    const next = result.state;
    const doneOrTruncated = result.done || result.truncated || stepInEpisode + 1 >= params.stepsPerEpisode;

    let target = r;
    if (!doneOrTruncated) {
      const nextRow = getRow(q, stateKey(next));
      target = r + gamma * (nextRow[argMax(nextRow)] ?? 0);
    }

    const old = row[a] ?? 0;
    row[a] = old + alpha * (target - old);

    state = doneOrTruncated ? initialState : next;
    episodeReturn += r;
    stepInEpisode += 1;
    rngSeed = (rngSeed + 1) >>> 0;

    if (doneOrTruncated || stepInEpisode >= params.stepsPerEpisode) {
      episode += 1;
      episodeReturnsDelta.push(episodeReturn);
      stepInEpisode = 0;
      episodeReturn = 0;
    }
  }

  const done = episode >= params.episodes;

  return {
    done,
    progress: {
      episode,
      stepInEpisode,
      episodeReturn,
      qSize: q.size,
    },
    state,
    episodeReturnsDelta,
    rngSeed,
    episode,
    stepInEpisode,
    episodeReturn,
  };
}

