import { createRng, type Rng } from "@/lib/playground/prng";

export enum LogisticsTruckRollAction {
  WAIT = 0,
  INCREASE_SAMPLING = 1,
  PREBOOK_SERVICE = 2,
  DISPATCH_SERVICE = 3,
  ESCALATE_SECURITY = 4,
}

export const LOGISTICS_TRUCK_ROLL_ACTIONS: Array<{ action: LogisticsTruckRollAction; label: string }> = [
  { action: LogisticsTruckRollAction.WAIT, label: "Wait" },
  { action: LogisticsTruckRollAction.INCREASE_SAMPLING, label: "Increase sampling" },
  { action: LogisticsTruckRollAction.PREBOOK_SERVICE, label: "Prebook service" },
  { action: LogisticsTruckRollAction.DISPATCH_SERVICE, label: "Dispatch service" },
  { action: LogisticsTruckRollAction.ESCALATE_SECURITY, label: "Escalate security" },
];

export type LogisticsTruckRollObservation = {
  temperatureCelsius: number;
  temperatureTrend: number;
  doorState: 0 | 1;
  locationStability: number;
  supplyLevelPercent: number;
  supplyDepletionRate: number;
  timeSinceLastService: number;
  assetRiskScore: number;
};

export const LOGISTICS_TRUCK_ROLL_OBS_KEYS: Array<keyof LogisticsTruckRollObservation> = [
  "temperatureCelsius",
  "temperatureTrend",
  "doorState",
  "locationStability",
  "supplyLevelPercent",
  "supplyDepletionRate",
  "timeSinceLastService",
  "assetRiskScore",
];

export type LogisticsTruckRollInfo = {
  step: number;
  event:
    | "wait"
    | "sampling"
    | "prebook"
    | "dispatch"
    | "escalate"
    | "failure"
    | "maxSteps"
    | "already_done";
  rewardBreakdown: {
    avoidFailure: number;
    avoidUnnecessaryService: number;
    unnecessaryService: number;
    failure: number;
    undetectedSecurity: number;
    waitHighRisk: number;
  };
  securityEvent: { occurred: boolean; detected: boolean };
  dispatchNecessary?: boolean;
  terminatedReason?: "failure" | "dispatch";
  truncatedReason?: "maxSteps";
};

export type DiscreteSpace = { kind: "discrete"; n: number };
export type BoxSpace = { kind: "box"; low: number[]; high: number[]; shape: number[] };

export type LogisticsTruckRollConfig = {
  maxSteps: number;

  // Dynamics (per-step)
  tempInitMean: number;
  tempInitStd: number;
  tempTrendInitMean: number;
  tempTrendInitStd: number;
  tempNoiseStd: number;
  tempTrendDriftStd: number;

  locationStabilityInitMean: number; // 0..1
  locationStabilityNoiseStd: number;

  supplyInitMean: number; // 0..100
  supplyInitStd: number;
  supplyDepletionMean: number; // percent per step
  supplyDepletionStd: number;

  timeSinceServiceInitMean: number;
  timeSinceServiceInitStd: number;

  // Sampling + security
  baseSecurityEventProb: number; // 0..1
  samplingBoostSteps: number;

  // Failure thresholds
  tempFailureC: number;
  supplyFailurePercent: number;
  serviceDueTime: number;

  // Decision thresholds
  highRiskThreshold: number; // 0..1
  dispatchRiskThreshold: number; // 0..1

  // Rewards
  waitHighRiskPenalty: number; // -1
  failurePenalty: number; // -10
  undetectedSecurityPenalty: number; // -8
  unnecessaryServicePenalty: number; // -2
  avoidFailureReward: number; // +5
  avoidUnnecessaryServiceReward: number; // +3
};

export type ResetOptions = { seed?: number };

export type StepResult = {
  observation: LogisticsTruckRollObservation;
  reward: number;
  terminated: boolean;
  truncated: boolean;
  info: LogisticsTruckRollInfo;
};

export type ResetResult = { observation: LogisticsTruckRollObservation; info: LogisticsTruckRollInfo };

