使用 🤗 Transformers 微調 W2V2-Bert 以支援低資源 ASR

釋出於 2024 年 1 月 19 日
在 GitHub 上更新
Open In Colab

最新訊息 (2024年1月)本博文深受《在多語言 ASR 上微調 XLS-R》和《為多語言 ASR 微調 MMS 介面卡模型》的啟發

引言

上個月,MetaAI 釋出了 Wav2Vec2-BERT,作為其 Seamless Communication(一個AI翻譯模型家族)的構建模組。

Wav2Vec2-BERT 是一系列改進的成果,其基礎是原始模型:Wav2Vec2,這是一個用於自動語音識別 (ASR) 的預訓練模型,由 Alexei Baevski、Michael Auli 和 Alex Conneau2020年9月釋出。僅需 10 分鐘的標註音訊資料,Wav2Vec2 就可以透過微調在 LibriSpeech 資料集上達到 5% 的詞錯誤率,首次展示了 ASR 的低資源遷移學習能力。

經過一系列多語言改進(XLSRXLS-RMMS),Wav2Vec2-BERT 是一個擁有 5.8 億引數的多功能音訊模型,它在覆蓋超過 143 種語言450 萬小時無標籤音訊資料上進行了預訓練。相比之下,XLS-R 使用了近 50 萬小時的 128 種語言的音訊資料,而 MMS 檢查點則在超過 1400 種語言50 多萬小時音訊上進行了預訓練。將資料量提升至數百萬小時,使得 Wav2Vec2-BERT 能夠在與語音相關的任務中,無論何種語言,都能取得更具競爭力的結果。

為了將其用於 ASR,Wav2Vec2-BERT 可以使用連線主義時間分類 (Connectionist Temporal Classification, CTC) 進行微調。CTC 是一種用於訓練序列到序列問題(如 ASR 和手寫識別)神經網路的演算法。我們強烈推薦閱讀 Awni Hannun 撰寫的優秀博文 Sequence Modeling with CTC (2017),以深入瞭解 CTC 演算法。

本 notebook 的目的是為您提供訓練 Wav2Vec2-BERT 模型——更具體地說是預訓練檢查點 facebook/w2v-bert-2.0——在 ASR 任務上所需的所有要素,全部使用開源工具和模型。它首先介紹了完整的預處理流程,然後對 W2V2-BERT 進行了少量微調。最後一部分彙集了 Hugging Face 專家關於擴充套件 CTC 訓練的技巧。

出於演示目的,我們在 Common Voice 16.0 的低資源蒙古語 ASR 資料集上對模型進行微調,該資料集包含約 14 小時的已驗證訓練資料。

動機

Whisper 是一套 ASR 模型,被公認為 ASR 任務中表現最佳的模型。它在英語 ASR 方面提供了最先進的效能,同時也非常適合利用有限資源進行多語言微調。

然而,當涉及到像蒙古語這樣的“資源貧乏”語言時,Whisper 的表現不佳,正如 Whisper 論文的 D.2.2 節所示——蒙古語或馬拉雅拉姆語在每個 Whisper 檢查點上的 WER 都超過了 100%。可用的檢查點詞彙量也有限,因此無法在字母表與該詞彙表不重疊的語言上進行微調。

此外,Whisper 是一個序列到序列模型,它以自迴歸方式執行 ASR,這使其天生就“慢”。對於在訓練資料集中不常見的語言,Whisper 的緩慢問題會更加嚴重。在這種情況下,Whisper 平均每個單詞需要生成更多的 token,因此耗時更長。

面對有限的資源——無論是訓練資料可用性還是推理限制——需要更“節儉”的模型。在這種情況下,Wav2Vec2-BERT 恰好滿足了這一需求。

Wav2Vec2-BERT 透過單次前向傳播預測 ASR,使其比 Whisper 快得多。正如本 notebook 將展示的,它需要少量資料即可達到有競爭力的效能易於適應任何字母表,並且資源效率更高

事實上,在經過類似的微調後,它在蒙古語 ASR 上的 WER 效能與 Whisper-large-v3 相當,同時速度快 10 到 30 倍以上,資源效率高 2.5 倍

注意:基準測試是在 Google Colab 上的 16GB V100 GPU 上進行的,在蒙古語 CV16 測試集上使用的批大小從 1 到 8 不等。

Notebook 設定

開始之前,我們先安裝 datasetstransformers。此外,我們還需要 accelerate 用於訓練,torchaudio 用於載入音訊檔案,以及 jiwer 用於使用詞錯誤率 (WER) 指標來評估我們微調後的模型。

%%capture
!pip install datasets
!pip install --upgrade transformers
!pip install torchaudio
!pip install jiwer
!pip install accelerate -U

我們強烈建議在訓練過程中將您的訓練檢查點直接上傳到 🤗 Hub🤗 Hub 提供:

  • 整合的版本控制:你可以確保在訓練過程中不會丟失任何模型檢查點。
  • Tensorboard 日誌:在訓練過程中跟蹤重要指標。
  • 模型卡片:記錄模型的功能及其預期用途。
  • 社群:一種與社群分享和協作的簡便方式!

為此,您需要從 Hugging Face 網站儲存您的身份驗證令牌(如果還沒有,請在此處註冊!)。在下方提示時輸入您的 Hub 身份驗證令牌即可。在此處查詢您的 Hub 身份驗證令牌

from huggingface_hub import notebook_login

notebook_login()

準備資料、分詞器、特徵提取器

ASR 模型將語音轉寫為文字,這意味著我們既需要一個將語音訊號處理成模型輸入格式(例如特徵向量)的特徵提取器,也需要一個將模型輸出格式處理成文字的分詞器。

在 🤗 Transformers 中,Wav2Vec2-BERT 模型因此配備了一個分詞器,名為 Wav2Vec2CTCTokenizer,和一個特徵提取器,名為 SeamlessM4TFeatureExtractor。該特徵提取器與 第一版第二版的 Seamless-M4T 共享,因為它們都以相同的方式處理音訊。

讓我們從建立分詞器開始,用它來將預測的輸出類別解碼為輸出轉寫文字。

建立 Wav2Vec2CTCTokenizer

請記住,在 CTC 上微調的 Wav2Vec2-like 模型透過單次前向傳播來轉寫音訊檔案,首先將音訊輸入處理成一系列處理過的上下文表示,然後使用最終的詞彙表輸出層將每個上下文表示分類為一個代表轉寫文字的字元。

該層的輸出大小對應於詞彙表中的 token 數量,因此只取決於用於微調的帶標籤資料集。所以第一步,我們將檢視選定的 Common Voice 資料集,並根據轉寫文字定義一個詞彙表。

對於本 notebook,我們將使用 Common Voice 16.0 資料集的蒙古語部分。蒙古語對應的語言程式碼是 "mn"

現在我們可以使用 🤗 Datasets 的簡單 API 來下載資料。資料集名稱是 "mozilla-foundation/common_voice_16_0",配置名稱對應於語言程式碼,在我們的例子中是 "mn"

注意:在能夠下載資料集之前,您必須先登入您的 Hugging Face 賬戶,訪問資料集倉庫頁面,然後點選“同意並訪問倉庫”來獲取許可權。

Common Voice 有許多不同的資料劃分,包括 invalidated,指的是那些被評為不夠“乾淨”而無法使用的資料。在本 notebook 中,我們只使用 "train""validation""test" 這幾個劃分。

因為蒙古語資料集非常小,我們將把驗證集和訓練集合併為一個訓練集,並只使用測試集進行驗證。

from datasets import load_dataset, load_metric, Audio

common_voice_train = load_dataset("mozilla-foundation/common_voice_16_0", "mn", split="train+validation", use_auth_token=True)
common_voice_test = load_dataset("mozilla-foundation/common_voice_16_0", "mn", split="test", use_auth_token=True)

許多 ASR 資料集僅為每個音訊陣列 'audio' 和檔案 'path' 提供目標文字 'sentence'。Common Voice 實際上提供了關於每個音訊檔案的更多資訊,例如 'accent' 等。為了使 notebook 儘可能通用,我們只考慮轉寫文字進行微調。

common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

讓我們寫一個簡短的函式來顯示資料集的一些隨機樣本,並執行幾次以感受一下轉寫文字的特點。

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

show_random_elements(common_voice_train.remove_columns(["path", "audio"]), num_examples=10)

