使用推測解碼讓 Whisper 推理速度翻倍

釋出於 2023 年 12 月 20 日
在 GitHub 上更新
Open In Colab

OpenAI 的 Whisper 是一款通用語音轉錄模型,它在一系列不同的基準測試和音訊條件下均取得了最先進的成果。最新的 large-v3 模型在 OpenASR 排行榜上名列前茅,被評為英語領域最佳的開源語音轉錄模型。該模型還表現出強大的多語言效能,在 Common Voice 15 資料集測試的 58 種語言中,有 42 種語言的詞錯誤率 (WER) 低於 30%。

儘管轉錄準確率非常出色,但其推理速度卻很慢。即使利用了 Flash Attention、半精度和分塊 (chunking) 等推理最佳化技術,在 16GB T4 GPU 上轉錄 1 小時的音訊片段也需要 6 分鐘以上。

在這篇博文中,我們將演示如何利用推測解碼 (Speculative Decoding) 將 Whisper 的推理時間減少 2 倍,同時從數學上確保模型獲得完全相同的輸出。因此,該方法可以完美地替代現有的 Whisper 工作流,因為它在保持相同準確率的同時,免費提供了 2 倍的速度提升。如果想看一個解釋更少但包含所有程式碼的精簡版博文,請參閱配套的 Google Colab

推測解碼

推測解碼由 Google 的 Yaniv Leviathan 等人在論文 《Fast Inference from Transformers via Speculative Decoding》 中提出。其工作原理基於這樣一個前提:一個更快的**輔助模型**通常會生成與一個更大的**主模型**相同的詞元 (token)。

首先,輔助模型自迴歸地生成一個包含 N N 個**候選詞元**的序列,y^1:N \hat{\boldsymbol{y}}_{1:N} 。在下圖中,輔助模型生成了 5 個候選詞元組成的序列:The quick brown sock jumps

雖然這些候選詞元生成得很快,但它們可能與主模型預測的詞元不同。因此,在第二步中,這些候選詞元被傳遞給主模型進行“驗證”。主模型將候選詞元作為輸入,並執行**單次前向傳播**。主模型的輸出是詞元序列 y1:N \boldsymbol{y}_{1:N} 中每一步的“正確”詞元。

在上圖中,我們看到主模型預測的前三個詞元與輔助模型的預測一致:The quick brown。然而,輔助模型的第四個候選詞元 sock 與主模型的正確詞元 fox 不匹配。

我們知道,在第一個不匹配出現之前的所有候選詞元(The quick brown)都是正確的,因為它們與主模型的預測一致。然而,在第一個不匹配之後,候選詞元就與主模型預測的實際詞元產生了分歧。因此,我們可以用主模型的正確詞元(fox)替換掉第一個不正確的候選詞元(sock),並丟棄其後所有預測的詞元,因為它們已經偏離了。修正後的序列 The quick brown fox 現在成為輔助模型的新輸入。

然後推理過程重複,輔助模型生成一組新的 N N 個候選詞元,再由主模型透過一次前向傳播進行驗證。

由於我們使用快速的輔助模型進行自迴歸生成,而只用慢速的主模型進行驗證性的前向傳播,解碼過程得以大幅加速。此外,主模型執行的驗證性前向傳播確保了我們能獲得與單獨使用主模型時**完全相同的輸出**。這使得推測解碼成為現有 Whisper 工作流的完美替代品,因為可以確信能達到同樣的質量。

為了最大程度地減少延遲,輔助模型應該比主模型快得多,同時儘可能頻繁地預測出相同的詞元分佈。實際上,這兩個屬性之間存在一種權衡:模型越快,其準確性就越低。然而,由於 70-80% 的預測詞元往往是“較簡單”的詞元,這種權衡傾向於選擇更快的模型,而不是更準確的模型。因此,輔助模型應該比主模型快至少 3 倍(越快越好),同時能正確預測示例中所有“簡單”的詞元。剩下的 20-30% 更“困難”的詞元則可以由更大的主模型來驗證。

