import { defaultInventoryEnvConfig } from "@/lib/inventoryRl/catalog";
import { buildInventoryObservation } from "@/lib/inventoryRl/features";
import type {
  EnvironmentContext,
  InventoryAction,
  InventoryEnvConfig,
  InventoryObservation,
  InventoryState,
  InventoryStepInfo,
  ResetOptions,
  SkuDefinition,
} from "@/lib/inventoryRl/types";

type Lot = {
  qty: number;
  ageHours: number;
};

type SeededRng = {
  seed: number;
  next: () => number;
};

type StepTuple = [InventoryObservation, number, boolean, boolean, InventoryStepInfo];
type ResetTuple = [InventoryObservation, { inventory: number[]; hour: number }];

function createRng(seed: number): SeededRng {
  let s = seed >>> 0;
  return {
    seed: s,
    next: () => {
      s = (1664525 * s + 1013904223) >>> 0;
      return s / 0x100000000;
    },
  };
}

function clamp(v: number, min: number, max: number): number {
  if (!Number.isFinite(v)) return min;
  return Math.max(min, Math.min(max, v));
}

function clampInt(v: number, min: number, max: number): number {
  return Math.trunc(clamp(v, min, max));
}

function isInHours(hour: number, values: number[]): boolean {
  return values.includes(hour);
}

function poisson(mean: number, rng: SeededRng): number {
  const safeMean = Math.max(0, mean);
  if (safeMean <= 0) return 0;
  if (safeMean > 50) {
    // Normal approximation for large lambda.
    const u1 = Math.max(rng.next(), 1e-7);
    const u2 = rng.next();
    const z = Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
    return Math.max(0, Math.round(safeMean + z * Math.sqrt(safeMean)));
  }
  const l = Math.exp(-safeMean);
  let p = 1;
  let k = 0;
  while (p > l) {
    k += 1;
    p *= Math.max(rng.next(), 1e-7);
  }
  return Math.max(0, k - 1);
}

function normalizeConfig(configRaw: Partial<InventoryEnvConfig>): InventoryEnvConfig {
  const base = defaultInventoryEnvConfig();
  const skus = (configRaw.skus && configRaw.skus.length > 0 ? configRaw.skus : base.skus).map((sku) => ({
    ...sku,
    shelfLifeHours: clampInt(sku.shelfLifeHours, 1, 48),
    minOrder: clampInt(sku.minOrder, 0, 1000),
    unitCost: clamp(sku.unitCost, 0, 1000),
    unitPrice: clamp(sku.unitPrice, 0, 1000),
  }));

  return {
    skus,
    businessHours: {
      open: clampInt(configRaw.businessHours?.open ?? base.businessHours.open, 0, 23),
      close: clampInt(configRaw.businessHours?.close ?? base.businessHours.close, 1, 24),
    },
    demandWindows: {
      pickupMorning: configRaw.demandWindows?.pickupMorning ?? base.demandWindows.pickupMorning,
      highDemandMorning: configRaw.demandWindows?.highDemandMorning ?? base.demandWindows.highDemandMorning,
      highDemandLunch: configRaw.demandWindows?.highDemandLunch ?? base.demandWindows.highDemandLunch,
      slowAfternoon: configRaw.demandWindows?.slowAfternoon ?? base.demandWindows.slowAfternoon,
      evening: configRaw.demandWindows?.evening ?? base.demandWindows.evening,
    },
    capacity: clampInt(configRaw.capacity ?? base.capacity, skus.length * 10, 1_000_000),
    maxOrderPerSku: clampInt(configRaw.maxOrderPerSku ?? base.maxOrderPerSku, 1, 10_000),
    maxStepsPerEpisode: clampInt(configRaw.maxStepsPerEpisode ?? base.maxStepsPerEpisode, 1, 10_000),
    wasteRatePerHour: clamp(configRaw.wasteRatePerHour ?? base.wasteRatePerHour, 0, 0.5),
    stockoutPenaltyMultiplier: clamp(configRaw.stockoutPenaltyMultiplier ?? base.stockoutPenaltyMultiplier, 0, 10),
    targetServiceLevel: clamp(configRaw.targetServiceLevel ?? base.targetServiceLevel, 0, 1),
    serviceShortfallPenaltyMultiplier: clamp(
      configRaw.serviceShortfallPenaltyMultiplier ?? base.serviceShortfallPenaltyMultiplier,
      0,
      20
    ),
    // Labor cost is intentionally disabled for imitation setup.
    laborCostPerUnit: 0,
    holdingCostPerUnit: clamp(configRaw.holdingCostPerUnit ?? base.holdingCostPerUnit, 0, 100),
  };
}

function defaultDemandForSku(sku: SkuDefinition): number {
  if (sku.category === "beverages") return 8;
  if (sku.category === "donuts") return 5;
  if (sku.category === "sandwiches") return 3;
  if (sku.category === "snacks") return 2;
  return 2;
}

export class InventoryEnv {
  public readonly config: InventoryEnvConfig;
  public readonly skuIds: string[];
  public readonly stateDim: number;
  public readonly actionDim: number;

  private rng: SeededRng;
  private demandMean: number[];
  private lotsBySku: Lot[][];
  private state: InventoryState;

  constructor(configRaw: Partial<InventoryEnvConfig> = {}, seed = 42) {
    this.config = normalizeConfig(configRaw);
    if (this.config.businessHours.close <= this.config.businessHours.open) {
      this.config.businessHours.close = Math.min(this.config.businessHours.open + 1, 24);
    }
    this.skuIds = this.config.skus.map((s) => s.id);
    this.actionDim = this.config.skus.length;
    // inventory + fixed time/context features + rolling sales/loss/waste signals
    this.stateDim = this.actionDim * 4 + 15;
    this.rng = createRng(seed);
    this.demandMean = this.config.skus.map((sku) => defaultDemandForSku(sku));
    this.lotsBySku = this.config.skus.map(() => []);
    this.state = this.buildInitialState();
  }

  setDemandMean(values: number[]): void {
    if (values.length !== this.actionDim) {
      throw new Error(`demand mean length mismatch: expected ${this.actionDim}, got ${values.length}`);
    }
    this.demandMean = values.map((v) => clamp(v, 0, 1_000_000));
  }

  reset(options: ResetOptions = {}): ResetTuple {
    if (options.seed !== undefined) {
      this.rng = createRng(options.seed);
    }
    this.state = this.buildInitialState(options);
    this.lotsBySku = this.state.inventory.map((qty) => (qty > 0 ? [{ qty, ageHours: 0 }] : []));
    return [this.buildObservation(this.state), { inventory: this.state.inventory.slice(), hour: this.state.context.hour }];
  }