好的!轉寫文字看起來相當乾淨。翻譯了這些轉寫句子後,似乎語言更像是書面文字,而不是嘈雜的對話。考慮到 Common Voice 是一個眾包的朗讀語音語料庫,這很合理。

我們可以看到轉寫文字包含一些特殊字元,比如 ,.?!;:。在沒有語言模型的情況下,將語音塊分類為這些特殊字元要困難得多,因為它們並不真正對應一個特徵性的聲音單元。例如,字母 "s" 有一個或多或少清晰的發音,而特殊字元 "." 則沒有。此外,為了理解語音訊號的含義,通常沒有必要在轉寫中包含特殊字元。

讓我們簡單地移除所有對詞義沒有貢獻且無法真正用聲音表示的字元,並對文字進行規範化。

import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\'\»\«]'

def remove_special_characters(batch):
    # remove special characters
    batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()

    return batch

common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)

讓我們再看一下處理後的文字標籤。

show_random_elements(common_voice_train.remove_columns(["path","audio"]))
Хойч үе юуны төлөө тэмцэлдэхийг би мэдэхгүй.	
Тэр өвдгөн дээрээ толгойгоо тавиад сулхан гиншинэ.	
Эхнэргүй ганц бие хүн гэсэн санагдана.	
Дамиран хотод төрж өссөн хээнцэр залуусын нэг билээ.	
Мөн судлаачид шинжлэх ухааны үндэстэй тайлбар хайдаг.	
Судалгааны ажил нь бүтэлгүй болсонд л гутарч маргааш илүү ажиллах тухай бодсон бололтой.	
Ийм зөрчлөөс гэтлэх гарц "Оноосон нэрийн сан"-г үүсгэснээр шийдвэрлэгдэнэ.	
Үүлтэй тэнгэрийн доогуур үзүүртэй моддын дээгүүр дүүлэн нисэх сэн.	
Та нар ямар юмаа ингэж булаацалдаа вэ?	
Тэд амьд хэлтрээ болов уу яагаа бол гэхээс одоо ч дотор арзганан бачуурдаг юм.	

在 CTC 中,通常將語音塊分類為字母,所以我們在這裡也這樣做。讓我們提取訓練和測試資料中所有不同的字母,並從這個字母集合構建我們的詞彙表。

我們編寫一個對映函式,將所有轉寫文字連線成一個長轉寫文字,然後將字串轉換為一個字元集合。將引數 batched=True 傳遞給 map(...) 函式非常重要,這樣對映函式就可以一次性訪問所有轉寫文字。

