import * as tf from "@tensorflow/tfjs";
import type {
  SerializedInventoryPolicy,
  SerializedMlpPolicy,
  TrainerOptions,
  TrainingEpochMetric,
  TrainingSample,
  TrainingTransition,
} from "@/lib/inventoryRl/types";

function clamp(v: number, min: number, max: number): number {
  if (!Number.isFinite(v)) return min;
  return Math.max(min, Math.min(max, v));
}

function clampInt(v: number, min: number, max: number): number {
  return Math.trunc(clamp(v, min, max));
}

function defaultTrainerOptions(maxOrderPerSku: number): TrainerOptions {
  return {
    epochs: 25,
    batchSize: 128,
    learningRate: 1e-3,
    maxOrderPerSku,
    loss: "huber",
    seed: 42,
    l2: 1e-5,
    earlyStoppingPatience: 12,
    earlyStoppingMinDelta: 1e-4,
  };
}

function shuffledIndices(length: number): number[] {
  const indices = Array.from({ length }, (_, i) => i);
  for (let i = indices.length - 1; i > 0; i--) {
    const j = Math.floor(Math.random() * (i + 1));
    const tmp = indices[i];
    indices[i] = indices[j] ?? indices[i];
    indices[j] = tmp ?? indices[j];
  }
  return indices;
}

function buildCdf(priorities: number[]): { cdf: Float64Array; total: number } {
  const cdf = new Float64Array(priorities.length);
  let running = 0;
  for (let i = 0; i < priorities.length; i++) {
    running += Math.max(1e-9, priorities[i] ?? 0);
    cdf[i] = running;
  }
  return { cdf, total: running };
}

function sampleIndicesFromCdf(cdf: Float64Array, total: number, count: number): number[] {
  const n = cdf.length;
  if (n <= 0 || total <= 0) return [];
  const out = new Array<number>(count);
  for (let k = 0; k < count; k++) {
    const r = Math.random() * total;
    let lo = 0;
    let hi = n - 1;
    while (lo < hi) {
      const mid = (lo + hi) >>> 1;
      if ((cdf[mid] ?? 0) >= r) hi = mid;
      else lo = mid + 1;
    }
    out[k] = lo;
  }
  return out;
}

function sampleHybridIndices(cdf: Float64Array, total: number, count: number, uniformFrac: number): number[] {
  const n = cdf.length;
  if (n <= 0 || count <= 0) return [];
  const clampedUniformFrac = clamp(uniformFrac, 0, 1);
  const prioritizedCount = Math.max(0, Math.min(count, Math.round(count * (1 - clampedUniformFrac))));
  const out = sampleIndicesFromCdf(cdf, total, prioritizedCount);
  while (out.length < count) out.push(Math.floor(Math.random() * n));
  // Shuffle to avoid any ordering bias between prioritized and uniform slices.
  for (let i = out.length - 1; i > 0; i--) {
    const j = Math.floor(Math.random() * (i + 1));
    const tmp = out[i];
    out[i] = out[j] ?? out[i];
    out[j] = tmp ?? out[j];
  }
  return out;
}

function buildModel(args: {
  stateDim: number;
  actionDim: number;
  hidden: number[];
  l2: number;
}): tf.LayersModel {
  const { stateDim, actionDim, hidden, l2 } = args;
  const model = tf.sequential();
  const regularizer = l2 > 0 ? tf.regularizers.l2({ l2 }) : undefined;
  hidden.forEach((units, idx) => {
    model.add(
      tf.layers.dense({
        units,
        inputShape: idx === 0 ? [stateDim] : undefined,
        activation: "relu",
        kernelInitializer: "heNormal",
        kernelRegularizer: regularizer,
      })
    );
  });
  model.add(
    tf.layers.dense({
      units: actionDim,
      activation: "relu",
      kernelInitializer: "glorotUniform",
      kernelRegularizer: regularizer,
    })
  );
  return model;
}

function buildQModel(args: { stateDim: number; numActions: number; hidden: number[]; l2: number }): tf.LayersModel {
  const { stateDim, numActions, hidden, l2 } = args;
  const model = tf.sequential();
  const regularizer = l2 > 0 ? tf.regularizers.l2({ l2 }) : undefined;
  hidden.forEach((units, idx) => {
    model.add(
      tf.layers.dense({
        units,
        inputShape: idx === 0 ? [stateDim] : undefined,
        activation: "relu",
        kernelInitializer: "heNormal",
        kernelRegularizer: regularizer,
      })
    );
  });
  model.add(
    tf.layers.dense({
      units: numActions,
      activation: "linear",
      kernelInitializer: "glorotUniform",
      kernelRegularizer: regularizer,
    })
  );
  return model;
}

function buildBehaviorModel(args: { stateDim: number; numActions: number; hidden: number[]; l2: number }): tf.LayersModel {
  const { stateDim, numActions, hidden, l2 } = args;
  const model = tf.sequential();
  const regularizer = l2 > 0 ? tf.regularizers.l2({ l2 }) : undefined;
  hidden.forEach((units, idx) => {
    model.add(
      tf.layers.dense({
        units,
        inputShape: idx === 0 ? [stateDim] : undefined,
        activation: "relu",
        kernelInitializer: "heNormal",
        kernelRegularizer: regularizer,
      })
    );
  });
  model.add(
    tf.layers.dense({
      units: numActions,
      activation: "softmax",
      kernelInitializer: "glorotUniform",
      kernelRegularizer: regularizer,
    })
  );
  return model;
}

function buildActorModel(args: { stateDim: number; actionDim: number; hidden: number[]; l2: number }): tf.LayersModel {
  const { stateDim, actionDim, hidden, l2 } = args;
  const model = tf.sequential();
  const regularizer = l2 > 0 ? tf.regularizers.l2({ l2 }) : undefined;
  hidden.forEach((units, idx) => {
    model.add(
      tf.layers.dense({
        units,
        inputShape: idx === 0 ? [stateDim] : undefined,
        activation: "relu",
        kernelInitializer: "heNormal",
        kernelRegularizer: regularizer,
      })
    );
  });
  model.add(
    tf.layers.dense({
      units: actionDim,
      activation: "sigmoid",
      kernelInitializer: "glorotUniform",
      kernelRegularizer: regularizer,
    })
  );
  return model;
}

function buildCriticModel(args: { stateDim: number; actionDim: number; hidden: number[]; l2: number }): tf.LayersModel {
  const { stateDim, actionDim, hidden, l2 } = args;
  const regularizer = l2 > 0 ? tf.regularizers.l2({ l2 }) : undefined;
  const stateInput = tf.input({ shape: [stateDim], name: "state" });
  const actionInput = tf.input({ shape: [actionDim], name: "action" });
  let x = tf.layers.concatenate().apply([stateInput, actionInput]) as tf.SymbolicTensor;
  for (const units of hidden) {
    x = tf.layers
      .dense({
        units,
        activation: "relu",
        kernelInitializer: "heNormal",
        kernelRegularizer: regularizer,
      })
      .apply(x) as tf.SymbolicTensor;
  }
  const q = tf.layers
    .dense({
      units: 1,
      activation: "linear",
      kernelInitializer: "glorotUniform",
      kernelRegularizer: regularizer,
    })
    .apply(x) as tf.SymbolicTensor;
  return tf.model({ inputs: [stateInput, actionInput], outputs: q });
}

