Transformers.js 文件

utils/generation

您正在檢視的是需要從原始碼安裝。如果您想進行常規 npm 安裝,請檢視最新的穩定版本 (v3.0.0)。
Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

utils/generation

用於生成任務的類、函式和實用工具。

待辦

  • 描述如何建立自定義 GenerationConfig

utils/generation.LogitsProcessorList ⇐ <code> Callable </code>

一個表示 logits 處理器列表的類。logits 處理器是修改語言模型輸出 logits 的函式。該類提供了新增新處理器並將所有處理器應用於一批 logits 的方法。

型別utils/generation 的靜態類
繼承Callable


new LogitsProcessorList()

構造 LogitsProcessorList 的新例項。


logitsProcessorList.push(item)

向列表新增新的 logits 處理器。

型別LogitsProcessorList 的例項方法

引數量型別描述
itemLogitsProcessor

要新增的 logits 處理器函式。


logitsProcessorList.extend(items)

向列表新增多個 logits 處理器。

型別LogitsProcessorList 的例項方法

引數量型別描述
itemsArray.<LogitsProcessor>

要新增的 logits 處理器函式。


logitsProcessorList._call(input_ids, batchedLogits)

將列表中所有的 logits 處理器應用於一批 logits,並就地修改它們。

型別LogitsProcessorList 的例項方法

引數量型別描述
input_idsArray.<number>

語言模型的輸入 ID。

batchedLogitsArray.<Array<number>>

一個二維 logits 陣列,其中每行對應於批處理中的單個輸入序列。


utils/generation.LogitsProcessor ⇐ <code> Callable </code>

logits 處理器的基類。

型別utils/generation 的靜態類
繼承Callable


logitsProcessor._call(input_ids, logits)

將處理器應用於輸入 logits。

型別LogitsProcessor 的例項抽象方法
丟擲:

  • Error 如果子類未實現 `_call`,則丟擲錯誤。
引數量型別描述
input_ids陣列

輸入 ID。

logits張量

要處理的 logits。


utils/generation.ForceTokensLogitsProcessor ⇐ <code> LogitsProcessor </code>

一個強制解碼器生成特定 token 的 logits 處理器。

型別utils/generation 的靜態類
繼承LogitsProcessor


new ForceTokensLogitsProcessor(forced_decoder_ids)

構造 ForceTokensLogitsProcessor 的新例項。

引數量型別描述
forced_decoder_ids陣列

要強制生成的 token ID。


forceTokensLogitsProcessor._call(input_ids, logits) ⇒ <code> Tensor </code>

將處理器應用於輸入 logits。

型別ForceTokensLogitsProcessor 的例項方法
返回Tensor - 處理後的 logits。

引數量型別描述
input_ids陣列

輸入 ID。

logits張量

要處理的 logits。


utils/generation.ForcedBOSTokenLogitsProcessor ⇐ <code> LogitsProcessor </code>

一個 LogitsProcessor,它強制在生成的序列開頭新增 BOS token。

型別utils/generation 的靜態類
繼承LogitsProcessor


new ForcedBOSTokenLogitsProcessor(bos_token_id)

建立一個 ForcedBOSTokenLogitsProcessor。

引數量型別描述
bos_token_id數字

要強制使用的序列開始 token 的 ID。


forcedBOSTokenLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

將 BOS token 強制應用於 logits。

型別ForcedBOSTokenLogitsProcessor 的例項方法
返回Object - 強制使用 BOS token 的 logits。

引數量型別描述
input_ids陣列

輸入 ID。

logitsObject

logits。


utils/generation.ForcedEOSTokenLogitsProcessor ⇐ <code> LogitsProcessor </code>

一個 logits 處理器,它將序列結束 token 的機率強制設定為 1。

型別utils/generation 的靜態類
繼承LogitsProcessor


new ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)

建立一個 ForcedEOSTokenLogitsProcessor。

引數量型別描述
max_length數字

序列的最大長度。

forced_eos_token_idnumber | Array<number>

要強制使用的序列結束 token 的 ID。


forcedEOSTokenLogitsProcessor._call(input_ids, logits)

將處理器應用於 input_ids 和 logits。

型別ForcedEOSTokenLogitsProcessor 的例項方法

引數量型別描述
input_idsArray.<number>

輸入 ID。

logits張量

logits 張量。


utils/generation.SuppressTokensAtBeginLogitsProcessor ⇐ <code> LogitsProcessor </code>

一個 LogitsProcessor,它在 generate 函式開始使用 begin_index 個 token 生成時立即抑制 token 列表。這應確保在生成開始時不會取樣由 begin_suppress_tokens 定義的 token。

