/**
 * @license
 * Copyright 2018 Google LLC
 *
 * Use of this source code is governed by an MIT-style
 * license that can be found in the LICENSE file or at
 * https://opensource.org/licenses/MIT.
 * =============================================================================
 */
/// <amd-module name="@tensorflow/tfjs-layers/dist/initializers" />
import { DataType, serialization, Tensor } from '@tensorflow/tfjs-core';
import { Shape } from './keras_format/common';
import { Distribution, FanMode } from './keras_format/initializer_config';
export declare function checkFanMode(value?: string): void;
export declare function checkDistribution(value?: string): void;
/**
 * Initializer base class.
 *
 * @doc {
 *   heading: 'Initializers', subheading: 'Classes', namespace: 'initializers'}
 */
export declare abstract class Initializer extends serialization.Serializable {
    fromConfigUsesCustomObjects(): boolean;
    /**
     * Generate an initial value.
     * @param shape
     * @param dtype
     * @return The init value.
     */
    abstract apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export declare class Zeros extends Initializer {
    /** @nocollapse */
    static className: string;
    apply(shape: Shape, dtype?: DataType): Tensor;
}
export declare class Ones extends Initializer {
    /** @nocollapse */
    static className: string;
    apply(shape: Shape, dtype?: DataType): Tensor;
}
export interface ConstantArgs {
    /** The value for each element in the variable. */
    value: number;
}
export declare class Constant extends Initializer {
    /** @nocollapse */
    static className: string;
    private value;
    constructor(args: ConstantArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface RandomUniformArgs {
    /** Lower bound of the range of random values to generate. */
    minval?: number;
    /** Upper bound of the range of random values to generate. */
    maxval?: number;
    /** Used to seed the random generator. */
    seed?: number;
}
export declare class RandomUniform extends Initializer {
    /** @nocollapse */
    static className: string;
    readonly DEFAULT_MINVAL = -0.05;
    readonly DEFAULT_MAXVAL = 0.05;
    private minval;
    private maxval;
    private seed;
    constructor(args: RandomUniformArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface RandomNormalArgs {
    /** Mean of the random values to generate. */
    mean?: number;
    /** Standard deviation of the random values to generate. */
    stddev?: number;
    /** Used to seed the random generator. */
    seed?: number;
}
export declare class RandomNormal extends Initializer {
    /** @nocollapse */
    static className: string;
    readonly DEFAULT_MEAN = 0;
    readonly DEFAULT_STDDEV = 0.05;
    private mean;
    private stddev;
    private seed;
    constructor(args: RandomNormalArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface TruncatedNormalArgs {
    /** Mean of the random values to generate. */
    mean?: number;
    /** Standard deviation of the random values to generate. */
    stddev?: number;
    /** Used to seed the random generator. */
    seed?: number;
}
export declare class TruncatedNormal extends Initializer {
    /** @nocollapse */
    static className: string;
    readonly DEFAULT_MEAN = 0;
    readonly DEFAULT_STDDEV = 0.05;
    private mean;
    private stddev;
    private seed;
    constructor(args: TruncatedNormalArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface IdentityArgs {
    /**
     * Multiplicative factor to apply to the identity matrix.
     */
    gain?: number;
}
export declare class Identity extends Initializer {
    /** @nocollapse */
    static className: string;
    private gain;
    constructor(args: IdentityArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface VarianceScalingArgs {
    /** Scaling factor (positive float). */
    scale?: number;
    /** Fanning mode for inputs and outputs. */
    mode?: FanMode;
    /** Probabilistic distribution of the values. */
    distribution?: Distribution;
    /** Random number generator seed. */
    seed?: number;
}
export declare class VarianceScaling extends Initializer {
    /** @nocollapse */
    static className: string;
    private scale;
    private mode;
    private distribution;
    private seed;
    /**
     * Constructor of VarianceScaling.
     * @throws ValueError for invalid value in scale.
     */
    constructor(args: VarianceScalingArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
export interface SeedOnlyInitializerArgs {
    /** Random number generator seed. */
    seed?: number;
}
export declare class GlorotUniform extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    /**
     * Constructor of GlorotUniform
     * @param scale
     * @param mode
     * @param distribution
     * @param seed
     */
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export declare class GlorotNormal extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    /**
     * Constructor of GlorotNormal.
     * @param scale
     * @param mode
     * @param distribution
     * @param seed
     */
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export declare class HeNormal extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export declare class HeUniform extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export declare class LeCunNormal extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export declare class LeCunUniform extends VarianceScaling {
    /** @nocollapse */
    static className: string;
    constructor(args?: SeedOnlyInitializerArgs);
    getClassName(): string;
}
export interface OrthogonalArgs extends SeedOnlyInitializerArgs {
    /**
     * Multiplicative factor to apply to the orthogonal matrix. Defaults to 1.
     */
    gain?: number;
}
export declare class Orthogonal extends Initializer {
    /** @nocollapse */
    static className: string;
    readonly DEFAULT_GAIN = 1;
    readonly ELEMENTS_WARN_SLOW = 2000;
    protected readonly gain: number;
    protected readonly seed: number;
    constructor(args?: OrthogonalArgs);
    apply(shape: Shape, dtype?: DataType): Tensor;
    getConfig(): serialization.ConfigDict;
}
/** @docinline */
export type InitializerIdentifier = 'constant' | 'glorotNormal' | 'glorotUniform' | 'heNormal' | 'heUniform' | 'identity' | 'leCunNormal' | 'leCunUniform' | 'ones' | 'orthogonal' | 'randomNormal' | 'randomUniform' | 'truncatedNormal' | 'varianceScaling' | 'zeros' | string;
export declare const INITIALIZER_IDENTIFIER_REGISTRY_SYMBOL_MAP: {
    [identifier in InitializerIdentifier]: string;
};
export declare function serializeInitializer(initializer: Initializer): serialization.ConfigDictValue;
export declare function getInitializer(identifier: InitializerIdentifier | Initializer | serialization.ConfigDict): Initializer;
