Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

混合并行无法收敛 #218

Closed
Ldpe2G opened this issue Mar 23, 2022 · 7 comments
Closed

混合并行无法收敛 #218

Ldpe2G opened this issue Mar 23, 2022 · 7 comments
Assignees
Labels
bug Something isn't working

Comments

@Ldpe2G
Copy link
Collaborator

Ldpe2G commented Mar 23, 2022

问题描述

swin 设置数据+流水并行,发现训练无法收敛

实验分支:#215

对比实验

以下实验,将总的 batch size 固定为 32

实验1,单卡

swin_cifar100.py 中的dist配置改为:

train.dist.data_parallel_size=1
train.dist.tensor_parallel_size=1
train.dist.pipeline_parallel_size=1

第一个 epoch top1 acc: 3.49

实验2,2卡

swin_cifar100.py 中的dist配置改为:

train.dist.data_parallel_size=1
train.dist.tensor_parallel_size=1
train.dist.pipeline_parallel_size=2

第一个 epoch top1 acc: 3.14

实验3,4卡

swin_cifar100.py 中的dist配置改为:

train.dist.data_parallel_size=1
train.dist.tensor_parallel_size=1
train.dist.pipeline_parallel_size=4

第一个 epoch top1 acc: 3.81

实验4,8卡

swin_cifar100.py 中的dist配置改为:

train.dist.data_parallel_size=2
train.dist.tensor_parallel_size=1
train.dist.pipeline_parallel_size=4

第一个 epoch top1 acc: 1.04

实验5,4卡

swin_cifar100.py 中的dist配置改为:

train.dist.data_parallel_size=2
train.dist.tensor_parallel_size=1
train.dist.pipeline_parallel_size=2

第一个 epoch top1 acc: 0.98

实验结论

单纯的朴素流水并行,1,2,4卡都能收敛,但是当数据加流水并行一起跑,就不会收敛。

在实验5基础上,做实验把 eager_trainer 中的 optimizer.step 注释掉,跑出来的精度也是1.多。

@Ldpe2G Ldpe2G added the bug Something isn't working label Mar 23, 2022
@Ldpe2G
Copy link
Collaborator Author

Ldpe2G commented Mar 28, 2022

alexnet 数据+流水混合并行实验

实验分支: #227

2卡数据并行 对比 4卡数据+流水(2个stage,每个stage 2卡)

实验设置:

  • 训练集 shuffle = False
  • train_aug 改成 test_aug
  • mixup_func=None
  • loss_func=None
  • graph.enabled = False

首先2卡数据并行训几个epoch,模型保存下来,然后重新加载

  • 2卡数据并行推理能达到相同精度
  • 2卡流水推理精度也没问题
  • 4卡数据+流水加载推理结果就很差

@L1aoXingyu
Copy link
Collaborator

L1aoXingyu commented Mar 28, 2022

只要数据并行+流水并行效果就很差吗,只做推理效果都很差,那感觉问题又清楚了一点

我正在跑其他并行的条件,看看是不是只有这种 case 有问题,别的 case loss 是 ok 的

@Ldpe2G
Copy link
Collaborator Author

Ldpe2G commented Mar 28, 2022

只要数据并行+流水并行效果就很差吗,只做推理效果都很差,那感觉问题又清楚了一点

是的,只加载模型上来推理就很差

@leaves-zwx
Copy link
Collaborator

这个 pipeline parallel 是朴素接力还是带上了 grad acc + stage id + checkpointing 的?

@Ldpe2G
Copy link
Collaborator Author

Ldpe2G commented Mar 28, 2022

这个 pipeline parallel 是朴素接力还是带上了 grad acc + stage id + checkpointing 的?

都没有,是朴素接力,上面那个实验是在 eager global 下跑的

@Ldpe2G
Copy link
Collaborator Author

Ldpe2G commented Mar 29, 2022

经过多次实验和代码分析,定位到 libai 中 dist 模块实现有 bug 导致了混合并行收敛问题,修复之后 8卡 3d 并行(2流水+2数据+2张量)能训收敛了,修复方式见:20bf385

其实和中兴用户之前遇到的问题是类似的,都是只用了部分数据集来训练和测试。

原因分析

首先看原来 dist 模块中 model_parallel_sizeget_data_parallel_rank 的实现

def model_parallel_size(self):
      return self._tensor_parallel_size * self._pipeline_parallel_size

def get_data_parallel_rank():
    dist_util = get_dist_util()
    return flow.env.get_rank() // dist_util.model_parallel_size

再来看 dataloader 中的 Sampler,有两个关键的参数 data_parallel_rankdata_parallel_size

sampler = CyclicSampler(
    dataset=dataset,
    micro_batch_size=train_batch_size,
    shuffle=True,
    consumed_samples=consumed_samples,
    data_parallel_rank=dist.get_data_parallel_rank(),
    data_parallel_size=dist.get_data_parallel_size(),
    seed=seed,
)