型別utils/generation 的靜態類
繼承LogitsProcessor


new SuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)

建立一個 SuppressTokensAtBeginLogitsProcessor。

引數量型別描述
begin_suppress_tokensArray.<number>

要抑制的 token ID。

begin_index數字

在抑制 token 之前要生成的 token 數量。


suppressTokensAtBeginLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

將 BOS token 強制應用於 logits。

型別SuppressTokensAtBeginLogitsProcessor 的例項方法
返回Object - 強制使用 BOS token 的 logits。

引數量型別描述
input_ids陣列

輸入 ID。

logitsObject

logits。


utils/generation.WhisperTimeStampLogitsProcessor ⇐ <code> LogitsProcessor </code>

一個 LogitsProcessor,用於處理向生成的文字新增時間戳。

型別utils/generation 的靜態類
繼承LogitsProcessor


new WhisperTimeStampLogitsProcessor(generate_config)

構造一個新的 WhisperTimeStampLogitsProcessor。

引數量型別描述
generate_configObject

傳遞給 transformer 模型 generate() 方法的配置物件。

generate_config.eos_token_id數字

序列結束 token 的 ID。

generate_config.no_timestamps_token_id數字

用於指示 token 不應包含時間戳的 token ID。

[generate_config.forced_decoder_ids]Array.<Array<number>>

一個由兩個元素組成的陣列,表示強制出現在輸出中的解碼器 ID。每個陣列的第二個元素指示該 token 是否為時間戳。

[generate_config.max_initial_timestamp_index]數字

初始時間戳可以出現的最大索引。


whisperTimeStampLogitsProcessor._call(input_ids, logits) ⇒ <code> Tensor </code>

修改 logits 以處理時間戳 token。

型別WhisperTimeStampLogitsProcessor 的例項方法
返回Tensor - 修改後的 logits。

引數量型別描述
input_ids陣列

輸入 token 序列。

logits張量

模型輸出的 logits。


utils/generation.NoRepeatNGramLogitsProcessor ⇐ <code> LogitsProcessor </code>

一個 logits 處理器,它不允許重複一定大小的 n-gram。

型別utils/generation 的靜態類
繼承LogitsProcessor


new NoRepeatNGramLogitsProcessor(no_repeat_ngram_size)

建立一個 NoRepeatNGramLogitsProcessor。

引數量型別描述
no_repeat_ngram_size數字

不重複 n-gram 的大小。所有此大小的 n-gram 只能出現一次。


noRepeatNGramLogitsProcessor.getNgrams(prevInputIds) ⇒ <code> Map. < string, Array < number > > </code>

從 token ID 序列生成 n-gram。

型別NoRepeatNGramLogitsProcessor 的例項方法
返回Map.<string, Array<number>> - 生成的 n-gram 的對映

引數量型別描述
prevInputIdsArray.<number>

上一個輸入 ID 列表


noRepeatNGramLogitsProcessor.getGeneratedNgrams(bannedNgrams, prevInputIds) ⇒ <code> Array. < number > </code>

從 token ID 序列生成 n-gram。

型別NoRepeatNGramLogitsProcessor 的例項方法
返回Array.<number> - 生成的 n-gram 的對映

引數量型別描述
bannedNgramsMap.<string, Array<number>>

停用 n-gram 的對映

prevInputIdsArray.<number>

上一個輸入 ID 列表


noRepeatNGramLogitsProcessor.calcBannedNgramTokens(prevInputIds) ⇒ <code> Array. < number > </code>

計算禁止的 n-gram token

型別NoRepeatNGramLogitsProcessor 的例項方法
返回Array.<number> - 生成的 n-gram 的對映

引數量型別描述
prevInputIdsArray.<number>

上一個輸入 ID 列表


noRepeatNGramLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

將不重複 n-gram 處理器應用於 logits。

型別NoRepeatNGramLogitsProcessor 的例項方法
返回Object - 經過不重複 n-gram 處理的 logits。

引數量型別描述
input_ids陣列

輸入 ID。

logitsObject

logits。


utils/generation.RepetitionPenaltyLogitsProcessor ⇐ <code> LogitsProcessor </code>

一個對重複輸出 token 進行懲罰的 logits 處理器。

型別utils/generation 的靜態類
繼承LogitsProcessor


new RepetitionPenaltyLogitsProcessor(penalty)

建立一個 RepetitionPenaltyLogitsProcessor。

引數量型別描述
penalty數字

對重複 token 應用的懲罰。


repetitionPenaltyLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

將重複懲罰應用於 logits。

型別RepetitionPenaltyLogitsProcessor 的例項方法
返回Object - 經過重複懲罰處理的 logits。

