import { defaultInventoryEnvConfig } from "@/lib/inventoryRl/catalog";
import {
  buildTrainingSamplesFromRows,
  buildTrainingTransitionsFromRows,
  evaluateExpertPolicyFromRows,
  parseInventoryDatasetCsv,
  parseInventoryDatasetJson,
} from "@/lib/inventoryRl/ingestion";
import { trainOfflineDdpgPolicy } from "@/lib/inventoryRl/offlineTrainer";
import type {
  ExpertPolicyEvaluation,
  InventoryEnvConfig,
  SerializedInventoryPolicy,
  TrainerOptions,
  ValidationResult,
} from "@/lib/inventoryRl/types";

export type TrainingPipelineResult = {
  validation: ValidationResult;
  samples: number;
  metrics: { epoch: number; loss: number }[];
  expertEvaluation?: ExpertPolicyEvaluation;
  modelJson?: SerializedInventoryPolicy;
};

export async function trainPolicyFromUploadedCsv(args: {
  csvText: string;
  config?: Partial<InventoryEnvConfig>;
  trainer?: Partial<TrainerOptions>;
}): Promise<TrainingPipelineResult> {
  const config = { ...defaultInventoryEnvConfig(), ...args.config };
  const parsed = parseInventoryDatasetCsv(args.csvText, config);
  if (!parsed.validation.valid) {
    return {
      validation: parsed.validation,
      samples: 0,
      metrics: [],
    };
  }

  const samples = buildTrainingSamplesFromRows(parsed.rows, config);
  const transitions = buildTrainingTransitionsFromRows(parsed.rows, config);
  const expertEvaluation = evaluateExpertPolicyFromRows(parsed.rows, config);
  const trained = await trainOfflineDdpgPolicy({
    samples,
    transitions,
    skuIds: config.skus.map((s) => s.id),
    options: {
      maxOrderPerSku: config.maxOrderPerSku,
      ...args.trainer,
    },
  });
  return {
    validation: parsed.validation,
    samples: samples.length,
    metrics: trained.metrics,
    expertEvaluation,
    modelJson: trained.policy.toJSON(),
  };
}

export async function trainPolicyFromUploadedJson(args: {
  rowsJson: unknown;
  config?: Partial<InventoryEnvConfig>;
  trainer?: Partial<TrainerOptions>;
}): Promise<TrainingPipelineResult> {
  const config = { ...defaultInventoryEnvConfig(), ...args.config };
  const parsed = parseInventoryDatasetJson(args.rowsJson, config);
  if (!parsed.validation.valid) {
    return {
      validation: parsed.validation,
      samples: 0,
      metrics: [],
    };
  }
  const samples = buildTrainingSamplesFromRows(parsed.rows, config);
  const transitions = buildTrainingTransitionsFromRows(parsed.rows, config);
  const expertEvaluation = evaluateExpertPolicyFromRows(parsed.rows, config);
  const trained = await trainOfflineDdpgPolicy({
    samples,
    transitions,
    skuIds: config.skus.map((s) => s.id),
    options: {
      maxOrderPerSku: config.maxOrderPerSku,
      ...args.trainer,
    },
  });
  return {
    validation: parsed.validation,
    samples: samples.length,
    metrics: trained.metrics,
    expertEvaluation,
    modelJson: trained.policy.toJSON(),
  };
}
