RWKV 介紹——一個兼具 Transformer 優點的 RNN

釋出於 2023 年 5 月 15 日
在 GitHub 上更新

ChatGPT 和聊天機器人驅動的應用在自然語言處理 (NLP) 領域引起了廣泛關注。社群一直在尋求強大、可靠的開源模型用於其應用和用例。這些強大模型的興起源於 Vaswani 等人於 2017 年首次引入的基於 Transformer 的模型的普及和廣泛採用。這些模型顯著優於此前基於迴圈神經網路 (RNN) 的最先進 NLP 模型,後者在該論文發表後被認為已“過時”。透過這篇博文,我們將介紹 RWKV 這一新架構的整合,它結合了 RNN 和 Transformer 的優點,並已於近期整合到 Hugging Face 的 Transformers 庫中。

RWKV 專案概述

RWKV 專案由 BlinkDL (Bo Peng) 發起並主導,他積極參與並維護該專案。社群在官方 Discord 頻道中組織起來,不斷在效能(RWKV.cpp、量化等)、可擴充套件性(資料集處理和抓取)和研究(聊天微調、多模態微調等)等各個主題上增強專案成果。用於訓練 RWKV 模型的 GPU 由 Stability AI 捐贈。

您可以透過加入 官方 Discord 頻道 來參與,並透過以下兩篇博文了解更多關於 RWKV 背後的通用思想:https://johanwind.github.io/2023/03/23/rwkv_overview.html / https://johanwind.github.io/2023/03/23/rwkv_details.html

Transformer 架構 vs RNN

RNN 架構是首批廣泛用於處理資料序列的神經網路架構之一,與採用固定大小輸入的經典架構不同。它將當前“token”(即資料流的當前資料點)和先前的“狀態”作為輸入,並計算預測的下一個 token 和預測的下一個狀態。然後,新狀態用於計算下一個 token 的預測,依此類推。RNN 還可以以不同的“模式”使用,從而使其能夠應用於不同的場景,正如 Andrej Karpathy 的博文 所述,例如一對一(影像分類)、一對多(影像字幕)、多對一(序列分類)、多對多(序列生成)等。

rnn_diagram
RNN 的可能配置概述。來源:Andrej Karpathy 的博文

由於 RNN 在每個步驟都使用相同的權重來計算預測,因此它們由於梯度消失問題而難以記憶長序列資訊。為了解決這一限制,人們引入了 LSTM 或 GRU 等新架構。然而,Transformer 架構被證明是迄今為止解決此問題最有效的。

在 Transformer 架構中,輸入 token 在自注意力模組中同時進行處理。token 首先使用 Query、Key 和 Value 權重線性投影到不同的空間。生成的矩陣直接用於計算注意力分數(透過 softmax,如下所示),然後乘以 Value 隱藏狀態以獲得最終的隱藏狀態。這種設計使架構能夠有效地緩解長序列問題,並且與 RNN 模型相比,其推理和訓練速度更快。

transformer_diagram
Transformer 模型中注意力分數的公式。來源:Jay Alammar 的博文
rwkv_attention_formula
RWKV 模型中注意力分數的公式。來源:RWKV 博文

在訓練過程中,Transformer 架構相對於傳統的 RNN 和 CNN 具有多項優勢。其中最顯著的優勢是其學習上下文表示的能力。與一次處理一個詞的 RNN 和 CNN 不同,Transformer 架構將輸入序列作為一個整體進行處理。這使得它能夠捕獲序列中詞語之間的長距離依賴關係,這對於語言翻譯和問答等任務特別有用。

在推理過程中,RNN 在速度和記憶體效率方面具有一些優勢。這些優勢包括簡單性(只需矩陣-向量操作)和記憶體效率(記憶體需求在推理過程中不會增加)。此外,由於計算只作用於當前 token 和狀態,計算速度與上下文視窗長度保持不變。

RWKV 架構