引數量型別描述
input_ids陣列

輸入 ID。

logitsObject

logits。


utils/generation.MinLengthLogitsProcessor ⇐ <code> LogitsProcessor </code>

一個強制最小 token 數量的 logits 處理器。

型別utils/generation 的靜態類
繼承LogitsProcessor


new MinLengthLogitsProcessor(min_length, eos_token_id)

建立一個 MinLengthLogitsProcessor。

引數量型別描述
min_length數字

當長度低於此值時,eos_token_id 的分數將設定為負無窮大。

eos_token_idnumber | Array<number>

序列結束 token 的 ID。


minLengthLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

應用 logits 處理器。

型別MinLengthLogitsProcessor 的例項方法
返回Object - 處理後的 logits。

引數量型別描述
input_ids陣列

輸入 ID。

logitsObject

logits。


utils/generation.MinNewTokensLengthLogitsProcessor ⇐ <code> LogitsProcessor </code>

一個強制最小新 token 數量的 logits 處理器。

型別utils/generation 的靜態類
繼承LogitsProcessor


new MinNewTokensLengthLogitsProcessor(prompt_length_to_skip, min_new_tokens, eos_token_id)

建立一個 MinNewTokensLengthLogitsProcessor。

引數量型別描述
prompt_length_to_skip數字

輸入 token 長度。

min_new_tokens數字

當新 token 長度低於此值時,eos_token_id 的分數將設定為負無窮大。

eos_token_idnumber | Array<number>

序列結束 token 的 ID。


minNewTokensLengthLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

應用 logits 處理器。

型別MinNewTokensLengthLogitsProcessor 的例項方法
返回Object - 處理後的 logits。

引數量型別描述
input_ids陣列

輸入 ID。

logitsObject

logits。


utils/generation.NoBadWordsLogitsProcessor

型別utils/generation 的靜態類


new NoBadWordsLogitsProcessor(bad_words_ids, eos_token_id)

建立一個 NoBadWordsLogitsProcessor

引數量型別描述
bad_words_idsArray.<Array<number>>

不允許生成的 token ID 列表的列表。

eos_token_idnumber | Array<number>

“序列結束”標記的 ID。可選地,使用列表來設定多個“序列結束”標記。


noBadWordsLogitsProcessor._call(input_ids, logits) ⇒ <code> Object </code>

應用 logits 處理器。

型別NoBadWordsLogitsProcessor 的例項方法
返回Object - 處理後的 logits。

引數量型別描述
input_ids陣列

輸入 ID。

logitsObject

logits。


utils/generation.Sampler

Sampler 是所有用於文字生成的取樣方法的基類。

型別utils/generation 的靜態類


new Sampler(generation_config)

建立具有指定生成配置的新 Sampler 物件。

引數量型別描述
generation_configGenerationConfigType

生成配置。


sampler._call(logits, index) ⇒ <code> void </code>

執行取樣器,使用指定的 logits。

型別Sampler 的例項方法

引數量型別
logits張量
索引數字

sampler.sample(logits, index)

用於取樣 logits 的抽象方法。

型別Sampler 的例項方法
丟擲:

  • 錯誤
引數量型別
logits張量
索引數字

sampler.getLogits(logits, index) ⇒ <code> Float32Array </code>

將指定的 logits 作為陣列返回,並應用了溫度。

型別Sampler 的例項方法

引數量型別
logits張量
索引數字

sampler.randomSelect(probabilities) ⇒ <code> number </code>

根據指定的機率隨機選擇一個專案。

型別Sampler 的例項方法
返回number - 所選專案的索引。

引數量型別描述
probabilities陣列

用於選擇的機率陣列。


Sampler.getSampler(generation_config) ⇒ <code> Sampler </code>

根據指定選項返回一個 Sampler 物件。

型別Sampler 的靜態方法Sampler
返回Sampler - 一個 Sampler 物件。

引數量型別描述
generation_configGenerationConfigType

包含取樣器選項的物件。


utils/generation.GenerationConfig : <code> * </code>

儲存生成任務配置的類。

型別utils/generation 的靜態常量utils/generation


utils/generation~GenerationConfig

型別utils/generation 的內部類utils/generation


new GenerationConfig(kwargs)

建立一個新的 GenerationConfig 物件。

引數量型別
kwargsGenerationConfigType

utils/generation~GreedySampler ⇐ <code> Sampler </code>

表示貪婪取樣器的類。

型別utils/generation 的內部類utils/generation
繼承Sampler


greedySampler.sample(logits, [index]) ⇒ <code> Array </code>

對給定 logits 張量的最大機率進行取樣。

型別GreedySampler 的例項方法GreedySampler
返回Array - 包含一個元組的陣列,該元組包含最大值的索引和無意義的分數(因為這是貪婪搜尋)。

引數量型別預設
logits張量
[index]數字-1

utils/generation~MultinomialSampler ⇐ <code> Sampler </code>

表示多項式取樣器的類。

型別utils/generation 的內部類utils/generation
繼承Sampler


multinomialSampler.sample(logits, index) ⇒ <code> Array </code>

從 logits 中取樣。

型別MultinomialSampler 的例項方法MultinomialSampler

引數量型別
logits張量
索引數字

utils/generation~BeamSearchSampler ⇐ <code> Sampler </code>

表示 BeamSearchSampler 的類。

型別utils/generation 的內部類utils/generation
繼承Sampler


beamSearchSampler.sample(logits, index) ⇒ <code> Array </code>

從 logits 中取樣。

型別BeamSearchSampler 的例項方法BeamSearchSampler

引數量型別
logits張量
索引數字

utils/generation~GenerationConfigType : <code> Object </code>

預設配置引數。

型別utils/generation 的內部型別定義utils/generation
屬性

名稱型別預設描述
[max_length]數字20

生成的令牌可以具有的最大長度。對應於輸入提示的長度 + max_new_tokens。如果也設定了 max_new_tokens,其效果將被覆蓋。

[max_new_tokens]數字

要生成的最大令牌數,忽略提示中的令牌數。

[min_length]數字0

要生成的序列的最小長度。對應於輸入提示的長度 + min_new_tokens。如果也設定了 min_new_tokens,其效果將被覆蓋。

[min_new_tokens]數字

要生成的最小令牌數,忽略提示中的令牌數。

[early_stopping]boolean | "never"false

控制基於束的方法(如束搜尋)的停止條件。它接受以下值

  • true,表示一旦有 num_beams 個完整候選,生成即停止;
  • false,表示應用啟發式方法,當不太可能找到更好的候選時,生成停止;
  • "never",表示束搜尋過程只有在無法找到更好候選時才停止(經典束搜尋演算法)。
[max_time]數字

允許計算執行的最大時間(秒)。生成將在分配時間過後完成當前輪次。

[do_sample]booleanfalse

是否使用取樣;否則使用貪婪解碼。

[num_beams]數字1

束搜尋的束數。1 表示沒有束搜尋。

[num_beam_groups]數字1

num_beams 分成組的數量,以確保不同束組之間的多樣性。有關更多詳細資訊,請參閱本文

[penalty_alpha]數字

這些值平衡了模型置信度和對比搜尋解碼中的退化懲罰。

[use_cache]booleantrue

模型是否應使用過去的最後鍵/值注意力(如果適用於模型)以加快解碼速度。

[temperature]數字1.0

用於調節下一個 token 機率的值。

[top_k]數字50

保留用於 top-k 過濾的最高機率詞彙 token 數量。

[top_p]數字1.0

如果設定為小於 1 的浮點數,則僅保留機率總和達到 top_p 或更高的最小機率令牌集進行生成。

[typical_p]數字1.0

區域性典型性衡量預測下一個目標令牌的條件機率與預測下一個隨機令牌的預期條件機率的相似程度,給定已生成的文字片段。如果設定為小於 1 的浮點數,則保留區域性典型性最高且機率總和達到 typical_p 或更高的最小令牌集進行生成。有關更多詳細資訊,請參閱本文

[epsilon_cutoff]數字0.0

如果設定為嚴格介於 0 和 1 之間的浮點數,則只採樣條件機率大於 epsilon_cutoff 的令牌。在論文中,建議的值範圍從 3e-4 到 9e-4,具體取決於模型的規模。有關更多詳細資訊,請參閱《截斷取樣作為語言模型去平滑》

[eta_cutoff]數字0.0

Eta 取樣是區域性典型性取樣和 epsilon 取樣的混合。如果設定為嚴格介於 0 和 1 之間的浮點數,只有當令牌大於 eta_cutoffsqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))) 時才考慮該令牌。後者直觀地是預期的下一個令牌機率,由 sqrt(eta_cutoff) 進行縮放。在論文中,建議的值範圍從 3e-4 到 2e-3,具體取決於模型的規模。有關更多詳細資訊,請參閱《截斷取樣作為語言模型去平滑》

[diversity_penalty]數字0.0

如果一個束在某個特定時間生成的令牌與來自其他組的任何束相同,則從該束的分數中減去此值。請注意,diversity_penalty 僅在啟用 group beam search 時才有效。

[repetition_penalty]數字1.0

