AWS Trainium & Inferentia 文件

在 AWS Inferentia 上使用 llama-2-13B 建立自己的聊天機器人

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

在 AWS Inferentia 上使用 llama-2-13B 建立自己的聊天機器人

此教程的筆記本版本請點選此處.

本指南將詳細介紹如何在 AWS Inferentia 上匯出、部署和執行 LLama-2 13B 聊天模型。

您將學習如何

  • 將 Llama-2 模型匯出為 Neuron 格式,
  • 將匯出的模型推送到 Hugging Face Hub,
  • 部署模型並在聊天應用程式中使用它。

注意:本教程是在 inf2.48xlarge AWS EC2 例項上建立的。

1. 將 Llama 2 模型匯出到 Neuron

在本指南中,我們將使用非門控的 NousResearch/Llama-2-13b-chat-hf 模型,它在功能上等同於原始的 meta-llama/Llama-2-13b-chat-hf

該模型是 Llama 2 模型家族的一部分,並已針對識別使用者助手之間的聊天互動進行了調整(稍後會詳細介紹)。

正如 optimum-neuron 文件 中所述,模型需要先編譯並匯出為序列化格式,然後才能在 Neuron 裝置上執行。

匯出模型時,我們將指定兩組引數

  • 使用 compiler_args,我們指定模型要部署在多少個核心上(每個神經元裝置有兩個核心),以及精度(此處為 float16),
  • 使用 input_shapes,我們設定模型的靜態輸入和輸出維度。所有模型編譯器都需要靜態形狀,Neuron 也不例外。請注意,sequence_length 不僅限制了輸入上下文的長度,還限制了鍵/值快取的長度,因此也限制了輸出長度。

根據您選擇的引數和 Inferentia 主機,這可能需要幾分鐘到一小時以上。

為了您的方便,我們在 Hugging Face hub 上託管了該模型的預編譯版本,因此您可以跳過匯出,直接在第 2 節開始使用該模型。

from optimum.neuron import NeuronModelForCausalLM

compiler_args = {"num_cores": 24, "auto_cast_type": 'fp16'}
input_shapes = {"batch_size": 1, "sequence_length": 2048}
model = NeuronModelForCausalLM.from_pretrained(
        "NousResearch/Llama-2-13b-chat-hf",
        export=True,
        **compiler_args,
        **input_shapes)

這可能需要一段時間。

幸運的是,您只需執行一次此操作,因為您可以儲存模型並在以後重新載入它。

model.save_pretrained("llama-2-13b-chat-neuron")

更棒的是,您可以將其推送到 Hugging Face hub

為此,您需要登入 HuggingFace 帳戶

在終端中,只需輸入以下命令並在請求時貼上您的 Hugging Face 令牌

huggingface-cli login

預設情況下,模型將上傳到您的帳戶(組織等於您的使用者名稱)。

如果您想將模型上傳到特定的 Hugging Face 組織,請隨意編輯以下程式碼。

from huggingface_hub import whoami

org = whoami()['name']

repo_id = f"{org}/llama-2-13b-chat-neuron"

model.push_to_hub("llama-2-13b-chat-neuron", repository_id=repo_id)

關於匯出引數的更多說明。

載入模型所需的最小記憶體可透過以下方式計算

   memory = bytes per parameter * number of parameters

Llama 2 13B 模型使用 float16 權重(儲存為 2 位元組)並具有 130 億個引數,這意味著它需要至少 2 * 13B 或約 26GB 記憶體來儲存其權重。

每個 NeuronCore 有 16GB 記憶體,這意味著 26GB 模型無法容納在單個 NeuronCore 上。

實際上,所需的總空間遠大於引數數量,這是由於快取注意力層投影(KV 快取)造成的。這種快取機制會使記憶體分配隨序列長度和批次大小線性增長。

在這裡,我們將 batch_size 設定為 1,這意味著我們只能並行處理一個輸入提示。我們將 sequence_length 設定為 2048,這相當於模型最大容量(4096)的一半。

評估 KV 快取大小的公式更為複雜,因為它還取決於與模型架構相關的引數,例如嵌入的寬度和解碼器塊的數量。

底線是,為了適應非常大的語言模型,會使用張量並行性將權重、資料和計算拆分到多個 NeuronCore 上,同時記住每個核心上的記憶體不能超過 16GB。

請注意,將核心數量增加到最低要求以上幾乎總是能使模型更快。增加張量並行度可以改善記憶體頻寬,從而提高模型效能。