function l1Distance(a: number[], b: number[]): number {
  const len = Math.min(a.length, b.length);
  let d = 0;
  for (let i = 0; i < len; i++) d += Math.abs((a[i] ?? 0) - (b[i] ?? 0));
  return d;
}

function nearestActionIndex(action: number[], bank: number[][]): number {
  let bestIdx = 0;
  let bestDist = Number.POSITIVE_INFINITY;
  for (let i = 0; i < bank.length; i++) {
    const d = l1Distance(action, bank[i] ?? []);
    if (d < bestDist) {
      bestDist = d;
      bestIdx = i;
    }
  }
  return bestIdx;
}

function buildActionBank(args: {
  transitions: TrainingTransition[];
  actionDim: number;
  maxOrderPerSku: number;
  maxActions: number;
}): number[][] {
  const { transitions, actionDim, maxOrderPerSku, maxActions } = args;
  const counts = new Map<string, { action: number[]; count: number }>();
  for (const t of transitions) {
    const action = new Array<number>(actionDim).fill(0);
    for (let j = 0; j < actionDim; j++) action[j] = clampInt(t.action[j] ?? 0, 0, maxOrderPerSku);
    const key = action.join(",");
    const prev = counts.get(key);
    if (prev) prev.count += 1;
    else counts.set(key, { action, count: 1 });
  }

  const zero = new Array<number>(actionDim).fill(0);
  const entries = [...counts.values()].sort((a, b) => b.count - a.count);
  if (entries.length === 0) return [zero];
  if (entries.length <= maxActions) {
    const direct = entries.map((x) => x.action);
    if (!direct.some((a) => a.every((x) => x === 0))) direct.unshift(zero);
    return direct.slice(0, Math.max(1, maxActions));
  }

  const selected: number[][] = [zero];
  const pool = entries.slice();
  while (selected.length < maxActions && pool.length > 0) {
    let bestIdx = 0;
    let bestScore = Number.NEGATIVE_INFINITY;
    for (let i = 0; i < pool.length; i++) {
      const candidate = pool[i];
      if (!candidate) continue;
      let minDist = Number.POSITIVE_INFINITY;
      for (const keep of selected) minDist = Math.min(minDist, l1Distance(candidate.action, keep));
      const score = minDist + Math.log1p(candidate.count);
      if (score > bestScore) {
        bestScore = score;
        bestIdx = i;
      }
    }
    const pick = pool.splice(bestIdx, 1)[0];
    if (pick) selected.push(pick.action);
  }
  return selected;
}

function cloneWeightsTo(model: tf.LayersModel, target: tf.LayersModel): void {
  const copies = model.getWeights().map((w) => w.clone());
  target.setWeights(copies);
  copies.forEach((w) => w.dispose());
}

function softUpdateTo(model: tf.LayersModel, target: tf.LayersModel, tau: number): void {
  const clampedTau = clamp(tau, 1e-6, 1);
  // Use cloned snapshots so we never dispose live layer variables.
  const source = model.getWeights().map((w) => w.clone());
  const current = target.getWeights().map((w) => w.clone());
  const mixed = source.map((w, i) => {
    const t = current[i];
    if (!t) return w.clone();
    return tf.tidy(() => t.mul(tf.scalar(1 - clampedTau)).add(w.mul(tf.scalar(clampedTau))));
  });
  target.setWeights(mixed);
  source.forEach((w) => w.dispose());
  current.forEach((w) => w.dispose());
  mixed.forEach((w) => w.dispose());
}

function trainableVars(model: tf.LayersModel): tf.Variable[] {
  return model.trainableWeights.map((w) => (w as unknown as { val: tf.Variable }).val);
}

type MlpCtorArgs = {
  stateDim: number;
  actionDim: number;
  maxOrderPerSku: number;
  skuIds: string[];
  hidden?: number[];
  model?: tf.LayersModel;
};

type BcqCtorArgs = {
  mode: "bcq";
  stateDim: number;
  actionDim: number;
  maxOrderPerSku: number;
  skuIds: string[];
  qModel: tf.LayersModel;
  behaviorModel: tf.LayersModel;
  actionBank: number[][];
  behaviorThreshold: number;
  qHidden: number[];
  behaviorHidden: number[];
};

type DdpgCtorArgs = {
  mode: "ddpg";
  stateDim: number;
  actionDim: number;
  maxOrderPerSku: number;
  skuIds: string[];
  actorModel: tf.LayersModel;
  actorHidden: number[];
};

export class MlpInventoryPolicy {
  readonly stateDim: number;
  readonly actionDim: number;
  readonly maxOrderPerSku: number;
  readonly skuIds: string[];
  readonly hidden: number[];
  readonly mode: "mlp" | "bcq" | "ddpg";

  private readonly model?: tf.LayersModel;
  private readonly actorModel?: tf.LayersModel;
  private readonly qModel?: tf.LayersModel;
  private readonly behaviorModel?: tf.LayersModel;
  private readonly actionBank: number[][];
  private readonly behaviorThreshold: number;
  private readonly qHidden: number[];
  private readonly behaviorHidden: number[];
  private readonly actorHidden: number[];

  constructor(args: MlpCtorArgs | BcqCtorArgs | DdpgCtorArgs) {
    this.stateDim = args.stateDim;
    this.actionDim = args.actionDim;
    this.maxOrderPerSku = args.maxOrderPerSku;
    this.skuIds = args.skuIds.slice();

    if ((args as BcqCtorArgs).mode === "bcq") {
      const bcq = args as BcqCtorArgs;
      this.mode = "bcq";
      this.hidden = [];
      this.model = undefined;
      this.actorModel = undefined;
      this.qModel = bcq.qModel;
      this.behaviorModel = bcq.behaviorModel;
      this.actionBank = bcq.actionBank.map((a) => a.slice());
      this.behaviorThreshold = clamp(bcq.behaviorThreshold, 0, 1);
      this.qHidden = bcq.qHidden.slice();
      this.behaviorHidden = bcq.behaviorHidden.slice();
      this.actorHidden = [];
      return;
    }

    if ((args as DdpgCtorArgs).mode === "ddpg") {
      const ddpg = args as DdpgCtorArgs;
      this.mode = "ddpg";
      this.hidden = [];
      this.model = undefined;
      this.actorModel = ddpg.actorModel;
      this.qModel = undefined;
      this.behaviorModel = undefined;
      this.actionBank = [];
      this.behaviorThreshold = 0;
      this.qHidden = [];
      this.behaviorHidden = [];
      this.actorHidden = ddpg.actorHidden.slice();
      return;
    }

    const mlp = args as MlpCtorArgs;
    this.mode = "mlp";
    this.hidden = (mlp.hidden ?? [256, 256, 128]).slice();
    this.model =
      mlp.model ??
      buildModel({
        stateDim: this.stateDim,
        actionDim: this.actionDim,
        hidden: this.hidden,
        l2: 0,
      });
    this.actorModel = undefined;
    this.qModel = undefined;
    this.behaviorModel = undefined;
    this.actionBank = [];
    this.behaviorThreshold = 0;
    this.qHidden = [];
    this.behaviorHidden = [];
    this.actorHidden = [];
  }