RWKV 的靈感來源於 Apple 的 Attention Free Transformer。該架構經過精心簡化和最佳化,使其能夠轉換為 RNN。此外,還添加了一些技巧,例如 TokenShiftSmallInitEmb (技巧列表詳見 官方 GitHub 倉庫的 README),以提升其效能與 GPT 相當。沒有這些技巧,模型的效能將大打折扣。目前,訓練方面已有基礎設施可將訓練擴充套件至 14B 引數,並且 RWKV-4(截至目前的最新版本)中已迭代修復了一些問題,例如數值不穩定性。

RWKV:RNN 和 Transformer 的結合

如何將 Transformer 和 RNN 的優點結合起來?基於 Transformer 的模型的主要缺點是,當上下文視窗大於一定值時,執行模型可能變得具有挑戰性,因為注意力分數是針對整個序列同時計算的。

RNN 天然支援非常長的上下文長度——僅受訓練中看到的上下文長度限制,但透過精心編碼,這可以擴充套件到數百萬個 token。目前,RWKV 模型在 8192 (ctx8192) 的上下文長度上進行訓練,它們的執行速度與 ctx1024 模型相同,並且需要相同的 RAM 量。

傳統 RNN 模型的主要缺點以及 RWKV 的不同之處

  1. 傳統的 RNN 模型無法利用很長的上下文(LSTM 在用作 LM 時只能管理大約 100 個 token)。然而,RWKV 可以利用數千甚至更多 token,如下所示
rwkv_loss
不同上下文長度和模型大小下的 LM 損失。來源:RWKV 原始倉庫
  1. 傳統 RNN 模型在訓練時無法並行化。RWKV 類似於“線性化 GPT”,它的訓練速度比 GPT 快。

透過將這兩種優勢結合到單一架構中,我們希望 RWKV 能夠成長為比其各部分之和更強大的模型。

RWKV 注意力機制的公式

該模型架構與經典基於 Transformer 的模型非常相似(即一個嵌入層、多個相同的層、層歸一化以及一個因果語言建模頭來預測下一個 token)。唯一的區別在於注意力層,它與傳統的基於 Transformer 的模型完全不同。

為了更全面地理解注意力層,我們建議深入閱讀 Johan Sokrates Wind 的博文中提供的詳細解釋。

現有檢查點

純語言模型:RWKV-4 模型

最受歡迎的 RWKV 模型引數量從約 1.7 億到 14 億不等。根據 RWKV 概述 博文,這些模型已在 Pile 資料集上進行訓練,並在不同基準上與其他 SoTA 模型進行評估,它們的表現相當不錯,結果非常可比。

rwkv_loss
RWKV-4 與其他常見架構的比較。來源:Johan Wind 的部落格文章

指令微調/聊天版本:RWKV-4 Raven

Bo 還訓練了一個 RWKV 架構的“聊天”版本,即 RWKV-4 Raven 模型。它是一個 RWKV-4 pile 模型(在 The Pile 資料集上預訓練的 RWKV 模型),在 ALPACA、CodeAlpaca、Guanaco、GPT4All 和更多資料集上進行了微調。該模型有多個版本,訓練語言不同(僅英語、英語+中文+日語、英語+日語等)和大小不同(1.5B 引數、7B 引數、14B 引數)。

所有 HF 轉換後的模型都可以在 Hugging Face Hub 的 RWKV 組織中找到。

🤗 Transformers 整合

該架構已透過 此拉取請求 新增到 transformers 庫中。截至本文撰寫之時,您可以透過從原始碼安裝 transformers 或使用庫的 main 分支來使用它。該架構與庫緊密整合,您可以像使用任何其他架構一樣使用它。

下面我們來看一些示例。

文字生成示例

要根據輸入提示生成文字,您可以使用 pipeline 來生成文字。

from transformers import pipeline

model_id = "RWKV/rwkv-4-169m-pile"

prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

pipe = pipeline("text-generation", model=model_id)
print(pipe(prompt, max_new_tokens=20))
>>> [{'generated_text': '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.\n\nThe researchers found that the dragons were able to communicate with each other, and that they were'}]