type InternalState = {
  step: number;
  rng: Rng;

  temperatureCelsius: number;
  temperatureTrend: number;
  doorState: 0 | 1;
  locationStability: number;
  supplyLevelPercent: number;
  supplyDepletionRate: number;
  timeSinceLastService: number;

  samplingRemaining: number;
  prebooked: boolean;
  dispatched: boolean;
  done: boolean;
};

function clampNumber(v: number, min: number, max: number): number {
  if (!Number.isFinite(v)) return min;
  return Math.max(min, Math.min(max, v));
}

function clampInt(v: number, min: number, max: number): number {
  if (!Number.isFinite(v)) return min;
  return Math.max(min, Math.min(max, Math.trunc(v)));
}

function randn(rng: Rng): number {
  // Box-Muller
  const u1 = Math.max(1e-12, rng.nextFloat());
  const u2 = Math.max(1e-12, rng.nextFloat());
  return Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
}

function computeRiskScore(s: Pick<
  InternalState,
  | "temperatureCelsius"
  | "temperatureTrend"
  | "doorState"
  | "locationStability"
  | "supplyLevelPercent"
  | "supplyDepletionRate"
  | "timeSinceLastService"
>, cfg: LogisticsTruckRollConfig): number {
  const tempRisk = clampNumber((s.temperatureCelsius - (cfg.tempFailureC - 6)) / 8, 0, 1);
  const trendRisk = clampNumber(Math.abs(s.temperatureTrend) / 2.5, 0, 1);
  const doorRisk = s.doorState ? 1 : 0;
  const locRisk = clampNumber(1 - s.locationStability, 0, 1);
  const supplyRisk = clampNumber((20 - s.supplyLevelPercent) / 20, 0, 1);
  const depletionRisk = clampNumber(s.supplyDepletionRate / 10, 0, 1);
  const serviceRisk = clampNumber(s.timeSinceLastService / Math.max(1, cfg.serviceDueTime), 0, 1);

  // Weighted sum, then clamp.
  const score =
    0.26 * tempRisk +
    0.10 * trendRisk +
    0.14 * doorRisk +
    0.10 * locRisk +
    0.18 * supplyRisk +
    0.07 * depletionRisk +
    0.15 * serviceRisk;
  return clampNumber(score, 0, 1);
}

export function defaultLogisticsTruckRollConfig(): LogisticsTruckRollConfig {
  return {
    maxSteps: 200,

    tempInitMean: 4,
    tempInitStd: 1.2,
    tempTrendInitMean: 0,
    tempTrendInitStd: 0.15,
    tempNoiseStd: 0.35,
    tempTrendDriftStd: 0.05,

    locationStabilityInitMean: 0.92,
    locationStabilityNoiseStd: 0.03,

    supplyInitMean: 85,
    supplyInitStd: 6,
    supplyDepletionMean: 0.6,
    supplyDepletionStd: 0.25,

    timeSinceServiceInitMean: 18,
    timeSinceServiceInitStd: 5,

    baseSecurityEventProb: 0.02,
    samplingBoostSteps: 16,

    tempFailureC: 12,
    supplyFailurePercent: 2,
    serviceDueTime: 60,

    highRiskThreshold: 0.7,
    dispatchRiskThreshold: 0.8,

    waitHighRiskPenalty: -1,
    failurePenalty: -10,
    undetectedSecurityPenalty: -8,
    unnecessaryServicePenalty: -2,
    avoidFailureReward: 5,
    avoidUnnecessaryServiceReward: 3,
  };
}