  predict(state: number[]): number[] {
    if (state.length !== this.stateDim) {
      throw new Error(`state dim mismatch: expected ${this.stateDim}, got ${state.length}`);
    }

    if (this.mode === "bcq") {
      if (!this.qModel || !this.behaviorModel) {
        throw new Error("BCQ policy models are missing");
      }
      const qModel = this.qModel;
      const behaviorModel = this.behaviorModel;
      if (this.actionBank.length === 0) {
        return new Array<number>(this.actionDim).fill(0);
      }
      return tf.tidy(() => {
        const x = tf.tensor2d(state, [1, this.stateDim], "float32");
        const q = qModel.predict(x) as tf.Tensor2D;
        const p = behaviorModel.predict(x) as tf.Tensor2D;
        const qValues = Array.from(q.dataSync());
        const pValues = Array.from(p.dataSync());

        let bestBehaviorIdx = 0;
        let bestBehaviorProb = Number.NEGATIVE_INFINITY;
        for (let i = 0; i < pValues.length; i++) {
          const prob = pValues[i] ?? 0;
          if (prob > bestBehaviorProb) {
            bestBehaviorProb = prob;
            bestBehaviorIdx = i;
          }
        }

        let bestIdx = bestBehaviorIdx;
        let bestQ = Number.NEGATIVE_INFINITY;
        for (let i = 0; i < qValues.length; i++) {
          const prob = pValues[i] ?? 0;
          const allowed = prob >= this.behaviorThreshold || i === bestBehaviorIdx;
          if (!allowed) continue;
          const qVal = qValues[i] ?? Number.NEGATIVE_INFINITY;
          if (qVal > bestQ) {
            bestQ = qVal;
            bestIdx = i;
          }
        }

        const picked = this.actionBank[bestIdx] ?? this.actionBank[bestBehaviorIdx] ?? new Array<number>(this.actionDim).fill(0);
        return picked.map((v) => clamp(v, 0, this.maxOrderPerSku));
      });
    }

    if (this.mode === "ddpg") {
      if (!this.actorModel) {
        throw new Error("DDPG actor model is missing");
      }
      const actorModel = this.actorModel;
      return tf.tidy(() => {
        const x = tf.tensor2d(state, [1, this.stateDim], "float32");
        const y01 = actorModel.predict(x) as tf.Tensor;
        const y = tf.clipByValue(y01, 0, 1).mul(tf.scalar(this.maxOrderPerSku));
        return Array.from(y.dataSync()).map((v) => clamp(v, 0, this.maxOrderPerSku));
      });
    }

    if (!this.model) {
      throw new Error("MLP policy model is missing");
    }
    const model = this.model;
    return tf.tidy(() => {
      const x = tf.tensor2d(state, [1, this.stateDim], "float32");
      const y = model.predict(x) as tf.Tensor;
      const clipped = tf.clipByValue(y, 0, this.maxOrderPerSku);
      return Array.from(clipped.dataSync());
    });
  }

  toJSON(): SerializedInventoryPolicy {
    if (this.mode === "bcq") {
      if (!this.qModel || !this.behaviorModel) {
        throw new Error("BCQ policy models are missing for serialization");
      }
      const qWeights = this.qModel.getWeights().map((tensor) => ({
        shape: tensor.shape.slice(),
        data: Array.from(tensor.dataSync()),
      }));
      const behaviorWeights = this.behaviorModel.getWeights().map((tensor) => ({
        shape: tensor.shape.slice(),
        data: Array.from(tensor.dataSync()),
      }));
      return {
        modelType: "bcq_tfjs",
        stateDim: this.stateDim,
        actionDim: this.actionDim,
        maxOrderPerSku: this.maxOrderPerSku,
        skuIds: this.skuIds.slice(),
        behaviorThreshold: this.behaviorThreshold,
        actionBank: this.actionBank.map((a) => a.slice()),
        qHidden: this.qHidden.slice(),
        behaviorHidden: this.behaviorHidden.slice(),
        qWeights,
        behaviorWeights,
        createdAt: new Date().toISOString(),
      };
    }

    if (this.mode === "ddpg") {
      if (!this.actorModel) {
        throw new Error("DDPG actor model is missing for serialization");
      }
      const actorWeights = this.actorModel.getWeights().map((tensor) => ({
        shape: tensor.shape.slice(),
        data: Array.from(tensor.dataSync()),
      }));
      return {
        modelType: "ddpg_tfjs",
        stateDim: this.stateDim,
        actionDim: this.actionDim,
        maxOrderPerSku: this.maxOrderPerSku,
        skuIds: this.skuIds.slice(),
        actorHidden: this.actorHidden.slice(),
        actorWeights,
        createdAt: new Date().toISOString(),
      };
    }

    if (!this.model) {
      throw new Error("MLP policy model is missing for serialization");
    }
    const weights = this.model.getWeights().map((tensor) => ({
      shape: tensor.shape.slice(),
      data: Array.from(tensor.dataSync()),
    }));
    return {
      modelType: "mlp_tfjs",
      stateDim: this.stateDim,
      actionDim: this.actionDim,
      maxOrderPerSku: this.maxOrderPerSku,
      hidden: this.hidden.slice(),
      weights,
      skuIds: this.skuIds.slice(),
      createdAt: new Date().toISOString(),
    };
  }

