Datasets 文件
Dataset 和 IterableDataset 之間的區別
並獲得增強的文件體驗
開始使用
Dataset 和 IterableDataset 之間的區別
有兩種型別的資料集物件:Dataset 和 IterableDataset。您選擇使用或建立哪種型別的資料集取決於資料集的大小。總的來說,由於其惰性行為和速度優勢,IterableDataset 非常適合處理大型資料集(想想數百 GB!),而 Dataset 則適用於其他所有情況。本頁將比較 Dataset 和 IterableDataset 之間的區別,以幫助您選擇適合自己的資料集物件。
下載與流式傳輸
當您有一個常規的 Dataset 時,您可以使用 `my_dataset[0]` 來訪問它。這提供了對行的隨機訪問。這類資料集也稱為“對映式”(map-style)資料集。例如,您可以這樣下載 ImageNet-1k 並訪問任何行:
from datasets import load_dataset
imagenet = load_dataset("timm/imagenet-1k-wds", split="train") # downloads the full dataset
print(imagenet[0])
但一個缺點是,您必須將整個資料集儲存在磁碟或記憶體中,這會阻止您訪問比磁碟更大的資料集。由於這對大型資料集可能帶來不便,因此存在另一種型別的資料集,即 IterableDataset。當您有一個 `IterableDataset` 時,您可以使用 `for` 迴圈在迭代資料集時逐步載入資料。這樣,只有一小部分樣本被載入到記憶體中,並且您不會在磁碟上寫入任何內容。
例如,您可以流式傳輸 ImageNet-1k 資料集而無需將其下載到磁碟:
from datasets import load_dataset
imagenet = load_dataset("timm/imagenet-1k-wds", split="train", streaming=True) # will start loading the data when iterated over
for example in imagenet:
print(example)
break
流式傳輸可以讀取線上資料而無需向磁碟寫入任何檔案。例如,您可以流式傳輸由多個分片(shard)組成的資料集,每個分片都可能有數百 GB,例如 C4 或 LAION-2B。有關如何流式傳輸資料集的更多資訊,請參閱資料集流式傳輸指南。
但這並不是唯一的區別,因為 `IterableDataset` 的“惰性”行為在資料集的建立和處理方面也同樣存在。
建立對映式資料集和可迭代資料集
您可以使用列表或字典建立 Dataset,資料會完全轉換為 Arrow 格式,以便您輕鬆訪問任何行:
my_dataset = Dataset.from_dict({"col_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]})
print(my_dataset[0])
另一方面,要建立 `IterableDataset`,您必須提供一種“惰性”載入資料的方式。在 Python 中,我們通常使用生成器函式。這些函式一次 `yield` 一個樣本,這意味著您不能像常規 `Dataset` 那樣透過切片來訪問行:
def my_generator(n):
for i in range(n):
yield {"col_1": i}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs={"n": 10})
for example in my_iterable_dataset:
print(example)
break
完全載入與逐步載入本地檔案
可以使用 load_dataset() 將本地或遠端資料檔案轉換為 Arrow 格式的 Dataset:
data_files = {"train": ["path/to/data.csv"]}
my_dataset = load_dataset("csv", data_files=data_files, split="train")
print(my_dataset[0])
然而,這需要一個從 CSV 到 Arrow 格式的轉換步驟,如果您的資料集很大,這會消耗時間和磁碟空間。
為了節省磁碟空間並跳過轉換步驟,您可以透過直接從本地檔案流式傳輸來定義一個 `IterableDataset`。這樣,當您迭代資料集時,資料會逐步從本地檔案中讀取:
data_files = {"train": ["path/to/data.csv"]}
my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True)
for example in my_iterable_dataset: # this reads the CSV file progressively as you iterate over the dataset
print(example)
break
支援多種檔案格式,如 CSV、JSONL 和 Parquet,以及影像和音訊檔案。您可以在相應的指南中找到更多關於載入表格、文字、視覺和音訊資料集的資訊。
即時資料處理與惰性資料處理
當您使用 Dataset.map() 處理一個 Dataset 物件時,整個資料集會立即被處理並返回。這與 `pandas` 的工作方式類似。
my_dataset = my_dataset.map(process_fn) # process_fn is applied on all the examples of the dataset
print(my_dataset[0])
另一方面,由於 `IterableDataset` 的“惰性”特性,呼叫 IterableDataset.map() 並不會將您的 `map` 函式應用於整個資料集。相反,您的 `map` 函式是即時應用的。
因此,您可以連結多個處理步驟,當您開始迭代資料集時,它們會一次性全部執行:
my_iterable_dataset = my_iterable_dataset.map(process_fn_1)
my_iterable_dataset = my_iterable_dataset.filter(filter_fn)
my_iterable_dataset = my_iterable_dataset.map(process_fn_2)
# process_fn_1, filter_fn and process_fn_2 are applied on-the-fly when iterating over the dataset
for example in my_iterable_dataset:
print(example)
break
精確洗牌與快速近似洗牌
當您使用 Dataset.shuffle() 對 Dataset 進行洗牌時,您應用的是對資料集的精確洗牌。它的工作原理是獲取一個索引列表 `[0, 1, 2, ... len(my_dataset) - 1]` 並對這個列表進行洗牌。然後,訪問 `my_dataset[0]` 會返回由洗牌後的索引對映的第一個元素定義的行和索引:
my_dataset = my_dataset.shuffle(seed=42)
print(my_dataset[0])
由於在 `IterableDataset` 的情況下我們沒有對行的隨機訪問許可權,因此我們不能使用一個洗牌後的索引列表來訪問任意位置的行。這使得精確洗牌無法實現。取而代之的是,在 IterableDataset.shuffle() 中使用了一種快速的近似洗牌方法。它使用一個洗牌緩衝區來迭代地從資料集中抽樣隨機樣本。由於資料集仍然是迭代讀取的,它提供了出色的速度效能。
my_iterable_dataset = my_iterable_dataset.shuffle(seed=42, buffer_size=100)
for example in my_iterable_dataset:
print(example)
break
但僅使用洗牌緩衝區不足以為機器學習模型訓練提供令人滿意的洗牌效果。因此,如果您的資料集由多個檔案或來源組成,IterableDataset.shuffle() 也會對資料集的分片進行洗牌。
# Stream from the internet
my_iterable_dataset = load_dataset("deepmind/code_contests", split="train", streaming=True)
my_iterable_dataset.num_shards # 39
# Stream from local files
data_files = {"train": [f"path/to/data_{i}.csv" for i in range(1024)]}
my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True)
my_iterable_dataset.num_shards # 1024
# From a generator function
def my_generator(n, sources):
for source in sources:
for example_id_for_current_source in range(n):
yield {"example_id": f"{source}_{example_id_for_current_source}"}
gen_kwargs = {"n": 10, "sources": [f"path/to/data_{i}" for i in range(1024)]}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs=gen_kwargs)
my_iterable_dataset.num_shards # 1024
速度差異
常規的 Dataset 物件基於 Arrow,它提供了對行的快速隨機訪問。得益於記憶體對映以及 Arrow 是一種記憶體格式的事實,從磁碟讀取資料不會進行昂貴的系統呼叫和反序列化。透過在連續的 Arrow 記錄批次上迭代,使用 `for` 迴圈迭代時,它提供了更快的資料載入速度。
然而,一旦您的 Dataset 有了索引對映(例如透過 Dataset.shuffle()),速度可能會慢 10 倍。這是因為有一個額外的步驟來使用索引對映獲取要讀取的行索引,更重要的是,您不再讀取連續的資料塊。要恢復速度,您需要使用 Dataset.flatten_indices() 再次將整個資料集重寫到磁碟上,這會移除索引對映。不過,這可能需要很長時間,具體取決於您的資料集大小。
my_dataset[0] # fast
my_dataset = my_dataset.shuffle(seed=42)
my_dataset[0] # up to 10x slower
my_dataset = my_dataset.flatten_indices() # rewrite the shuffled dataset on disk as contiguous chunks of data
my_dataset[0] # fast again
在這種情況下,我們建議切換到 IterableDataset 並利用其快速的近似洗牌方法 IterableDataset.shuffle()。它只對分片順序進行洗牌,併為您的資料集新增一個洗牌緩衝區,這能保持資料集的最佳速度。您也可以輕鬆地重新洗牌資料集。
for example in enumerate(my_iterable_dataset): # fast
pass
shuffled_iterable_dataset = my_iterable_dataset.shuffle(seed=42, buffer_size=100)
for example in enumerate(shuffled_iterable_dataset): # as fast as before
pass
shuffled_iterable_dataset = my_iterable_dataset.shuffle(seed=1337, buffer_size=100) # reshuffling using another seed is instantaneous
for example in enumerate(shuffled_iterable_dataset): # still as fast as before
pass
如果您在多個週期(epoch)上使用資料集,用於洗牌分片順序和洗牌緩衝區的有效種子是 `seed + epoch`。這使得在不同週期之間輕鬆地重新洗牌資料整合為可能。
for epoch in range(n_epochs):
my_iterable_dataset.set_epoch(epoch)
for example in my_iterable_dataset: # fast + reshuffled at each epoch using `effective_seed = seed + epoch`
pass
要重新開始迭代一個對映式資料集,您只需跳過前面的樣本即可:
my_dataset = my_dataset.select(range(start_index, len(dataset)))
但如果您使用帶 `Sampler` 的 `DataLoader`,您應該儲存您的取樣器狀態(您可能編寫了一個允許恢復的自定義取樣器)。
另一方面,可迭代資料集不提供對特定樣本索引的隨機訪問以供恢復。但您可以使用 IterableDataset.state_dict() 和 IterableDataset.load_state_dict() 從檢查點恢復,類似於您對模型和最佳化器可以做的那樣:
>>> iterable_dataset = Dataset.from_dict({"a": range(6)}).to_iterable_dataset(num_shards=3)
>>> # save in the middle of training
>>> state_dict = iterable_dataset.state_dict()
>>> # and resume later
>>> iterable_dataset.load_state_dict(state_dict)
在底層,可迭代資料集會跟蹤當前正在讀取的分片和當前分片中的樣本索引,並將此資訊儲存在 `state_dict` 中。
要從檢查點恢復,資料集會跳過所有先前讀取過的分片,以從當前分片重新開始。然後它會讀取該分片並跳過樣本,直到達到檢查點中的確切樣本位置。
因此,重新啟動資料集相當快,因為它不會重新讀取已經迭代過的分片。儘管如此,恢復資料集通常不是瞬時的,因為它必須從當前分片的開頭重新開始讀取並跳過樣本,直到達到檢查點位置。
這可以與 `torchdata` 中的 `StatefulDataLoader` 一起使用,請參閱使用 PyTorch DataLoader 進行流式傳輸。
從對映式切換到可迭代式
如果您想受益於 IterableDataset 的“惰性”行為或其速度優勢,您可以將您的對映式 Dataset 切換為 IterableDataset。
my_iterable_dataset = my_dataset.to_iterable_dataset()
如果您想對資料集進行洗牌或將其與 PyTorch DataLoader 一起使用,我們建議生成一個分片的 IterableDataset。
my_iterable_dataset = my_dataset.to_iterable_dataset(num_shards=1024)
my_iterable_dataset.num_shards # 1024