目录

设置参数:

训练时参数:

调用命令:


设置参数

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
    parser.add_argument('--local_rank', type=int, default=0, help='Local rank of the process')
    parser.add_argument('--device', type=str, default='0', help='Local rank of the process')
    args = parser.parse_args()

    torch.distributed.init_process_group(backend='nccl', init_method='env://')

    local_rank = args.local_rank
    device = torch.device('cuda', local_rank)
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    epoch_num = 500
    batch_size = args.batch_size


    if torch.cuda.is_available():
        net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids= 
    [local_rank],find_unused_parameters=True)


训练参数

# 在训练循环使用本地 GPU 设备
for batch in dataloader:
    inputs, labels = batch
    inputs = inputs.to(device)
    labels = labels.to(device)
    # 在这里进行训练
    ...

调用命令

CUDA_VISIBLE_DEVICES=3,4 /data3/lbg/envs//aimet_3.8/bin/python3.8 -m torch.distributed.launch --master_port 49998 --nproc_per_node 2 train.py --device '' --batch_size 256

原文地址:https://blog.csdn.net/jacke121/article/details/134702645

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任

如若转载,请注明出处:http://www.7code.cn/show_34500.html

如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱suwngjj01@126.com进行投诉反馈,一经查实,立即删除

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注