  step(action: InventoryAction, demandOverride?: number[]): StepTuple {
    const inventoryStart = this.currentInventoryBySku();
    const order = this.sanitizeAction(action);
    const capacityLeft = Math.max(0, this.state.context.capacity - sum(inventoryStart));
    const orderCapped = this.capOrderByCapacity(order, capacityLeft);
    this.applyOrder(orderCapped);
    const inventoryAfterOrder = this.currentInventoryBySku();

    const demand = this.resolveDemand(demandOverride);
    const sales = new Array<number>(this.actionDim).fill(0);
    const lostSales = new Array<number>(this.actionDim).fill(0);
    for (let i = 0; i < this.actionDim; i++) {
      const requested = demand[i] ?? 0;
      const sold = this.consumeFromLots(i, requested);
      sales[i] = sold;
      lostSales[i] = Math.max(0, requested - sold);
    }

    const waste = this.ageAndWasteLots();
    const inventory = this.currentInventoryBySku();
    this.assertStepInvariants({
      inventoryStart,
      inventoryAfterOrder,
      order: orderCapped,
      demand,
      sales,
      lostSales,
      waste,
      inventoryEnd: inventory,
    });
    this.state.inventory = inventory;
    this.updateRollingSignals(sales, lostSales, waste);

    const margin = this.calculateMargin(sales);
    const stockoutPenalty = this.calculateStockoutPenalty(lostSales, demand);
    const servicePenalty = this.calculateServiceShortfallPenalty(sales, demand);
    const understockRiskPenalty = this.calculateUnderstockRiskPenalty(inventoryAfterOrder, demand);
    const laborCost = 0;
    const holdingCost = this.config.holdingCostPerUnit * sum(inventory);
    const wasteCost = this.calculateWasteCost(waste);
    const reward = margin - wasteCost - stockoutPenalty - servicePenalty - understockRiskPenalty - holdingCost;

    const nextContext = this.advanceContext();
    this.state.context = nextContext;
    this.state.stepCount += 1;

    const terminated = nextContext.hour === this.config.businessHours.open;
    const truncated = !terminated && this.state.stepCount >= this.config.maxStepsPerEpisode;
    const event: InventoryStepInfo["event"] = terminated ? "terminated" : truncated ? "truncated" : "step";

    const info: InventoryStepInfo = {
      sales,
      waste,
      lostSales,
      order: orderCapped,
      demand,
      inventoryStart,
      inventoryAfterOrder,
      inventoryEnd: inventory,
      margin,
      stockoutPenalty,
      servicePenalty,
      understockRiskPenalty,
      laborCost,
      holdingCost,
      event,
    };

    return [this.buildObservation(this.state), reward, terminated, truncated, info];
  }

  getState(): InventoryState {
    return {
      inventory: this.state.inventory.slice(),
      context: {
        ...this.state.context,
        sales7d: this.state.context.sales7d.slice(),
        lost7d: this.state.context.lost7d.slice(),
        waste7d: this.state.context.waste7d.slice(),
      },
      stepCount: this.state.stepCount,
    };
  }

  private buildInitialState(options: ResetOptions = {}): InventoryState {
    const capacity = clampInt(options.capacity ?? this.config.capacity, this.actionDim * 10, 1_000_000);
    const inventory = this.normalizeInventory(options.inventory, capacity);
    const hour = clampInt(options.hour ?? this.config.businessHours.open, this.config.businessHours.open, this.config.businessHours.close - 1);
    const weekday = clampInt(options.weekday ?? 0, 0, 6);
    const isWeekend = options.isWeekend ?? weekday >= 5;
    const context: EnvironmentContext = {
      hour,
      weekday,
      isWeekend,
      isHoliday: options.isHoliday ?? false,
      temp: clamp(options.temp ?? 60, -30, 130),
      precip: clamp(options.precip ?? 0, 0, 1),
      sales7d: this.normalizeSales7d(options.sales7d),
      lost7d: this.normalizeLost7d(options.lost7d),
      waste7d: this.normalizeWaste7d(options.waste7d),
      capacity,
    };
    return { inventory, context, stepCount: 0 };
  }

  private normalizeInventory(input: number[] | undefined, capacity: number): number[] {
    if (!input || input.length !== this.actionDim) {
      return this.config.skus.map((sku, idx) =>
        clampInt(this.demandMean[idx] * 1.2 + sku.minOrder, 0, Math.min(this.config.maxOrderPerSku * 2, capacity))
      );
    }
    return input.map((v) => clampInt(v, 0, capacity));
  }

  private normalizeSales7d(input?: number[]): number[] {
    if (!input || input.length !== this.actionDim) {
      return this.demandMean.map((v) => Math.max(0, v * 7));
    }
    return input.map((v) => clamp(v, 0, 1_000_000));
  }

  private normalizeLost7d(input?: number[]): number[] {
    if (!input || input.length !== this.actionDim) {
      return new Array<number>(this.actionDim).fill(0);
    }
    return input.map((v) => clamp(v, 0, 1_000_000));
  }

