Datasets 文件
與 JAX 配合使用
並獲得增強的文件體驗
開始使用
與 JAX 配合使用
本文件是關於如何將 datasets
與 JAX 結合使用的快速入門,特別關注如何從我們的資料集中獲取 jax.Array
物件,以及如何使用它們來訓練 JAX 模型。
復現上述程式碼需要 jax
和 jaxlib
,因此請確保透過 pip install datasets[jax]
安裝它們。
資料集格式
預設情況下,資料集返回常規的 Python 物件:整數、浮點數、字串、列表等,並且字串和二進位制物件保持不變,因為 JAX 僅支援數字。
要獲得 JAX 陣列(類似 numpy),您可以將資料集的格式設定為 jax
>>> from datasets import Dataset
>>> data = [[1, 2], [3, 4]]
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("jax")
>>> ds[0]
{'data': DeviceArray([1, 2], dtype=int32)}
>>> ds[:2]
{'data': DeviceArray([
[1, 2],
[3, 4]], dtype=int32)}
Dataset 物件是 Arrow 表的包裝器,它允許將資料集中的陣列快速讀取為 JAX 陣列。
請注意,完全相同的過程也適用於 DatasetDict
物件,因此當將 DatasetDict
的格式設定為 jax
時,其中的所有 Dataset
都會被格式化為 jax
>>> from datasets import DatasetDict
>>> data = {"train": {"data": [[1, 2], [3, 4]]}, "test": {"data": [[5, 6], [7, 8]]}}
>>> dds = DatasetDict.from_dict(data)
>>> dds = dds.with_format("jax")
>>> dds["train"][:2]
{'data': DeviceArray([
[1, 2],
[3, 4]], dtype=int32)}
您需要考慮的另一件事是,格式化操作在您實際訪問資料之前不會應用。因此,如果您想從資料集中獲取 JAX 陣列,您需要先訪問資料,否則格式將保持不變。
最後,要將資料載入到您選擇的裝置上,您可以指定 device
引數,但請注意,不支援 jaxlib.xla_extension.Device
,因為它既不能用 pickle
也不能用 dill
序列化,因此您需要改用其字串識別符號。
>>> import jax
>>> from datasets import Dataset
>>> data = [[1, 2], [3, 4]]
>>> ds = Dataset.from_dict({"data": data})
>>> device = str(jax.devices()[0]) # Not casting to `str` before passing it to `with_format` will raise a `ValueError`
>>> ds = ds.with_format("jax", device=device)
>>> ds[0]
{'data': DeviceArray([1, 2], dtype=int32)}
>>> ds[0]["data"].device()
TFRT_CPU_0
>>> assert ds[0]["data"].device() == jax.devices()[0]
True
請注意,如果未向 with_format
提供 device
引數,它將使用預設裝置,即 jax.devices()[0]
。
N 維陣列
如果您的資料集由N維陣列組成,您會發現如果形狀固定,它們預設被視為相同的張量
>>> from datasets import Dataset
>>> data = [[[1, 2],[3, 4]], [[5, 6],[7, 8]]] # fixed shape
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("jax")
>>> ds[0]
{'data': Array([[1, 2],
[3, 4]], dtype=int32)}
>>> from datasets import Dataset
>>> data = [[[1, 2],[3]], [[4, 5, 6],[7, 8]]] # varying shape
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("jax")
>>> ds[0]
{'data': [Array([1, 2], dtype=int32), Array([3], dtype=int32)]}
然而,這種邏輯通常需要慢速的形狀比較和資料複製。為了避免這種情況,您必須明確使用 `Array` 特徵型別並指定張量的形狀
>>> from datasets import Dataset, Features, Array2D
>>> data = [[[1, 2],[3, 4]],[[5, 6],[7, 8]]]
>>> features = Features({"data": Array2D(shape=(2, 2), dtype='int32')})
>>> ds = Dataset.from_dict({"data": data}, features=features)
>>> ds = ds.with_format("jax")
>>> ds[0]
{'data': Array([[1, 2],
[3, 4]], dtype=int32)}
>>> ds[:2]
{'data': Array([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]], dtype=int32)}
其他特徵型別
ClassLabel 資料可以正確轉換為陣列。
>>> from datasets import Dataset, Features, ClassLabel
>>> labels = [0, 0, 1]
>>> features = Features({"label": ClassLabel(names=["negative", "positive"])})
>>> ds = Dataset.from_dict({"label": labels}, features=features)
>>> ds = ds.with_format("jax")
>>> ds[:3]
{'label': DeviceArray([0, 0, 1], dtype=int32)}
字串和二進位制物件保持不變,因為 JAX 僅支援數字。
要使用 Image 特徵型別,您需要安裝 `vision` 額外依賴:`pip install datasets[vision]`。
>>> from datasets import Dataset, Features, Image
>>> images = ["path/to/image.png"] * 10
>>> features = Features({"image": Image()})
>>> ds = Dataset.from_dict({"image": images}, features=features)
>>> ds = ds.with_format("jax")
>>> ds[0]["image"].shape
(512, 512, 3)
>>> ds[0]
{'image': DeviceArray([[[ 255, 255, 255],
[ 255, 255, 255],
...,
[ 255, 255, 255],
[ 255, 255, 255]]], dtype=uint8)}
>>> ds[:2]["image"].shape
(2, 512, 512, 3)
>>> ds[:2]
{'image': DeviceArray([[[[ 255, 255, 255],
[ 255, 255, 255],
...,
[ 255, 255, 255],
[ 255, 255, 255]]]], dtype=uint8)}
要使用 Audio 特徵型別,您需要安裝 `audio` 額外依賴:`pip install datasets[audio]`。
>>> from datasets import Dataset, Features, Audio
>>> audio = ["path/to/audio.wav"] * 10
>>> features = Features({"audio": Audio()})
>>> ds = Dataset.from_dict({"audio": audio}, features=features)
>>> ds = ds.with_format("jax")
>>> ds[0]["audio"]["array"]
DeviceArray([-0.059021 , -0.03894043, -0.00735474, ..., 0.0133667 ,
0.01809692, 0.00268555], dtype=float32)
>>> ds[0]["audio"]["sampling_rate"]
DeviceArray(44100, dtype=int32, weak_type=True)
資料載入
JAX 沒有任何內建的資料載入功能,因此您需要使用像 PyTorch 這樣的庫,透過 DataLoader
來載入資料,或者使用 TensorFlow 的 tf.data.Dataset
。引用 JAX 文件 中關於此主題的內容:“JAX 專注於程式轉換和加速器支援的 NumPy,所以我們不將資料載入或整理包含在 JAX 庫中。已經有很多優秀的資料載入器了,所以我們直接使用它們,而不是重新發明輪子。我們將使用 PyTorch 的資料載入器,並建立一個小小的墊片使其能夠處理 NumPy 陣列。”
這就是為什麼在 datasets
中進行 JAX 格式化如此有用的原因,因為它讓您可以使用 HuggingFace Hub 中的任何模型與 JAX,而無需擔心資料載入部分。
使用 with_format('jax')
從資料集中獲取 JAX 陣列最簡單的方法是使用 with_format('jax')
方法。假設我們想在 MNIST 資料集 上訓練一個神經網路,該資料集可在 HuggingFace Hub 的 https://huggingface.co/datasets/mnist 找到。
>>> from datasets import load_dataset
>>> ds = load_dataset("mnist")
>>> ds = ds.with_format("jax")
>>> ds["train"][0]
{'image': DeviceArray([[ 0, 0, 0, ...],
[ 0, 0, 0, ...],
...,
[ 0, 0, 0, ...],
[ 0, 0, 0, ...]], dtype=uint8),
'label': DeviceArray(5, dtype=int32)}
設定格式後,我們可以使用 Dataset.iter()
方法將資料集分批送入 JAX 模型。
>>> for epoch in range(epochs):
... for batch in ds["train"].iter(batch_size=32):
... x, y = batch["image"], batch["label"]
... ...