def extract_all_chars(batch):
  all_text = " ".join(batch["sentence"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)

現在,我們建立訓練集和測試集中所有不同字母的並集,並將結果列表轉換為一個帶索引的字典。

vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
{' ': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'l': 9,
 'n': 10,
 'o': 11,
 'r': 12,
 't': 13,
 'x': 14,
 'а': 15,
 'б': 16,
 'в': 17,
 'г': 18,
 'д': 19,
 'е': 20,
 'ж': 21,
 'з': 22,
 'и': 23,
 'й': 24,
 'к': 25,
 'л': 26,
 'м': 27,
 'н': 28,
 'о': 29,
 'п': 30,
 'р': 31,
 'с': 32,
 'т': 33,
 'у': 34,
 'ф': 35,
 'х': 36,
 'ц': 37,
 'ч': 38,
 'ш': 39,
 'ъ': 40,
 'ы': 41,
 'ь': 42,
 'э': 43,
 'ю': 44,
 'я': 45,
 'ё': 46,
 'ү': 47,
 'ө': 48}

清理資料集是一個需要謹慎進行的反覆過程。

檢視訓練集和測試集中的單個字母,我們發現既有拉丁字母,也有蒙古語西裡爾字母。在與一位母語為目標語言的人士(感謝 Mishig 的審閱)討論後,我們將移除拉丁字母,原因有二:

  1. CTC 演算法受益於較小的詞彙量,因此建議移除多餘的字元。
  2. 在這個例子中,我們完全專注於蒙古語字母表。
def remove_latin_characters(batch):
    batch["sentence"] = re.sub(r'[a-z]+', '', batch["sentence"])
    return batch

# remove latin characters
common_voice_train = common_voice_train.map(remove_latin_characters)
common_voice_test = common_voice_test.map(remove_latin_characters)

# extract unique characters again
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
{' ': 0,
 'а': 1,
 'б': 2,
 'в': 3,
 'г': 4,
 'д': 5,
 'е': 6,
 'ж': 7,
 'з': 8,
 'и': 9,
 'й': 10,
 'к': 11,
 'л': 12,
 'м': 13,
 'н': 14,
 'о': 15,
 'п': 16,
 'р': 17,
 'с': 18,
 'т': 19,
 'у': 20,
 'ф': 21,
 'х': 22,
 'ц': 23,
 'ч': 24,
 'ш': 25,
 'ъ': 26,
 'ы': 27,
 'ь': 28,
 'э': 29,
 'ю': 30,
 'я': 31,
 'ё': 32,
 'ү': 33,
 'ө': 34}

太好了,我們看到蒙古語字母表中的所有字母都出現在資料集中(這並不意外),我們還提取了特殊字元 " "。注意,我們沒有排除這個特殊字元,因為:模型必須學會預測單詞何時結束,否則模型預測將永遠是一串字元,無法將單詞彼此分開。

應該始終記住,預處理是訓練模型前非常重要的一步。例如,我們不希望模型僅僅因為我們忘記了資料規範化而去區分 aAaA 之間的區別根本不取決於字母的“發音”,而更多地取決於語法規則——例如,在句子開頭使用大寫字母。因此,消除大小寫字母之間的差異是明智的,這樣模型就能更容易地學習轉寫語音。您可以在音訊 Transformers 課程中閱讀更多關於預處理對 ASR 任務影響的內容。

為了更清楚地表明 " " 有自己的 token 類別,我們給它一個更顯眼的字元 |。此外,我們還添加了一個“未知”token,以便模型以後可以處理 Common Voice 訓練集中未遇到的字元。

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

最後,我們還添加了一個與 CTC 的“空白 token”相對應的填充 token。“空白 token”是 CTC 演算法的核心組成部分。更多資訊,請參閱這篇博文的“對齊”部分。

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
37

太好了,現在我們的詞彙表已經完成,包含 37 個 token,這意味著我們將在預訓練的 Wav2Vec2-BERT 檢查點之上新增的線性層將具有 37 的輸出維度。

現在讓我們將詞彙表儲存為 json 檔案。

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

最後一步,我們使用這個 json 檔案將詞彙表載入到 Wav2Vec2CTCTokenizer 類的一個例項中。

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

如果想要將剛剛建立的分詞器與本 notebook 中微調的模型一起重用,強烈建議將 tokenizer 上傳到 🤗 Hub。讓我們將要上傳檔案的倉庫命名為 "w2v-bert-2.0-mongolian-colab-CV16.0"

repo_name = "w2v-bert-2.0-mongolian-colab-CV16.0"

然後將分詞器上傳到 🤗 Hub

tokenizer.push_to_hub(repo_name)

太好了,您可以在 https://huggingface.co/<your-username>/w2v-bert-2.0-mongolian-colab-CV16.0 下看到剛剛建立的倉庫。

建立 SeamlessM4TFeatureExtractor

SeamlessM4TFeatureExtractor 的作用是將原始音訊輸入準備成模型能夠“理解”的格式。因此,它將一維振幅值序列(即原始音訊輸入)對映到一個二維的對數梅爾頻譜圖矩陣。後者將訊號的頻率資訊編碼為時間的函式。請參閱音訊 Transformers 課程的這一節以瞭解更多關於頻譜圖及其重要性的資訊。

與分詞器不同,特徵提取器不需要從資料中“學習”,因此我們可以直接從初始模型檢查點載入它。

from transformers import SeamlessM4TFeatureExtractor

feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")

太好了,Wav2Vec2-BERT 的特徵提取流程就此完全定義好了!

為了提高使用者友好性,特徵提取器和分詞器被封裝在一個名為 Wav2Vec2BertProcessor 的類中,這樣一來,我們只需要一個 model 和一個 processor 物件即可。

from transformers import Wav2Vec2BertProcessor

processor = Wav2Vec2BertProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.push_to_hub(repo_name)

接下來,我們可以準備資料集了。

預處理資料

到目前為止,我們還沒有檢視語音訊號的實際數值,只看了轉寫文字。除了 sentence,我們的資料集還包括另外兩個列名 pathaudiopath 指明瞭音訊檔案的絕對路徑。我們來看看。

common_voice_train[0]["path"]
/root/.cache/huggingface/datasets/downloads/extracted/276aa682ce2b6a24934bc401b1f30e004c3fb178dd41d6295b273329f592844a/mn_train_0/common_voice_mn_18578097.mp3

Wav2Vec2-BERT 期望輸入格式為 16 kHz 的一維陣列。這意味著音訊檔案必須被載入和重取樣。

幸運的是,datasets 透過呼叫另一列 audio 自動完成這個過程。讓我們來試試。

common_voice_train[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/276aa682ce2b6a24934bc401b1f30e004c3fb178dd41d6295b273329f592844a/mn_train_0/common_voice_mn_18578097.mp3',
 'array': array([ 0.00000000e+00, -1.64773251e-14,  1.81765166e-13, ...,
        -3.23167333e-05,  2.20304846e-05,  3.26883201e-05]),
 'sampling_rate': 48000}

太好了,我們可以看到音訊檔案已自動載入。這要歸功於 datasets == 4.13.3 中引入的新的"Audio"特性,它在呼叫時會即時載入和重取樣音訊檔案。

在上面的例子中,我們可以看到音訊資料以 48kHz 的取樣率載入,而 Wav2Vec2-BERT 是在 16kHz 的取樣率下進行預訓練的。取樣率起著重要作用,因為它定義了每秒測量多少語音訊號的資料點。因此,以更高的取樣率進行取樣可以更好地逼近真實的語音訊號,但每秒也需要更多的數值。

一個預訓練的檢查點期望其輸入資料與它訓練時所用的資料大致來自相同的分佈。以兩種不同速率取樣的相同語音訊號具有非常不同的分佈,例如,將取樣率加倍會導致資料點長度增加一倍。因此,在微調一個 ASR 模型的預訓練檢查點之前,驗證用於預訓練模型的資料取樣率與用於微調模型的資料集取樣率是否匹配至關重要。

幸運的是,我們可以透過使用 cast_column 將音訊特徵設定為正確的取樣率。

common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

我們再來看看 "audio"

common_voice_train[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/276aa682ce2b6a24934bc401b1f30e004c3fb178dd41d6295b273329f592844a/mn_train_0/common_voice_mn_18578097.mp3',
 'array': array([ 9.09494702e-12, -2.27373675e-13,  5.45696821e-12, ...,
        -5.22854862e-06, -1.21556368e-05, -9.76262163e-06]),
 'sampling_rate': 16000}

這似乎起作用了!讓我們聽幾個音訊檔案,以更好地瞭解資料集並驗證音訊是否已正確載入。

import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(common_voice_train)-1)

