Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dataset in pipeline #4968

Merged
merged 1 commit into from
Jul 1, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 44 additions & 35 deletions docs/guides/06_distributed_training/pipeline_parallel_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,38 @@
import paddle.nn.functional as F
import paddle.distributed as dist
import random
from paddle.io import Dataset, BatchSampler, DataLoader


创建数据集

.. code-block:: python
BATCH_NUM = 20
BATCH_SIZE = 16
EPOCH_NUM = 4

IMAGE_SIZE = 784
CLASS_NUM = 10
MICRO_BATCH_SIZE = 2

class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples

def __getitem__(self, idx):
image = np.random.random([1, 28, 28]).astype('float32')
label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
return image, label

def __len__(self):
return self.num_samples

dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
train_reader = DataLoader(dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=2)


构建一个可以运行流水线的模型,模型的layer需要被LayerDesc或者继承了LayerDesc的SharedLayerDesc包裹,这里因为不需要共享参数,所以就使用LayerDesc
Expand All @@ -77,8 +109,9 @@
def forward(self, x):
return x.reshape(shape=self.shape)


class AlexNetPipeDesc(PipelineLayer):
def __init__(self, num_classes=10, **kwargs):
def __init__(self, num_classes=CLASS_NUM, **kwargs):
self.num_classes = num_classes
decs = [
LayerDesc(
Expand Down Expand Up @@ -108,14 +141,11 @@
]
super(AlexNetPipeDesc, self).__init__(
layers=decs, loss_fn=nn.CrossEntropyLoss(), **kwargs)

然后初始化分布式环境,这一步主要是构建流水线通信组的拓扑

.. code-block:: python

batch_size = 4
micro_batch_size = 2

strategy = fleet.DistributedStrategy()
model_parallel_size = 1
data_parallel_size = 1
Expand All @@ -126,12 +156,11 @@
"pp_degree": pipeline_parallel_size
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
"accumulate_steps": BATCH_SIZE // MICRO_BATCH_SIZE,
"micro_batch_size": MICRO_BATCH_SIZE
}


fleet.init(is_collective=True, strategy=strategy)

fleet.init(is_collective=True, strategy=strategy)

为了保证流水线并行参数初始化和普通模型初始化一致,需要在不同卡间设置不同的seed。

Expand Down Expand Up @@ -162,7 +191,6 @@ fleet.distributed_optimizer(...):这一步则是为优化器添加分布式属

.. code-block:: python


class ReshapeHelp(Layer):
def __init__(self, shape):
super(ReshapeHelp, self).__init__()
Expand Down Expand Up @@ -214,35 +242,16 @@ fleet.distributed_optimizer(...):这一步则是为优化器添加分布式属
optimizer = fleet.distributed_optimizer(optimizer)


创建mnist数据集

.. code-block:: python

train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size, drop_last=True
)

开始训练

model.train_batch(...):这一步主要就是执行1F1B的流水线并行方式

.. code-block:: python

for step_id, data in enumerate(train_reader()):
x_data = np.array([x[0] for x in data]).astype("float32").reshape(
batch_size, 1, 28, 28
)
y_data = np.array([x[1] for x in data]).astype("int64").reshape(
batch_size, 1
)
img = paddle.to_tensor(x_data)
label = paddle.to_tensor(y_data)
img.stop_gradient = True
label.stop_gradient = True
if step_id >= 5:
break

loss = model.train_batch([img, label], optimizer, scheduler)
for i, (image, label) in enumerate(train_reader()):
if i >= 5:
break
loss = model.train_batch([image, label], optimizer, scheduler)
print("pp_loss: ", loss.numpy())

运行方式(需要保证当前机器有两张GPU):
Expand All @@ -252,7 +261,7 @@ model.train_batch(...):这一步主要就是执行1F1B的流水线并行方式
export CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch alexnet_dygraph_pipeline.py # alexnet_dygraph_pipeline.py是用户运行动态图流水线的python文件

基于AlexNet的流水线并行动态图代码:`alex <https://github.com/PaddlePaddle/FleetX/tree/develop/examples/pipeline>`_。
基于AlexNet的完整的流水线并行动态图代码:`alex <https://github.com/PaddlePaddle/FleetX/tree/develop/examples/pipeline>`_。

控制台输出信息如下:

Expand Down