為了最佳化效能,建議使用例項上所有可用的核心。

在本指南中,我們使用了 inf2.48xlarge 的所有 24 個核心,但如果您使用的是 inf2.24xlarge 例項,則應將其更改為 12。

2. 在 AWS Inferentia2 上使用 Llama 2 生成文字

模型匯出後,您可以使用 transformers 庫生成文字,如此帖子中詳細介紹

如果如建議您跳過了第一節,請不要擔心:我們將使用集線器上已有的預編譯模型。

from optimum.neuron import NeuronModelForCausalLM

try:
    model
except NameError:
    # Edit this to use another base model
    model = NeuronModelForCausalLM.from_pretrained('aws-neuron/Llama-2-13b-chat-hf-neuron-latency')

我們將需要一個 Llama 2 分詞器將提示字串轉換為文字標記。

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-chat-hf")

支援以下生成策略

  • 貪婪搜尋,
  • 帶有 top-k 和 top-p(帶溫度)的多項式取樣。

大多數 logits 預處理/過濾器(例如重複懲罰)都受支援。

inputs = tokenizer("What is deep-learning ?", return_tensors="pt")
outputs = model.generate(**inputs,
                         max_new_tokens=128,
                         do_sample=True,
                         temperature=0.9,
                         top_k=50,
                         top_p=0.9)
tokenizer.batch_decode(outputs, skip_special_tokens=True)

3. 在 AWS Inferentia2 上使用 llama 建立聊天應用程式

我們特意選擇了一個 Llama 2 聊天變體,以說明匯出模型在編碼上下文長度增長時表現出的出色行為。

模型要求提示遵循特定模板,該模板對應於使用者角色和助手角色之間的互動。

每個聊天模型都有其自己的編碼此類內容的約定,我們不會在本指南中詳細介紹,因為我們將直接使用與我們模型對應的Hugging Face 聊天模板

下面的實用函式將使用者和模型之間的一系列對話轉換為格式正確的聊天提示。

def format_chat_prompt(message, history, max_tokens):
    """ Convert a history of messages to a chat prompt


    Args:
        message(str): the new user message.
        history (List[str]): the list of user messages and assistant responses.
        max_tokens (int): the maximum number of input tokens accepted by the model.

    Returns:
        a `str` prompt.
    """
    chat = []
    # Convert all messages in history to chat interactions
    for interaction in history:
        chat.append({"role": "user", "content" : interaction[0]})
        chat.append({"role": "assistant", "content" : interaction[1]})
    # Add the new message
    chat.append({"role": "user", "content" : message})
    # Generate the prompt, verifying that we don't go beyond the maximum number of tokens
    for i in range(0, len(chat), 2):
        # Generate candidate prompt with the last n-i entries
        prompt = tokenizer.apply_chat_template(chat[i:], tokenize=False)
        # Tokenize to check if we're over the limit
        tokens = tokenizer(prompt)
        if len(tokens.input_ids) <= max_tokens:
            # We're good, stop here
            return prompt
    # We shall never reach this line
    raise SystemError

我們現在可以構建一個簡單的聊天應用程式了。

我們只需將使用者和助手之間的互動儲存在一個列表中,然後用它來生成輸入提示。

history = []
max_tokens = 1024

def chat(message, history, max_tokens):
    prompt = format_chat_prompt(message, history, max_tokens)
    # Uncomment the line below to see what the formatted prompt looks like
    #print(prompt)
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs,
                             max_length=2048,
                             do_sample=True,
                             temperature=0.9,
                             top_k=50,
                             repetition_penalty=1.2)
    # Do not include the input tokens
    outputs = outputs[0, inputs.input_ids.size(-1):]
    response = tokenizer.decode(outputs, skip_special_tokens=True)
    history.append([message, response])
    return response

要測試聊天應用程式,您可以例如使用以下提示序列

print(chat("My favorite color is blue. My favorite fruit is strawberry.", history, max_tokens))
print(chat("Name a fruit that is on my favorite colour.", history, max_tokens))
print(chat("What is the colour of my favorite fruit ?", history, max_tokens))
<警告>

雖然非常強大,但大型語言模型有時會“幻覺”。我們稱“幻覺”為生成的無關或編造的內容,但模型將其呈現為準確的。這是 LLM 的一個缺陷,並非在 Trainium/Inferentia 上使用它們的副作用。

</警告>

© . This site is unofficial and not affiliated with Hugging Face, Inc.