Transformers 文件
特徵提取器
並獲得增強的文件體驗
開始使用
特徵提取器
特徵提取器將音訊資料預處理成給定模型所需的正確格式。它接收原始音訊訊號並將其轉換為可饋送到模型的張量。張量形狀取決於模型,但特徵提取器會根據你使用的模型為你正確預處理音訊資料。特徵提取器還包括填充、截斷和重取樣方法。
呼叫 from_pretrained() 從 Hugging Face Hub 或本地目錄載入特徵提取器及其預處理器配置。特徵提取器和預處理器配置儲存在 preprocessor_config.json 檔案中。
將音訊訊號(通常儲存在 `array` 中)傳遞給特徵提取器,並將 `sampling_rate` 引數設定為預訓練音訊模型的取樣率。重要的是,音訊資料的取樣率必須與預訓練音訊模型訓練所用資料的取樣率相匹配。
from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
processed_sample = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=16000)
processed_sample
{'input_values': [array([ 9.4472744e-05, 3.0777880e-03, -2.8888427e-03, ...,
-2.8888427e-03, 9.4472744e-05, 9.4472744e-05], dtype=float32)]}
特徵提取器返回一個輸入,`input_values`,該輸入已準備好供模型使用。
本指南將引導你瞭解特徵提取器類以及如何預處理音訊資料。
特徵提取器類
Transformers 特徵提取器繼承自基類 SequenceFeatureExtractor,該類是 FeatureExtractionMixin 的子類。
- SequenceFeatureExtractor 提供了一個方法 pad(),用於將序列填充到特定長度,以避免序列長度不一致。
- FeatureExtractionMixin 提供了 from_pretrained() 和 save_pretrained() 來載入和儲存特徵提取器。
有兩種方法可以載入特徵提取器:AutoFeatureExtractor 和模型特定的特徵提取器類。
AutoClass API 會自動為給定模型載入正確的特徵提取器。
使用 from_pretrained() 載入特徵提取器。
from transformers import AutoFeatureExtractor
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
預處理
特徵提取器期望輸入為特定形狀的 PyTorch 張量。確切的輸入形狀可能因你使用的特定音訊模型而異。
例如,Whisper 期望 `input_features` 是形狀為 `(batch_size, feature_size, sequence_length)` 的張量,而 Wav2Vec2 期望 `input_values` 是形狀為 `(batch_size, sequence_length)` 的張量。
特徵提取器會為所使用的任何音訊模型生成正確的輸入形狀。
特徵提取器還設定音訊檔案的取樣率(每秒取樣的音訊訊號值數量)。你的音訊資料的取樣率必須與預訓練模型訓練所用資料集的取樣率匹配。該值通常在模型卡中給出。
使用 from_pretrained() 載入資料集和特徵提取器。
from datasets import load_dataset, Audio
from transformers import AutoFeatureExtractor
dataset = load_dataset("PolyAI/minds14", name="en-US", split="train")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
檢視資料集中的第一個示例,並訪問包含原始音訊訊號 `array` 的 `audio` 列。
dataset[0]["audio"]["array"]
array([ 0. , 0.00024414, -0.00024414, ..., -0.00024414,
0. , 0. ])
特徵提取器將 `array` 預處理成給定音訊模型的預期輸入格式。使用 `sampling_rate` 引數設定合適的取樣率。
processed_dataset = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=16000)
processed_dataset
{'input_values': [array([ 9.4472744e-05, 3.0777880e-03, -2.8888427e-03, ...,
-2.8888427e-03, 9.4472744e-05, 9.4472744e-05], dtype=float32)]}
填充
音訊序列長度不同是一個問題,因為 Transformers 期望所有序列具有相同的長度,以便進行批處理。長度不等的序列無法進行批處理。
dataset[0]["audio"]["array"].shape
(86699,)
dataset[1]["audio"]["array"].shape
(53248,)
填充會新增一個特殊的*填充標記*,以確保所有序列都具有相同的長度。特徵提取器會將一個 `0`(解釋為靜音)新增到 `array` 中進行填充。設定 `padding=True` 可以將序列填充到批次中最長序列的長度。
def preprocess_function(examples):
audio_arrays = [x["array"] for x in examples["audio"]]
inputs = feature_extractor(
audio_arrays,
sampling_rate=16000,
padding=True,
)
return inputs
processed_dataset = preprocess_function(dataset[:5])
processed_dataset["input_values"][0].shape
(86699,)
processed_dataset["input_values"][1].shape
(86699,)
截斷
模型只能處理到一定長度的序列,否則會崩潰。
截斷是一種從序列中移除多餘標記以確保其不超過最大長度的策略。將 `truncation=True` 設定為截斷序列到 `max_length` 引數指定的長度。
def preprocess_function(examples):
audio_arrays = [x["array"] for x in examples["audio"]]
inputs = feature_extractor(
audio_arrays,
sampling_rate=16000,
max_length=50000,
truncation=True,
)
return inputs
processed_dataset = preprocess_function(dataset[:5])
processed_dataset["input_values"][0].shape
(50000,)
processed_dataset["input_values"][1].shape
(50000,)
重取樣
Datasets 庫也可以重取樣音訊資料,使其與音訊模型預期的取樣率匹配。這種方法在載入音訊資料時即時進行重取樣,這可能比就地重取樣整個資料集更快。
你正在處理的音訊資料集的取樣率為 8kHz,而預訓練模型期望的取樣率為 16kHz。
dataset[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/f507fdca7f475d961f5bb7093bcc9d544f16f8cab8608e772a2ed4fbeb4d6f50/en-US~JOINT_ACCOUNT/602ba55abb1e6d0fbce92065.wav',
'array': array([ 0. , 0.00024414, -0.00024414, ..., -0.00024414,
0. , 0. ]),
'sampling_rate': 8000}
對 `audio` 列呼叫 cast_column,將取樣率上取樣到 16kHz。
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
當你載入資料集樣本時,它現在被重取樣到 16kHz。
dataset[0]["audio"]
{'path': '/root/.cache/huggingface/datasets/downloads/extracted/f507fdca7f475d961f5bb7093bcc9d544f16f8cab8608e772a2ed4fbeb4d6f50/en-US~JOINT_ACCOUNT/602ba55abb1e6d0fbce92065.wav',
'array': array([ 1.70562416e-05, 2.18727451e-04, 2.28099874e-04, ...,
3.43842403e-05, -5.96364771e-06, -1.76846661e-05]),
'sampling_rate': 16000}