選擇輔助模型的唯一限制是它必須與主模型共享相同的詞彙表。也就是說,輔助模型必須使用與主模型完全相同的分詞器 (tokenizer)。因此,如果我們想對 Whisper 的多語言版本(例如 large-v2 (多語言))使用推測解碼,我們需要選擇一個 Whisper 的多語言版本作為輔助模型,例如 tiny。而如果我們想對 Whisper 的純英文版本(例如 medium.en)使用推測解碼,則需要一個純英文版本的輔助模型,例如 tiny.en。目前,Whisper large-v3 是一個例外,因為它是唯一一個詞彙表大小經過擴充套件的 Whisper 模型檢查點,因此與之前的 Whisper 模型檢查點不相容。

現在我們瞭解了推測解碼的背景知識,可以開始進行實際操作了。在 🤗 Transformers 庫中,推測解碼被實現為“輔助生成”(assisted generation) 推理策略。有關該實現的更多詳細資訊,建議讀者閱讀 Joao Gante 撰寫的關於輔助生成的精彩博文。

英語語音轉錄

基準實現

我們首先對 Whisper large-v2 進行基準測試,以獲取推理速度的基準資料。我們可以透過便捷的 AutoModelForSpeechSeq2SeqAutoProcessor 類來載入主模型及其對應的處理器。我們將以 float16 精度載入模型,並透過傳遞 low_cpu_mem_usage=True 來確保載入時間儘可能短。此外,我們希望確保模型以 safetensors 格式載入,因此傳遞 use_safetensors=True。最後,我們將傳遞引數 attn_implementation="sdpa",以透過 PyTorch 的 SDPA 注意力核來利用 Flash Attention 的加速效果。

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v2"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

讓我們載入用於基準測試的英語語音轉錄資料集。我們將從 LibriSpeech ASR validation-clean 資料集中載入一個包含 73 個樣本的小型資料集。這大約相當於 9MB 的資料,因此非常輕量級,可以快速下載到裝置上。

from datasets import load_dataset

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

對於基準測試,我們只想測量生成時間,所以讓我們編寫一個簡短的輔助函式來測量這一步。以下函式將返回解碼後的詞元和執行模型所花費的時間。

import time

def generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

我們現在可以遍歷資料集中的音訊樣本,並累加總生成時間。

from tqdm import tqdm

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)
    
    output, gen_time = generate_with_time(model, inputs)
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["text"]))

print(all_time)

輸出

100%|██████████| 73/73 [01:37<00:00,  1.33s/it]
72.99542546272278

好的!我們看到轉錄 73 個樣本花費了 73 秒。現在來檢查一下預測結果的詞錯誤率 (WER)。

from evaluate import load

wer = load("wer")
print(wer.compute(predictions=predictions, references=references))

輸出

0.03507271171941831

我們的最終基準資料是 73 秒的執行時間和 3.5% 的詞錯誤率。

推測解碼

現在讓我們載入用於推測解碼的輔助模型。在這個例子中,我們將使用 Whisper 的一個蒸餾變體,distil-large-v2。這個蒸餾模型複製了 Whisper 的整個編碼器,但解碼器層數從 32 層減少到了 2 層。因此,它的執行速度比 Whisper 快 6 倍,而在非分佈測試集上的詞錯誤率 (WER) 僅相差 1% 以內。這使其成為輔助模型的完美選擇,因為它兼具高轉錄準確率和快速生成速度的優點1{}^1

由於 Distil-Whisper 使用與 Whisper 模型完全相同的編碼器,我們可以在主模型和輔助模型之間共享編碼器。這樣,我們只需要將 Distil-Whisper 的 2 層解碼器作為“僅解碼器”模型載入。我們可以透過便捷的 AutoModelForCausalLM auto 類來完成此操作。實際上,這隻會比單獨使用主模型增加 8% 的視訊記憶體佔用。

from transformers import AutoModelForCausalLM

