设置参数:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
parser.add_argument('--local_rank', type=int, default=0, help='Local rank of the process')
parser.add_argument('--device', type=str, default='0', help='Local rank of the process')
args = parser.parse_args()
torch.distributed.init_process_group(backend='nccl', init_method='env://')
local_rank = args.local_rank
device = torch.device('cuda', local_rank)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
epoch_num = 500
batch_size = args.batch_size
if torch.cuda.is_available():
net.to(device)
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=
[local_rank],find_unused_parameters=True)
训练时参数:
# 在训练循环中使用本地 GPU 设备
for batch in dataloader:
inputs, labels = batch
inputs = inputs.to(device)
labels = labels.to(device)
# 在这里进行训练
...
调用命令:
CUDA_VISIBLE_DEVICES=3,4 /data3/lbg/envs//aimet_3.8/bin/python3.8 -m torch.distributed.launch --master_port 49998 --nproc_per_node 2 train.py --device '' --batch_size 256
原文地址:https://blog.csdn.net/jacke121/article/details/134702645
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.7code.cn/show_34500.html
如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱:suwngjj01@126.com进行投诉反馈,一经查实,立即删除!
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。