萬事通,亦有所精:一個多功能 Transformer 智慧體
引言
我們很高興能與大家分享“萬事通”(Jack of All Trades, JAT)專案,該專案旨在朝著通用智慧體的方向邁進。這個專案始於對 Gato (Reed et al., 2022) 工作的開源復現,Gato 提出訓練一個能夠同時執行視覺語言和決策任務的 Transformer 模型。因此,我們首先構建了一個 Gato 資料集的開源版本。然後,我們在此基礎上訓練了多模態 Transformer 模型,並引入了多項對 Gato 的改進,以更好地處理序列資料和連續值。
總的來說,該專案取得了以下成果:
- 在多種多樣的任務上釋出了大量的**專家級強化學習 (RL) 智慧體**。
- 釋出了 **JAT 資料集**,這是首個用於通用智慧體訓練的資料集。它包含了數十萬條由專家智慧體收集的專家軌跡。
- 釋出了 **JAT 模型**,這是一個基於 Transformer 的智慧體,能夠玩影片遊戲、控制機器人執行各種任務、在簡單的導航環境中理解並執行指令等等!

資料集與專家策略
專家策略
傳統強化學習 (RL) 通常是在單一環境中訓練策略。利用這些專家策略是構建一個多功能智慧體的有效途徑。我們選擇了多種不同性質和難度的環境,包括 Atari、BabyAI、Meta-World 和 MuJoCo。對於每種環境,我們都訓練一個智慧體,直到其達到業界頂尖水平。(對於 BabyAI,我們使用了 BabyAI bot)。這些訓練出的智慧體被稱為專家智慧體,並已釋出在 🤗 Hub 上。您可以在 JAT 資料集卡片中找到所有智慧體的列表。
JAT 資料集
我們釋出了 JAT 資料集,這是首個用於通用智慧體訓練的資料集。JAT 資料集包含由上述專家智慧體收集的數十萬條專家軌跡。要使用此資料集,只需像從 🤗 Hub 載入任何其他資料集一樣載入它即可。
>>> from datasets import load_dataset
>>> dataset = load_dataset("jat-project/jat-dataset", "metaworld-assembly")
>>> first_episode = dataset["train"][0]
>>> first_episode.keys()
dict_keys(['continuous_observations', 'continuous_actions', 'rewards'])
>>> len(first_episode["rewards"])
500
>>> first_episode["continuous_actions"][0]
[6.459120273590088, 2.2422609329223633, -5.914587020874023, -19.799840927124023]
除了強化學習資料外,我們還加入了文字資料集,以便為使用者提供獨特的互動介面。因此,您還會找到用於 Wikipedia、Oscar、OK-VQA 和 Conceptual-Captions 的子集。
JAT 智慧體架構
JAT 的架構基於 Transformer,使用了 EleutherAI 的 GPT-Neo 實現。JAT 的獨特之處在於其嵌入機制,該機制被設計為能內在地處理序列決策任務。我們將觀測值嵌入與動作嵌入以及相應的獎勵交錯排列。
因此,每個嵌入要麼對應一個觀測值(與獎勵相關),要麼對應一個動作。但 JAT 是如何編碼這些資訊的呢?這取決於資料的型別。如果資料(觀測值或動作)是影像(如 Atari 遊戲),那麼 JAT 使用 CNN。如果是連續向量,JAT 使用線性層。最後,如果是離散值,JAT 使用線性投影層。同樣的原理也用於模型輸出,具體取決於要預測的資料型別。預測是因果的,將觀測值移動一個時間步。這樣,智慧體必須根據之前所有的觀測值和動作來預測下一個動作。
此外,我們覺得訓練我們的智慧體執行自然語言處理 (NLP) 和計算機視覺 (CV) 任務會很有趣。為此,我們還讓編碼器能夠接受文字和影像資料作為輸入。對於文字資料,我們使用 GPT-2 的分詞策略進行分詞;對於影像,我們使用 ViT 型別的編碼器。
考慮到資料的模態會因環境不同而改變,JAT 是如何計算損失的呢?它會為每種模態分別計算損失。對於影像和連續值,它使用 MSE 損失。對於離散值,它使用交叉熵損失。最終的損失是序列中每個元素損失的平均值。等等,這是否意味著我們對預測動作和預測觀測值賦予了相同的權重?實際上不是,但我們將在下面詳細討論這一點。
實驗與結果
我們在所有 157 個訓練任務上評估了 JAT。我們收集了 10 個回合的資料並記錄了總獎勵。為了便於閱讀,我們按領域彙總了結果。
如果用一個數字來總結這些結果,那就是 65.8%,這是與 4 個領域的 JAT 專家相比的平均效能。這表明 JAT 能夠在各種各樣的任務上模仿專家的表現。讓我們更深入地瞭解一下細節:
- 對於 Atari 57,智慧體達到了專家分數的 14.1%,相當於人類水平的 37.6%。它在 21 個遊戲中超過了人類水平。
- 對於 BabyAI,智慧體達到了專家分數的 99.0%,並且僅在 1 個任務上未能超過專家分數的 50%。
- 對於 Meta-World,智慧體達到了專家分數的 65.5%。
- 對於 MuJoCo,智慧體達到了專家分數的 84.8%。
最令人印象深刻的是,JAT 僅使用**單一網路**就在所有領域取得了這樣的效能。為了衡量這一效能,讓我們看看 JAT 在幾個任務上的表現:
想試試嗎?你可以的!JAT 模型已在 🤗 Hub 上提供!
對於文字任務,我們的模型表現出初步的能力,我們建議讀者參閱論文以獲取更多細節。
預測觀測值的驚人好處
在訓練強化學習 (RL) 智慧體時,主要目標是最大化未來獎勵。但是,如果我們還要求智慧體預測它未來會觀測到什麼呢?這個額外的任務會幫助還是妨礙學習過程?
關於這個問題,有兩種對立的觀點。一方面,學習預測觀測值可以提供對環境更深入的理解,從而實現更好、更快的學習。另一方面,這可能會分散智慧體對其主要目標的注意力,導致在觀測值和動作預測方面的表現都平平無奇。
為了解決這個爭論,我們進行了一項實驗,使用了一個結合了觀測值損失和動作損失的損失函式,並用一個加權引數 來平衡這兩個目標。
結果非常值得注意。當 太高(0.5)時,預測觀測值的額外目標似乎妨礙了學習過程。但當 較低時,對學習的影響可以忽略不計,智慧體的效能與不將觀測值預測作為目標時獲得的效能相似。
然而,我們發現在 附近有一個最佳點,在這裡,學習預測觀測值實際上提高了智慧體的學習效率。我們的研究表明,在學習過程中加入觀測值預測是有益的,只要平衡得當。這一發現對這類智慧體的設計具有重要意義,凸顯了輔助目標在提高學習效率方面的潛在價值。
所以,下次你訓練強化學習智慧體時,不妨考慮讓它預測未來會觀測到什麼。這可能會帶來更好的效能和更快的學習!
結論
在這項工作中,我們介紹了 JAT,一個多功能 Transformer 智慧體,能夠掌握多種序列決策任務,並在自然語言處理 (NLP) 和計算機視覺 (CV) 任務中展現出初步的能力。對於所有這些任務,JAT 都使用單一網路。我們的貢獻包括髮布了專家級強化學習智慧體、JAT 資料集和 JAT 模型。我們希望這項工作能激勵通用智慧體領域的未來研究,併為開發更多功能、更強大的 AI 系統做出貢獻。
下一步是什麼?對研究的展望
我們相信 JAT 專案為通用智慧體領域的研究開闢了一個新方向,而我們僅僅觸及了皮毛。以下是一些未來工作的想法:
改進資料:儘管具有開創性,JAT 資料集仍處於早期階段。專家軌跡僅來自每個環境的一個專家智慧體,這可能會導致一些偏差。儘管我們已盡最大努力達到業界頂尖水平,但一些環境仍然具有挑戰性。我們相信收集更多資料和訓練更多專家智慧體將**大有幫助**。
使用離線強化學習 (Offline RL):JAT 智慧體使用基礎的行為克隆進行訓練。這意味著兩件事:(1)我們無法利用次優軌跡;(2)JAT 智慧體無法超越專家。我們選擇這種方法是為了簡單,但我們相信使用離線強化學習可以**真正幫助**提高智慧體的效能,而且實現起來不會太複雜。
釋放更智慧的多工取樣策略的全部潛力:目前,JAT 智慧體從所有任務中均勻取樣資料,但這種方法可能會限制其發展。透過動態調整取樣率以專注於最具挑戰性的任務,我們可以極大地加速智慧體的學習過程,並解鎖**顯著的效能提升**。
連結
引用
@article{gallouedec2024jack,
title = {{Jack of All Trades, Master of Some, a Multi-Purpose Transformer Agent}},
author = {Gallouédec, Quentin and Beeching, Edward and Romac, Clément and Dellandréa, Emmanuel},
journal = {arXiv preprint arXiv:2402.09844},
year = {2024},
url = {https://arxiv.org/abs/2402.09844}
}