  static fromJSON(json: SerializedInventoryPolicy): MlpInventoryPolicy {
    if (json.modelType === "bcq_tfjs") {
      const qModel = buildQModel({
        stateDim: json.stateDim,
        numActions: json.actionBank.length,
        hidden: json.qHidden,
        l2: 0,
      });
      const qTensors = json.qWeights.map((w) => tf.tensor(w.data, w.shape, "float32"));
      qModel.setWeights(qTensors);
      qTensors.forEach((tensor) => tensor.dispose());

      const behaviorModel = buildBehaviorModel({
        stateDim: json.stateDim,
        numActions: json.actionBank.length,
        hidden: json.behaviorHidden,
        l2: 0,
      });
      const behaviorTensors = json.behaviorWeights.map((w) => tf.tensor(w.data, w.shape, "float32"));
      behaviorModel.setWeights(behaviorTensors);
      behaviorTensors.forEach((tensor) => tensor.dispose());

      return new MlpInventoryPolicy({
        mode: "bcq",
        stateDim: json.stateDim,
        actionDim: json.actionDim,
        maxOrderPerSku: json.maxOrderPerSku,
        skuIds: json.skuIds,
        qModel,
        behaviorModel,
        actionBank: json.actionBank,
        behaviorThreshold: json.behaviorThreshold,
        qHidden: json.qHidden,
        behaviorHidden: json.behaviorHidden,
      });
    }

    if (json.modelType === "ddpg_tfjs") {
      const actorModel = buildActorModel({
        stateDim: json.stateDim,
        actionDim: json.actionDim,
        hidden: json.actorHidden,
        l2: 0,
      });
      const actorTensors = json.actorWeights.map((w) => tf.tensor(w.data, w.shape, "float32"));
      actorModel.setWeights(actorTensors);
      actorTensors.forEach((tensor) => tensor.dispose());
      return new MlpInventoryPolicy({
        mode: "ddpg",
        stateDim: json.stateDim,
        actionDim: json.actionDim,
        maxOrderPerSku: json.maxOrderPerSku,
        skuIds: json.skuIds,
        actorModel,
        actorHidden: json.actorHidden,
      });
    }

    const model = buildModel({
      stateDim: json.stateDim,
      actionDim: json.actionDim,
      hidden: json.hidden,
      l2: 0,
    });
    const tensors = json.weights.map((w) => tf.tensor(w.data, w.shape, "float32"));
    model.setWeights(tensors);
    tensors.forEach((tensor) => tensor.dispose());
    return new MlpInventoryPolicy({
      stateDim: json.stateDim,
      actionDim: json.actionDim,
      maxOrderPerSku: json.maxOrderPerSku,
      skuIds: json.skuIds,
      hidden: json.hidden,
      model,
    });
  }

  _model(): tf.LayersModel {
    if (!this.model) {
      throw new Error("Policy is non-MLP-backed and does not expose a single MLP model");
    }
    return this.model;
  }
}

export async function trainOfflineMlpPolicy(args: {
  samples: TrainingSample[];
  skuIds: string[];
  options?: Partial<TrainerOptions>;
  hidden?: number[];
}): Promise<{ policy: MlpInventoryPolicy; metrics: TrainingEpochMetric[] }> {
  const { samples, skuIds } = args;
  if (samples.length === 0) {
    throw new Error("no training samples available");
  }

  const stateDim = samples[0]?.state.length ?? 0;
  const actionDim = samples[0]?.action.length ?? 0;
  if (stateDim === 0 || actionDim === 0) {
    throw new Error("training sample dimensions are invalid");
  }

  const opts = { ...defaultTrainerOptions(120), ...args.options };
  const epochs = clampInt(opts.epochs, 1, 1000);
  const batchSize = clampInt(opts.batchSize, 8, 8192);
  const learningRate = clamp(opts.learningRate, 1e-6, 1);
  const l2 = clamp(opts.l2, 0, 10);
  const earlyStoppingPatience = clampInt(opts.earlyStoppingPatience ?? 12, 0, 200);
  const earlyStoppingMinDelta = clamp(opts.earlyStoppingMinDelta ?? 1e-4, 0, 10);
  const hidden = (args.hidden ?? [256, 256, 128]).map((h) => clampInt(h, 8, 2048));

  const policy = new MlpInventoryPolicy({
    stateDim,
    actionDim,
    maxOrderPerSku: opts.maxOrderPerSku,
    skuIds,
    hidden,
    model: buildModel({ stateDim, actionDim, hidden, l2 }),
  });
  const model = policy._model();
  const optimizer = tf.train.adam(learningRate);

  const n = samples.length;
  const xValues = new Float32Array(n * stateDim);
  const yValues = new Float32Array(n * actionDim);
  const wValues = new Float32Array(n);
  for (let i = 0; i < n; i++) {
    const sample = samples[i];
    if (!sample) continue;
    for (let j = 0; j < stateDim; j++) xValues[i * stateDim + j] = sample.state[j] ?? 0;
    for (let j = 0; j < actionDim; j++) yValues[i * actionDim + j] = sample.action[j] ?? 0;
    wValues[i] = clamp(sample.weight, 0.1, 100);
  }

  const x = tf.tensor2d(xValues, [n, stateDim], "float32");
  const y = tf.tensor2d(yValues, [n, actionDim], "float32");
  const w = tf.tensor1d(wValues, "float32");

  const metrics: TrainingEpochMetric[] = [];
  let bestLoss = Number.POSITIVE_INFINITY;
  let staleEpochs = 0;
  try {
    for (let epoch = 0; epoch < epochs; epoch++) {
      const order = shuffledIndices(n);
      let epochLoss = 0;
      let batchCount = 0;
      for (let start = 0; start < n; start += batchSize) {
        const end = Math.min(start + batchSize, n);
        const batchIndices = order.slice(start, end);
        const batchLoss = tf.tidy(() => {
          const idx = tf.tensor1d(batchIndices, "int32");
          const xBatch = tf.gather(x, idx) as tf.Tensor2D;
          const yBatch = tf.gather(y, idx) as tf.Tensor2D;
          const wBatch = tf.gather(w, idx) as tf.Tensor1D;

          const lossTensor = optimizer.minimize(() => {
            const yPred = model.apply(xBatch, { training: true }) as tf.Tensor2D;
            const diff = yPred.sub(yBatch);
            const perActionLoss =
              opts.loss === "mse"
                ? diff.square()
                : (() => {
                    const abs = diff.abs();
                    const one = tf.scalar(1, "float32");
                    const quadratic = abs.minimum(one);
                    const linear = abs.sub(quadratic);
                    return quadratic.square().mul(0.5).add(linear);
                  })();
            const perSample = perActionLoss.mean(1);
            return perSample.mul(wBatch).mean();
          }, true);

          if (!lossTensor) {
            throw new Error("optimizer failed to compute loss");
          }

          const lossValue = Number(lossTensor.dataSync()[0] ?? 0);
          lossTensor.dispose();
          return lossValue;
        });
        epochLoss += batchLoss;
        batchCount += 1;
      }
      metrics.push({
        epoch: epoch + 1,
        loss: epochLoss / Math.max(1, batchCount),
        phase: "bc",
      });
      const latestLoss = metrics[metrics.length - 1]?.loss ?? Number.POSITIVE_INFINITY;
      if (latestLoss < bestLoss - earlyStoppingMinDelta) {
        bestLoss = latestLoss;
        staleEpochs = 0;
      } else {
        staleEpochs += 1;
      }
      if (earlyStoppingPatience > 0 && staleEpochs >= earlyStoppingPatience) {
        break;
      }
      await tf.nextFrame();
    }
  } finally {
    x.dispose();
    y.dispose();
    w.dispose();
    optimizer.dispose();
  }

  return { policy, metrics };
}

