注意
转到末尾下载完整的示例代码。
将 TensorDict 用于数据集¶
在本教程中,我们将演示如何用于
在训练中高效、透明地加载和管理数据
管道。本教程在很大程度上基于 PyTorch 快速入门
教程、
但进行了修改以演示 .TensorDict
TensorDict
import torch
import torch.nn as nn
from tensordict import MemoryMappedTensor, TensorDict
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu
该模块包含许多方便的预制
数据。在本教程中,我们将使用相对简单的 FashionMNIST 数据集。每
image 是一件衣服,目的是将服装类型分类为
图片(例如“包”、“运动鞋”等)。torchvision.datasets
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
我们将创建两个 tensordict,分别用于训练和测试数据。我们创造 memory-mapped tensor 来保存数据。这将使我们能够有效地加载 从磁盘批量转换数据,而不是重复加载和转换 单个图像。
training_data_td = TensorDict(
{
"images": MemoryMappedTensor.empty(
(len(training_data), *training_data[0][0].squeeze().shape),
dtype=torch.float32,
),
"targets": MemoryMappedTensor.empty((len(training_data),), dtype=torch.int64),
},
batch_size=[len(training_data)],
device=device,
)
test_data_td = TensorDict(
{
"images": MemoryMappedTensor.empty(
(len(test_data), *test_data[0][0].squeeze().shape), dtype=torch.float32
),
"targets": MemoryMappedTensor.empty((len(test_data),), dtype=torch.int64),
},
batch_size=[len(test_data)],
device=device,
)
然后,我们可以迭代数据以填充内存映射的张量。这需要 一些时间,但预先执行转换将节省 稍后培训。
数据加载器¶
我们将从 -提供的 Dataset 以及
我们的内存映射 TensorDict 中。torchvision
由于 implements 和 (以及 ) 我们可以像使用地图风格的 Dataset 一样使用它,并直接从它创建一个。注意,因为已经可以处理批量索引了,
不需要排序规则,因此我们将 identity 函数作为 .TensorDict
__len__
__getitem__
__getitems__
DataLoader
TensorDict
collate_fn
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size) # noqa: TOR401
test_dataloader = DataLoader(test_data, batch_size=batch_size) # noqa: TOR401
train_dataloader_td = DataLoader( # noqa: TOR401
training_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
test_dataloader_td = DataLoader( # noqa: TOR401
test_data_td, batch_size=batch_size, collate_fn=lambda x: x
)
型¶
我们使用快速入门教程中的相同模型。
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = Net().to(device)
model_td = Net().to(device)
model, model_td
(Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
), Net(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
))
优化参数¶
我们将使用随机梯度下降和 交叉熵损失。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_td = torch.optim.SGD(model_td.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
我们基于 DataLoader 的训练循环非常相似,我们只是
调整我们解压缩数据的方式,以便 .该方法加载存储在 memmap 张量中的数据。TensorDict
TensorDict
.contiguous()
def train_td(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, data in enumerate(dataloader):
X, y = data["images"].contiguous(), data["targets"].contiguous()
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def test_td(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for batch in dataloader:
X, y = batch["images"].contiguous(), batch["targets"].contiguous()
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
for d in train_dataloader_td:
print(d)
break
import time
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train_td(train_dataloader_td, model_td, loss_fn, optimizer_td)
test_td(test_dataloader_td, model_td, loss_fn)
print(f"TensorDict training done! time: {time.time() - t0: 4.4f} s")
t0 = time.time()
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print(f"Training done! time: {time.time() - t0: 4.4f} s")
TensorDict(
fields={
images: Tensor(shape=torch.Size([64, 28, 28]), device=cpu, dtype=torch.float32, is_shared=False),
targets: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([64]),
device=cpu,
is_shared=False)
Epoch 1
-------------------------
loss: 2.293629 [ 0/60000]
loss: 2.283033 [ 6400/60000]
loss: 2.256484 [12800/60000]
loss: 2.256073 [19200/60000]
loss: 2.243077 [25600/60000]
loss: 2.200774 [32000/60000]
loss: 2.217680 [38400/60000]
loss: 2.179253 [44800/60000]
loss: 2.169554 [51200/60000]
loss: 2.139296 [57600/60000]
Test Error:
Accuracy: 44.9%, Avg loss: 2.136127
Epoch 2
-------------------------
loss: 2.144999 [ 0/60000]
loss: 2.140007 [ 6400/60000]
loss: 2.071116 [12800/60000]
loss: 2.092876 [19200/60000]
loss: 2.046267 [25600/60000]
loss: 1.972650 [32000/60000]
loss: 2.009594 [38400/60000]
loss: 1.925318 [44800/60000]
loss: 1.919691 [51200/60000]
loss: 1.849112 [57600/60000]
Test Error:
Accuracy: 59.9%, Avg loss: 1.853141
Epoch 3
-------------------------
loss: 1.884452 [ 0/60000]
loss: 1.857988 [ 6400/60000]
loss: 1.732184 [12800/60000]
loss: 1.780074 [19200/60000]
loss: 1.670048 [25600/60000]
loss: 1.618595 [32000/60000]
loss: 1.644744 [38400/60000]
loss: 1.546432 [44800/60000]
loss: 1.562902 [51200/60000]
loss: 1.465493 [57600/60000]
Test Error:
Accuracy: 61.8%, Avg loss: 1.488911
Epoch 4
-------------------------
loss: 1.553364 [ 0/60000]
loss: 1.522695 [ 6400/60000]
loss: 1.372123 [12800/60000]
loss: 1.449625 [19200/60000]
loss: 1.332421 [25600/60000]
loss: 1.329863 [32000/60000]
loss: 1.347445 [38400/60000]
loss: 1.271675 [44800/60000]
loss: 1.303136 [51200/60000]
loss: 1.211887 [57600/60000]
Test Error:
Accuracy: 63.5%, Avg loss: 1.239002
Epoch 5
-------------------------
loss: 1.314920 [ 0/60000]
loss: 1.296802 [ 6400/60000]
loss: 1.132275 [12800/60000]
loss: 1.240491 [19200/60000]
loss: 1.118523 [25600/60000]
loss: 1.144619 [32000/60000]
loss: 1.168913 [38400/60000]
loss: 1.102033 [44800/60000]
loss: 1.141350 [51200/60000]
loss: 1.064545 [57600/60000]
Test Error:
Accuracy: 64.8%, Avg loss: 1.083915
TensorDict training done! time: 8.6170 s
Epoch 1
-------------------------
loss: 2.290858 [ 0/60000]
loss: 2.286275 [ 6400/60000]
loss: 2.267298 [12800/60000]
loss: 2.265695 [19200/60000]
loss: 2.254766 [25600/60000]
loss: 2.215138 [32000/60000]
loss: 2.229925 [38400/60000]
loss: 2.191596 [44800/60000]
loss: 2.188072 [51200/60000]
loss: 2.166435 [57600/60000]
Test Error:
Accuracy: 43.9%, Avg loss: 2.155688
Epoch 2
-------------------------
loss: 2.158259 [ 0/60000]
loss: 2.155049 [ 6400/60000]
loss: 2.093325 [12800/60000]
loss: 2.112539 [19200/60000]
loss: 2.072581 [25600/60000]
loss: 1.997474 [32000/60000]
loss: 2.036950 [38400/60000]
loss: 1.950100 [44800/60000]
loss: 1.957567 [51200/60000]
loss: 1.895268 [57600/60000]
Test Error:
Accuracy: 55.7%, Avg loss: 1.885085
Epoch 3
-------------------------
loss: 1.911978 [ 0/60000]
loss: 1.891730 [ 6400/60000]
loss: 1.763546 [12800/60000]
loss: 1.807613 [19200/60000]
loss: 1.707991 [25600/60000]
loss: 1.642813 [32000/60000]
loss: 1.683022 [38400/60000]
loss: 1.573177 [44800/60000]
loss: 1.606999 [51200/60000]
loss: 1.511539 [57600/60000]
Test Error:
Accuracy: 61.3%, Avg loss: 1.516370
Epoch 4
-------------------------
loss: 1.578484 [ 0/60000]
loss: 1.552396 [ 6400/60000]
loss: 1.389732 [12800/60000]
loss: 1.469645 [19200/60000]
loss: 1.355540 [25600/60000]
loss: 1.340846 [32000/60000]
loss: 1.369824 [38400/60000]
loss: 1.286581 [44800/60000]
loss: 1.330394 [51200/60000]
loss: 1.239374 [57600/60000]
Test Error:
Accuracy: 63.4%, Avg loss: 1.253140
Epoch 5
-------------------------
loss: 1.325610 [ 0/60000]
loss: 1.314812 [ 6400/60000]
loss: 1.139354 [12800/60000]
loss: 1.251529 [19200/60000]
loss: 1.128213 [25600/60000]
loss: 1.147657 [32000/60000]
loss: 1.176389 [38400/60000]
loss: 1.109371 [44800/60000]
loss: 1.156932 [51200/60000]
loss: 1.078585 [57600/60000]
Test Error:
Accuracy: 64.4%, Avg loss: 1.089056
Training done! time: 34.5017 s
脚本总运行时间:(0 分 55.621 秒)