  private normalizeWaste7d(input?: number[]): number[] {
    if (!input || input.length !== this.actionDim) {
      return new Array<number>(this.actionDim).fill(0);
    }
    return input.map((v) => clamp(v, 0, 1_000_000));
  }

  private buildObservation(state: InventoryState): InventoryObservation {
    return buildInventoryObservation({
      inventory: state.inventory,
      context: state.context,
      demandWindows: this.config.demandWindows,
    });
  }

  private sanitizeAction(action: InventoryAction): number[] {
    const out = new Array<number>(this.actionDim).fill(0);
    for (let i = 0; i < this.actionDim; i++) {
      const proposed = clampInt(action[i] ?? 0, 0, this.config.maxOrderPerSku);
      out[i] = proposed > 0 ? Math.max(proposed, this.config.skus[i]?.minOrder ?? 0) : 0;
    }
    return out;
  }

  private capOrderByCapacity(order: number[], capacityLeft: number): number[] {
    const total = sum(order);
    if (total <= capacityLeft) return order;
    if (total <= 0 || capacityLeft <= 0) return order.map(() => 0);
    const ratio = capacityLeft / total;
    return order.map((v) => Math.max(0, Math.floor(v * ratio)));
  }

  private applyOrder(order: number[]): void {
    for (let i = 0; i < this.actionDim; i++) {
      const qty = order[i] ?? 0;
      if (qty > 0) {
        this.lotsBySku[i]?.push({ qty, ageHours: 0 });
      }
    }
  }

  private resolveDemand(demandOverride?: number[]): number[] {
    if (demandOverride) {
      if (demandOverride.length !== this.actionDim) {
        throw new Error(`demand override length mismatch: expected ${this.actionDim}, got ${demandOverride.length}`);
      }
      return demandOverride.map((v) => clampInt(v, 0, 1_000_000));
    }

    const factor = this.hourDemandFactor(this.state.context.hour);
    const weekendFactor = this.state.context.isWeekend ? 1.2 : 1;
    const holidayFactor = this.state.context.isHoliday ? 1.25 : 1;
    const precipFactor = 1 - 0.15 * this.state.context.precip;
    const weatherFactor = clamp(precipFactor * (0.9 + 0.1 * (this.state.context.temp / 70)), 0.5, 1.5);
    const totalFactor = factor * weekendFactor * holidayFactor * weatherFactor;

    return this.demandMean.map((base) => poisson(base * totalFactor, this.rng));
  }

  private hourDemandFactor(hour: number): number {
    if (
      isInHours(hour, this.config.demandWindows.highDemandMorning) ||
      isInHours(hour, this.config.demandWindows.highDemandLunch)
    ) {
      return 2.2;
    }
    if (isInHours(hour, this.config.demandWindows.pickupMorning)) {
      return 1.8;
    }
    if (isInHours(hour, this.config.demandWindows.slowAfternoon)) {
      return 0.5;
    }
    if (isInHours(hour, this.config.demandWindows.evening)) {
      return 0.9;
    }
    return 1;
  }

  private consumeFromLots(skuIdx: number, demandQty: number): number {
    const lots = this.lotsBySku[skuIdx] ?? [];
    let remaining = demandQty;
    let sold = 0;
    while (remaining > 0 && lots.length > 0) {
      const lot = lots[0];
      if (!lot || lot.qty <= 0) {
        lots.shift();
        continue;
      }
      const take = Math.min(lot.qty, remaining);
      lot.qty -= take;
      remaining -= take;
      sold += take;
      if (lot.qty <= 0) lots.shift();
    }
    this.lotsBySku[skuIdx] = lots;
    return sold;
  }

