目录

torchtune.training

检查点

Torchtune 提供检查点程序,允许在检查点格式之间无缝转换,以便进行训练并与生态系统的其他部分进行互操作性。有关 checkpointing,请参阅 checkpointing 深入探讨

FullModelHFCheckpointer

Checkpointer 读取和写入 HF 格式的 checkpoint。

FullModelMetaCheck指针

Checkpointer 的 Checkpointer,它以 Meta 的格式读取和写入 checkpoint。

FullModelTorchTuneCheck指针

Checkpointer,它以与 torchtune 兼容的格式读取和写入检查点。

ModelType

ModelType 被 checkpointer 用来区分不同的模型架构。

格式化检查点文件

此类提供了一种更简洁的方式来表示格式为 .file_{i}_of_{n_files}.pth

update_state_dict_for_classifier

验证分类器模型的检查点加载的状态字典。

精度降低

用于在降低精度设置下工作的实用程序。

get_dtype

获取与给定精度字符串对应的 torch.dtype。

set_default_dtype

上下文管理器来设置 torch 的默认 dtype。

validate_expected_param_dtype

验证所有输入参数是否都具有预期的 dtype。

get_quantizer_mode

给定一个量化器对象,返回一个指定量化类型的字符串。

分散式

用于启用和使用分布式训练的实用程序。

init_distributed

初始化 所需的进程组。torch.distributed

is_distributed

检查是否设置了初始化 torch.distributed 所需的所有环境变量,并且是否正确安装了 distributed。

get_world_size_and_rank

获取默认进程组中当前进程的当前世界大小(又名总排名数)和排名编号的函数。

gather_cpu_state_dict

在 CPU 上将分片状态 dict 转换为完整状态 dict仅在 rank0 上返回非空结果以避免 CPU 内存峰值

内存管理

用于减少训练期间内存消耗的实用程序。

apply_selective_activation_checkpointing

用于设置激活 checkpointing 并包装模型以进行 checkpointing 的实用程序。

set_activation_checkpointing

将激活检查点应用于传入模型的实用程序。

OptimizerInBackwardWrapper 中

一个基本类,用于向后运行的优化器的 checkpoint save 和 load。

create_optim_in_bwd_wrapper

为向后运行的优化器步骤创建包装器。

register_optim_in_bwd_hooks

为向后运行的优化器步骤注册钩子。

调度程序

在训练过程中控制 lr 的实用程序。

get_cosine_schedule_with_warmup

创建一个学习率计划,该计划将学习率从 0.0 线性增加到 lr ,然后在余弦计划中减少到 0.0(假设 = 0.5)。num_warmup_stepsnum_training_steps-num_warmup_stepsnum_cycles

get_lr

Full_finetune_distributed 和 full_finetune_single_device 假设所有优化器具有相同的 LR,以验证所有 LR 是否相同,如果返回 True。

指标日志记录

各种日志记录实用程序。

metric_logging。彗星记录器

与 Comet (https://www.comet.com/site/) 一起使用的记录器。

metric_logging。WandBLogger

用于 Weights and Biases 应用程序 (https://wandb.ai/) 的记录器。

metric_logging。张量板记录器

与 PyTorch 的 TensorBoard 实现 (https://pytorch.org/docs/stable/tensorboard.html) 一起使用的记录器。

metric_logging。标准输出记录器

记录器到标准输出。

metric_logging。磁盘记录器

记录器到磁盘。

性能和分析

Torchtune 提供了用于分析和调试内存和性能的实用程序 的微调工作。

get_memory_stats

计算传入设备的内存摘要。

log_memory_stats

将包含内存统计信息的 dict 记录到 Logger 中。

setup_torch_profiler

设置并返回包含设置后更新的分析器配置。

杂项

get_unmasked_sequence_lengths

返回每个 batch 元素的序列长度,不包括掩码标记。

set_seed

为常用库中的伪随机数生成器设置种子的函数。

文档

访问 PyTorch 的全面开发人员文档

查看文档

教程

获取面向初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源