/**
 * @license
 * Copyright 2020 Google Inc. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import {getCustomConverterOpsModule, getCustomModuleString} from './custom_module';
import {CustomTFJSBundleConfig, ImportProvider} from './types';

const mockImportProvider: ImportProvider = {
  importCoreStr: () => 'import CORE',
  importConverterStr: () => 'import CONVERTER',
  importBackendStr: (name: string) => `import BACKEND ${name}`,
  importKernelStr: (kernelName: string, backend: string) => {
    const importPath = `${
        kernelName === 'Invalid' ? 'BACKEND_Invalid' : 'BACKEND'} ${backend}`;
    return {
      importPath,
      importStatement: `import KERNEL ${kernelName} from ${importPath}`,
      kernelConfigId: `${kernelName}_${backend}`
    };
  },
  importGradientConfigStr: (kernel: string) => {
    const importPath = kernel === 'Invalid' ? 'BACKEND_Invalid' : 'BACKEND';
    return {
      importPath,
      importStatement: `import GRADIENT ${kernel} from ${importPath}`,
      gradConfigId: `${kernel}_GRAD_CONFIG`,
    };
  },
  importOpForConverterStr: (opSymbol: string) => {
    return `export * from ${opSymbol}`;
  },
  importNamespacedOpsForConverterStr: (
      namespace: string, opSymbols: string[]) => {
    return `export ${opSymbols.join(',')} as ${namespace} from ${namespace}/`;
  },
  validateImportPath: (importPath: string) => {
    return !importPath.includes('Invalid');
  }
};

describe('getCustomModuleString forwardModeOnly=true', () => {
  const forwardModeOnly = true;
  it('one kernel, one backend', () => {
    const config = {
      kernels: ['MathKrnl', 'Invalid'],
      backends: ['FastBcknd'],
      models: [] as string[],
      forwardModeOnly
    };
    const {tfjs, core} = getCustomModuleString(
        // cast because FastBcknd is not a valid backend per the type
        config as CustomTFJSBundleConfig, mockImportProvider);

    expect(core).toContain('import CORE');
    expect(tfjs).toContain('import CORE');

    expect(tfjs).toContain('import BACKEND FastBcknd');
    expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
    expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');
    expect(tfjs).not.toContain('import KERNEL Invalid from BACKEND FastBcknd');
    expect(tfjs).not.toContain('registerKernel(Invalid_FastBcknd)');

    expect(tfjs).not.toContain('GRADIENT');
  });

  it('one kernel, one backend, one model', () => {
    const config = {
      kernels: ['MathKrnl'],
      backends: ['FastBcknd'],
      models: ['model1.json'],
      forwardModeOnly
    };
    const {tfjs, core} = getCustomModuleString(
        // cast because FastBcknd is not a valid backend per the type
        config as CustomTFJSBundleConfig, mockImportProvider);

    expect(core).toContain('import CORE');
    expect(tfjs).toContain('import CORE');
    expect(tfjs).toContain('import CONVERTER');

    expect(tfjs).toContain('import BACKEND FastBcknd');
    expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
    expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');

    expect(tfjs).not.toContain('GRADIENT');
  });

  it('one kernel, two backend', () => {
    const config = {
      kernels: ['MathKrnl'],
      backends: ['FastBcknd', 'SlowBcknd'],
      models: [] as string[],
      forwardModeOnly
    };

    const {tfjs} = getCustomModuleString(
        // cast because the backends are not truly valid backend per the type
        config as CustomTFJSBundleConfig, mockImportProvider);

    expect(tfjs).toContain('import CORE');

    expect(tfjs).toContain('import BACKEND FastBcknd');
    expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
    expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');

    expect(tfjs).toContain('import BACKEND SlowBcknd');
    expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND SlowBcknd');
    expect(tfjs).toContain('registerKernel(MathKrnl_SlowBcknd)');

    expect(tfjs).not.toContain('GRADIENT');
  });

  it('two kernels, one backend', () => {
    const config = {
      kernels: ['MathKrnl', 'MathKrn2'],
      backends: ['FastBcknd'],
      models: [] as string[],
      forwardModeOnly
    };
    const {tfjs} = getCustomModuleString(
        config as CustomTFJSBundleConfig, mockImportProvider);

    expect(tfjs).toContain('import CORE');

    expect(tfjs).toContain('import BACKEND FastBcknd');
    expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
    expect(tfjs).toContain('import KERNEL MathKrn2 from BACKEND FastBcknd');
    expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');
    expect(tfjs).toContain('registerKernel(MathKrn2_FastBcknd)');

    expect(tfjs).not.toContain('GRADIENT');
  });

  it('two kernels, two backends', () => {
    const config = {
      kernels: ['MathKrnl', 'MathKrn2'],
      backends: ['FastBcknd', 'SlowBcknd'],
      models: [] as string[],
      forwardModeOnly
    };
    const {tfjs} = getCustomModuleString(
        config as CustomTFJSBundleConfig, mockImportProvider);

    expect(tfjs).toContain('import CORE');

    expect(tfjs).toContain('import BACKEND FastBcknd');
    expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
    expect(tfjs).toContain('import KERNEL MathKrn2 from BACKEND FastBcknd');
    expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');
    expect(tfjs).toContain('registerKernel(MathKrn2_FastBcknd)');

    expect(tfjs).toContain('import BACKEND SlowBcknd');
    expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND SlowBcknd');
    expect(tfjs).toContain('import KERNEL MathKrn2 from BACKEND SlowBcknd');
    expect(tfjs).toContain('registerKernel(MathKrnl_SlowBcknd)');
    expect(tfjs).toContain('registerKernel(MathKrn2_SlowBcknd)');

    expect(tfjs).not.toContain('GRADIENT');
  });
});

describe('getCustomModuleString forwardModeOnly=false', () => {
  const forwardModeOnly = false;

  it('one kernel, one backend', () => {
    const config = {
      kernels: ['MathKrnl', 'Invalid'],
      backends: ['FastBcknd'],
      models: [] as string[],
      forwardModeOnly
    };

    const {tfjs} = getCustomModuleString(
        config as CustomTFJSBundleConfig, mockImportProvider);

    expect(tfjs).toContain('import CORE');

    expect(tfjs).toContain('import BACKEND FastBcknd');
    expect(tfjs).toContain('import KERNEL MathKrnl from BACKEND FastBcknd');
    expect(tfjs).toContain('registerKernel(MathKrnl_FastBcknd)');
    expect(tfjs).not.toContain('import KERNEL Invalid from BACKEND FastBcknd');
    expect(tfjs).not.toContain('registerKernel(Invalid_FastBcknd)');

    expect(tfjs).toContain('import GRADIENT MathKrnl');
    expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)');
    expect(tfjs).not.toContain('import GRADIENT Invalid');
    expect(tfjs).not.toContain('registerKernel(Invalid_GRAD_CONFIG)');
  });

  it('one kernel, two backend', () => {
    const config = {
      kernels: ['MathKrnl'],
      backends: ['FastBcknd', 'SlowBcknd'],
      models: [] as string[],
      forwardModeOnly
    };

    const {tfjs} = getCustomModuleString(
        config as CustomTFJSBundleConfig, mockImportProvider);

    expect(tfjs).toContain('import GRADIENT MathKrnl');
    expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)');

    const gradIndex = tfjs.indexOf('GRADIENT');
    expect(tfjs.indexOf('GRADIENT', gradIndex + 1))
        .toBe(-1, `Gradient import appears twice in:\n ${tfjs}`);
  });

  it('two kernels, one backend', () => {
    const config = {
      kernels: ['MathKrnl', 'MathKrn2'],
      backends: ['FastBcknd'],
      models: [] as string[],
      forwardModeOnly
    };

    const {tfjs} = getCustomModuleString(
        config as CustomTFJSBundleConfig, mockImportProvider);

    expect(tfjs).toContain('import GRADIENT MathKrnl');
    expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)');

    expect(tfjs).toContain('import GRADIENT MathKrn2');
    expect(tfjs).toContain('registerGradient(MathKrn2_GRAD_CONFIG)');
  });

  it('two kernels, two backends', () => {
    const config = {
      kernels: ['MathKrnl', 'MathKrn2'],
      backends: ['FastBcknd', 'SlowBcknd'],
      models: [] as string[],
      forwardModeOnly
    };

    const {tfjs} = getCustomModuleString(
        config as CustomTFJSBundleConfig, mockImportProvider);

    expect(tfjs).toContain('import GRADIENT MathKrnl');
    expect(tfjs).toContain('registerGradient(MathKrnl_GRAD_CONFIG)');

    expect(tfjs).toContain('import GRADIENT MathKrn2');
    expect(tfjs).toContain('registerGradient(MathKrn2_GRAD_CONFIG)');
  });
});

describe('getCustomConverterOpsModule', () => {
  it('non namespaced ops', () => {
    const result =
        getCustomConverterOpsModule(['add', 'sub'], mockImportProvider);

    expect(result).toContain('export * from add');
    expect(result).toContain('export * from sub');
  });

  it('namespaced ops', () => {
    const result = getCustomConverterOpsModule(
        ['image.resizeBilinear', 'image.resizeNearestNeighbor'],
        mockImportProvider);

    expect(result).toContain(
        'export resizeBilinear,resizeNearestNeighbor as image from image/');
  });
});