export function normalizeLogisticsTruckRollConfig(raw: LogisticsTruckRollConfig): LogisticsTruckRollConfig {
  const maxSteps = clampInt(raw.maxSteps, 1, 50_000);

  const locationStabilityInitMean = clampNumber(raw.locationStabilityInitMean, 0, 1);
  const locationStabilityNoiseStd = clampNumber(raw.locationStabilityNoiseStd, 0, 0.5);

  const baseSecurityEventProb = clampNumber(raw.baseSecurityEventProb, 0, 1);
  const samplingBoostSteps = clampInt(raw.samplingBoostSteps, 0, 10_000);

  const tempFailureC = clampNumber(raw.tempFailureC, -50, 80);
  const supplyFailurePercent = clampNumber(raw.supplyFailurePercent, 0, 50);
  const serviceDueTime = clampNumber(raw.serviceDueTime, 1, 10_000);

  const highRiskThreshold = clampNumber(raw.highRiskThreshold, 0, 1);
  const dispatchRiskThreshold = clampNumber(raw.dispatchRiskThreshold, 0, 1);

  return {
    ...raw,
    maxSteps,
    tempInitMean: clampNumber(raw.tempInitMean, -50, 80),
    tempInitStd: clampNumber(raw.tempInitStd, 0, 50),
    tempTrendInitMean: clampNumber(raw.tempTrendInitMean, -5, 5),
    tempTrendInitStd: clampNumber(raw.tempTrendInitStd, 0, 5),
    tempNoiseStd: clampNumber(raw.tempNoiseStd, 0, 10),
    tempTrendDriftStd: clampNumber(raw.tempTrendDriftStd, 0, 5),

    locationStabilityInitMean,
    locationStabilityNoiseStd,

    supplyInitMean: clampNumber(raw.supplyInitMean, 0, 100),
    supplyInitStd: clampNumber(raw.supplyInitStd, 0, 100),
    supplyDepletionMean: clampNumber(raw.supplyDepletionMean, 0, 50),
    supplyDepletionStd: clampNumber(raw.supplyDepletionStd, 0, 50),

    timeSinceServiceInitMean: clampNumber(raw.timeSinceServiceInitMean, 0, 10_000),
    timeSinceServiceInitStd: clampNumber(raw.timeSinceServiceInitStd, 0, 10_000),

    baseSecurityEventProb,
    samplingBoostSteps,

    tempFailureC,
    supplyFailurePercent,
    serviceDueTime,

    highRiskThreshold,
    dispatchRiskThreshold,

    waitHighRiskPenalty: clampNumber(raw.waitHighRiskPenalty, -100, 0),
    failurePenalty: clampNumber(raw.failurePenalty, -1_000, 0),
    undetectedSecurityPenalty: clampNumber(raw.undetectedSecurityPenalty, -1_000, 0),
    unnecessaryServicePenalty: clampNumber(raw.unnecessaryServicePenalty, -1_000, 0),
    avoidFailureReward: clampNumber(raw.avoidFailureReward, 0, 1_000),
    avoidUnnecessaryServiceReward: clampNumber(raw.avoidUnnecessaryServiceReward, 0, 1_000),
  };
}

export class LogisticsTruckRollEnv {
  action_space: DiscreteSpace = { kind: "discrete", n: 5 };
  observation_space: BoxSpace = {
    kind: "box",
    low: [-50, -10, 0, 0, 0, 0, 0, 0],
    high: [80, 10, 1, 1, 100, 50, 10_000, 1],
    shape: [8],
  };

  private cfg: LogisticsTruckRollConfig | null = null;
  private s: InternalState | null = null;

  init(config: LogisticsTruckRollConfig) {
    this.cfg = normalizeLogisticsTruckRollConfig(config);
    this.s = null;
  }

  reset(options?: ResetOptions): ResetResult {
    const cfg = this.cfg;
    if (!cfg) throw new Error("Env not initialized (call init(config) first)");

    const seed = (options?.seed ?? 1234) >>> 0;
    const rng = createRng(seed);

    const temperatureCelsius = clampNumber(cfg.tempInitMean + cfg.tempInitStd * randn(rng), -50, 80);
    const temperatureTrend = clampNumber(cfg.tempTrendInitMean + cfg.tempTrendInitStd * randn(rng), -5, 5);
    const locationStability = clampNumber(cfg.locationStabilityInitMean + cfg.locationStabilityNoiseStd * randn(rng), 0, 1);
    const supplyLevelPercent = clampNumber(cfg.supplyInitMean + cfg.supplyInitStd * randn(rng), 0, 100);
    const supplyDepletionRate = clampNumber(cfg.supplyDepletionMean + cfg.supplyDepletionStd * randn(rng), 0, 50);
    const timeSinceLastService = clampNumber(cfg.timeSinceServiceInitMean + cfg.timeSinceServiceInitStd * randn(rng), 0, 10_000);

    const s: InternalState = {
      step: 0,
      rng,
      temperatureCelsius,
      temperatureTrend,
      doorState: 0,
      locationStability,
      supplyLevelPercent,
      supplyDepletionRate,
      timeSinceLastService,
      samplingRemaining: 0,
      prebooked: false,
      dispatched: false,
      done: false,
    };
    this.s = s;

    const obs = this.observe();
    return {
      observation: obs,
      info: {
        step: 0,
        event: "wait",
        rewardBreakdown: {
          avoidFailure: 0,
          avoidUnnecessaryService: 0,
          unnecessaryService: 0,
          failure: 0,
          undetectedSecurity: 0,
          waitHighRisk: 0,
        },
        securityEvent: { occurred: false, detected: false },
      },
    };
  }

