Skip to content

Commit

Permalink
fix some review bug
Browse files Browse the repository at this point in the history
  • Loading branch information
youth123 committed Sep 9, 2021
1 parent 1461992 commit 3baaa23
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 273 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/collective/global_gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GlobalGatherOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ndim_input, 2,
platform::errors::InvalidArgument(
"The input tensor's dimension must be 2. "
"But received input's dimension = [%s].",
"But received input's dimension = %d.",
ndim_input));
framework::DDim out_dims = framework::make_ddim({-1, -1});
ctx->SetOutputDim("Out", out_dims);
Expand Down
25 changes: 17 additions & 8 deletions paddle/fluid/operators/collective/global_gather_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,39 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
auto x = ctx.Input<framework::LoDTensor>("X");
auto local_count = ctx.Input<framework::LoDTensor>("local_count");
auto global_count = ctx.Input<framework::LoDTensor>("global_count");
auto local_count_type = local_count->type();
auto global_count_type = global_count->type();
if (local_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in local_count."));
}
if (global_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in global_count."));
}
auto out = ctx.Output<framework::LoDTensor>("Out");
const int64_t* cpu_local_count_data;
const int64_t* cpu_global_count_data;
auto local_count_len = 0;

framework::Tensor cpu_local_count;
if (platform::is_gpu_place(local_count->place())) {
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
local_count_len = local_count->numel();
} else {
framework::TensorCopySync(*local_count, platform::CPUPlace(),
&cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
local_count_len = cpu_local_count.numel();
} else {
cpu_local_count_data = local_count->data<int64_t>();
local_count_len = local_count->numel();
}

framework::Tensor cpu_global_count;
if (platform::is_gpu_place(global_count->place())) {
if (platform::is_cpu_place(global_count->place())) {
cpu_global_count_data = global_count->data<int64_t>();
} else {
framework::TensorCopySync(*global_count, platform::CPUPlace(),
&cpu_global_count);
cpu_global_count_data = cpu_global_count.data<int64_t>();
} else {
cpu_global_count_data = global_count->data<int64_t>();
}

ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
Expand Down Expand Up @@ -111,7 +121,6 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
}
}
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
}
#else
PADDLE_THROW(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/collective/global_scatter_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GlobalScatterOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ndim_input, 2,
platform::errors::InvalidArgument(
"The input tensor's dimension must be 2. "
"But received input's dimension = [%s].",
"But received input's dimension = %d.",
ndim_input));

framework::DDim out_dims = framework::make_ddim({-1, -1});
Expand Down
25 changes: 17 additions & 8 deletions paddle/fluid/operators/collective/global_scatter_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,37 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
auto x = ctx.Input<framework::LoDTensor>("X");
auto local_count = ctx.Input<framework::LoDTensor>("local_count");
auto global_count = ctx.Input<framework::LoDTensor>("global_count");
auto local_count_type = local_count->type();
auto global_count_type = global_count->type();
if (local_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in local_count."));
}
if (global_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in global_count."));
}
auto out = ctx.Output<framework::LoDTensor>("Out");
const int64_t* cpu_local_count_data;
const int64_t* cpu_global_count_data;
framework::Tensor cpu_local_count;
if (platform::is_gpu_place(local_count->place())) {
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
} else {
framework::TensorCopy(*local_count, platform::CPUPlace(),
&cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
} else {
cpu_local_count_data = local_count->data<int64_t>();
}
auto global_count_len = 0;
framework::Tensor cpu_global_count;
if (platform::is_gpu_place(global_count->place())) {
if (platform::is_cpu_place(global_count->place())) {
cpu_global_count_data = global_count->data<int64_t>();
global_count_len = global_count->numel();
} else {
framework::TensorCopy(*global_count, platform::CPUPlace(),
&cpu_global_count);
cpu_global_count_data = cpu_global_count.data<int64_t>();
global_count_len = cpu_global_count.numel();
} else {
cpu_global_count_data = global_count->data<int64_t>();
global_count_len = global_count->numel();
}

ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
Expand Down Expand Up @@ -110,7 +120,6 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
}
}
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGroupEnd());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
}

#else
Expand Down
182 changes: 182 additions & 0 deletions python/paddle/fluid/tests/unittests/test_collective_api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,187 @@ def check_with_place(self,
self.assertTrue(
np.allclose(
input1, result_data, rtol=1e-05, atol=1e-05))
elif col_type == "global_gather":
in_feat = 2
n_expert = 2
world_size = 2
tot_expert = n_expert * world_size

np.random.seed(pid0)
local_expert_count1 = np.random.randint(
1, 4, size=tot_expert).astype("int")
expert_ptr1 = np.ones(tot_expert, dtype=np.int32)
expert_ptr1[0] = 0
for i in range(1, tot_expert):
expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1]

np.random.seed(pid1)
local_expert_count2 = np.random.randint(
1, 4, size=tot_expert).astype("int")
expert_ptr2 = np.ones(tot_expert, dtype=np.int32)
expert_ptr2[0] = 0
for i in range(1, tot_expert):
expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1]

global_expert_count1 = np.zeros(tot_expert).astype("int")
global_expert_count2 = np.zeros(tot_expert).astype("int")
global_expert_count1[0:n_expert] = local_expert_count1[0:n_expert]
global_expert_count1[n_expert:] = local_expert_count2[0:n_expert]
global_expert_count2[0:n_expert] = local_expert_count1[n_expert:]
global_expert_count2[n_expert:] = local_expert_count2[n_expert:]

np.random.seed(pid0)
fwd_expert_count = sum(global_expert_count1).astype("int")
local_input_buf1 = np.random.rand(fwd_expert_count,
in_feat).astype("float32")
np.random.seed(pid1)
fwd_expert_count = sum(global_expert_count2).astype("int")
local_input_buf2 = np.random.rand(fwd_expert_count,
in_feat).astype("float32")
output1 = [[], [], [], []]
output2 = [[], [], [], []]
send_ptr1 = 0
send_ptr2 = 0

for i in range(n_expert):
for j in range(world_size):
idx = j * n_expert + i
if j == 0:
output1_part1 = local_input_buf1[send_ptr1: \
send_ptr1 + global_expert_count1[idx], :]
output1_part2 = local_input_buf2[send_ptr2: \
send_ptr2 + global_expert_count2[idx], :]
output1[i].extend(output1_part1)
output1[i + n_expert].extend(output1_part2)
else:
output2_part1 = local_input_buf1[send_ptr1: \
send_ptr1 + global_expert_count1[idx]]
output2_part2 = local_input_buf2[send_ptr2: \
send_ptr2 + global_expert_count2[idx]]
output2[i].extend(output2_part1)
output2[i + n_expert].extend(output2_part2)
send_ptr1 = send_ptr1 + global_expert_count1[idx]
send_ptr2 = send_ptr2 + global_expert_count2[idx]
result1 = []
result2 = []
for i in range(tot_expert):
for arr in output1[i]:
if arr == []:
continue
result1.append(arr)
for i in range(tot_expert):
for arr in output2[i]:
if arr == []:
continue
result2.append(arr)
if result1 == []:
output1 = np.array([])
else:
output1 = np.concatenate(
result1, axis=0).reshape(
sum(local_expert_count1), in_feat)
if result2 == []:
output2 = np.array([])
else:
output2 = np.concatenate(
result2, axis=0).reshape(
sum(local_expert_count2), in_feat)

if tr0_out[0] is None or tr0_out[0].shape[0] == 0:
tr0_out[0] = np.array([])

if tr1_out[0] is None or tr1_out[0].shape[0] == 0:
tr1_out[0] = np.array([])

self.assertTrue(
np.allclose(
tr0_out[0], output1, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out[0], output2, rtol=1e-05, atol=1e-05))
if static_mode == 0:
self.assertTrue(
np.allclose(
tr0_out[1],
2 * local_input_buf1,
rtol=1e-05,
atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out[1],
2 * local_input_buf2,
rtol=1e-05,
atol=1e-05))

elif col_type == "global_scatter":
np.random.seed(pid0)
local_expert_count1 = np.random.randint(1, 4, size=4).astype("int")
fwd_expert_count = sum(local_expert_count1)
local_input_buf1 = np.random.rand(fwd_expert_count,
2).astype("float32")
expert_ptr1 = np.ones(4, dtype=np.int32)
expert_ptr1[0] = 0
for i in range(1, 4):
expert_ptr1[i] = expert_ptr1[i - 1] + local_expert_count1[i - 1]
np.random.seed(pid1)
local_expert_count2 = np.random.randint(1, 4, size=4).astype("int")
fwd_expert_count = sum(local_expert_count2)
local_input_buf2 = np.random.rand(fwd_expert_count,
2).astype("float32")
expert_ptr2 = np.ones(4, dtype=np.int32)
expert_ptr2[0] = 0
for i in range(1, 4):
expert_ptr2[i] = expert_ptr2[i - 1] + local_expert_count2[i - 1]

output1 = []
output2 = []
for i in range(2):
for j in range(2):
idx = j * 2 + i
if j == 0:
# send data to 0 card
output1.append(local_input_buf1[expert_ptr1[idx]: \
expert_ptr1[idx]+local_expert_count1[idx]])
output1.append(local_input_buf2[expert_ptr2[idx]:\
expert_ptr2[idx]+local_expert_count2[idx]])
else:
output2.append(local_input_buf1[expert_ptr1[idx]: \
expert_ptr1[idx]+local_expert_count1[idx]])
output2.append(local_input_buf2[expert_ptr2[idx]:\
expert_ptr2[idx]+local_expert_count2[idx]])
if output1 == []:
output1 = np.array([])
else:
output1 = np.concatenate(output1)
if output2 == []:
output2 = np.array([])
else:
output2 = np.concatenate(output2)

if tr0_out[0] is None or tr0_out[0].shape[0] == 0:
tr0_out[0] = np.array([])

if tr1_out[0] is None or tr1_out[0].shape[0] == 0:
tr1_out[0] = np.array([])

self.assertTrue(
np.allclose(
tr0_out[0], output1, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out[0], output2, rtol=1e-05, atol=1e-05))
if static_mode == 0:
self.assertTrue(
np.allclose(
tr0_out[1],
2 * local_input_buf1,
rtol=1e-05,
atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out[1],
2 * local_input_buf2,
rtol=1e-05,
atol=1e-05))
else:
pass
Loading

0 comments on commit 3baaa23

Please sign in to comment.