使用聊天数据微调 Llama3¶
Llama3 Instruct 引入了一个新的提示模板,用于对聊天数据进行微调。在本教程中, 我们将介绍您需要了解的内容,以便您快速开始准备自己的 用于微调 Llama3 Instruct 的自定义聊天数据集。
Llama3 Instruct 格式与 Llama2 有何不同
关于提示模板和特殊令牌的所有信息
如何使用自己的聊天数据集来微调 Llama3 Instruct
熟悉配置数据集
注意
本教程需要 torchtune > 0.1.1 版本
模板从 Llama2 更改为 Llama3¶
Llama2 聊天模型在提示预训练 型。由于聊天模型是使用此提示模板进行预训练的,因此如果要运行 推理,您需要使用相同的模板以获得最佳性能 关于聊天数据。否则,模型将只执行标准文本补全,即 可能符合也可能不一致。
来自 Llama2 官方提示 template 指南中,我们可以看到添加了特殊的标签:
<s>[INST] <<SYS>>
You are a helpful, respectful, and honest assistant.
<</SYS>>
Hi! I am a human. [/INST] Hello there! Nice to meet you! I'm Meta AI, your friendly AI assistant </s>
Llama3 Instruct 对 Llama2 中的模板进行了全面修改,以更好地支持多轮对话。相同的文本 在 Llama3 Instruct 格式中,将如下所示:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful, respectful, and honest assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>
Hi! I am a human.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Hello there! Nice to meet you! I'm Meta AI, your friendly AI assistant<|eot_id|>
这些标签完全不同,它们实际上的编码方式与 美洲驼2.让我们来演练一下使用 Llama2 模板和 llama3 模板来了解如何操作。
注意
Llama3 Base 模型使用与 Llama3 Instruct 不同的提示模板 因为它还没有被 imdirect tuned 并且额外的特殊标记是未经训练的。如果你 在 Llama3 Base 模型上运行推理而不进行微调,我们建议使用 Base 模板以获得最佳性能。通常,对于 instruct 和 chat 数据,我们建议使用 Llama3 使用其提示模板进行指示。本教程的其余部分假定您使用的是 Llama3 指示。
标记提示模板和特殊令牌¶
假设我有一个用户-助手轮次的样本,并附有一个系统 提示:
sample = [
{
"role": "system",
"content": "You are a helpful, respectful, and honest assistant.",
},
{
"role": "user",
"content": "Who are the most influential hip-hop artists of all time?",
},
{
"role": "assistant",
"content": "Here is a list of some of the most influential hip-hop "
"artists of all time: 2Pac, Rakim, N.W.A., Run-D.M.C., and Nas.",
},
]
现在,让我们使用 class 和
了解它是如何被标记化的。Llama2ChatFormat 是提示模板的一个示例,
它只是使用风味文本构建一个提示来指示某个任务。
from torchtune.data import Llama2ChatFormat, Message
messages = [Message.from_dict(msg) for msg in sample]
formatted_messages = Llama2ChatFormat.format(messages)
print(formatted_messages)
# [
# Message(
# role='user',
# content='[INST] <<SYS>>\nYou are a helpful, respectful, and honest assistant.\n<</SYS>>\n\nWho are the most influential hip-hop artists of all time? [/INST] ',
# ...,
# ),
# Message(
# role='assistant',
# content='Here is a list of some of the most influential hip-hop artists of all time: 2Pac, Rakim, N.W.A., Run-D.M.C., and Nas.',
# ...,
# ),
# ]
还有一些 Llama2 使用的特殊令牌,这些令牌不在提示模板中。
如果你看一下我们的类,你会注意到
我们不包括 and 令牌。这些是序列的开头
(BOS) 和序列结束 (EOS) 令牌,它们在分词器中以不同的方式表示
而不是提示模板的其余部分。让我们用 Llama2 使用的 对这个例子进行标记,看看
为什么。
<s>
</s>
from torchtune.models.llama2 import llama2_tokenizer
tokenizer = llama2_tokenizer("/tmp/Llama-2-7b-hf/tokenizer.model")
user_message = formatted_messages[0].content
tokens = tokenizer.encode(user_message, add_bos=True, add_eos=True)
print(tokens)
# [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, ..., 2]
我们在对示例文本进行编码时添加了 BOS 和 EOS 令牌。这会显示出来 作为 ID 1 和 2。我们可以验证这些是我们的 BOS 和 EOS 代币。
print(tokenizer._spm_model.spm_model.piece_to_id("<s>"))
# 1
print(tokenizer._spm_model.spm_model.piece_to_id("</s>"))
# 2
BOS 和 EOS 代币就是我们所说的特殊代币,因为它们有自己的
预留令牌 ID。这意味着它们将在
模型的 learnt embedding 表。其余的 prompt 模板标记,并被标记为普通文本,而不是它们自己的 ID。[INST]
<<SYS>>
print(tokenizer.decode(518))
# '['
print(tokenizer.decode(25580))
# 'INST'
print(tokenizer.decode(29962))
# ']'
print(tokenizer.decode([3532, 14816, 29903, 6778]))
# '<<SYS>>'
请务必注意,您不应将特殊的预留代币放在 input 提示,因为它将被视为普通文本,而不是特殊文本 令 牌。
print(tokenizer.encode("<s>", add_bos=False, add_eos=False))
# [529, 29879, 29958]
现在让我们看一下 Llama3 的格式,看看它是如何以不同的方式进行标记化的 比 Llama2.
from torchtune.models.llama3 import llama3_tokenizer
tokenizer = llama3_tokenizer("/tmp/Meta-Llama-3-8B/original/tokenizer.model")
messages = [Message.from_dict(msg) for msg in sample]
tokens, mask = tokenizer.tokenize_messages(messages)
print(tokenizer.decode(tokens))
# '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful,
# and honest assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho
# are the most influential hip-hop artists of all time?<|eot_id|><|start_header_id|>
# assistant<|end_header_id|>\n\nHere is a list of some of the most influential hip-hop
# artists of all time: 2Pac, Rakim, N.W.A., Run-D.M.C., and Nas.<|eot_id|>'
注意
我们使用了 Llama3 的 API,它与
编码。它只是管理以正确的
places 在对单个消息进行编码之后。tokenize_messages
我们可以看到 tokenizer 处理了所有格式设置,而无需我们指定提示 模板。事实证明,所有额外的标签都是特殊的标记,我们不需要 单独的提示模板。我们可以通过检查标签是否被编码来验证这一点 作为他们自己的令牌 ID 来获取。
print(tokenizer.special_tokens["<|begin_of_text|>"])
# 128000
print(tokenizer.special_tokens["<|eot_id|>"])
# 128009
最好的部分是 - 所有这些特殊令牌都完全由 tokenizer 处理。 这意味着您不必担心弄乱任何必需的提示模板!
何时应使用提示模板?¶
是否使用提示模板取决于您所需的推理 行为是。如果您在 base 模型,并且它是使用提示模板进行预训练的,或者您想要准备一个 微调模型,以期望特定任务的推理具有特定的提示结构。
使用提示模板进行微调并非绝对必要,但通常
特定任务将需要特定模板。例如,它提供了一个轻量级结构,以便为要求汇总文本的提示准备微调模型。
这将环绕用户消息,而助手消息保持不变。
f"Summarize this dialogue:\n{dialogue}\n---\nSummary:\n"
您可以使用此模板微调 Llama2,即使该模型最初是预先训练的
使用 ,只要这是模型
在推理期间看到。该模型应该足够健壮,以适应新模板。
对自定义聊天数据集进行微调¶
让我们尝试使用自定义 chat 数据集。我们将介绍如何设置数据,以便对其进行标记化 正确地输入到我们的模型中。
假设我们有一个保存为 CSV 文件的本地数据集,其中包含问题 以及来自在线论坛的答案。我们怎么能把这样的东西变成一个格式 Llama3 正确理解和标记化?
import pandas as pd
df = pd.read_csv('your_file.csv', nrows=1)
print("Header:", df.columns.tolist())
# ['input', 'output']
print("First row:", df.iloc[0].tolist())
# [
# "How do GPS receivers communicate with satellites?",
# "The first thing to know is the communication is one-way...",
# ]
Llama3 分词器类 ,
期望输入采用
format.让我们
快速编写一个函数,可以将 CSV 文件中的单行解析为
Message 数据类。该函数还需要具有 train_on_input 参数。
def message_converter(sample: Mapping[str, Any], train_on_input: bool) -> List[Message]:
input_msg = sample["input"]
output_msg = sample["output"]
user_message = Message(
role="user",
content=input_msg,
masked=not train_on_input, # Mask if not training on prompt
)
assistant_message = Message(
role="assistant",
content=output_msg,
masked=False,
)
# A single turn conversation
messages = [user_message, assistant_message]
return messages
由于我们正在微调 Llama3,因此分词器将处理
我们。但是,如果我们要微调一个需要模板的模型,例如
Mistral-7B 模型,该模型使用 、
我们需要使用聊天格式,例如
format
所有消息都根据他们的建议。
现在让我们为数据集创建一个 builder 函数,该函数加载到我们的本地文件中。
使用我们的函数转换为 Messages 列表,并创建一个对象。
def custom_dataset(
*,
tokenizer: ModelTokenizer,
max_seq_len: int = 2048, # You can expose this if you want to experiment
) -> ChatDataset:
return ChatDataset(
tokenizer=tokenizer,
# For local csv files, we specify "csv" as the source, just like in
# load_dataset
source="csv",
# Default split of "train" is required for local files
split="train",
convert_to_messages=message_converter,
# Llama3 does not need a chat format
chat_format=None,
max_seq_len=max_seq_len,
# To load a local file we specify it as data_files just like in
# load_dataset
data_files="your_file.csv",
)
注意
你可以将任何关键字参数 for 传入到我们所有的
Dataset 类,它们将遵循这些类。这对于常用参数很有用
例如,使用 或 configure with 指定数据分割load_dataset
split
name
现在我们准备好开始微调了!我们将使用内置的 LoRA 单个设备配方。
使用命令获取配置的副本并更新它以使用您的新数据集。为您的项目创建新文件夹
并确保 dataset builder 和 message converter 保存在该目录中,
然后在 config 中指定它。tune cp
8B_lora_single_device.yaml
dataset:
_component_: path.to.my.custom_dataset
max_seq_len: 2048
启动 fine-tune!
$ tune run lora_finetune_single_device --config custom_8B_lora_single_device.yaml epochs=15