  private ageAndWasteLots(): number[] {
    const waste = new Array<number>(this.actionDim).fill(0);
    for (let i = 0; i < this.actionDim; i++) {
      const sku = this.config.skus[i];
      if (!sku) continue;
      const nextLots: Lot[] = [];
      for (const lot of this.lotsBySku[i] ?? []) {
        lot.ageHours += 1;
        if (lot.ageHours >= sku.shelfLifeHours) {
          waste[i] += lot.qty;
          continue;
        }
        const spoilage = Math.min(lot.qty, Math.floor(lot.qty * this.config.wasteRatePerHour * (0.8 + this.rng.next() * 0.4)));
        if (spoilage > 0) {
          lot.qty -= spoilage;
          waste[i] += spoilage;
        }
        if (lot.qty > 0) {
          nextLots.push(lot);
        }
      }
      this.lotsBySku[i] = nextLots;
    }
    return waste;
  }

  private currentInventoryBySku(): number[] {
    return this.lotsBySku.map((lots) => sum(lots.map((lot) => lot.qty)));
  }

  private totalInventory(): number {
    return sum(this.currentInventoryBySku());
  }

  private updateRollingSignals(sales: number[], lostSales: number[], waste: number[]): void {
    for (let i = 0; i < this.actionDim; i++) {
      const prevSales = this.state.context.sales7d[i] ?? 0;
      const prevLost = this.state.context.lost7d[i] ?? 0;
      const prevWaste = this.state.context.waste7d[i] ?? 0;
      this.state.context.sales7d[i] = prevSales * (6 / 7) + (sales[i] ?? 0);
      this.state.context.lost7d[i] = prevLost * (6 / 7) + (lostSales[i] ?? 0);
      this.state.context.waste7d[i] = prevWaste * (6 / 7) + (waste[i] ?? 0);
    }
  }

  private calculateMargin(sales: number[]): number {
    let total = 0;
    for (let i = 0; i < this.actionDim; i++) {
      const sku = this.config.skus[i];
      if (!sku) continue;
      total += (sku.unitPrice - sku.unitCost) * (sales[i] ?? 0);
    }
    return total;
  }

  private calculateWasteCost(waste: number[]): number {
    let total = 0;
    for (let i = 0; i < this.actionDim; i++) {
      const sku = this.config.skus[i];
      if (!sku) continue;
      total += sku.unitCost * (waste[i] ?? 0);
    }
    return total;
  }

  private calculateStockoutPenalty(lostSales: number[], demand: number[]): number {
    let total = 0;
    for (let i = 0; i < this.actionDim; i++) {
      const sku = this.config.skus[i];
      if (!sku) continue;
      const lost = Math.max(0, lostSales[i] ?? 0);
      const requested = Math.max(0, demand[i] ?? 0);
      const basePenalty = sku.unitPrice * this.config.stockoutPenaltyMultiplier * lost;
      // Penalize severe service failures super-linearly: lost^2 / demand.
      const severityPenalty =
        requested > 0 ? sku.unitPrice * this.config.stockoutPenaltyMultiplier * ((lost * lost) / requested) : 0;
      total += basePenalty + severityPenalty;
    }
    return total;
  }

  private calculateServiceShortfallPenalty(sales: number[], demand: number[]): number {
    const target = this.config.targetServiceLevel;
    if (target <= 0) return 0;
    let total = 0;
    for (let i = 0; i < this.actionDim; i++) {
      const sku = this.config.skus[i];
      if (!sku) continue;
      const requested = Math.max(0, demand[i] ?? 0);
      const served = Math.max(0, sales[i] ?? 0);
      const requiredServed = requested * target;
      const shortfall = Math.max(0, requiredServed - served);
      total += sku.unitPrice * this.config.serviceShortfallPenaltyMultiplier * shortfall;
    }
    return total;
  }

  private calculateUnderstockRiskPenalty(inventoryAfterOrder: number[], demand: number[]): number {
    const target = Math.max(0, Math.min(1, this.config.targetServiceLevel));
    if (target <= 0) return 0;
    let total = 0;
    for (let i = 0; i < this.actionDim; i++) {
      const sku = this.config.skus[i];
      if (!sku) continue;
      const available = Math.max(0, inventoryAfterOrder[i] ?? 0);
      const requested = Math.max(0, demand[i] ?? 0);
      const desired = requested * target + Math.max(0, sku.minOrder);
      const gap = Math.max(0, desired - available);
      total += sku.unitPrice * this.config.stockoutPenaltyMultiplier * 0.4 * gap;
    }
    return total;
  }

