多模态数据集¶
多模态数据集包括多个数据模态,例如文本 + 图像,可用于训练基于 transformer 的模型。 torchtune 目前仅支持视觉语言模型 (VLM) 的多模态文本 + 图像聊天数据集。
在 torchtune 中使用多模式数据集进行微调的主要入口点是构建器。这样,您就可以指定遵循多模式聊天数据格式的本地数据集或 Hugging Face 数据集
直接从配置中,并在其上训练您的 VLM。multimodal_chat_dataset()
多模态数据集示例¶
以下是可视化问答任务的多模式聊天数据集示例。请注意,有一个占位符
在文本中,用于放置图像令牌的位置。在下面的示例中,这将被 image 特殊标记替换。"<image>"
<|image|>
# data/my_data.json
[
{
"dialogue": [
{
"from": "human",
"value": "<image>What time is it on the clock?",
},
{
"from": "gpt",
"value": "It is 10:00 AM.",
},
],
"image_path": "images/clock.jpg",
},
...,
]
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
from torchtune.datasets.multimodal import multimodal_chat_dataset
transform = Llama3VisionTransform(
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
prompt_template="torchtune.data.QuestionAnswerTemplate",
max_seq_len=8192,
image_size=560,
)
ds = multimodal_chat_dataset(
model_transform=model_transform,
source="json",
data_files="data/my_data.json",
column_map={
"dialogue": "conversations",
"image_path": "image",
},
image_dir="/home/user/dataset/", # /home/user/dataset/images/clock.jpg
image_tag="<image>",
split="train",
)
tokenized_dict = ds[0]
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nQuestion:<|image|>What time is it on the clock?Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nIt is 10:00AM.<|eot_id|>'
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])
# In config - model_transforms takes the place of the tokenizer
model_transform:
_component_: torchtune.models.llama3_2_vision_transform
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
prompt_template: torchtune.data.QuestionAnswerTemplate
max_seq_len: 8192
dataset:
_component_: torchtune.datasets.multimodal.multimodal_chat_dataset
source: json
data_files: data/my_data.json
split: train
column_map:
dialogue: conversations
image_path: image
image_dir: /home/user/dataset/
image_tag: "<image>"
split: train
多模态数据集格式¶
目前,多模态数据集应遵循 “sharegpt” 聊天格式,其中图像路径位于一列中 用户与助手对话位于另一列中。
| conversations | image |
|------------------------------------|--------------|
| [{"from": "human", "value": "Q1"}, | images/1.jpg |
| {"from": "gpt", "value": "A1"}] | |
例如,您可以看到 ShareGPT4V 数据集的架构。
目前,仅支持每个对话样本一个图像路径。multimodal_chat_dataset()
从 Hugging Face 加载多模态数据集¶
您只需将数据集 repo 名称传递给 ,然后将其传递到 Hugging Face 的 .
对于大多数数据集,您还需要通过 指定 and/或 subset 。source
load_dataset
split
name
# In code
from torchtune.models.llama3_2_vision import llama3_2_vision_transform
from torchtune.datasets.multimodal import multimodal_chat_dataset
transform = llama3_2_vision_transform(
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
max_seq_len=8192,
image_size=560,
)
ds = multimodal_chat_dataset(
model_transform=model_transform,
source="Lin-Chen/ShareGPT4V",
split="train",
name="ShareGPT4V",
image_dir="/home/user/dataset/",
image_tag="<image>",
)
# In config
model_transform:
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
max_seq_len: 8192
image_size: 560
# Tokenizer is passed into the dataset in the recipe
dataset:
_component_: torchtune.datasets.multimodal.multimodal_chat_dataset
source: Lin-Chen/ShareGPT4V
split: train
name: ShareGPT4V
image_dir: /home/user/dataset/
image_tag: "<image>"
这将使用默认列名称 “conversations” 和 “image”。要更改列名称,请使用 argument(请参阅重命名列)。column_map
加载本地和远程多模态数据集¶
要通过遵循 instruct 格式的 https 加载本地或远程数据集,您需要指定 、 和 参数。有关加载本地或远程文件的更多详细信息,请参阅 Hugging Face 的文档。请参阅上面的示例多模态数据集。source
data_files
split
load_dataset
加载图像¶
在许多情况下,您的数据集将包含图像的路径,而不是原始图像本身。 将自动为您处理此问题,但如果您正在为自定义多模态数据集编写自定义消息转换
(请参阅自定义消息转换)中,您可以直接使用该实用程序。
multimodal_chat_dataset()
from torchtune.data import load_image
from pathlib import Path
sample = {
"conversations": [
{
"from": "human",
"value": "What time is it on the clock?",
},
{
"from": "gpt",
"value": "It is 10:00 AM.",
},
],
"image": "images/clock.jpg",
}
image_dir = "/home/user/dataset/"
pil_image = load_image(Path(image_dir) / Path(sample["image"]))
print(pil_image)
# <PIL.Image.Image>
然后,您可以将 PIL 图像直接添加到相关消息的内容中。仅支持将 PIL 图像作为图像内容
in 中,而不是图像路径或 URL。
from torchtune.data import Message
user_message = None
for msg in sample["conversations"]:
if msg["from"] == "human":
user_message = Message(
role="user",
content=[
{"type": "image", "content": pil_image},
{"type": "text", "content": msg["value"]},
]
)
print(user_message.contains_media)
# True
print(user_message.get_media())
# [<PIL.Image.Image>]
print(user_message.text_content)
# What time is it on the clock?
如果数据集中的图像路径是相对路径,则可以使用参数 in 将图像下载到本地的完整路径前面。image_dir
multimodal_chat_dataset()
在文本中交错图像¶
Torchtune 支持在文本中的任何位置添加多个图像,只要您的模型支持即可。
import PIL
from torchtune.data import Message
image_dog = PIL.Image.new(mode="RGB", size=(4, 4))
image_cat = PIL.Image.new(mode="RGB", size=(4, 4))
image_bird = PIL.Image.new(mode="RGB", size=(4, 4))
user_message = Message(
role="user",
content=[
{"type": "image", "content": image_dog},
{"type": "text", "content": "This is an image of a dog. "},
{"type": "image", "content": image_cat},
{"type": "text", "content": "This is an image of a cat. "},
{"type": "image", "content": image_bird},
{"type": "text", "content": "This is a bird, the best pet of the three."},
]
)
print(user_message.contains_media)
# True
print(user_message.get_media())
# [<PIL.Image.Image>, <PIL.Image.Image>, <PIL.Image.Image>]
print(user_message.text_content)
# This is an image of a dog. This is an image of a cat. This is a bird, the best pet of the three.
您的数据集可能包含图像占位符标记,这些标记指示应在文本中引用图像的位置
例如,请参阅 ShareGPT4V <https://huggingface.co/datasets/Lin-Chen/ShareGPT4V>,它使用 .
您可以使用实用程序 , 轻松创建与上述类似的交错消息内容 ,
,它将 image placeholder 标签替换为传入的图像。
"<image>"
import PIL
from torchtune.data import Message, format_content_with_images
image_dog = PIL.Image.new(mode="RGB", size=(4, 4))
image_cat = PIL.Image.new(mode="RGB", size=(4, 4))
image_bird = PIL.Image.new(mode="RGB", size=(4, 4))
text = "[img]This is an image of a dog. [img]This is an image of a cat. [img]This is a bird, the best pet of the three."
user_message = Message(
role="user",
content=format_content_with_images(
content=text,
image_tag="[img]",
images=[image_dog, image_cat, image_bird],
),
)
print(user_message.contains_media)
# True
print(user_message.get_media())
# [<PIL.Image.Image>,<PIL.Image.Image>, <PIL.Image.Image>]
print(user_message.text_content)
# This is an image of a dog. This is an image of a cat. This is a bird, the best pet of the three.
当您传入 时,会自动为您处理此问题。multimodal_chat_dataset()
image_tag