assistant_model_id = "distil-whisper/distil-large-v2"

assistant_model = AutoModelForCausalLM.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)

assistant_model.to(device)

1{}^1 我們計劃釋出一個改進版的 Distil-Whisper,它在詞元分佈上具有更強的一致性,這將進一步提高推測解碼的效能。請關注 Distil-Whisper 程式碼庫以獲取更新。


我們可以為我們的推測解碼基準測試定義一個修改過的函式。與之前的函式唯一的區別是,我們在呼叫 .generate 時傳遞了輔助模型。

def assisted_generate_with_time(model, inputs, **kwargs):
    start_time = time.time()
    outputs = model.generate(**inputs, assistant_model=assistant_model, **kwargs)
    generation_time = time.time() - start_time
    return outputs, generation_time

讓我們使用 Distil-Whisper 作為 Whisper 的輔助模型來執行推測解碼的基準測試。

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)
    
    output, gen_time = assisted_generate_with_time(model, inputs)
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["text"]))

print(all_time)

輸出

100%|██████████| 73/73 [00:38<00:00,  1.88it/s]
32.69683289527893

使用推測解碼後,推理時間僅為 33 秒,比之前快了 2.2 倍!讓我們驗證一下詞錯誤率 (WER) 是否相同。

print(wer.compute(predictions=predictions, references=references))

輸出

0.03507271171941831

完美!詞錯誤率 (WER) 仍然是 3.5%,因為我們得到了與單獨使用主模型時完全相同的輸出。

推測解碼也可以與簡單易用的 🤗 Transformers pipeline API 一起使用進行推理。下面,我們使用模型和處理器例項化 pipeline,然後用它來轉錄玩具資料集中的第一個樣本。這種方法可以擴充套件到轉錄任意長度的音訊樣本,包括使用批處理。

from transformers import pipeline

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    batch_size=4,
    generate_kwargs={"assistant_model": assistant_model},
    torch_dtype=torch_dtype,
    device=device,
)

sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])

輸出

 Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.

Distil-Whisper 模型卡上可以找到一個端到端的程式碼片段,用於執行 Whisper 和 Distil-Whisper 的推測解碼。它將本筆記本中涵蓋的推理階段整合到一個程式碼示例中。

多語言語音轉錄

Distil-Whisper 是英語語音轉錄的理想輔助模型,因為它在短音訊和長音訊樣本上的效能與原始 Whisper 模型的詞錯誤率 (WER) 相差不到 1%,而速度卻快了 6 倍。然而,官方的 Distil-Whisper 模型檢查點僅支援英語,這意味著它們不能用於多語言語音轉錄。

要將推測解碼用於多語言語音轉錄,可以使用官方的多語言 Whisper 模型檢查點之一,或者一個經過微調的 Whisper 變體。在撰寫本文時,Hugging Face Hub 上有超過 5000 個經過微調的 Whisper 模型檢查點,涵蓋 100 多種語言。這些為選擇在單一語言上表現出色的輔助 Whisper 模型檢查點提供了絕佳的起點。在本例中,我們將使用最小的官方多語言模型檢查點,即 Whisper tiny。歡迎你嘗試在你所用語言上微調過的不同模型檢查點!

讓我們載入新輔助模型 Whisper tiny 的權重。由於 Whisper tiny 中的編碼器與 large-v2 中的不同,這次我們將使用 AutoModelForSpeechSeq2Seq 類同時載入編碼器和解碼器。

assistant_model_id = "openai/whisper-tiny"

assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    assistant_model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
)

assistant_model.to(device);

對於我們的基準測試資料集,我們將從 VoxPopuli 資料集的荷蘭語 ("nl") 部分載入 73 個樣本。

dataset = load_dataset("sanchit-gandhi/voxpopuli_dummy", "nl", split="validation")