export async function trainOfflineDdpgPolicy(args: {
  samples: TrainingSample[];
  transitions: TrainingTransition[];
  skuIds: string[];
  options?: Partial<
    TrainerOptions & {
      gamma: number;
      tau: number;
      actorLearningRate: number;
      criticLearningRate: number;
      actorBcWeight: number;
      priorityAlpha: number;
      uniformSampleFrac: number;
      actorUpdateEvery: number;
      behaviorPretrainEpochs: number;
      targetPolicyNoiseRatio: number;
      targetPolicyNoiseClipRatio: number;
      targetUpdateEvery: number;
      criticEarlyStoppingPatience: number;
      criticEarlyStoppingMinDelta: number;
    }
  >;
  hidden?: number[];
}): Promise<{ policy: MlpInventoryPolicy; metrics: TrainingEpochMetric[] }> {
  const { samples, transitions, skuIds } = args;
  if (samples.length === 0) {
    throw new Error("no training samples available");
  }
  const stateDim = samples[0]?.state.length ?? 0;
  const actionDim = samples[0]?.action.length ?? 0;
  if (stateDim === 0 || actionDim === 0) {
    throw new Error("training sample dimensions are invalid");
  }

  const opts = { ...defaultTrainerOptions(120), ...args.options };
  const validTransitions = transitions.filter(
    (t) => t.state.length === stateDim && t.nextState.length === stateDim && t.action.length === actionDim
  );
  if (validTransitions.length === 0) {
    throw new Error("DDPG requires transition data; no valid transitions were found");
  }

  const epochs = clampInt(opts.epochs, 1, 500);
  const batchSize = clampInt(opts.batchSize, 8, 8192);
  const learningRate = clamp(opts.learningRate, 1e-6, 1);
  const l2 = clamp(opts.l2, 0, 10);
  const gamma = clamp(Number(args.options?.gamma ?? 0.99), 0, 0.999);
  const tau = clamp(Number(args.options?.tau ?? 0.02), 1e-6, 1);
  const actorLearningRate = clamp(Number(args.options?.actorLearningRate ?? learningRate * 0.5), 1e-6, 1);
  const criticLearningRate = clamp(Number(args.options?.criticLearningRate ?? learningRate), 1e-6, 1);
  const actorBcWeight = clamp(Number(args.options?.actorBcWeight ?? 0.2), 0, 100);
  const priorityAlpha = clamp(Number(args.options?.priorityAlpha ?? 0.7), 0, 1);
  const uniformSampleFrac = clamp(Number(args.options?.uniformSampleFrac ?? 0.3), 0, 1);
  const actorUpdateEvery = clampInt(Number(args.options?.actorUpdateEvery ?? 2), 1, 32);
  const behaviorPretrainEpochs = clampInt(
    Number(args.options?.behaviorPretrainEpochs ?? Math.max(1, Math.round(epochs * 0.1))),
    0,
    200
  );
  const targetPolicyNoiseRatio = clamp(Number(args.options?.targetPolicyNoiseRatio ?? 0.04), 0, 1);
  const targetPolicyNoiseClipRatio = clamp(Number(args.options?.targetPolicyNoiseClipRatio ?? 0.08), 0, 1);
  const targetUpdateEvery = clampInt(Number(args.options?.targetUpdateEvery ?? 1), 1, 50);

  const earlyStoppingPatience = clampInt(opts.earlyStoppingPatience ?? 12, 0, 200);
  const earlyStoppingMinDelta = clamp(opts.earlyStoppingMinDelta ?? 1e-4, 0, 10);
  const criticEarlyStoppingPatience = clampInt(
    Number(args.options?.criticEarlyStoppingPatience ?? earlyStoppingPatience),
    0,
    200
  );
  const criticEarlyStoppingMinDelta = clamp(
    Number(args.options?.criticEarlyStoppingMinDelta ?? earlyStoppingMinDelta),
    0,
    10
  );

  const hidden = (args.hidden ?? [256, 256, 128]).map((h) => clampInt(h, 8, 2048));
  const actorHidden = [Math.max(128, hidden[0] ?? 256), Math.max(96, hidden[1] ?? 192)];
  const criticHidden = [Math.max(192, hidden[0] ?? 256), Math.max(128, hidden[1] ?? 192)];

  const n = validTransitions.length;
  const sValues = new Float32Array(n * stateDim);
  const nsValues = new Float32Array(n * stateDim);
  const aValues = new Float32Array(n * actionDim);
  const rValues = new Float32Array(n);
  const dValues = new Float32Array(n);
  const wValues = new Float32Array(n);
  for (let i = 0; i < n; i++) {
    const t = validTransitions[i];
    if (!t) continue;
    for (let j = 0; j < stateDim; j++) {
      sValues[i * stateDim + j] = t.state[j] ?? 0;
      nsValues[i * stateDim + j] = t.nextState[j] ?? 0;
    }
    for (let j = 0; j < actionDim; j++) {
      aValues[i * actionDim + j] = clamp(t.action[j] ?? 0, 0, opts.maxOrderPerSku);
    }
    rValues[i] = t.reward;
    dValues[i] = t.done ? 1 : 0;
    wValues[i] = clamp(t.weight, 0.25, 8);
  }

  let rewardMean = 0;
  for (let i = 0; i < n; i++) rewardMean += rValues[i] ?? 0;
  rewardMean /= Math.max(1, n);
  let rewardVar = 0;
  for (let i = 0; i < n; i++) {
    const dR = (rValues[i] ?? 0) - rewardMean;
    rewardVar += dR * dR;
  }
  rewardVar /= Math.max(1, n);
  const rewardStd = Math.sqrt(Math.max(1e-6, rewardVar));

  const priorities = new Array<number>(n);
  for (let i = 0; i < n; i++) {
    const rewardMag = Math.abs(rValues[i] ?? 0);
    const base = Math.max(1e-6, (wValues[i] ?? 1) * (1 + rewardMag));
    priorities[i] = Math.pow(base, priorityAlpha);
  }
  const { cdf: priorityCdf, total: priorityTotal } = buildCdf(priorities);

  const s = tf.tensor2d(sValues, [n, stateDim], "float32");
  const ns = tf.tensor2d(nsValues, [n, stateDim], "float32");
  const a = tf.tensor2d(aValues, [n, actionDim], "float32");
  const r = tf.tensor1d(rValues, "float32");
  const d = tf.tensor1d(dValues, "float32");
  const w = tf.tensor1d(wValues, "float32");

  const actor = buildActorModel({ stateDim, actionDim, hidden: actorHidden, l2 });
  const critic = buildCriticModel({ stateDim, actionDim, hidden: criticHidden, l2 });
  const actorTarget = buildActorModel({ stateDim, actionDim, hidden: actorHidden, l2: 0 });
  const criticTarget = buildCriticModel({ stateDim, actionDim, hidden: criticHidden, l2: 0 });
  cloneWeightsTo(actor, actorTarget);
  cloneWeightsTo(critic, criticTarget);

  const actorVars = trainableVars(actor);
  const criticVars = trainableVars(critic);
  const actorOpt = tf.train.adam(actorLearningRate);
  const criticOpt = tf.train.adam(criticLearningRate);

  const metrics: TrainingEpochMetric[] = [];
  let bestCriticLoss = Number.POSITIVE_INFINITY;
  let staleCriticEpochs = 0;
  let updateStep = 0;

  try {
    for (let preEpoch = 0; preEpoch < behaviorPretrainEpochs; preEpoch++) {
      const preBatches = Math.max(1, Math.ceil(n / batchSize));
      let preLossSum = 0;
      for (let b = 0; b < preBatches; b++) {
        const size = Math.min(batchSize, n);
        const batchIndices = sampleHybridIndices(priorityCdf, priorityTotal, size, uniformSampleFrac);
        const preLoss = tf.tidy(() => {
          const idx = tf.tensor1d(batchIndices, "int32");
          const sBatch = tf.gather(s, idx) as tf.Tensor2D;
          const aBatch = tf.gather(a, idx) as tf.Tensor2D;
          const wBatch = tf.gather(w, idx) as tf.Tensor1D;
          const preLossTensor = actorOpt.minimize(() => {
            const predA01 = actor.apply(sBatch, { training: true }) as tf.Tensor2D;
            const predA = tf.clipByValue(predA01, 0, 1).mul(tf.scalar(opts.maxOrderPerSku));
            const diff = predA.sub(aBatch).div(tf.scalar(Math.max(1, opts.maxOrderPerSku)));
            const bcPer = diff.square().mean(1);
            return bcPer.mul(wBatch).mean();
          }, true, actorVars);
          if (!preLossTensor) throw new Error("optimizer failed for DDPG actor behavior pretraining");
          const value = Number(preLossTensor.dataSync()[0] ?? 0);
          preLossTensor.dispose();
          return value;
        });
        preLossSum += preLoss;
      }
      metrics.push({
        epoch: preEpoch + 1,
        loss: preLossSum / Math.max(1, preBatches),
        phase: "ddpg_actor",
      });
      await tf.nextFrame();
    }

    for (let epoch = 0; epoch < epochs; epoch++) {
      let criticLossSum = 0;
      let actorLossSum = 0;
      let batchCount = 0;
      let actorBatchCount = 0;

      for (let start = 0; start < n; start += batchSize) {
        const end = Math.min(start + batchSize, n);
        const size = Math.max(1, end - start);
        const batchIndices = sampleHybridIndices(priorityCdf, priorityTotal, size, uniformSampleFrac);
        const [criticLossValue, actorLossValue] = tf.tidy(() => {
          const idx = tf.tensor1d(batchIndices, "int32");
          const sBatch = tf.gather(s, idx) as tf.Tensor2D;
          const nsBatch = tf.gather(ns, idx) as tf.Tensor2D;
          const aBatch = tf.gather(a, idx) as tf.Tensor2D;
          const rBatch = tf.gather(r, idx) as tf.Tensor1D;
          const dBatch = tf.gather(d, idx) as tf.Tensor1D;
          const wBatch = tf.gather(w, idx) as tf.Tensor1D;

          const criticLossTensor = criticOpt.minimize(() => {
            const nextA01 = actorTarget.apply(nsBatch, { training: false }) as tf.Tensor2D;
            const noiseStd = targetPolicyNoiseRatio * opts.maxOrderPerSku;
            const noiseClip = targetPolicyNoiseClipRatio * opts.maxOrderPerSku;
            const noise =
              noiseStd > 0
                ? tf.randomNormal(nextA01.shape, 0, noiseStd, "float32").clipByValue(-noiseClip, noiseClip)
                : tf.zerosLike(nextA01);
            const nextA = tf
              .clipByValue(nextA01, 0, 1)
              .mul(tf.scalar(opts.maxOrderPerSku))
              .add(noise)
              .clipByValue(0, opts.maxOrderPerSku);
            const targetQ2d = criticTarget.apply([nsBatch, nextA], { training: false }) as tf.Tensor2D;
            const targetQ = targetQ2d.squeeze([1]);
            const rNorm = rBatch.sub(tf.scalar(rewardMean)).div(tf.scalar(Math.max(1e-6, rewardStd)));
            const y = rNorm.add(tf.scalar(gamma).mul(tf.scalar(1).sub(dBatch)).mul(targetQ));

            const qPred2d = critic.apply([sBatch, aBatch], { training: true }) as tf.Tensor2D;
            const qPred = qPred2d.squeeze([1]);
            const tdError = qPred.sub(y);
            const perSample =
              opts.loss === "mse"
                ? tdError.square()
                : (() => {
                    const abs = tdError.abs();
                    const quadratic = abs.minimum(tf.scalar(1));
                    const linear = abs.sub(quadratic);
                    return quadratic.square().mul(0.5).add(linear);
                  })();
            return perSample.mul(wBatch).mean();
          }, true, criticVars);
          if (!criticLossTensor) throw new Error("optimizer failed for DDPG critic training");
          const criticLoss = Number(criticLossTensor.dataSync()[0] ?? 0);
          criticLossTensor.dispose();

          let actorLoss = 0;
          if (updateStep % actorUpdateEvery === 0) {
            const actorLossTensor = actorOpt.minimize(() => {
              const predA01 = actor.apply(sBatch, { training: true }) as tf.Tensor2D;
              const predA = tf.clipByValue(predA01, 0, 1).mul(tf.scalar(opts.maxOrderPerSku));
              const q2d = critic.apply([sBatch, predA], { training: false }) as tf.Tensor2D;
              const q = q2d.squeeze([1]);
              const absQ = q.abs().mean().add(tf.scalar(1e-6));
              const lambda = tf.scalar(2.5).div(absQ);
              const pgLoss = q.mean().mul(lambda).neg();

              const diff = predA.sub(aBatch).div(tf.scalar(Math.max(1, opts.maxOrderPerSku)));
              const bcPer = diff.square().mean(1);
              const bcLoss = bcPer.mul(wBatch).mean();
              return pgLoss.add(bcLoss.mul(tf.scalar(actorBcWeight)));
            }, true, actorVars);
            if (!actorLossTensor) throw new Error("optimizer failed for DDPG actor training");
            actorLoss = Number(actorLossTensor.dataSync()[0] ?? 0);
            actorLossTensor.dispose();
          }

          return [criticLoss, actorLoss] as const;
        });

        criticLossSum += criticLossValue;
        if (updateStep % actorUpdateEvery === 0) {
          actorLossSum += actorLossValue;
          actorBatchCount += 1;
        }
        batchCount += 1;
        updateStep += 1;
        if (updateStep % targetUpdateEvery === 0) {
          softUpdateTo(actor, actorTarget, tau);
          softUpdateTo(critic, criticTarget, tau);
        }
      }

      const criticLoss = criticLossSum / Math.max(1, batchCount);
      const actorLoss = actorLossSum / Math.max(1, actorBatchCount);
      metrics.push({ epoch: epoch + 1, loss: criticLoss, phase: "ddpg_critic" });
      metrics.push({ epoch: epoch + 1, loss: actorLoss, phase: "ddpg_actor" });

      if (criticLoss < bestCriticLoss - criticEarlyStoppingMinDelta) {
        bestCriticLoss = criticLoss;
        staleCriticEpochs = 0;
      } else {
        staleCriticEpochs += 1;
      }
      if (criticEarlyStoppingPatience > 0 && staleCriticEpochs >= criticEarlyStoppingPatience) {
        break;
      }
      await tf.nextFrame();
    }
  } finally {
    s.dispose();
    ns.dispose();
    a.dispose();
    r.dispose();
    d.dispose();
    w.dispose();
    critic.dispose();
    actorTarget.dispose();
    criticTarget.dispose();
    actorOpt.dispose();
    criticOpt.dispose();
  }

  const policy = new MlpInventoryPolicy({
    mode: "ddpg",
    stateDim,
    actionDim,
    maxOrderPerSku: opts.maxOrderPerSku,
    skuIds,
    actorModel: actor,
    actorHidden,
  });
  return { policy, metrics };
}