print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)

看起來資料現在已經正確載入和重取樣了。

可以聽出,說話者在變化,他們的語速、口音和背景環境等也隨之改變。不過,總的來說,錄音聽起來足夠清晰,這對於一個眾包的朗讀語音語料庫來說是意料之中的。

讓我們做最後一次檢查,確認資料準備是否正確,透過列印語音輸入的形狀、其轉寫文字以及相應的取樣率。

rand_int = random.randint(0, len(common_voice_train)-1)

print("Target text:", common_voice_train[rand_int]["sentence"])
print("Input array shape:", common_voice_train[rand_int]["audio"]["array"].shape)
print("Sampling rate:", common_voice_train[rand_int]["audio"]["sampling_rate"])
Target text: энэ бол тэдний амжилтын бодит нууц
Input array shape: (74496,)
Sampling rate: 16000

好的!一切看起來都沒問題——資料是一維陣列,取樣率總是 16kHz,目標文字也已規範化。

最後,我們可以利用 Wav2Vec2BertProcessor 將資料處理成 Wav2Vec2BertForCTC 訓練時期望的格式。為此,讓我們使用 Dataset 的 map(...) 函式。

首先,我們載入並重取樣音訊資料,只需呼叫 batch["audio"]。其次,我們從載入的音訊檔案中提取 input_features。在我們的例子中,Wav2Vec2BertProcessor 建立了一個比原始波形更復雜的表示,即對數梅爾特徵提取。第三,我們將轉寫文字編碼為標籤 ID。