或者您可以從下面的程式碼片段開始執行

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile")
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")

prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(inputs["input_ids"], max_new_tokens=20)
print(tokenizer.decode(output[0].tolist()))
>>> In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.\n\nThe researchers found that the dragons were able to communicate with each other, and that they were

使用 Raven 模型(聊天模型)

您可以像使用 Alpaca 風格一樣提示聊天模型,下面是一個示例

from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "RWKV/rwkv-raven-1b5"

model = AutoModelForCausalLM.from_pretrained(model_id).to(0)
tokenizer = AutoTokenizer.from_pretrained(model_id)

question = "Tell me about ravens"
prompt = f"### Instruction: {question}\n### Response:"

inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=100)

print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))
>>> ### Instruction: Tell me about ravens
### Response: RAVENS are a type of bird that is native to the Middle East and North Africa. They are known for their intelligence, adaptability, and their ability to live in a variety of environments. RAVENS are known for their intelligence, adaptability, and their ability to live in a variety of environments. They are known for their intelligence, adaptability, and their ability to live in a variety of environments.

根據 Bo 的說法,更佳的指令技術詳見 此 Discord 訊息(點選前請務必加入頻道)

| discord_message |

權重轉換

任何使用者都可以透過簡單地執行 transformers 庫中提供的轉換指令碼,將原始 RWKV 權重轉換為 HF 格式。首先,將“原始”權重推送到 Hugging Face Hub(我們將其倉庫命名為 RAW_HUB_REPO,原始檔案命名為 RAW_FILE),然後執行轉換指令碼

python convert_rwkv_checkpoint_to_hf.py --repo_id RAW_HUB_REPO --checkpoint_file RAW_FILE --output_dir OUTPUT_DIR

如果您想將轉換後的模型推送到 Hub(例如,在 dummy_user/converted-rwkv 下),請先不要登入 huggingface-cli login,然後再推送模型,然後執行

python convert_rwkv_checkpoint_to_hf.py --repo_id RAW_HUB_REPO --checkpoint_file RAW_FILE --output_dir OUTPUT_DIR --push_to_hub --model_name dummy_user/converted-rwkv

未來工作

多語言 RWKV

Bo 目前正在研究一個多語言語料庫以訓練 RWKV 模型。最近 釋出了一個新的多語言分詞器

社群導向和研究專案

RWKV 社群非常活躍,並致力於多個後續方向的研究,一個酷炫的專案列表可以在 Discord 上的一個專用頻道中找到(點選連結前請務必加入頻道)。還有一個專門針對此架構研究的頻道,歡迎加入並貢獻!

模型壓縮和加速

由於 RWKV 只需進行矩陣-向量運算,因此它是非標準和實驗性計算硬體(例如光子處理器/加速器)的理想選擇。

因此,該架構也可以自然地受益於經典的加速和壓縮技術(例如 ONNX、4 位/8 位量化等),我們希望這能隨著架構與 Transformer 的整合而普及給開發人員和從業者。

RWKV 在不久的將來也可以受益於 optimum 庫中提出的加速技術。其中一些技術在 rwkv.cpp 倉庫rwkv-cpp-cuda 倉庫中有所強調。

致謝

Hugging Face 團隊衷心感謝 Bo 和 RWKV 社群抽出時間並回答我們關於該架構的問題。我們還要感謝他們的幫助和支援,並期待在 HF 生態系統中看到更多 RWKV 模型的應用。我們還要感謝 Johan Wind 在 RWKV 方面撰寫的博文,這極大地幫助我們理解了該架構及其潛力。最後,我們還要特別感謝 ArEnSc 發起了最初的 transformers PR。此外,還要特別感謝 Merve NoyanMaria KhalusovaPedro Cuenca 慷慨地審閱了這篇博文,使其變得更好!

引用

如果您在工作中使用了 RWKV,請使用 以下 cff 引用

社群

註冊登入 發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.