export async function trainOfflineBcqPolicy(args: {
  samples: TrainingSample[];
  transitions: TrainingTransition[];
  skuIds: string[];
  options?: Partial<
    TrainerOptions & {
      gamma: number;
      actionBankSize: number;
      behaviorThreshold: number;
      qEpochs: number;
      behaviorEpochs: number;
      targetUpdateEvery: number;
      qEarlyStoppingPatience: number;
      qEarlyStoppingMinDelta: number;
      behaviorEarlyStoppingPatience: number;
      behaviorEarlyStoppingMinDelta: number;
    }
  >;
  hidden?: number[];
}): Promise<{ policy: MlpInventoryPolicy; metrics: TrainingEpochMetric[] }> {
  const { samples, transitions, skuIds } = args;
  if (samples.length === 0) {
    throw new Error("no training samples available");
  }

  const opts = { ...defaultTrainerOptions(120), ...args.options };
  const stateDim = samples[0]?.state.length ?? 0;
  const actionDim = samples[0]?.action.length ?? 0;
  if (stateDim === 0 || actionDim === 0) {
    throw new Error("training sample dimensions are invalid");
  }

  const validTransitions = transitions.filter(
    (t) => t.state.length === stateDim && t.nextState.length === stateDim && t.action.length === actionDim
  );
  if (validTransitions.length === 0) {
    throw new Error("BCQ requires transition data; no valid transitions were found");
  }

  const hidden = (args.hidden ?? [256, 256, 128]).map((h) => clampInt(h, 8, 2048));
  const qHidden = [Math.max(128, hidden[0] ?? 256), Math.max(96, hidden[1] ?? 192)];
  const behaviorHidden = [Math.max(128, hidden[0] ?? 256), Math.max(96, hidden[1] ?? 192)];
  const batchSize = clampInt(opts.batchSize, 8, 8192);
  const learningRate = clamp(opts.learningRate, 1e-6, 1);
  const l2 = clamp(opts.l2, 0, 10);
  const gamma = clamp(Number(args.options?.gamma ?? 0.97), 0, 0.999);
  const actionBankSize = clampInt(Number(args.options?.actionBankSize ?? 24), 4, 128);
  const behaviorThreshold = clamp(Number(args.options?.behaviorThreshold ?? 0.05), 0, 1);
  const targetUpdateEvery = clampInt(Number(args.options?.targetUpdateEvery ?? 2), 1, 50);

  const behaviorEpochs = clampInt(Number(args.options?.behaviorEpochs ?? opts.epochs), 1, 500);
  const qEpochs = clampInt(Number(args.options?.qEpochs ?? opts.epochs), 1, 500);

  const earlyStoppingPatience = clampInt(opts.earlyStoppingPatience ?? 12, 0, 200);
  const earlyStoppingMinDelta = clamp(opts.earlyStoppingMinDelta ?? 1e-4, 0, 10);
  const qEarlyStoppingPatience = clampInt(
    Number(args.options?.qEarlyStoppingPatience ?? earlyStoppingPatience),
    0,
    200
  );
  const qEarlyStoppingMinDelta = clamp(
    Number(args.options?.qEarlyStoppingMinDelta ?? earlyStoppingMinDelta),
    0,
    10
  );
  const behaviorEarlyStoppingPatience = clampInt(
    Number(args.options?.behaviorEarlyStoppingPatience ?? earlyStoppingPatience),
    0,
    200
  );
  const behaviorEarlyStoppingMinDelta = clamp(
    Number(args.options?.behaviorEarlyStoppingMinDelta ?? earlyStoppingMinDelta),
    0,
    10
  );

  const actionBank = buildActionBank({
    transitions: validTransitions,
    actionDim,
    maxOrderPerSku: opts.maxOrderPerSku,
    maxActions: actionBankSize,
  });
  const numActions = actionBank.length;

  const behaviorModel = buildBehaviorModel({
    stateDim,
    numActions,
    hidden: behaviorHidden,
    l2,
  });
  const qModel = buildQModel({
    stateDim,
    numActions,
    hidden: qHidden,
    l2,
  });
  const qTarget = buildQModel({
    stateDim,
    numActions,
    hidden: qHidden,
    l2: 0,
  });
  cloneWeightsTo(qModel, qTarget);

  const behaviorOptimizer = tf.train.adam(learningRate);
  const qOptimizer = tf.train.adam(Math.max(1e-5, learningRate * 0.5));

  const n = validTransitions.length;
  const sValues = new Float32Array(n * stateDim);
  const nsValues = new Float32Array(n * stateDim);
  const aValues = new Int32Array(n);
  const rValues = new Float32Array(n);
  const dValues = new Float32Array(n);
  const wValues = new Float32Array(n);
  for (let i = 0; i < n; i++) {
    const t = validTransitions[i];
    if (!t) continue;
    for (let j = 0; j < stateDim; j++) {
      sValues[i * stateDim + j] = t.state[j] ?? 0;
      nsValues[i * stateDim + j] = t.nextState[j] ?? 0;
    }
    aValues[i] = nearestActionIndex(t.action, actionBank);
    rValues[i] = t.reward;
    dValues[i] = t.done ? 1 : 0;
    wValues[i] = clamp(t.weight, 0.1, 100);
  }

  const s = tf.tensor2d(sValues, [n, stateDim], "float32");
  const ns = tf.tensor2d(nsValues, [n, stateDim], "float32");
  const a = tf.tensor1d(aValues, "int32");
  const r = tf.tensor1d(rValues, "float32");
  const d = tf.tensor1d(dValues, "float32");
  const w = tf.tensor1d(wValues, "float32");

  const behaviorMetrics: TrainingEpochMetric[] = [];
  let bestBehaviorLoss = Number.POSITIVE_INFINITY;
  let staleBehaviorEpochs = 0;

  const qMetrics: TrainingEpochMetric[] = [];
  let bestQLoss = Number.POSITIVE_INFINITY;
  let staleQEpochs = 0;

  try {
    for (let epoch = 0; epoch < behaviorEpochs; epoch++) {
      const order = shuffledIndices(n);
      let epochLoss = 0;
      let batchCount = 0;
      for (let start = 0; start < n; start += batchSize) {
        const end = Math.min(start + batchSize, n);
        const batchIndices = order.slice(start, end);
        const batchLoss = tf.tidy(() => {
          const idx = tf.tensor1d(batchIndices, "int32");
          const sBatch = tf.gather(s, idx) as tf.Tensor2D;
          const aBatch = tf.gather(a, idx) as tf.Tensor1D;
          const wBatch = tf.gather(w, idx) as tf.Tensor1D;

          const lossTensor = behaviorOptimizer.minimize(() => {
            const probs = behaviorModel.apply(sBatch, { training: true }) as tf.Tensor2D;
            const oneHot = tf.oneHot(aBatch, numActions).asType("float32");
            const logProbs = probs.add(tf.scalar(1e-7)).log();
            const cePer = oneHot.mul(logProbs).sum(1).neg();
            return cePer.mul(wBatch).mean();
          }, true);

          if (!lossTensor) throw new Error("optimizer failed for BCQ behavior training");
          const lossValue = Number(lossTensor.dataSync()[0] ?? 0);
          lossTensor.dispose();
          return lossValue;
        });
        epochLoss += batchLoss;
        batchCount += 1;
      }

      behaviorMetrics.push({
        epoch: epoch + 1,
        loss: epochLoss / Math.max(1, batchCount),
        phase: "bcq_behavior",
      });

      const latestLoss = behaviorMetrics[behaviorMetrics.length - 1]?.loss ?? Number.POSITIVE_INFINITY;
      if (latestLoss < bestBehaviorLoss - behaviorEarlyStoppingMinDelta) {
        bestBehaviorLoss = latestLoss;
        staleBehaviorEpochs = 0;
      } else {
        staleBehaviorEpochs += 1;
      }
      if (behaviorEarlyStoppingPatience > 0 && staleBehaviorEpochs >= behaviorEarlyStoppingPatience) {
        break;
      }
      await tf.nextFrame();
    }

    for (let epoch = 0; epoch < qEpochs; epoch++) {
      const order = shuffledIndices(n);
      let epochLoss = 0;
      let batchCount = 0;
      for (let start = 0; start < n; start += batchSize) {
        const end = Math.min(start + batchSize, n);
        const batchIndices = order.slice(start, end);
        const batchLoss = tf.tidy(() => {
          const idx = tf.tensor1d(batchIndices, "int32");
          const sBatch = tf.gather(s, idx) as tf.Tensor2D;
          const nsBatch = tf.gather(ns, idx) as tf.Tensor2D;
          const aBatch = tf.gather(a, idx) as tf.Tensor1D;
          const rBatch = tf.gather(r, idx) as tf.Tensor1D;
          const dBatch = tf.gather(d, idx) as tf.Tensor1D;
          const wBatch = tf.gather(w, idx) as tf.Tensor1D;

          const lossTensor = qOptimizer.minimize(() => {
            const qAll = qModel.apply(sBatch, { training: true }) as tf.Tensor2D;
            const oneHot = tf.oneHot(aBatch, numActions).asType("float32");
            const qChosen = qAll.mul(oneHot).sum(1);

            const nextQAll = qTarget.apply(nsBatch, { training: false }) as tf.Tensor2D;
            const behaviorProbs = behaviorModel.apply(nsBatch, { training: false }) as tf.Tensor2D;
            const behaviorMask = behaviorProbs.greaterEqual(tf.scalar(behaviorThreshold)).asType("float32");
            const bestBehaviorIdx = behaviorProbs.argMax(1).asType("int32");
            const bestBehaviorOneHot = tf.oneHot(bestBehaviorIdx, numActions).asType("float32");
            const allowedMask = tf.maximum(behaviorMask, bestBehaviorOneHot);

            const maskedNextQ = nextQAll.mul(allowedMask).add(tf.scalar(-1e9).mul(tf.scalar(1).sub(allowedMask)));
            const maxNextQ = maskedNextQ.max(1);

            const bootstrap = maxNextQ.mul(tf.scalar(gamma)).mul(tf.scalar(1).sub(dBatch));
            const tdTarget = rBatch.add(bootstrap);
            const tdError = qChosen.sub(tdTarget);

            const tdLossPer =
              opts.loss === "mse"
                ? tdError.square()
                : (() => {
                    const abs = tdError.abs();
                    const quadratic = abs.minimum(tf.scalar(1));
                    const linear = abs.sub(quadratic);
                    return quadratic.square().mul(0.5).add(linear);
                  })();

            return tdLossPer.mul(wBatch).mean();
          }, true);

          if (!lossTensor) throw new Error("optimizer failed for BCQ Q training");
          const lossValue = Number(lossTensor.dataSync()[0] ?? 0);
          lossTensor.dispose();
          return lossValue;
        });
        epochLoss += batchLoss;
        batchCount += 1;
      }

      if ((epoch + 1) % targetUpdateEvery === 0) cloneWeightsTo(qModel, qTarget);

      qMetrics.push({
        epoch: behaviorMetrics.length + epoch + 1,
        loss: epochLoss / Math.max(1, batchCount),
        phase: "bcq_q",
      });

      const latestLoss = qMetrics[qMetrics.length - 1]?.loss ?? Number.POSITIVE_INFINITY;
      if (latestLoss < bestQLoss - qEarlyStoppingMinDelta) {
        bestQLoss = latestLoss;
        staleQEpochs = 0;
      } else {
        staleQEpochs += 1;
      }
      if (qEarlyStoppingPatience > 0 && staleQEpochs >= qEarlyStoppingPatience) {
        break;
      }
      await tf.nextFrame();
    }
  } finally {
    s.dispose();
    ns.dispose();
    a.dispose();
    r.dispose();
    d.dispose();
    w.dispose();
    qTarget.dispose();
    behaviorOptimizer.dispose();
    qOptimizer.dispose();
  }

  const policy = new MlpInventoryPolicy({
    mode: "bcq",
    stateDim,
    actionDim,
    maxOrderPerSku: opts.maxOrderPerSku,
    skuIds,
    qModel,
    behaviorModel,
    actionBank,
    behaviorThreshold,
    qHidden,
    behaviorHidden,
  });

  return {
    policy,
    metrics: [...behaviorMetrics, ...qMetrics],
  };
}

// Backward-compatible exports
export const trainOfflineBcDqnPolicy = trainOfflineBcqPolicy;
export const LinearInventoryPolicy = MlpInventoryPolicy;
export const trainOfflineLinearPolicy = trainOfflineMlpPolicy;