def prepare_dataset(batch):
    audio = batch["audio"]
    batch["input_features"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    batch["input_length"] = len(batch["input_features"])

    batch["labels"] = processor(text=batch["sentence"]).input_ids
    return batch

讓我們將資料準備函式應用到所有樣本上。

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)

注意**:`datasets` 自動處理音訊載入和重取樣。如果您希望實現自己定製的資料載入/取樣,可以隨時使用 `“path”` 列,而忽略 `“audio”` 列。

太棒了,現在我們準備好開始訓練了!

訓練

資料已經處理完畢,我們可以開始設定訓練流程了。我們將使用 🤗 Transformer 的 Trainer 類,為此我們主要需要做以下幾件事:

  • 定義一個數據整理器。與大多數 NLP 模型不同,Wav2Vec2-BERT 的輸入長度遠大於輸出長度。鑑於輸入尺寸較大,動態填充訓練批次效率更高,這意味著所有訓練樣本只應填充到其批次中最長樣本的長度,而不是整個資料集中最長樣本的長度。因此,微調 Wav2Vec2-BERT 需要一個特殊的填充資料整理器,我們將在下面定義它。

  • 評估指標。在訓練期間,模型應以詞錯誤率進行評估。我們應該相應地定義一個 compute_metrics 函式。

  • 載入預訓練的檢查點。我們需要載入一個預訓練的檢查點,併為其進行正確的訓練配置。

  • 定義訓練配置。

在微調模型後,我們將在測試資料上對其進行正確評估,並驗證它確實學會了正確轉寫語音。

設定 Trainer

讓我們從定義資料整理器開始。資料整理器的程式碼是從這個例子中複製的。

不深入太多細節,與常見的資料整理器相比,這個資料整理器對 input_featureslabels 的處理方式不同,因此對它們應用了不同的填充函式。這是必要的,因為在語音中,輸入和輸出是不同模態的,這意味著它們不應該用相同的填充函式來處理。與常見的資料整理器類似,標籤中的填充 token 用 -100 替換,這樣這些 token 在計算損失時就會被考慮在內。

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:

    processor: Wav2Vec2BertProcessor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        labels_batch = self.processor.pad(
            labels=label_features,
            padding=self.padding,
            return_tensors="pt",
        )
        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

接下來,定義評估指標。如前所述,ASR 中最主要的指標是詞錯誤率 (WER),因此我們在這個 notebook 中也使用它。

wer_metric = load_metric("wer")

模型將返回一個 logit 向量序列: y1,,ym \mathbf{y}_1, \ldots, \mathbf{y}_m ,其中 y1=fθ(x1,,xn)[0] \mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0] n>>m n >> m

一個 logit 向量 y1 \mathbf{y}_1 包含了我們之前定義的詞彙表中每個單詞的對數機率,因此 len(yi)= \text{len}(\mathbf{y}_i) = `config.vocab_size`。我們對模型最可能的預測感興趣,因此取 logits 的 `argmax(...)`。此外,我們透過將 `-100` 替換為 `pad_token_id` 並解碼這些 ID,將編碼後的標籤轉換回原始字串,同時確保連續的 token 在 CTC 風格下**不**被分組為同一個 token1 {}^1

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

現在,我們可以載入主要的預訓練檢查點了。必須將分詞器的 `pad_token_id` 定義為模型的 `pad_token_id`,或者在 `Wav2Vec2BertForCTC` 的情況下,也定義為 CTC 的*空白 token*2 {}^2 。為了節省 GPU 記憶體,我們啟用了 PyTorch 的梯度檢查點,並將損失縮減設定為“mean”。

由於我們只訓練一小部分權重,模型不容易過擬合。因此,我們確保停用所有 dropout 層。

注意:當使用此 notebook 在 Common Voice 的另一種語言上訓練 Wav2Vec2-BERT 時,這些超引數設定可能效果不佳。請根據您的用例隨意調整。

from transformers import Wav2Vec2BertForCTC