  private advanceContext(): EnvironmentContext {
    let nextHour = this.state.context.hour + 1;
    let nextWeekday = this.state.context.weekday;
    let nextHoliday = this.state.context.isHoliday;
    let nextTemp = this.state.context.temp + (this.rng.next() * 2 - 1);
    let nextPrecip = clamp(this.state.context.precip + (this.rng.next() * 0.1 - 0.05), 0, 1);
    if (nextHour >= this.config.businessHours.close) {
      nextHour = this.config.businessHours.open;
      nextWeekday = (nextWeekday + 1) % 7;
      nextHoliday = false;
      nextTemp = clamp(60 + this.rng.next() * 20, -30, 130);
      nextPrecip = clamp(this.rng.next() < 0.25 ? 1 : 0, 0, 1);
    }
    return {
      ...this.state.context,
      hour: nextHour,
      weekday: nextWeekday,
      isWeekend: nextWeekday >= 5,
      isHoliday: nextHoliday,
      temp: nextTemp,
      precip: nextPrecip,
      sales7d: this.state.context.sales7d.slice(),
      lost7d: this.state.context.lost7d.slice(),
      waste7d: this.state.context.waste7d.slice(),
    };
  }

  private assertStepInvariants(args: {
    inventoryStart: number[];
    inventoryAfterOrder: number[];
    order: number[];
    demand: number[];
    sales: number[];
    lostSales: number[];
    waste: number[];
    inventoryEnd: number[];
  }): void {
    const { inventoryStart, inventoryAfterOrder, order, demand, sales, lostSales, waste, inventoryEnd } = args;
    for (let i = 0; i < this.actionDim; i++) {
      const start = Math.max(0, Math.trunc(inventoryStart[i] ?? 0));
      const afterOrder = Math.max(0, Math.trunc(inventoryAfterOrder[i] ?? 0));
      const ordered = Math.max(0, Math.trunc(order[i] ?? 0));
      const requested = Math.max(0, Math.trunc(demand[i] ?? 0));
      const sold = Math.max(0, Math.trunc(sales[i] ?? 0));
      const lost = Math.max(0, Math.trunc(lostSales[i] ?? 0));
      const wasted = Math.max(0, Math.trunc(waste[i] ?? 0));
      const end = Math.max(0, Math.trunc(inventoryEnd[i] ?? 0));

      const expectedAfterOrder = start + ordered;
      if (afterOrder !== expectedAfterOrder) {
        throw new Error(`inventory flow violation after order for sku ${this.skuIds[i] ?? i}: ${afterOrder} != ${expectedAfterOrder}`);
      }
      if (sold > afterOrder) {
        throw new Error(`sales exceed inventory for sku ${this.skuIds[i] ?? i}: sold=${sold}, available=${afterOrder}`);
      }
      const expectedLost = Math.max(0, requested - sold);
      if (lost !== expectedLost) {
        throw new Error(`lost sales mismatch for sku ${this.skuIds[i] ?? i}: ${lost} != ${expectedLost}`);
      }
      const afterSales = afterOrder - sold;
      if (wasted > afterSales) {
        throw new Error(`waste exceed remaining inventory for sku ${this.skuIds[i] ?? i}: waste=${wasted}, remaining=${afterSales}`);
      }
      const expectedEnd = afterSales - wasted;
      if (end !== expectedEnd) {
        throw new Error(`inventory end mismatch for sku ${this.skuIds[i] ?? i}: ${end} != ${expectedEnd}`);
      }
    }
  }
}

function sum(values: number[]): number {
  let total = 0;
  for (const value of values) total += value;
  return total;
}
