Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve split and concat op #8270

Merged
merged 6 commits into from
Feb 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/fluid/framework/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,5 +314,15 @@ DDim stride(const DDim& ddim) {
}
return framework::make_ddim(strides);
}

DDim stride_numel(const framework::DDim& ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i];
}
return framework::make_ddim(strides);
}

} // namespace framework
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ddim.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims);
DDim flatten_to_1d(const DDim& src);

DDim stride(const DDim& ddim);

DDim stride_numel(const DDim& ddim);
} // namespace framework
} // namespace paddle

Expand Down
38 changes: 19 additions & 19 deletions paddle/fluid/operators/concat_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@ class ConcatKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = ins.size();
auto place = ctx.GetPlace();
out->mutable_data<T>(place);

auto out_stride = framework::stride_numel(out->dims());

size_t output_offset = 0;
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride(out->dims());
for (size_t i = 0; i < n; i++) {
auto& in = ins[i];
auto axis_dim = in->dims()[axis];
auto in_stride = framework::stride(in->dims());
StridedMemcpy<T>(ctx.device_context(), in->data<T>(), in_stride,
in->dims(), out_stride, out->data<T>() + output_offset);
output_offset += axis_dim * in_stride[axis];
for (auto* in : ins) {
auto in_stride = framework::stride_numel(in->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
out->data<T>() + output_offset, out_stride,
in->data<T>(), in_stride);
output_offset += in_stride[axis];
}
}
};
Expand All @@ -50,17 +51,16 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = outs.size();
size_t input_offset = 0;
auto in_stride = framework::stride(in->dims());
for (size_t i = 0; i < n; i++) {
auto& out = outs[i];
auto in_stride = framework::stride_numel(in->dims());

for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
size_t axis_dim = out->dims()[axis];
auto out_stride = framework::stride(out->dims());
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
in_stride, out->dims(), out_stride, out->data<T>());
input_offset += axis_dim * in_stride[axis];
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride);
input_offset += out_stride[axis];
}
}
};
Expand Down
19 changes: 10 additions & 9 deletions paddle/fluid/operators/split_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include <chrono>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h"
Expand All @@ -27,18 +28,18 @@ class SplitOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto in_stride = framework::stride(in->dims());
auto in_stride = framework::stride_numel(in->dims());
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = outs.size();
auto place = ctx.GetPlace();

size_t input_offset = 0;
for (size_t i = 0; i < n; i++) {
auto& out = outs[i];
for (auto& out : outs) {
out->mutable_data<T>(ctx.GetPlace());
size_t axis_dim = out->dims()[axis];
auto out_stride = framework::stride(out->dims());
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
in_stride, out->dims(), out_stride, out->data<T>());
input_offset += axis_dim * in_stride[axis];
auto out_stride = framework::stride_numel(out->dims());
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
out_stride, in->data<T>() + input_offset,
in_stride);
input_offset += out_stride[axis];
}
}
};
Expand Down
57 changes: 57 additions & 0 deletions paddle/fluid/operators/strided_memcpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,62 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
StridedCopyDimVisitor<T> func(dev_ctx, src, src_stride, dst_stride, dst);
boost::apply_visitor(func, dst_dim);
}

// Strided numel memory copy from src to dst by the specified axis
//
// For example, for a tensor dims [4, 20, 100], the strieded numel is
// [8000, 2000, 100]
//
// NOTE: The src and dst tensor should have the same elements
// except the specified axis.
template <typename T>
inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
int64_t axis, T* dst,
const framework::DDim& dst_stride_numel,
const T* src,
const framework::DDim& src_stride_numel) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis];
auto place = ctx.GetPlace();

PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
"src and dst tensor should have the same dims size.");

for (int64_t i = 0; i < axis; ++i) {
if (i < axis) {
PADDLE_ENFORCE_EQ(src_stride_numel[i] / src_stride_numel[axis],
dst_stride_numel[i] / dst_stride_numel[axis],
"src and dst should have the same elements "
"except the specified axis.");
} else if (i == axis) {
continue;
} else {
PADDLE_ENFORCE_EQ(src_stride_numel[i], dst_stride_numel[i],
"src and dst should have the same elements "
"except the specified axis.");
}
}

for (int64_t i = 0; i < before; ++i) {
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
src + i * src_after, sizeof(T) * src_after);
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
src + i * src_after, sizeof(T) * src_after,
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
}

} // namespace operators
} // namespace paddle
8 changes: 4 additions & 4 deletions python/paddle/v2/fluid/tests/test_split_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
class TestSplitOp(OpTest):
def setUp(self):
self.op_type = "split"
axis = 0
x = np.random.random((4, 2, 5)).astype('float32')
out = np.split(x, [1, 3], axis)
axis = 1
x = np.random.random((4, 5, 6)).astype('float32')
out = np.split(x, [2, 3], axis)
self.inputs = {'X': x}
self.attrs = {'axis': axis, 'sections': [1, 2, 1]}
self.attrs = {'axis': axis, 'sections': [2, 1, 2]}
self.outputs = {'Out': [('out%d' % i, out[i]) \
for i in xrange(len(out))]}

Expand Down