model = Wav2Vec2BertForCTC.from_pretrained(
    "facebook/w2v-bert-2.0",
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.0,
    layerdrop=0.0,
    ctc_loss_reduction="mean",
    add_adapter=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

最後一步,我們定義所有與訓練相關的引數。對其中一些引數進行更多解釋:

  • group_by_length 透過將相似輸入長度的訓練樣本分組到一個批次中,使訓練更加高效。這可以透過大量減少透過模型的無用填充 token 的總數來顯著加快訓練時間。
  • learning_rate 是透過啟發式調整得到的,直到微調變得穩定。請注意,這些引數很大程度上取決於 Common Voice 資料集,對於其他語音資料集可能不是最優的。

關於其他引數的更多解釋,可以檢視文件

在訓練期間,每 600 個訓練步驟就會有一個檢查點非同步上傳到 Hub。這讓您即使在模型仍在訓練時,也能試用演示小部件。

注意:如果不想將模型檢查點上傳到 Hub,只需設定 push_to_hub=False

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=10,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=600,
  eval_steps=300,
  logging_steps=300,
  learning_rate=5e-5,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=True,
)

現在,所有例項都可以傳遞給 Trainer,我們準備開始訓練了!

from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

1 {}^1 為了讓模型能夠獨立於說話者的語速,在 CTC 中,連續相同的 token 會被簡單地歸為一個 token。然而,在解碼時,編碼後的標籤不應該被分組,因為它們不對應於模型的預測 token,這就是為什麼必須傳遞 `group_tokens=False` 引數的原因。如果我們不傳遞這個引數,像 `“hello”` 這樣的詞就會被錯誤地編碼和解碼為 `“helo”`。 2 {}^2 空白 token 允許模型預測像 `“hello”` 這樣的詞,透過強制它在兩個 l 之間插入空白 token。我們的模型對 `“hello”` 的一個符合 CTC 規範的預測會是 `[PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD]`。

訓練

訓練將花費數小時,具體取決於分配給此 notebook 的 GPU。雖然訓練後的模型在 Common Voice 的蒙古語測試資料上取得了尚可的結果,但它絕不是一個最優微調的模型。本 notebook 的目的僅僅是演示如何在一個 ASR 資料集上微調 Wav2Vec2-BERT。

trainer.train()
步驟 訓練損失 驗證損失 詞錯誤率 (Wer)
300 1.712700 0.647740 0.517892
600 0.349300 0.615849 0.442027
900 0.180500 0.525088 0.367305
1200 0.075400 0.528768 0.324016

訓練損失和驗證 WER 都很好地下降了。相比之下,使用 whisper-large-v3(公認的 OpenAI 最先進的 ASR 模型)進行相同的訓練,最終的 WER 為 33.3%。您可以在這裡找到最終的 Whisper 檢查點。這表明 Wav2Vec2-Bert 在低資源語言上可以達到接近或等同於最先進水平的效能

你現在可以把訓練結果上傳到 🤗 Hub,只需執行這條指令即可。

trainer.push_to_hub()

你現在可以和所有的朋友、家人、心愛的寵物分享這個模型:他們都可以用“your-username/the-name-you-picked”這個識別符號來載入它,例如:

from transformers import AutoModelForCTC, Wav2Vec2BertProcessor

model = AutoModelForCTC.from_pretrained("ylacombe/w2v-bert-2.0-mongolian-colab-CV16.0")
processor = Wav2Vec2BertProcessor.from_pretrained("ylacombe/w2v-bert-2.0-mongolian-colab-CV16.0")

關於如何微調 Wav2Vec2-BERT 的更多示例,請參閱官方語音識別示例

評估

作為最後一次檢查,我們載入模型並驗證它確實學會了轉寫蒙古語語音。

我們先載入預訓練的檢查點。

model = Wav2Vec2BertForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2BertProcessor.from_pretrained(repo_name)

讓我們處理音訊,進行一次前向傳播並預測 ID。

sample = common_voice_test[0]
input_features = torch.tensor(sample["input_features"]).to("cuda").unsqueeze(0)

with torch.no_grad():
    logits = model(input_features).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

最後,我們可以從預測的 token 解碼出示例,並與參考轉寫文字進行比較。

print(processor.decode(pred_ids))
print(processor.decode(sample["labels"]).lower())
эрчүүдийн ганцаардлыг эмэхтэйчүүд ойлгох нь ховор юм
эрчүдийн ганцардлыг эмэгтэйчүд ойлгох нь ховор юм

好的!從我們的預測中絕對可以辨認出轉寫內容,但還不夠完美。將模型訓練更長時間,在資料預處理上花費更多時間,尤其是使用語言模型進行解碼,肯定會提高模型的整體效能。

