import { createRng } from "@/lib/playground/prng";
import {
  LogisticsTruckRollAction,
  LogisticsTruckRollEnv,
  type LogisticsTruckRollConfig,
  type LogisticsTruckRollObservation,
} from "@/lib/playground/logisticsTruckRollEnv";

export type QTable = Map<string, Float32Array>;

export type QLearningParams = {
  episodes: number;
  stepsPerEpisode: number;
  gamma: number;
  alpha: number;
  epsilon: number;
  seed: number;
};

export const NUM_ACTIONS = 5;

function getRow(q: QTable, key: string): Float32Array {
  const row = q.get(key);
  if (row) return row;
  const next = new Float32Array(NUM_ACTIONS);
  q.set(key, next);
  return next;
}

function argMax(row: Float32Array): number {
  let bi = 0;
  let bv = row[0] ?? 0;
  for (let i = 1; i < row.length; i++) {
    const v = row[i] ?? 0;
    if (v > bv) {
      bv = v;
      bi = i;
    }
  }
  return bi;
}

function clampNumber(v: number, min: number, max: number): number {
  if (!Number.isFinite(v)) return min;
  return Math.max(min, Math.min(max, v));
}

// Intentionally simple discretization for tabular Q-learning (demo-quality).
export function obsKey(obs: LogisticsTruckRollObservation): string {
  const tempBin = obs.temperatureCelsius < 2 ? 0 : obs.temperatureCelsius < 6 ? 1 : obs.temperatureCelsius < 10 ? 2 : 3;
  const trendBin = obs.temperatureTrend < -0.2 ? -1 : obs.temperatureTrend > 0.2 ? 1 : 0;
  const door = obs.doorState;
  const locBin = obs.locationStability > 0.9 ? 2 : obs.locationStability > 0.75 ? 1 : 0;
  const supplyBin = obs.supplyLevelPercent > 70 ? 3 : obs.supplyLevelPercent > 40 ? 2 : obs.supplyLevelPercent > 15 ? 1 : 0;
  const depBin = obs.supplyDepletionRate < 0.3 ? 0 : obs.supplyDepletionRate < 0.9 ? 1 : 2;
  const serviceBin = obs.timeSinceLastService < 20 ? 0 : obs.timeSinceLastService < 60 ? 1 : 2;
  const riskBin = obs.assetRiskScore < 0.4 ? 0 : obs.assetRiskScore < 0.7 ? 1 : obs.assetRiskScore < 0.85 ? 2 : 3;
  return `t${tempBin}|tr${trendBin}|d${door}|l${locBin}|s${supplyBin}|dr${depBin}|svc${serviceBin}|r${riskBin}`;
}

export function selectActionEpsilonGreedy(
  q: QTable,
  obs: LogisticsTruckRollObservation,
  epsilon: number,
  seed: number
): LogisticsTruckRollAction {
  const rng = createRng(seed);
  if (rng.nextFloat() < epsilon) return rng.nextInt(NUM_ACTIONS) as LogisticsTruckRollAction;
  const row = getRow(q, obsKey(obs));
  return argMax(row) as LogisticsTruckRollAction;
}

export function actGreedy(q: QTable, obs: LogisticsTruckRollObservation): LogisticsTruckRollAction {
  const row = getRow(q, obsKey(obs));
  return argMax(row) as LogisticsTruckRollAction;
}

export function trainLogisticsQAgentChunked(args: {
  config: LogisticsTruckRollConfig;
  params: QLearningParams;
  q: QTable;
  startAt?: { episode: number; stepInEpisode: number; episodeReturn: number; rngSeed: number };
  maxTransitions: number;
}): {
  done: boolean;
  episodeReturnsDelta: number[];
  progress: { episode: number; qSize: number };
  resume: { episode: number; stepInEpisode: number; episodeReturn: number; rngSeed: number };
} {
  const { config, params, q, maxTransitions } = args;
  const gamma = clampNumber(params.gamma, 0, 0.999);
  const alpha = clampNumber(params.alpha, 0.001, 1);
  const epsilon = clampNumber(params.epsilon, 0, 1);

  let episode = args.startAt?.episode ?? 0;
  let stepInEpisode = args.startAt?.stepInEpisode ?? 0;
  let episodeReturn = args.startAt?.episodeReturn ?? 0;
  let rngSeed = args.startAt?.rngSeed ?? (params.seed >>> 0);

  const episodeReturnsDelta: number[] = [];

  let env = new LogisticsTruckRollEnv();
  env.init(config);
  let { observation: obs } = env.reset({ seed: rngSeed });

  for (let t = 0; t < maxTransitions; t++) {
    if (episode >= params.episodes) break;

    if (stepInEpisode === 0 && t > 0) {
      env = new LogisticsTruckRollEnv();
      env.init(config);
      ({ observation: obs } = env.reset({ seed: rngSeed }));
    }

    const aSeed = (rngSeed + (episode * 1_000_003 + stepInEpisode * 97)) >>> 0;
    const action = selectActionEpsilonGreedy(q, obs, epsilon, aSeed);
    const row = getRow(q, obsKey(obs));

    const res = env.step(action);
    const r = res.reward;
    const done = res.terminated || res.truncated || stepInEpisode + 1 >= params.stepsPerEpisode;

    let target = r;
    if (!done) {
      const nextRow = getRow(q, obsKey(res.observation));
      target = r + gamma * (nextRow[argMax(nextRow)] ?? 0);
    }
    row[action] = (row[action] ?? 0) + alpha * (target - (row[action] ?? 0));

    obs = res.observation;
    episodeReturn += r;
    stepInEpisode += 1;
    rngSeed = (rngSeed + 1) >>> 0;

    if (done) {
      episode += 1;
      episodeReturnsDelta.push(episodeReturn);
      episodeReturn = 0;
      stepInEpisode = 0;
    }
  }

  const done = episode >= params.episodes;
  return {
    done,
    episodeReturnsDelta,
    progress: { episode, qSize: q.size },
    resume: { episode, stepInEpisode, episodeReturn, rngSeed },
  };
}

