目录

torchtune.utils

Checkpointing

torchtune 提供了检查点程序,允许在训练和与其他生态系统组件互操作之间无缝转换检查点格式。有关检查点的全面概述,请参阅 检查点深度解析

FullModelHFCheckpointer

在HF格式中读取和写入检查点的检查点程序。

FullModelMetaCheckpointer

检查点保存器,用于以 Meta 的格式读取和写入检查点文件。

分布式

用于启用分布式训练并支持其工作的工具。

init_distributed

初始化 torch.distributed。

get_world_size_and_rank

获取当前世界大小(即总排名数)和当前训练器的排名号的函数。

降低精度

用于在低精度环境下工作的工具。

get_dtype

获取与给定精度字符串对应的 torch.dtype。

list_dtypes

返回用于微调的支持数据类型列表。

内存管理

在训练期间减少内存消耗的实用工具。

set_activation_checkpointing

用于设置激活检查点并将模型包装以进行检查点的工具。

性能与分析

TorchTune 提供了用于分析和调试微调任务性能的工具。

profiler

用于包装torch.profiler以对模型的操作符进行剖析的实用组件。

指标日志记录

各种日志工具。

metric_logging.WandBLogger

用于Weights and Biases应用程序的记录器 (https://wandb.ai/)。

metric_logging.TensorBoardLogger

用于 PyTorch 实现的 TensorBoard 的日志记录器 (https://pytorch.org/docs/stable/tensorboard.html)。

metric_logging.StdoutLogger

记录到标准输出。

metric_logging.DiskLogger

记录到磁盘。

数据

用于处理数据和数据集的工具。

padded_collate

将批次中的序列填充至该批次中最长序列的长度,并将整数列表转换为张量。

其他

TuneRecipeArgumentParser

一个有用的工具子类,它为argparse.ArgumentParser添加了一个内置参数“config”。

get_logger

获取一个带有流处理器的日志记录器。

get_device

接受设备或设备字符串的函数,验证其在给定机器和分布式设置下是否正确且可用,并返回一个 torch.device。

set_seed

设置常用库中伪随机数生成器种子的函数。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源