/**
 * @license
 * Copyright 2020 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import * as fs from 'fs';
import * as path from 'path';

import {getCustomConverterOpsModule, getCustomModuleString} from './custom_module';
import {getOpsForConfig} from './model_parser';
import {CustomTFJSBundleConfig, ImportProvider, ModuleProvider, SupportedBackend} from './types';
import {bail, kernelNameToVariableName, opNameToFileName} from './util';

export function getModuleProvider(opts: {}): ModuleProvider {
  return new ESMModuleProvider();
}

class ESMModuleProvider implements ModuleProvider {
  /**
   * Writes out custom tfjs module(s) to disk.
   */
  produceCustomTFJSModule(config: CustomTFJSBundleConfig) {
    const {normalizedOutputPath} = config;

    const moduleStrs = getCustomModuleString(config, esmImportProvider);

    fs.mkdirSync(normalizedOutputPath, {recursive: true});
    console.log(`Will write custom tfjs modules to ${normalizedOutputPath}`);

    const customTfjsFileName = 'custom_tfjs.js';
    const customTfjsCoreFileName = 'custom_tfjs_core.js';

    // Write a custom module for @tensorflow/tfjs and @tensorflow/tfjs-core
    fs.writeFileSync(
        path.join(normalizedOutputPath, customTfjsCoreFileName),
        moduleStrs.core);
    fs.writeFileSync(
        path.join(normalizedOutputPath, customTfjsFileName), moduleStrs.tfjs);

    // Write a custom module tfjs-core ops used by converter executors

    let kernelToOps;
    let mappingPath;
    try {
      mappingPath =
          require.resolve('@tensorflow/tfjs-converter/metadata/kernel2op.json');
      kernelToOps = JSON.parse(fs.readFileSync(mappingPath, 'utf-8'));
    } catch (e) {
      bail(`Error loading kernel to ops mapping file ${mappingPath}`);
    }

    const converterOps = getOpsForConfig(config, kernelToOps);
    if (converterOps.length > 0) {
      const converterOpsModule =
          getCustomConverterOpsModule(converterOps, esmImportProvider);

      const customConverterOpsFileName = 'custom_ops_for_converter.js';

      fs.writeFileSync(
          path.join(normalizedOutputPath, customConverterOpsFileName),
          converterOpsModule);
    }
  }
}

/**
 * An import provider to generate custom esm modules.
 */
// Exported for tests.
export const esmImportProvider: ImportProvider = {
  importCoreStr(forwardModeOnly: boolean) {
    const importLines = [
      `import {registerKernel} from '@tensorflow/tfjs-core/dist/base';`,
      `import '@tensorflow/tfjs-core/dist/base_side_effects';`,
      `export * from '@tensorflow/tfjs-core/dist/base';`
    ];

    if (!forwardModeOnly) {
      importLines.push(
          `import {registerGradient} from '@tensorflow/tfjs-core/dist/base';`);
    }
    return importLines.join('\n');
  },

  importConverterStr() {
    return `export * from '@tensorflow/tfjs-converter';`;
  },

  importBackendStr(backend: SupportedBackend) {
    const backendPkg = getBackendPath(backend);
    return `export * from '${backendPkg}/dist/base';`;
  },

  importKernelStr(kernelName: string, backend: SupportedBackend) {
    const backendPkg = getBackendPath(backend);
    const kernelConfigId = `${kernelName}_${backend}`;
    const importPath = `${backendPkg}/dist/kernels/${kernelName}`;
    const importStatement =
        `import {${kernelNameToVariableName(kernelName)}Config as ${
            kernelConfigId}} from '${importPath}';`;
    return {importPath, importStatement, kernelConfigId};
  },

  importGradientConfigStr(kernelName: string) {
    const gradConfigId = `${kernelNameToVariableName(kernelName)}GradConfig`;
    const importPath =
        `@tensorflow/tfjs-core/dist/gradients/${kernelName}_grad`;
    const importStatement = `import {${gradConfigId}} from '${importPath}';`;
    return {importPath, importStatement, gradConfigId};
  },

  importOpForConverterStr(opSymbol) {
    const opFileName = opNameToFileName(opSymbol);
    return `export {${opSymbol}} from '@tensorflow/tfjs-core/dist/ops/${
        opFileName}';`;
  },

  importNamespacedOpsForConverterStr(namespace, opSymbols) {
    const result: string[] = [];

    for (const opSymbol of opSymbols) {
      const opFileName = opNameToFileName(opSymbol);
      const opAlias = `${opSymbol}_${namespace}`;
      result.push(`import {${opSymbol} as ${
          opAlias}} from '@tensorflow/tfjs-core/dist/ops/${namespace}/${
          opFileName}';`);
    }

    result.push(`export const ${namespace} = {`);
    for (const opSymbol of opSymbols) {
      const opAlias = `${opSymbol}_${namespace}`;
      result.push(`\t${opSymbol}: ${opAlias},`);
    }
    result.push(`};`);

    return result.join('\n');
  },

  validateImportPath(importPath: string): boolean {
    try {
      require.resolve(importPath);
      return true;
    } catch (e) {
      return false;
    }
  }
};

function getBackendPath(backend: SupportedBackend) {
  switch (backend) {
    case 'cpu':
      return '@tensorflow/tfjs-backend-cpu';
    case 'webgl':
      return '@tensorflow/tfjs-backend-webgl';
    case 'wasm':
      return '@tensorflow/tfjs-backend-wasm';
    default:
      throw new Error(`Unsupported backend ${backend}`);
  }
}
