AWS SageMaker¶
- class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerScheduler(session_name: str, client: Optional[Any] = None, docker_client: Optional[DockerClient] = None)[source]¶
基础:
DockerWorkspaceMixin,Scheduler[AWSSageMakerOpts]AWSSageMakerScheduler 是一个连接到 AWS SageMaker 的 TorchX 调度接口。
$ torchx run -s aws_sagemaker utils.echo --image alpine:latest --msg hello aws_batch://torchx_user/1234 $ torchx status aws_batch://torchx_user/1234 ...
身份验证通过环境使用
boto3凭证进行加载。配置选项
usage: role=ROLE,instance_type=INSTANCE_TYPE,[instance_count=INSTANCE_COUNT],[user=USER],[keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS],[volume_size=VOLUME_SIZE],[volume_kms_key=VOLUME_KMS_KEY],[max_run=MAX_RUN],[input_mode=INPUT_MODE],[output_path=OUTPUT_PATH],[output_kms_key=OUTPUT_KMS_KEY],[base_job_name=BASE_JOB_NAME],[tags=TAGS],[subnets=SUBNETS],[security_group_ids=SECURITY_GROUP_IDS],[model_uri=MODEL_URI],[model_channel_name=MODEL_CHANNEL_NAME],[metric_definitions=METRIC_DEFINITIONS],[encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC],[use_spot_instances=USE_SPOT_INSTANCES],[max_wait=MAX_WAIT],[checkpoint_s3_uri=CHECKPOINT_S3_URI],[checkpoint_local_path=CHECKPOINT_LOCAL_PATH],[debugger_hook_config=DEBUGGER_HOOK_CONFIG],[enable_sagemaker_metrics=ENABLE_SAGEMAKER_METRICS],[enable_network_isolation=ENABLE_NETWORK_ISOLATION],[disable_profiler=DISABLE_PROFILER],[environment=ENVIRONMENT],[max_retry_attempts=MAX_RETRY_ATTEMPTS],[source_dir=SOURCE_DIR],[git_config=GIT_CONFIG],[hyperparameters=HYPERPARAMETERS],[container_log_level=CONTAINER_LOG_LEVEL],[code_location=CODE_LOCATION],[dependencies=DEPENDENCIES],[training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE],[training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN],[disable_output_compression=DISABLE_OUTPUT_COMPRESSION],[enable_infra_check=ENABLE_INFRA_CHECK],[image_repo=IMAGE_REPO],[quiet=QUIET] required arguments: role=ROLE (str) an AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs that create Amazon SageMaker endpoints use this role to access training data and model artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. instance_type=INSTANCE_TYPE (str) type of EC2 instance to use for training, for example, 'ml.c4.xlarge' optional arguments: instance_count=INSTANCE_COUNT (int, 1) number of Amazon EC2 instances to use for training. Required if instance_groups is not set. user=USER (str, ec2-user) the username to tag the job with. `getpass.getuser()` if not specified. keep_alive_period_in_seconds=KEEP_ALIVE_PERIOD_IN_SECONDS (int, None) the duration of time in seconds to retain configured resources in a warm pool for subsequent training jobs. volume_size=VOLUME_SIZE (int, None) size in GB of the storage volume to use for storing input and output data during training (default: 30). volume_kms_key=VOLUME_KMS_KEY (str, None) KMS key ID for encrypting EBS volume attached to the training instance. max_run=MAX_RUN (int, None) timeout in seconds for training (default: 24 * 60 * 60). input_mode=INPUT_MODE (str, None) the input mode that the algorithm supports (default: ‘File’). output_path=OUTPUT_PATH (str, None) S3 location for saving the training result (model artifacts and output files). If not specified, results are stored to a default bucket. If the bucket with the specific name does not exist, the estimator creates the bucket during the fit() method execution. output_kms_key=OUTPUT_KMS_KEY (str, None) KMS key ID for encrypting the training output (default: Your IAM role’s KMS key for Amazon S3). base_job_name=BASE_JOB_NAME (str, None) prefix for training job name when the fit() method launches. If not specified, the estimator generates a default job name based on the training image name and current timestamp. tags=TAGS (typing.List[typing.Dict[str, str]], None) list of tags for labeling a training job. subnets=SUBNETS (typing.List[str], None) list of subnet ids. If not specified training job will be created without VPC config. security_group_ids=SECURITY_GROUP_IDS (typing.List[str], None) list of security group ids. If not specified training job will be created without VPC config. model_uri=MODEL_URI (str, None) URI where a pre-trained model is stored, either locally or in S3. model_channel_name=MODEL_CHANNEL_NAME (str, None) name of the channel where ‘model_uri’ will be downloaded (default: ‘model’). metric_definitions=METRIC_DEFINITIONS (typing.List[typing.Dict[str, str]], None) list of dictionaries that defines the metric(s) used to evaluate the training jobs. Each dictionary contains two keys: ‘Name’ for the name of the metric, and ‘Regex’ for the regular expression used to extract the metric from the logs. encrypt_inter_container_traffic=ENCRYPT_INTER_CONTAINER_TRAFFIC (bool, None) specifies whether traffic between training containers is encrypted for the training job (default: False). use_spot_instances=USE_SPOT_INSTANCES (bool, None) specifies whether to use SageMaker Managed Spot instances for training. If enabled then the max_wait arg should also be set. max_wait=MAX_WAIT (int, None) timeout in seconds waiting for spot training job. checkpoint_s3_uri=CHECKPOINT_S3_URI (str, None) S3 URI in which to persist checkpoints that the algorithm persists (if any) during training. checkpoint_local_path=CHECKPOINT_LOCAL_PATH (str, None) local path that the algorithm writes its checkpoints to. debugger_hook_config=DEBUGGER_HOOK_CONFIG (bool, None) configuration for how debugging information is emitted with SageMaker Debugger. If not specified, a default one is created using the estimator’s output_path, unless the region does not support SageMaker Debugger. To disable SageMaker Debugger, set this parameter to False. enable_sagemaker_metrics=ENABLE_SAGEMAKER_METRICS (bool, None) enable SageMaker Metrics Time Series. enable_network_isolation=ENABLE_NETWORK_ISOLATION (bool, None) specifies whether container will run in network isolation mode (default: False). disable_profiler=DISABLE_PROFILER (bool, None) specifies whether Debugger monitoring and profiling will be disabled (default: False). environment=ENVIRONMENT (typing.Dict[str, str], None) environment variables to be set for use during training job max_retry_attempts=MAX_RETRY_ATTEMPTS (int, None) number of times to move a job to the STARTING status. You can specify between 1 and 30 attempts. source_dir=SOURCE_DIR (str, None) absolute, relative, or S3 URI Path to a directory with any other training source code dependencies aside from the entry point file (default: current working directory) git_config=GIT_CONFIG (typing.Dict[str, str], None) git configurations used for cloning files, including repo, branch, commit, 2FA_enabled, username, password, and token. hyperparameters=HYPERPARAMETERS (typing.Dict[str, str], None) dictionary containing the hyperparameters to initialize this estimator with. container_log_level=CONTAINER_LOG_LEVEL (int, None) log level to use within the container (default: logging.INFO). code_location=CODE_LOCATION (str, None) S3 prefix URI where custom code is uploaded. dependencies=DEPENDENCIES (typing.List[str], None) list of absolute or relative paths to directories with any additional libraries that should be exported to the container. training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE (str, None) specifies how SageMaker accesses the Docker image that contains the training algorithm. training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN (str, None) Amazon Resource Name (ARN) of an AWS Lambda function that provides credentials to authenticate to the private Docker registry where your training image is hosted. disable_output_compression=DISABLE_OUTPUT_COMPRESSION (bool, None) when set to true, Model is uploaded to Amazon S3 without compression after training finishes. enable_infra_check=ENABLE_INFRA_CHECK (bool, None) specifies whether it is running Sagemaker built-in infra check jobs. image_repo=IMAGE_REPO (str, None) (remote jobs) the image repository to use when pushing patched images, must have push access. Ex: example.com/your/container quiet=QUIET (bool, False) whether to suppress verbose output for image building. Defaults to ``False``.兼容性
功能
调度程序支持
获取日志
❌
分布式作业
✔️
取消任务
✔️
描述工作
部分支持。SageMakerScheduler 将返回作业和副本状态,但不提供完整的原始 AppSpec。
工作区 / 补丁修复
✔️
挂载
❌
弹性
❌
- describe(app_id: str) Optional[DescribeAppResponse][source]¶
描述指定的应用程序。
- Returns:
应用程序定义描述或
None如果应用程序不存在。
- list() List[ListAppResponse][source]¶
对于调度器上发布的应用,此 API 返回一个 ListAppResponse 对象列表,每个对象包含应用 ID 及其状态。 注意:此 API 处于原型阶段,可能会发生变化。
- log_iter(app_id: str, role_name: str, k: int = 0, regex: Optional[str] = None, since: Optional[datetime] = None, until: Optional[datetime] = None, should_tail: bool = False, streams: Optional[Stream] = None) Iterable[str][source]¶
返回一个迭代器,用于访问日志行的
k``th replica of the ``role。 该迭代器在读取完所有符合条件的日志行后结束。如果调度程序支持基于时间指针获取日志行,则
since,until字段会被遵循,否则会被忽略。不指定since和until相当于获取所有可用的日志行。如果until为空,则迭代器的行为就像tail -f一样,跟随日志输出直到作业达到终端状态。日志的确切定义取决于调度程序的具体设置。有些调度程序可能会将标准错误或标准输出视为日志,而其他调度程序则可能从日志文件中读取日志。
行为和假设:
如果在不存在的应用程序上调用此方法,会产生未定义行为。 调用者应在调用此方法之前使用
exists(app_id)检查应用程序是否存在。不是有状态的,用相同的参数调用此方法两次 会返回一个新的迭代器。之前的迭代 进度会丢失。
不一定始终支持日志追尾功能。并非所有调度器都支持实时日志迭代(例如,在应用程序运行时追尾日志)。有关迭代器的行为,请参阅特定调度器的文档。
- 3.1 If the scheduler supports log-tailing, it should be controlled
by
should_tailparameter.
不保证日志保留。调用此方法时,底层调度程序可能已经清除了该应用程序的日志记录。如果是这样的话,此方法将引发任意异常。
如果
should_tail为 True,该方法仅在可访问的日志行已完全耗尽且应用程序达到最终状态时抛出StopIteration异常。 例如,如果应用程序卡住且没有生成任何日志行,则迭代器会阻塞直到应用程序最终被终止(通过超时或手动操作),此时它会抛出一个StopIteration异常。如果
should_tail是 False,该方法在没有更多日志时抛出StopIteration。不一定由所有调度器支持。
某些调度器可能通过支持
__getitem__(例如,iter[50]寻找第 50 行日志)来支持行光标。- Whitespace is preserved, each new line should include
\n. To 支持交互式进度条,返回的行不需要包含
\n,但应该在不换行的情况下打印,以正确处理\r回车符。
- Whitespace is preserved, each new line should include
- Parameters:
流 – 要选择的IO输出流。 选项之一:combined、stdout、stderr。 如果所选流不被调度程序支持,它将抛出一个ValueError异常。
- Returns:
一个指定角色副本的日志行上的
Iterator- Raises:
NotImplementedError – 如果调度器不支持日志迭代
- schedule(dryrun_info: AppDryRunInfo[AWSSageMakerJob]) str[source]¶
与
submit相同,只是它接受一个AppDryRunInfo。 实现者被鼓励实现此方法而不是直接实现submit,因为submit可以通过以下方式简单实现:dryrun_info = self.submit_dryrun(app, cfg) return schedule(dryrun_info)
- class torchx.schedulers.aws_sagemaker_scheduler.AWSSageMakerJob(job_name: str, job_def: Dict[str, Any], images_to_push: Dict[str, Tuple[str, str]])[source]¶
工作定义了调度作业所需的关键值。这将是AppDryRunInfo对象中request的值。
job_name:定义在 SageMaker 中显示的任务名称
job_def:定义将在 SageMaker 上用于调度作业的工作描述。
images_to_push: 由 torchx 用于推送到 image_repo
参考¶
- torchx.schedulers.aws_sagemaker_scheduler.create_scheduler(session_name: str, **kwargs: object) AWSSageMakerScheduler[source]¶