  private observe(): LogisticsTruckRollObservation {
    const cfg = this.cfg;
    const s = this.s;
    if (!cfg) throw new Error("Env not initialized");
    if (!s) throw new Error("Env not reset");
    const assetRiskScore = computeRiskScore(s, cfg);
    return {
      temperatureCelsius: s.temperatureCelsius,
      temperatureTrend: s.temperatureTrend,
      doorState: s.doorState,
      locationStability: s.locationStability,
      supplyLevelPercent: s.supplyLevelPercent,
      supplyDepletionRate: s.supplyDepletionRate,
      timeSinceLastService: s.timeSinceLastService,
      assetRiskScore,
    };
  }

  step(actionRaw: LogisticsTruckRollAction): StepResult {
    const cfg = this.cfg;
    const s0 = this.s;
    if (!cfg) throw new Error("Env not initialized");
    if (!s0) throw new Error("Env not reset");

    if (s0.done) {
      const obs = this.observe();
      return {
        observation: obs,
        reward: 0,
        terminated: true,
        truncated: false,
        info: {
          step: s0.step,
          event: "already_done",
          rewardBreakdown: {
            avoidFailure: 0,
            avoidUnnecessaryService: 0,
            unnecessaryService: 0,
            failure: 0,
            undetectedSecurity: 0,
            waitHighRisk: 0,
          },
          securityEvent: { occurred: false, detected: false },
        },
      };
    }

    const action =
      typeof actionRaw === "number" && Number.isFinite(actionRaw)
        ? (Math.max(0, Math.min(4, Math.trunc(actionRaw))) as LogisticsTruckRollAction)
        : LogisticsTruckRollAction.WAIT;

    const breakdown: LogisticsTruckRollInfo["rewardBreakdown"] = {
      avoidFailure: 0,
      avoidUnnecessaryService: 0,
      unnecessaryService: 0,
      failure: 0,
      undetectedSecurity: 0,
      waitHighRisk: 0,
    };

    let reward = 0;
    let event: LogisticsTruckRollInfo["event"] = "wait";

    // --- Action effects ---
    if (action === LogisticsTruckRollAction.INCREASE_SAMPLING) {
      s0.samplingRemaining = Math.max(s0.samplingRemaining, cfg.samplingBoostSteps);
      event = "sampling";
    } else if (action === LogisticsTruckRollAction.PREBOOK_SERVICE) {
      s0.prebooked = true;
      event = "prebook";
    } else if (action === LogisticsTruckRollAction.ESCALATE_SECURITY) {
      // Treat as "force detection" for this step.
      event = "escalate";
    } else if (action === LogisticsTruckRollAction.DISPATCH_SERVICE) {
      event = "dispatch";
    }

    const risk0 = computeRiskScore(s0, cfg);
    if (action === LogisticsTruckRollAction.WAIT && risk0 >= cfg.highRiskThreshold) {
      reward += cfg.waitHighRiskPenalty;
      breakdown.waitHighRisk += cfg.waitHighRiskPenalty;
    }

    // --- Stochastic security event ---
    const secProb = clampNumber(cfg.baseSecurityEventProb + 0.06 * (1 - s0.locationStability) + 0.04 * risk0, 0, 1);
    const securityOccurred = s0.rng.nextFloat() < secProb;
    const samplingOn = s0.samplingRemaining > 0;
    const detectProb = clampNumber((samplingOn ? 0.75 : 0.20) + (action === LogisticsTruckRollAction.ESCALATE_SECURITY ? 0.55 : 0), 0, 1);
    const securityDetected = securityOccurred && s0.rng.nextFloat() < detectProb;

    if (securityOccurred) {
      // Door might open; location might become unstable.
      s0.doorState = 1;
      s0.locationStability = clampNumber(s0.locationStability - 0.12, 0, 1);
      if (!securityDetected) {
        reward += cfg.undetectedSecurityPenalty;
        breakdown.undetectedSecurity += cfg.undetectedSecurityPenalty;
      }
    } else {
      // Relax door back closed with some probability.
      if (s0.doorState === 1 && s0.rng.nextFloat() < 0.35) s0.doorState = 0;
    }

    // --- Dynamics update ---
    s0.temperatureTrend = clampNumber(
      s0.temperatureTrend + cfg.tempTrendDriftStd * randn(s0.rng),
      -5,
      5
    );
    s0.temperatureCelsius = clampNumber(
      s0.temperatureCelsius + s0.temperatureTrend + cfg.tempNoiseStd * randn(s0.rng),
      -50,
      80
    );
    s0.locationStability = clampNumber(
      s0.locationStability + cfg.locationStabilityNoiseStd * randn(s0.rng),
      0,
      1
    );
    s0.supplyDepletionRate = clampNumber(
      cfg.supplyDepletionMean + cfg.supplyDepletionStd * randn(s0.rng),
      0,
      50
    );
    s0.supplyLevelPercent = clampNumber(s0.supplyLevelPercent - s0.supplyDepletionRate, 0, 100);
    s0.timeSinceLastService = clampNumber(s0.timeSinceLastService + 1, 0, 10_000);

    if (s0.samplingRemaining > 0) s0.samplingRemaining -= 1;

    // --- Termination checks ---
    const failure =
      s0.temperatureCelsius >= cfg.tempFailureC ||
      s0.supplyLevelPercent <= cfg.supplyFailurePercent;

    let terminated = false;
    let truncated = false;
    let terminatedReason: LogisticsTruckRollInfo["terminatedReason"] | undefined;
    let truncatedReason: LogisticsTruckRollInfo["truncatedReason"] | undefined;
    let dispatchNecessary: boolean | undefined;

    if (failure) {
      terminated = true;
      terminatedReason = "failure";
      reward += cfg.failurePenalty;
      breakdown.failure += cfg.failurePenalty;
      event = "failure";
    } else if (action === LogisticsTruckRollAction.DISPATCH_SERVICE) {
      terminated = true;
      terminatedReason = "dispatch";
      s0.dispatched = true;

      const risk1 = computeRiskScore(s0, cfg);
      dispatchNecessary =
        risk1 >= cfg.dispatchRiskThreshold ||
        s0.timeSinceLastService >= cfg.serviceDueTime ||
        s0.temperatureCelsius >= cfg.tempFailureC - 1.5 ||
        s0.supplyLevelPercent <= cfg.supplyFailurePercent + 6;

      if (dispatchNecessary) {
        reward += cfg.avoidFailureReward;
        breakdown.avoidFailure += cfg.avoidFailureReward;
      } else {
        reward += cfg.unnecessaryServicePenalty;
        breakdown.unnecessaryService += cfg.unnecessaryServicePenalty;
      }
    }

    const nextStep = s0.step + 1;
    s0.step = nextStep;

    if (!terminated && nextStep >= cfg.maxSteps) {
      truncated = true;
      truncatedReason = "maxSteps";
      event = "maxSteps";
      // End-of-horizon shaping: avoided failure and avoided service.
      reward += cfg.avoidFailureReward;
      breakdown.avoidFailure += cfg.avoidFailureReward;
      if (!s0.dispatched) {
        reward += cfg.avoidUnnecessaryServiceReward;
        breakdown.avoidUnnecessaryService += cfg.avoidUnnecessaryServiceReward;
      }
    }

    if (terminated || truncated) s0.done = true;

    const obs = this.observe();
    const info: LogisticsTruckRollInfo = {
      step: nextStep,
      event,
      rewardBreakdown: breakdown,
      securityEvent: { occurred: securityOccurred, detected: securityDetected },
      dispatchNecessary,
      terminatedReason,
      truncatedReason,
    };

    return { observation: obs, reward, terminated, truncated, info };
  }

  render() {
    return null;
  }

  close() {
    this.s = null;
  }
}