然而,對於一個低資源語言的演示模型來說,這個結果已經相當可以接受了 🤗。

擴充套件訓練規模

我們在這篇博文中展示了 Meta 的 w2v-bert-2.0 微調如何在低資源語言上取得接近最先進水平的效能。

為了更進一步,我整理了一套由我在 Hugging Face 的同事們提供的關於如何擴充套件此模型訓練的技巧和要點。這些技巧是在我向他們展示這篇博文的訓練執行以及其他訓練嘗試(這裡這裡)時浮出水面的。

非常感謝 PatrickSanchitPablo 提供的寶貴專業知識和幫助 🤗

請注意,Common Voice 最新版本 (CV16) 為許多語言提供了更多小時的資料,從而為在許多低資源語言中構建更高效的模型提供了肥沃的土壤。

資料集相關技巧

CTC ASR 通常使用小寫、無標點的轉寫文字。這簡化了 CTC 任務,因為模型被視為“純聲學”模型,意味著它的預測主要基於音訊的語音聲音,而不是口語句子的任何語言模型上下文。

頻率極低的字元會透過錯誤的目標導致損失激增,從而顯著影響學習過程中的損失。預設情況下,本博文建立的 CTC 分詞器會將它們新增到詞彙表中,即使它們的頻率與更常見的字元相比可以忽略不計。我們可以將這些字元視為資料集標註中的“錯誤”,以便將它們從詞彙表中移除,並在訓練期間簡單地分類為 `"[UNK]"`。

因此,絕對有必要重新檢查分詞器詞彙表,並移除所有低頻字元,就像我們建立分詞器時移除拉丁字元一樣。

請注意,Common Voice 資料集特別容易出現這類“錯誤”字元,例如來自其他語言的字元(阪)。

訓練相關技巧

每個 CTC token 看到的平均時長: 透過實驗,我們發現每個 CTC token 看到的理想時長比例是 10 到 35 毫秒。換句話說,為了能夠正確學習和預測,CTC token 需要看到的聲學資訊時長既不能太低也不能太高。實際上,它應該大致對應於我們人類發一個音素所需時間的一小部分。

我的一次訓練執行的損失曲線最初如預期般平穩下降,但在某個點開始爆炸。我意識到我一直使用的是一個沒有架構改動的基本檢查點,每個 CTC token 看到的訊號時長為 30 到 60 毫秒。新增一個卷積介面卡層來對編碼器隱藏狀態沿時間維度進行子取樣,足以將訊號塊取樣減少到期望的時長,並防止這種損失曲線的出現。

訓練不足:我的同事們在檢視我的訓練執行時很快注意到模型嚴重訓練不足,這一點可以從損失曲線上看出來,它看起來像是在陡峭下降的中間被停止了。這也指出了其他問題,特別是損失曲線不夠平滑,這是超引數設定不當的跡象。

這裡有幾種解決我們案例中訓練不足的方法:

  • 預熱率可能太高,導致學習率下降過快。一個解決方法是保持預熱率在 5% 到 15% 之間,並增加訓練輪數。預熱步驟對於逐漸將新的語言模型頭權重與預訓練模型對齊至關重要。
  • 損失曲線不平滑的問題可以透過調整 AdamWβ2 \beta_2 來解決,該引數通常可以預設設定為 0.95 到 0.98。

相關文章和附加連結列於此處:

社群

這簡直是一篇低水平、複製貼上、程式碼寫得差的帖子。顯然作者在寫這些程式碼時對任何事情都毫無理解。

·

你好 @nicccobb
我看到您覺得這篇帖子和程式碼有所欠缺。
您能具體指出哪些部分的程式碼看起來不正確嗎?我願意在需要的地方進行修改並提供更多深度。另外,我已經發現一個明顯的錯誤:我計算每個 CTC token 覆蓋的時間跨度(毫秒/token)是錯的。
我很樂意聽取您認為應該修正或擴充的任何其他觀點。建設性的反饋總是受歡迎的——我樂於糾正和學習。

編輯:好文章!謝謝!!

我有一個同樣的專案,但是是關於阿拉伯語轉寫的。我可以遵循同樣的步驟嗎?

註冊登入 以發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.