/**
 * @license
 * Copyright 2023 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */
/// <amd-module name="@tensorflow/tfjs-layers/dist/layers/nlp/models/gpt2/gpt2_causal_lm" />
/**
 * GPT2 Causal LM (Language Model).
 */
import { NamedTensorMap, Tensor, serialization } from '@tensorflow/tfjs-core';
import { GPT2Preprocessor } from './gpt2_preprocessor';
import { GenerativeTask } from '../generative_task';
import { GPT2Backbone } from './gpt2_backbone';
import { PipelineModelArgs } from '../../utils';
export declare interface GPT2CausalLMArgs extends PipelineModelArgs {
    /**
     * A `GPT2Backbone` instance.
     */
    backbone: GPT2Backbone;
    /**
     * Optional `GPT2CausalLMPreprocessor`.
     * If `null`, this model will not apply preprocessing, and inputs should be
     * preprocessed before calling the model.
     */
    preprocessor?: GPT2Preprocessor;
}
/**
 * An end-to-end GPT2 model for causal language modeling.
 *
 * A causal language model (LM) predicts the next token based on previous
 * tokens. This task setup can be used to train the model unsupervised on
 * plain text input, or to autoregressively generate plain text similar to
 * the data used for training. This task can be used for pre-training or
 * fine-tuning a GPT-2 model, simply by calling `fit()`.
 *
 * This model has a `generate()` method, which generates text based on a
 * prompt. The generation strategy used is controlled by an additional
 * sampler` argument on `compile()`.
 * By default, the top k results will be returned.
 *
 * This model can optionally be configured with a `preprocessor` layer, in
 * which case it will automatically apply preprocessing to string inputs during
 * fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
 * when creating the model with `fromPreset()`.
 *
 * Disclaimer: Pre-trained models are provided on an "as is" basis, without
 * warranties or conditions of any kind. The underlying model is provided by a
 * third party and subject to a separate license, available
 * here](https://github.com/openai/gpt-2).
 *
 * Use `generate()` to do text generation.
 * ```js
 * const gpt2LM = GPT2CausalLM.fromPreset('gpt2_base_en');
 * gpt2LM.generate("I want to say", max_length=30);
 * // Generate with batched prompts.
 * gpt2LM.generate(["This is a", "Where are you"], max_length=30);
 * ```
 *
 * Use `generate()` without preprocessing.
 * ```js
 * // Prompt the model with `5338, 318` (the token ids for `"Who is"`).
 * // Use `"paddingMask"` to indicate values that should not be overridden.
 * const prompt = {
 *  tokenIds: tf.tensor([[5338, 318, 0, 0, 0], [5338, 318, 0, 0, 0]]),
 *  paddingMask: tf.tensor([[1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]]),
 * };
 * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null);
 * gpt2LM.generate(prompt);
 * ```
 *
 * Call `fit()` on a single batch.
 * ```js
 * const features = ['The quick brown fox jumped.', 'I forgot my homework.'];
 * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en');
 * gpt2LM.fit(features, {batchSize: 2});
 * ```
 *
 * Call `fit()` without preprocessing.
 * ```js
 * const x = {
 *  tokenIds: tf.tensor([[50256, 1, 2, 3, 4], [50256, 1, 2, 3, 4]]),
 *  paddingMask: tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]),
 * };
 * const y = tf.tensor([[1, 2, 3, 4, 50256], [1, 2, 3, 4, 50256]]);
 * const sw = tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]);
 * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null);
 * gpt2LM.fit(x, y, {sampleWeight: sw, batchSize: 2});
 * ```
 *
 * Custom backbone and vocabulary.
 * ```js
 * const features = ["a quick fox.", "a fox quick."];
 * const vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6};
 * const merges = [
 *  "Ġ q", "u i", "c k", "ui ck", "Ġq uick", "Ġ f", "o x", "Ġf ox"
 * ];
 * const tokenizer = new GPT2Tokenizer({vocabulary: vocab, merges});
 * const preprocessor =  new GPT2CausalLMPreprocessor({
 *  tokenizer,
 *  sequence_length: 128,
 * });
 * const backbone = new GPT2Backbone({
 *  vocabularysize: 30552,
 *  numlayers: 4,
 *  numheads: 4,
 *  hiddendim: 256,
 *  intermediatedim: 512,
 *  maxSequenceLength: 128,
 * });
 * const gpt2LM = new GPT2CausalLM({backbone, preprocessor});
 * gpt2LM.fit(features, {batch_size: 2});
 * ```
 */
export declare class GPT2CausalLM extends GenerativeTask {
    /** @nocollapse */
    static className: string;
    constructor(args: GPT2CausalLMArgs);
    static presets<T extends serialization.Serializable>(cls: serialization.SerializableConstructor<T>): {};
    /**
     * Forward pass of `GPT2CausalLM` with cache.
     *
     * `callWithCache` adds an additional forward pass for the model for
     * autoregressive inference. Unlike calling the model directly, this method
     * allows caching previous key/value Tensors in multi-head attention layer,
     * and avoids recomputing the outputs of seen tokens.
     *
     * @param tokenIds a dense int Tensor with shape `[batchSize, maxLength]`.
     * @param cache a dense float Tensor, the cache of key and value.
     * @param cacheUpdateIndex Integer. The index of current inputs in the whole
     *  sequence.
     * @returns [logits, hiddenStates, cache], where `logits` is the
     *  language model logits for the input tokenIds, `hiddenStates` is
     *  the final hidden representation of the input tokens, and `cache` is
     *  the decoding cache.
     */
    callWithCache(tokenIds: Tensor, cache: Tensor, cacheUpdateIndex: number): [Tensor, Tensor, Tensor];
    /**
     * Build an empty cache for use with `callWithCache()`.
     */
    private buildCache;
    /**
     * A compilable generation function for a single batch of inputs.
     *
     * This function represents the inner generation function for a single batch
     *  of inputs.
     *
     * @param inputs An object with two keys `tokenIds` and `paddingMask` and
     *  batched tensor values.
     * @param endTokenId The id of the end token to stop on. If all
     *  sequences have produced a new `endTokenId`, generation will stop.
     */
    generateStep(inputs: NamedTensorMap, endTokenId: number): NamedTensorMap;
}
