[nr * args.np + lr for lr in range(args.np)] if nr == args.node_rank: sharding_group = dist.new_group(ranks=wr_list, backend=args.dist_backend) else: dist.new_group(ranks=wr_list, backend=args.dist_backend) # for inter-node PowerSGD for lr in range(args.np): wr_list = [nr * args.np + lr for nr in range(args.nn)] if lr == distenv.local_rank: averaging_group = dist.new_group(ranks=wr_list, backend=args.dist_backend) else: dist.new_group(ranks=wr_list, backend=args.dist_backend)