重複懲罰的引數。1.0 表示沒有懲罰。有關更多詳細資訊,請參閱本文

[encoder_repetition_penalty]數字1.0

encoder_repetition_penalty 的引數。對原始輸入中不存在的序列施加指數懲罰。1.0 表示沒有懲罰。

[length_penalty]數字1.0

對基於束的生成中使用的長度施加指數懲罰。它作為序列長度的指數應用,然後用於除以序列的分數。由於分數是序列的對數似然(即負數),length_penalty > 0.0 促進較長的序列,而 length_penalty < 0.0 鼓勵較短的序列。

[no_repeat_ngram_size]數字0

如果設定為大於 0 的整數,則該大小的所有 ngrams 只能出現一次。

[bad_words_ids]Array.<Array<number>>

不允許生成的令牌 ID 列表。要獲取不應出現在生成文字中的單詞的令牌 ID,請使用 (await tokenizer(bad_words, {add_prefix_space: true, add_special_tokens: false})).input_ids

[force_words_ids]Array<Array<number>> | Array<Array<Array<number>>>

必須生成的令牌 ID 列表。如果給定 number[][],則將其視為必須包含的簡單單詞列表,與 bad_words_ids 相反。如果給定 number[][][],則觸發不相容約束,允許每種單詞的不同形式。

[renormalize_logits]booleanfalse

在應用所有 logits 處理器或扭曲器(包括自定義的)之後是否重新歸一化 logits。強烈建議將此標誌設定為 true,因為搜尋演算法假定分數 logits 已歸一化,但某些 logit 處理器或扭曲器會破壞歸一化。

[constraints]Array.<Object>

可以新增到生成中的自定義約束,以確保輸出將以最合理的方式包含由 Constraint 物件定義的某些令牌的使用。

[forced_bos_token_id]數字

decoder_start_token_id 之後強制作為第一個生成的令牌的令牌 ID。對於像 mBART 這樣的多語言模型很有用,其中第一個生成的令牌需要是目標語言令牌。

[forced_eos_token_id]number | Array<number>

當達到 max_length 時強制作為最後一個生成的令牌的令牌 ID。可選地,使用列表設定多個*序列結束*令牌。

[remove_invalid_values]booleanfalse

是否刪除模型可能產生的*NaN*和*inf*輸出,以防止生成方法崩潰。請注意,使用 remove_invalid_values 可能會減慢生成速度。

[exponential_decay_length_penalty]Array.<number>

此元組在生成一定數量的令牌後新增一個指數增長的長度懲罰。該元組應包含:(start_index, decay_factor),其中 start_index 表示懲罰開始的位置,decay_factor 表示指數衰減的因子。

[suppress_tokens]Array.<number>

在生成時將被抑制的令牌列表。SupressTokens logit 處理器將它們的對數機率設定為 -inf,以便它們不被取樣。

[begin_suppress_tokens]Array.<number>

在生成開始時將被抑制的令牌列表。SupressBeginTokens logit 處理器將它們的對數機率設定為 -inf,以便它們不被取樣。

[forced_decoder_ids]Array.<Array<number>>

整數對的列表,表示在取樣之前將被強制的生成索引到令牌索引的對映。例如,[[1, 123]] 表示第二個生成的令牌將始終是索引為 123 的令牌。

[num_return_sequences]數字1

批處理中每個元素獨立計算的返回序列數。

[output_attentions]booleanfalse

是否返回所有注意力層的注意力張量。有關更多詳細資訊,請參閱返回張量下的 attentions

[output_hidden_states]booleanfalse

是否返回所有層的隱藏狀態。有關更多詳細資訊,請參閱返回張量下的 hidden_states

[output_scores]booleanfalse

是否返回預測分數。有關更多詳細資訊,請參閱返回張量下的 scores

[return_dict_in_generate]booleanfalse

是否返回 ModelOutput 而不是普通元組。

[pad_token_id]數字

*填充*令牌的 ID。

[bos_token_id]數字

*序列開始*令牌的 ID。

[eos_token_id]number | Array<number>

“序列結束”標記的 ID。可選地,使用列表來設定多個“序列結束”標記。

[encoder_no_repeat_ngram_size]數字0

如果設定為大於 0 的整數,則 encoder_input_ids 中出現的所有相同大小的 ngrams 不能在 decoder_input_ids 中出現。

[decoder_start_token_id]數字

如果編碼器-解碼器模型使用與*bos*不同的令牌開始解碼,則該令牌的 ID。

[generation_kwargs]Object{}

額外的生成 kwargs 將轉發到模型的 generate 函式。不在 generate 簽名中的 kwargs 將在模型的前向傳播中使用。


< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.