太棒了!我們現在可以像之前一樣,為我們的基準 Whisper large-v2 模型重新執行基準測試。唯一的變化是,我們向 generate 函式傳遞了語言和任務引數,以確保我們執行的是語音轉錄(而不是語音翻譯)。推測解碼與語音轉錄和語音翻譯任務完全相容。只需根據需要設定任務引數即可,如下所示。

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)
    
    output, gen_time = generate_with_time(model, inputs, language="nl", task="transcribe")
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

wer_result = wer.compute(predictions=predictions, references=references)

print("Time:", all_time)
print("WER:", wer_result)

輸出

100%|██████████| 73/73 [02:05<00:00,  1.72s/it]
Time: 116.50992178916931
WER: 0.127190136275146

好的!我們的基準時間是 117 秒,詞錯誤率 (WER) 為 12.8%。現在讓我們使用推測解碼重新執行生成過程。

all_time = 0
predictions = []
references = []

for sample in tqdm(dataset):
    audio = sample["audio"]
    inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
    inputs = inputs.to(device=device, dtype=torch.float16)

    output, gen_time = assisted_generate_with_time(model, inputs, language="nl", task="transcribe")
    all_time += gen_time
    predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
    references.append(processor.tokenizer._normalize(sample["normalized_text"]))

wer_result = wer.compute(predictions=predictions, references=references)

print("Time:", all_time)
print("WER:", wer_result)

輸出

100%|██████████| 73/73 [01:08<00:00,  1.06it/s]
Time: 62.10229682922363
WER: 0.127190136275146

我們再次達到了 12.8% 的詞錯誤率 (WER),但這次推理時間僅為 62 秒,速度提升了 1.9 倍。考慮到載入輔助模型的開銷很低,並且在數學上能保證獲得完全相同的輸出,推測解碼為現有的 Whisper 工作流提供了一個完美的替代方案。

高效推測解碼策略

在最後一部分,我們將介紹兩種策略,以確保使用推測解碼時能獲得最快的推理速度。

輔助模型

我們的目標是選擇一個比主模型快至少 3 倍**並且**能正確轉錄至少 70-80% 預測詞元的輔助模型,這些詞元通常是示例中“較簡單”的詞元。如果你有特定語言的轉錄需求,一個有效的策略是訓練兩個不同大小的 Whisper 模型,並用一個作為另一個的輔助模型。

  • 首先,微調 Whisper large-v3 作為你的主模型。
  • 其次,在相同的資料集上蒸餾 Whisper large-v3,以作為快速的輔助模型。

微調和蒸餾可以提高主模型和輔助模型在你所選語言上的詞錯誤率 (WER) 效能,同時最大化詞元分佈的一致性。關於 Whisper 微調的完整指南可以在這裡找到,蒸餾的指南在這裡

批處理大小

值得注意的是,推測解碼在批處理大小為 1 時能獲得最大的速度提升。對於批處理的推測解碼,**批次中所有**候選詞元都必須與驗證詞元匹配,這些詞元才會被接受。如果批次中某個位置的詞元不一致,那麼該位置之後的所有候選詞元都將被丟棄。因此,推測解碼更適合較小的批處理大小。在實踐中,我們發現推測解碼在批處理大小達到 4 之前都能提供加速效果。當批處理大小超過 4 時,推測解碼的推理速度會比單獨使用主模型還要慢。完整結果請參考 Distil-Whisper 論文的 D.3 節。

結論

在這篇博文中,我們介紹了應用於 Whisper 模型進行語音轉錄的推測解碼推理策略。我們展示瞭如何實現 2 倍的速度提升,同時在數學上確保輸出與單獨使用原始模型完全相同。我們鼓勵你嘗試使用推測解碼作為現有 Whisper 工作流的直接替代方案,因為它使用額外輔助模型的開銷很低,並且能保證獲得相同的轉錄結果。

致謝

博文作者 Sanchit Gandhi。非常感謝 Patrick von PlatenPedro Cuenca 提出的建設性意見,以及 Joao Gante 在 🤗 Transformers 中實現的輔助生成功能。

社群

註冊登入以發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.