使用 🤗 Transformers 微調 XLS-R 進行多語言 ASR
新 (11/2021):此部落格文章已更新,以介紹 XLSR 的繼任者,名為 XLS-R。
Wav2Vec2 是用於自動語音識別 (ASR) 的預訓練模型,由 Alexei Baevski、Michael Auli 和 Alex Conneau 於 2020 年 9 月 釋出。在 Wav2Vec2 在最流行的英語 ASR 資料集 LibriSpeech 上展示出卓越效能後不久,Facebook AI 釋出了 Wav2Vec2 的多語言版本,稱為 XLSR。XLSR 代表跨語言語音表示,指的是模型學習在多種語言中有用的語音表示的能力。
XLSR 的繼任者,簡稱為 XLS-R(指的是“XLM-R 用於語音”),由 Arun Babu、Changhan Wang、Andros Tjandra 等人於 2021 年 11 月 釋出。XLS-R 使用了近 50 萬小時的 128 種語言的音訊資料進行自監督預訓練,其大小從 3 億引數到 20 億引數不等。您可以在 🤗 Hub 上找到預訓練檢查點。
與 BERT 的掩碼語言建模目標類似,XLS-R 透過在自監督預訓練期間(即下方左圖)將特徵向量隨機掩碼,然後將其傳遞給 Transformer 網路來學習上下文語音表示。
對於微調,在預訓練網路之上新增一個線性層,以便在音訊下游任務(如語音識別、語音翻譯和音訊分類)的帶標籤資料上訓練模型(即下方右圖)。
XLS-R 在語音識別、語音翻譯和說話人/語言識別方面均表現出比以前最先進結果顯著的改進,參見官方論文中的表 3-6、表 7-10 和表 11-12。
設定
在本部落格中,我們將詳細解釋如何微調 XLS-R——更具體地說是預訓練檢查點 Wav2Vec2-XLS-R-300M——用於 ASR。
為了演示目的,我們將模型在 Common Voice 的低資源 ASR 資料集上進行微調,該資料集僅包含約 4 小時已驗證的訓練資料。
XLS-R 使用連線主義時間分類 (CTC) 進行微調,這是一種用於訓練序列到序列問題(如 ASR 和手寫識別)神經網路的演算法。
我強烈推薦閱讀 Awni Hannun 撰寫的精彩部落格文章 Sequence Modeling with CTC (2017)。
在開始之前,我們先安裝 datasets
和 transformers
。此外,我們還需要 torchaudio
來載入音訊檔案,以及 jiwer
來使用 詞錯誤率 (WER) 指標 評估我們微調的模型。
!pip install datasets==1.18.3
!pip install transformers==4.11.3
!pip install huggingface_hub==0.1
!pip install torchaudio
!pip install librosa
!pip install jiwer
我們強烈建議您在訓練期間將訓練檢查點直接上傳到 Hugging Face Hub。Hugging Face Hub 集成了版本控制,因此您可以確保在訓練期間不會丟失任何模型檢查點。
為此,您必須儲存來自 Hugging Face 網站的身份驗證令牌(如果您尚未註冊,請在此處註冊!)。
from huggingface_hub import notebook_login
notebook_login()
列印輸出
Login successful
Your token has been saved to /root/.huggingface/token
然後您需要安裝 Git-LFS 才能上傳您的模型檢查點
apt install git-lfs
在 論文中,模型使用音素錯誤率 (PER) 進行評估,但目前 ASR 中最常見的指標是詞錯誤率 (WER)。為了使本筆記本儘可能通用,我們決定使用 WER 評估模型。
準備資料、分詞器、特徵提取器
ASR 模型將語音轉寫為文字,這意味著我們既需要一個將語音訊號處理成模型輸入格式(例如特徵向量)的特徵提取器,也需要一個將模型輸出格式處理成文字的分詞器。
在 🤗 Transformers 中,XLS-R 模型因此配備了分詞器 Wav2Vec2CTCTokenizer 和特徵提取器 Wav2Vec2FeatureExtractor。
讓我們從建立分詞器開始,用它來將預測的輸出類別解碼為輸出轉寫文字。
建立 Wav2Vec2CTCTokenizer
預訓練的 XLS-R 模型將語音訊號對映到上下文表示序列,如上圖所示。然而,對於語音識別,模型必須將此上下文表示序列對映到其對應的轉錄,這意味著必須在 Transformer 塊之上新增一個線性層(上圖黃色部分所示)。此線性層用於將每個上下文表示分類為一個標記類別,類似於在 BERT 的嵌入之上新增線性層以在預訓練後進行進一步分類的方式(參見以下部落格文章的“BERT”部分)。在預訓練之後,在 BERT 的嵌入之上新增一個線性層以進行進一步分類——參見此部落格文章的“BERT”部分。
此層的輸出大小對應於詞彙表中的標記數量,這**不**取決於 XLS-R 的預訓練任務,而僅取決於用於微調的帶標籤資料集。因此,第一步,我們將檢視選擇的 Common Voice 資料集,並根據轉錄定義一個詞彙表。
首先,我們前往 Common Voice 官方網站並選擇一種語言來微調 XLS-R。在本筆記本中,我們將使用土耳其語。
對於每個特定語言的資料集,您可以找到與您所選語言對應的語言程式碼。在 Common Voice 上,查詢“版本”欄位。語言程式碼對應於下劃線之前的字首。例如,土耳其語的語言程式碼是 "tr"
。
很好!現在我們可以使用 🤗 Datasets 簡單的 API 來下載資料。資料集名稱是 "common_voice"
,配置名稱對應於語言程式碼,在本例中是 "tr"
。
Common Voice 有許多不同的拆分,包括 invalidated
,它指的是未被評為“足夠清晰”而無法被認為有用的資料。在本筆記本中,我們只使用 "train"
、"validation"
和 "test"
拆分。
由於土耳其語資料集很小,我們將驗證資料和訓練資料合併到一個訓練資料集中,只使用測試資料進行驗證。
from datasets import load_dataset, load_metric, Audio
common_voice_train = load_dataset("common_voice", "tr", split="train+validation")
common_voice_test = load_dataset("common_voice", "tr", split="test")
許多 ASR 資料集僅提供每個音訊陣列 'audio'
和檔案 'path'
的目標文字 'sentence'
。Common Voice 實際上提供了每個音訊檔案的更多資訊,例如 'accent'
等。為了使筆記本儘可能通用,我們只考慮轉錄文字進行微調。
common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
讓我們寫一個簡短的函式來顯示資料集的一些隨機樣本,並執行幾次以感受一下轉寫文字的特點。
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML
def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
display(HTML(df.to_html()))
列印輸出
索引 | 句子 |
---|---|
1 | Jonuz是唯一接受短期任務的候選人。 |
2 | 我們從這場鬥爭中獲得希望。 |
3 | 展覽中展示了五項克羅埃西亞創新。 |
4 | 萬物皆有其名。 |
5 | 該機構已準備好私有化。 |
6 | 定居點的景色很美。 |
7 | 事件的肇事者未能找到。 |
8 | 然而,這些努力都白費了。 |
9 | 該專案價值2.77百萬歐元。 |
10 | 大型重建專案分為四個階段。 |
好的!轉錄看起來相當清晰。翻譯這些轉錄的句子後,似乎這種語言更像書面文字而不是嘈雜的對話。這很合理,考慮到 Common Voice 是一個眾包的朗讀語音語料庫。
我們可以看到轉錄包含一些特殊字元,例如 ,.?!;:
。在沒有語言模型的情況下,將語音片段分類為這些特殊字元要困難得多,因為它們與特徵聲音單元沒有真正的對應關係。例如,字母 "s"
有一個或多或少清晰的聲音,而特殊字元 "."
則沒有。此外,為了理解語音訊號的含義,通常不需要在轉錄中包含特殊字元。
讓我們簡單地移除所有對詞義沒有貢獻且無法真正用聲音表示的字元,並對文字進行規範化。
import re
chars_to_remove_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'
def remove_special_characters(batch):
batch["sentence"] = re.sub(chars_to_remove_regex, '', batch["sentence"]).lower()
return batch
common_voice_train = common_voice_train.map(remove_special_characters)
common_voice_test = common_voice_test.map(remove_special_characters)
讓我們再看一下處理後的文字標籤。
show_random_elements(common_voice_train.remove_columns(["path","audio"]))
列印輸出
索引 | 轉錄 |
---|---|
1 | 他們說其中一個是為白人而戰 |
2 | 馬克圖夫的刑期於六月結束 |
3 | 與原作不同,衣服沒有脫下 |
4 | 這些物品的總價值達到一億歐元。 |
5 | 桌上至少有兩個選項。 |
6 | 這絕非不合理的狂熱。 |
7 | 這種狀況在1990年代隨著國家分裂而改變。 |
8 | 期限是六個月。 |
9 | 但是,成本可能會高得多。 |
10 | 首府費拉坐落在一座小山上。 |
很好!這看起來好多了。我們已經從轉錄中刪除了大部分特殊字元,並將其規範化為全小寫。
在最終確定預處理之前,諮詢目標語言的母語人士總是有利的,以檢視文字是否可以進一步簡化。對於這篇部落格文章,Merve 很好心地快速瀏覽了一下,並指出土耳其語中像 â
這樣的“帶帽”字元已經不再使用了,可以用它們的“不帶帽”對應物(例如 a
)替換。
這意味著我們應該將像 "yargı sistemi hâlâ sağlıksız"
這樣的句子替換為 "yargı sistemi hala sağlıksız"
。
我們再編寫一個簡短的對映函式來進一步簡化文字標籤。
def replace_hatted_characters(batch):
batch["sentence"] = re.sub('[â]', 'a', batch["sentence"])
batch["sentence"] = re.sub('[î]', 'i', batch["sentence"])
batch["sentence"] = re.sub('[ô]', 'o', batch["sentence"])
batch["sentence"] = re.sub('[û]', 'u', batch["sentence"])
return batch
common_voice_train = common_voice_train.map(replace_hatted_characters)
common_voice_test = common_voice_test.map(replace_hatted_characters)
在 CTC 中,通常將語音塊分類為字母,所以我們在這裡也這樣做。讓我們提取訓練和測試資料中所有不同的字母,並從這個字母集合構建我們的詞彙表。
我們編寫了一個對映函式,將所有轉錄連線成一個長轉錄,然後將字串轉換為一組字元。重要的是將引數 batched=True
傳遞給 map(...)
函式,以便對映函式可以同時訪問所有轉錄。
def extract_all_chars(batch):
all_text = " ".join(batch["sentence"])
vocab = list(set(all_text))
return {"vocab": [vocab], "all_text": [all_text]}
vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)
vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)
現在,我們建立訓練集和測試集中所有不同字母的並集,並將結果列表轉換為一個帶索引的字典。
vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict
列印輸出
{
' ': 0,
'a': 1,
'b': 2,
'c': 3,
'd': 4,
'e': 5,
'f': 6,
'g': 7,
'h': 8,
'i': 9,
'j': 10,
'k': 11,
'l': 12,
'm': 13,
'n': 14,
'o': 15,
'p': 16,
'q': 17,
'r': 18,
's': 19,
't': 20,
'u': 21,
'v': 22,
'w': 23,
'x': 24,
'y': 25,
'z': 26,
'ç': 27,
'ë': 28,
'ö': 29,
'ü': 30,
'ğ': 31,
'ı': 32,
'ş': 33,
'̇': 34
}
太棒了,我們看到所有字母都出現在資料集中(這並不奇怪),而且我們還提取了特殊字元 ""
和 '
。請注意,我們沒有排除這些特殊字元,因為:
模型必須學會預測何時一個詞結束,否則模型預測將總是一個字元序列,這將導致無法將詞彼此分離。
人們應該始終牢記,預處理是模型訓練前非常重要的一步。例如,我們不希望模型區分 a
和 A
僅僅因為我們忘記了規範化資料。a
和 A
之間的區別根本不取決於字母的“發音”,而更多地取決於語法規則——例如,在句子開頭使用大寫字母。因此,消除大寫字母和非大寫字母之間的差異是明智的,這樣模型就能更容易地學習轉錄語音。
為了更清楚地表明 " "
擁有自己的標記類別,我們給它一個更明顯的字元 |
。此外,我們還添加了一個“未知”標記,以便模型以後可以處理 Common Voice 訓練集中未遇到的字元。
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]
最後,我們還添加了一個填充標記,對應於 CTC 的“空白標記”。“空白標記”是 CTC 演算法的核心組成部分。有關更多資訊,請參閱此處的“對齊”部分。
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)
太棒了,現在我們的詞彙表已完成,包含 39 個標記,這意味著我們將新增到預訓練 XLS-R 檢查點之上的線性層將具有 39 的輸出維度。
現在讓我們將詞彙表儲存為 json 檔案。
import json
with open('vocab.json', 'w') as vocab_file:
json.dump(vocab_dict, vocab_file)
最後一步,我們使用 json 檔案將詞彙表載入到 Wav2Vec2CTCTokenizer
類的一個例項中。
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
如果想在本筆記本中將剛剛建立的分詞器與微調模型重複使用,強烈建議將 tokenizer
上傳到 Hugging Face Hub。我們將上傳檔案的倉庫命名為 "wav2vec2-large-xlsr-turkish-demo-colab"
。
repo_name = "wav2vec2-large-xls-r-300m-tr-colab"
然後將分詞器上傳到 🤗 Hub。
tokenizer.push_to_hub(repo_name)
太好了,您可以在 https://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-tr-colab
下檢視剛剛建立的倉庫
建立 Wav2Vec2FeatureExtractor
語音是一種連續訊號,為了能被計算機處理,它首先必須被離散化,這通常被稱為**取樣**。取樣率在這裡起著重要作用,因為它定義了每秒測量多少語音訊號資料點。因此,更高的取樣率會導致對*真實*語音訊號更好的近似,但每秒也需要更多值。
預訓練的檢查點期望其輸入資料以與訓練時所用資料大致相同的分佈進行取樣。以兩種不同速率取樣的相同語音訊號具有非常不同的分佈。例如,取樣率加倍會導致資料點時長加倍。因此,在微調 ASR 模型的預訓練檢查點之前,驗證用於預訓練模型的資料的取樣率是否與用於微調模型的資料集的取樣率匹配至關重要。
XLS-R 在 16kHz 取樣率的 Babel、多語言 LibriSpeech (MLS)、Common Voice、VoxPopuli 和 VoxLingua107 音訊資料上進行了預訓練。Common Voice 的原始取樣率為 48kHz,因此我們接下來需要將微調資料下采樣到 16kHz。
Wav2Vec2FeatureExtractor
物件需要例項化以下引數:
feature_size
: 語音模型將特徵向量序列作為輸入。雖然此序列的長度顯然不同,但特徵大小不應改變。在 Wav2Vec2 的情況下,特徵大小為 1,因為模型是在原始語音訊號上訓練的 。sampling_rate
: 模型訓練時使用的取樣率。padding_value
:對於批次推理,較短的輸入需要用特定值填充。do_normalize
:輸入是否應該進行零均值單位方差歸一化。通常,語音模型在歸一化輸入後表現更好。return_attention_mask
: 模型是否應該使用attention_mask
進行批次推理。通常,XLS-R 模型檢查點**總是**應該使用attention_mask
。
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
太好了,XLS-R 的特徵提取管道由此完全定義!
為了提高使用者友好性,特徵提取器和分詞器被封裝在一個 Wav2Vec2Processor
類中,這樣只需要一個 model
和 processor
物件。
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
接下來,我們可以準備資料集了。
預處理資料
到目前為止,我們還沒有檢視語音訊號的實際值,而只是轉錄。除了 sentence
,我們的資料集中還包含另外兩個列名 path
和 audio
。path
表示音訊檔案的絕對路徑。讓我們看看。
common_voice_train[0]["path"]
XLS-R 期望以 16 kHz 的一維陣列格式輸入。這意味著必須載入並重取樣音訊檔案。
幸運的是,datasets
透過呼叫另一個列 audio
自動完成此操作。讓我們試一試。
common_voice_train[0]["audio"]
{'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-8.8930130e-05, -3.8027763e-05, -2.9146671e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
'sampling_rate': 48000}
太好了,我們可以看到音訊檔案已自動載入。這要歸功於 datasets == 1.18.3
中引入的新的 "Audio"
特性,它可以在呼叫時動態載入和重取樣音訊檔案。
在上面的示例中,我們可以看到音訊資料以 48kHz 的取樣率載入,而模型期望的取樣率為 16kHz。我們可以透過使用 cast_column
將音訊特徵設定為正確的取樣率
common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))
我們再看看 "audio"
。
common_voice_train[0]["audio"]
{'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-7.4556941e-05, -1.4621433e-05, -5.7861507e-05], dtype=float32),
'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
'sampling_rate': 16000}
這似乎奏效了!讓我們聽幾段音訊檔案,以便更好地理解資料集並驗證音訊是否正確載入。
import IPython.display as ipd
import numpy as np
import random
rand_int = random.randint(0, len(common_voice_train)-1)
print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)
列印輸出
sunulan bütün teklifler i̇ngilizce idi
資料現在似乎已正確載入並重新取樣。
可以聽到,說話人的語速、口音和背景環境等都在變化。然而,總體而言,錄音聽起來清晰可接受,這對於眾包朗讀語音語料庫來說是意料之中的。
讓我們做最後一次檢查,確認資料準備是否正確,透過列印語音輸入的形狀、其轉寫文字以及相應的取樣率。
rand_int = random.randint(0, len(common_voice_train)-1)
print("Target text:", common_voice_train[rand_int]["sentence"])
print("Input array shape:", common_voice_train[rand_int]["audio"]["array"].shape)
print("Sampling rate:", common_voice_train[rand_int]["audio"]["sampling_rate"])
列印輸出
Target text: makedonya bu yıl otuz adet tyetmiş iki tankı aldı
Input array shape: (71040,)
Sampling rate: 16000
好的!一切看起來都沒問題——資料是一維陣列,取樣率總是 16kHz,目標文字也已規範化。
最後,我們可以利用 Wav2Vec2Processor
將資料處理成 Wav2Vec2ForCTC
訓練所需的格式。為此,我們使用 Dataset 的 map(...)
函式。
首先,我們載入並重取樣音訊資料,只需呼叫 batch["audio"]
。其次,我們從載入的音訊檔案中提取 input_values
。在我們的例子中,Wav2Vec2Processor
只對資料進行歸一化。然而,對於其他語音模型,此步驟可能包括更復雜的特徵提取,例如 Log-Mel 特徵提取。第三,我們將轉錄編碼為標籤 ID。
注意:此對映函式是 Wav2Vec2Processor
類應如何使用的良好示例。在“正常”情況下,呼叫 processor(...)
會被重定向到 Wav2Vec2FeatureExtractor
的呼叫方法。但是,當將處理器包裝到 as_target_processor
上下文中時,相同的方法會被重定向到 Wav2Vec2CTCTokenizer
的呼叫方法。有關更多資訊,請查閱文件。
def prepare_dataset(batch):
audio = batch["audio"]
# batched output is "un-batched"
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["input_length"] = len(batch["input_values"])
with processor.as_target_processor():
batch["labels"] = processor(batch["sentence"]).input_ids
return batch
讓我們將資料準備函式應用到所有樣本上。
common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)
注意:目前 datasets
利用 torchaudio
和 librosa
進行音訊載入和重取樣。如果您希望實現自己的自定義資料載入/取樣,請隨意使用 "path"
列並忽略 "audio"
列。
長輸入序列需要大量記憶體。XLS-R 基於 self-attention
。對於長輸入序列,記憶體需求與輸入長度呈二次方關係(參見這篇 Reddit 帖子)。如果此演示因“記憶體不足”錯誤而崩潰,您可能需要取消註釋以下行,以過濾所有長度超過 5 秒的訓練序列。
#max_input_length_in_sec = 5.0
#common_voice_train = common_voice_train.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])
太棒了,現在我們準備好開始訓練了!
訓練
資料已處理完畢,我們已準備好開始設定訓練管道。我們將使用 🤗 的 Trainer,為此我們主要需要執行以下操作:
定義一個數據整理器。與大多數 NLP 模型不同,XLS-R 的輸入長度遠大於輸出長度。例如,輸入長度為 50000 的樣本的輸出長度不超過 100。鑑於輸入大小較大,動態填充訓練批次效率更高,這意味著所有訓練樣本應僅填充到其批次中最長的樣本,而不是整體最長的樣本。因此,微調 XLS-R 需要一個特殊的填充資料整理器,我們將在下面定義。
評估指標。在訓練期間,模型應以詞錯誤率進行評估。我們應該相應地定義一個
compute_metrics
函式。載入預訓練檢查點。我們需要載入預訓練檢查點並對其進行正確配置以進行訓練。
定義訓練配置。
在微調模型後,我們將在測試資料上對其進行正確評估,並驗證它確實學會了正確轉寫語音。
設定訓練器
我們從定義資料整理器開始。資料整理器的程式碼是從 這個示例 複製的。
不深入細節,與常見的資料整理器不同,此資料整理器對 input_values
和 labels
進行不同的處理,因此對它們應用單獨的填充函式(再次利用 XLS-R 處理器上下文管理器)。這是必要的,因為在語音中,輸入和輸出屬於不同的模態,這意味著它們不應由相同的填充函式處理。與常見的資料整理器類似,標籤中的填充標記用 -100
填充,以便在計算損失時**不**考慮這些標記。
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
@dataclass
class DataCollatorCTCWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
The processor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
"""
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
接下來,定義評估指標。如前所述,ASR 中最主要的指標是詞錯誤率 (WER),因此我們在這個 notebook 中也使用它。
wer_metric = load_metric("wer")
模型將返回一系列 logit 向量:,其中 和 。
一個 logit 向量 包含我們之前定義的詞彙表中每個詞的對數機率,因此 config.vocab_size
。我們對模型最可能的預測感興趣,因此取 logits 的 argmax(...)
。此外,我們將編碼後的標籤透過將 -100
替換為 pad_token_id
來轉換回原始字串,並在解碼 ID 時確保連續的標記在 CTC 樣式中**不**被分組為相同的標記 。
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
現在,我們可以載入 Wav2Vec2-XLS-R-300M 的預訓練檢查點。分詞器的 pad_token_id
必須用於定義模型的 pad_token_id
,或者在 Wav2Vec2ForCTC
的情況下,也定義 CTC 的*空白標記* 。為了節省 GPU 記憶體,我們啟用了 PyTorch 的 梯度檢查點,並將損失減少設定為“mean”。
由於資料集相當小(約 6 小時訓練資料),並且 Common Voice 相當嘈雜,微調 Facebook 的 wav2vec2-xls-r-300m 檢查點似乎需要一些超引數調整。因此,我不得不嘗試不同的 dropout 值、SpecAugment 的掩碼 dropout 率、層 dropout 和學習率,直到訓練看起來足夠穩定。
注意:如果使用本筆記本在 Common Voice 的另一種語言上訓練 XLS-R,這些超引數設定可能效果不佳。請根據您的用例隨意調整這些引數。
from transformers import Wav2Vec2ForCTC
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-xls-r-300m",
attention_dropout=0.0,
hidden_dropout=0.0,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.0,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer),
)
XLS-R 的第一個元件由一堆 CNN 層組成,用於從原始語音訊號中提取具有聲學意義但與上下文無關的特徵。模型的這一部分在預訓練期間已經充分訓練,並且如論文所述,不再需要微調。因此,我們可以將*特徵提取*部分的所有引數的 requires_grad
設定為 False
。
model.freeze_feature_extractor()
最後一步,我們定義所有與訓練相關的引數。對其中一些引數進行更多解釋:
group_by_length
透過將輸入長度相似的訓練樣本分組到一個批次中,使訓練更高效。這可以透過大大減少模型中無用填充標記的總數來顯著加快訓練時間。learning_rate
和weight_decay
經過啟發式調整,直到微調變得穩定。請注意,這些引數強烈依賴於 Common Voice 資料集,並且對於其他語音資料集可能不是最優的。
關於其他引數的更多解釋,可以檢視文件。
在訓練期間,每 400 個訓練步驟將非同步上傳一個檢查點到 Hub。這允許您在模型仍在訓練時也可以使用演示小部件。
注意:如果不想將模型檢查點上傳到 Hub,只需設定 push_to_hub=False
。
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=repo_name,
group_by_length=True,
per_device_train_batch_size=16,
gradient_accumulation_steps=2,
evaluation_strategy="steps",
num_train_epochs=30,
gradient_checkpointing=True,
fp16=True,
save_steps=400,
eval_steps=400,
logging_steps=400,
learning_rate=3e-4,
warmup_steps=500,
save_total_limit=2,
push_to_hub=True,
)
現在,所有例項都可以傳遞給 Trainer,我們準備開始訓練了!
from transformers import Trainer
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=common_voice_train,
eval_dataset=common_voice_test,
tokenizer=processor.feature_extractor,
)
為了使模型獨立於說話人語速,在 CTC 中,連續的相同標記簡單地被分組為單個標記。然而,在解碼時,編碼的標籤不應被分組,因為它們與模型的預測標記不對應,這就是為什麼必須傳遞 group_tokens=False
引數。如果我們不傳遞此引數,像 "hello"
這樣的詞將被錯誤地編碼並解碼為 "helo"
。 空白標記允許模型透過強制在兩個 l 之間插入空白標記來預測像 "hello"
這樣的詞。我們模型的 "hello"
的 CTC 符合預測將是 [PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD]
。
訓練
訓練將花費數小時,具體取決於分配給此筆記本的 GPU。雖然訓練好的模型在土耳其語 Common Voice 的測試資料上取得了令人滿意的結果,但它絕不是一個最優微調的模型。本筆記本的目的是演示如何在 ASR 資料集上微調 XLS-R XLSR-Wav2Vec2。
根據分配給您的 Google Colab 的 GPU,您可能會在這裡看到“記憶體不足”錯誤。在這種情況下,最好將 per_device_train_batch_size
減小到 8 甚至更少,並增加 gradient_accumulation
。
trainer.train()
列印輸出
訓練損失 | 輪次 | 步驟 | 驗證損失 | 詞錯誤率 (Wer) |
---|---|---|---|---|
3.8842 | 3.67 | 400 | 0.6794 | 0.7000 |
0.4115 | 7.34 | 800 | 0.4304 | 0.4548 |
0.1946 | 11.01 | 1200 | 0.4466 | 0.4216 |
0.1308 | 14.68 | 1600 | 0.4526 | 0.3961 |
0.0997 | 18.35 | 2000 | 0.4567 | 0.3696 |
0.0784 | 22.02 | 2400 | 0.4193 | 0.3442 |
0.0633 | 25.69 | 2800 | 0.4153 | 0.3347 |
0.0498 | 29.36 | 3200 | 0.4077 | 0.3195 |
訓練損失和驗證 WER 都在穩步下降。
您現在可以將訓練結果上傳到 Hub,只需執行此指令即可
trainer.push_to_hub()
你現在可以和所有的朋友、家人、心愛的寵物分享這個模型:他們都可以用“your-username/the-name-you-picked”這個識別符號來載入它,例如:
from transformers import AutoModelForCTC, Wav2Vec2Processor
model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
有關如何微調 XLS-R 的更多示例,請參閱官方 🤗 Transformers 示例。
評估
最後,我們載入模型並驗證它是否確實學會了轉錄土耳其語語音。
讓我們首先載入預訓練檢查點。
model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)
現在,我們只取測試集中的第一個示例,透過模型執行它,並取 logits 的 argmax(...)
來檢索預測的標記 ID。
input_dict = processor(common_voice_test[0]["input_values"], return_tensors="pt", padding=True)
logits = model(input_dict.input_values.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)[0]
強烈建議將 sampling_rate
引數傳遞給此函式。否則可能會導致難以除錯的靜默錯誤。
我們對 common_voice_test
進行了相當大的修改,因此資料集例項不再包含原始句子標籤。因此,我們重新使用原始資料集來獲取第一個示例的標籤。
common_voice_test_transcription = load_dataset("common_voice", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test")
最後,我們可以對示例進行解碼。
print("Prediction:")
print(processor.decode(pred_ids))
print("\nReference:")
print(common_voice_test_transcription[0]["sentence"].lower())
列印輸出
預測字串 | 目標文字 |
---|---|
hatta küçük şeyleri için bir büyt bir şeyleri kolluyor veyınıki çuk şeyler için bir bir mizi inciltiyoruz | hayatta küçük şeyleri kovalıyor ve yine küçük şeyler için birbirimizi incitiyoruz. |
好的!轉錄無疑可以從我們的預測中識別出來,但它還不夠完美。模型訓練時間再長一點,在資料預處理上投入更多時間,特別是使用語言模型進行解碼,肯定會提高模型的整體效能。
然而,對於低資源語言的演示模型來說,結果還是相當不錯的🤗。