sampler 的实现逻辑,首先确定当前 rank 所要读取的数据段,计算本rank所负责读取的长度:

data_size_per_epoch = dataset_size / data_parallel_size

然后计算本rank的起始偏移:

start_idx = data_parallel_rank * data_size_per_epoch

逻辑很简单,所以每个rank在读取数据集的时候,所读取的数据就取决于 data_parallel_rankdata_parallel_size 这两个参数。

然后回到 4卡(数据+流水并行) 这个例子,配置文件中对应的配置如下:

train.dist.data_parallel_size=2
train.dist.tensor_parallel_size=1
train.dist.pipeline_parallel_size=2

所以根据原来 dist 模块中的实现,每个 rank 调用 dist.get_data_parallel_rank()dist.get_data_parallel_size() 返回的结果如下:

model_parallel_size = tensor_parallel_size * pipeline_parallel_size = 2

def get_data_parallel_rank():
    return flow.env.get_rank() // model_parallel_size
    

rank 0
	get_data_parallel_rank = 0 // 2 = 0
	get_data_parallel_size = 2

rank 1
	get_data_parallel_rank = 1 // 2 = 0
	get_data_parallel_size = 2

rank 2
	get_data_parallel_rank = 2 // 2 = 1
	get_data_parallel_size = 2

rank 3
	get_data_parallel_rank = 3 // 2 = 1
	get_data_parallel_size = 2

所以就是 rank0 和 rank1 读取的数据是一样的,而且只读了一半的数据,然后 在上诉并行配置下,网络输入在 to_global 的时候的 sbp 和 placement 为 [oneflow.sbp.split(axis=0)]oneflow.placement(type="cuda", ranks=[0, 1])。而这样to_global之后,就丢掉了 rank 2 和 rank 3 上另外一半有用的数据,而 rank 0 和 rank1 上虽然是 split(0) 但是都是一样的数据重复了。

这就导致了训练和测试都只用了一半的数据,但是为啥会导致收敛不了?

然后修复之后的结果:

model_parallel_size = tensor_parallel_size = 1

def get_data_parallel_rank():
    return (flow.env.get_rank() // model_parallel_size) % data_parallel_size
    

rank 0
	get_data_parallel_rank = (0 // 1) % 2 = 0
	get_data_parallel_size = 2

rank 1
	get_data_parallel_rank = (1 // 1) % 2 = 1
	get_data_parallel_size = 2

rank 2
	get_data_parallel_rank = (2 // 1) % 2 = 0
	get_data_parallel_size = 2

rank 3
	get_data_parallel_rank = (3 // 1) % 2 = 1
	get_data_parallel_size = 2

这样的话 rank0 和 rank1 读取的数据加起来就是完整的数据集了,这才符合 数据并行的设置。

在来看下修复后 8卡 3d 并行(2流水+2数据+2张量)下的结果:

train.dist.data_parallel_size=2
train.dist.tensor_parallel_size=2
train.dist.pipeline_parallel_size=2

model_parallel_size = tensor_parallel_size = 2

def get_data_parallel_rank():
    return (flow.env.get_rank() // model_parallel_size) % data_parallel_size
    

rank 0
	get_data_parallel_rank = (0 // 2) % 2 = 0
	get_data_parallel_size = 2

rank 1
	get_data_parallel_rank = (1 // 2) % 2 = 0
	get_data_parallel_size = 2

rank 2
	get_data_parallel_rank = (2 // 2) % 2 = 1
	get_data_parallel_size = 2

rank 3
	get_data_parallel_rank = (3 // 2) % 2 = 1
	get_data_parallel_size = 2
    
rank 4
	get_data_parallel_rank = (4 // 2) % 2 = 0
	get_data_parallel_size = 2

rank 5
	get_data_parallel_rank = (5 // 2) % 2 = 0
	get_data_parallel_size = 2

rank 6
	get_data_parallel_rank = (6 // 2) % 2 = 1
	get_data_parallel_size = 2

rank 7
	get_data_parallel_rank = (7 // 2) % 2 = 1
	get_data_parallel_size = 2

然后上述并行设置下,网络输入的 sbp 和 placement 是 [oneflow.sbp.split(axis=0), oenflow.sbp.broadcast] oneflow.placement(type="cuda", ranks=[[0, 1], [2, 3]]),而 rank 0 和 rank 1 读的是同样的一半数据,rank 2 和 rank3 读的是同样的另一半数据,而 sbp 是 [s(0), b] ,所以也符合并行的要求。

@L1aoXingyu
Copy link
Collaborator

流水并行的 stage 读了数据却没用上,所以流水的 stage 越多,读到的数据越少,收敛越差。
如果是数据+tensor 并行的情况,应该是处理了这个问题,所以可以收敛。

@Ldpe2G Ldpe2G closed this as completed Mar 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants