Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method converting Tensor to Eigen TensorMap #2805

Merged
merged 19 commits into from
Jul 18, 2017
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ddim lib
cc_library(enforce SRCS enforce.cc DEPS glog)
cc_test(enforce_test SRCS enforce_test.cc DEPS enforce)
cc_library(ddim SRCS ddim.cc)
cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory)
Expand Down
28 changes: 27 additions & 1 deletion paddle/framework/ddim.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <boost/variant.hpp>
#include <initializer_list>
#include <stdexcept>
#include <vector>

#include "paddle/framework/dim.h"
#include "paddle/framework/enforce.h"
#include "unsupported/Eigen/CXX11/Tensor"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -102,6 +117,17 @@ int arity(const DDim& ddim);

std::ostream& operator<<(std::ostream&, const DDim&);

template <int NDIMS>
Eigen::DSizes<Eigen::DenseIndex, NDIMS> ToEigenDSizes(const DDim& dims) {
int rank = arity(dims);
PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same");
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
for (int d = 0; d < rank; d++) {
dsizes[d] = dims[d];
}
return dsizes;
}

} // namespace framework
} // namespace paddle

Expand Down
14 changes: 14 additions & 0 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@ limitations under the License. */
namespace paddle {
namespace framework {

template <>
Eigen::DefaultDevice* KernelContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>();
}

#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice*
KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>();
}
#endif

void OperatorBase::CreateInOutOffsetMap(const OpProto& proto) {
PADDLE_ENFORCE(in_out_idxs_.empty(), "duplicate call CreateInOutOffsetMap");
for (int i = 0; i < proto.inputs_size(); i++) {
Expand Down
23 changes: 23 additions & 0 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ limitations under the License. */
namespace paddle {
namespace framework {

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif

class OperatorBase;
using OperatorPtr = std::shared_ptr<OperatorBase>;
/**
Expand Down Expand Up @@ -134,6 +149,13 @@ class KernelContext {
return res;
}

template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;

platform::Place GetPlace() const { return device_context_.GetPlace(); }

const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
Expand All @@ -147,6 +169,7 @@ class OpKernel {
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/

virtual void Compute(const KernelContext& context) const = 0;

virtual ~OpKernel() {}
Expand Down
69 changes: 69 additions & 0 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ limitations under the License. */
#include <memory>
#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
#include "paddle/framework/tensor_types.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"

namespace paddle {
namespace framework {
Expand All @@ -36,6 +38,13 @@ class Tensor {
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}

template <typename T>
T* raw_data() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we always use raw_data in OpKernel. The mutable_data is only used for allocation. Maybe we could change mutable_data to allocation

CheckDims<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}

template <typename T>
T* mutable_data(DDim dims, platform::Place place) {
set_dims(dims);
Expand Down Expand Up @@ -70,6 +79,66 @@ class Tensor {
offset_);
}

template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) {
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
return typename TTypes<T, NDIMS>::Tensor(raw_data<T>(), dims);
}

template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor tensor() {
return typename TTypes<T, NDIMS>::Tensor(
raw_data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}

// flat to rank = 1
template <typename T>
typename TTypes<T>::Flat flat() {
return shaped<T, 1>(make_ddim({static_cast<int>(product(dims_))}));
}

// to TensorType Vec
template <typename T>
typename TTypes<T>::Vec vec() {
return tensor<T, 1>();
}

// to TensorType Matrix
template <typename T>
typename TTypes<T>::Matrix matrix() {
return tensor<T, 2>();
}

// const versions of all the methods above.
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(DDim new_dims) const {
Eigen::array<Eigen::DenseIndex, NDIMS> dims =
paddle::framework::ToEigenDSizes<NDIMS>(new_dims);
return typename TTypes<T, NDIMS>::Tensor(data<T>(), dims);
}

template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstantTensor tensor() const {
return typename TTypes<T, NDIMS>::Tensor(
data<T>(), paddle::framework::ToEigenDSizes<NDIMS>(dims_));
}

template <typename T>
typename TTypes<T>::ConstFlat flat() const {
return shaped<T, 1>(make_ddim({static_cast<int>(product(dims_))}));
}

template <typename T>
typename TTypes<T>::ConstVec vec() const {
return tensor<T, 1>();
}

template <typename T>
typename TTypes<T>::ConstMatrix matrix() const {
return tensor<T, 2>();
}

template <typename T>
void ShareDataFrom(const Tensor& src) {
src.CheckDims<T>();
Expand Down
67 changes: 67 additions & 0 deletions paddle/framework/tensor_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "unsupported/Eigen/CXX11/Tensor"

namespace paddle {
namespace framework {

// Helper to define Tensor types given that the scalar is of type T.
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
struct TTypes {
// Rank-<NDIMS> tensor of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Tensor;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstTensor;

// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
typedef Eigen::TensorMap<
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Scalar;
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
Eigen::RowMajor, IndexType>,
Eigen::Aligned>
ConstScalar;

// Rank-1 tensor (vector) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Flat;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstFlat;
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Vec;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstVec;

// Rank-2 tensor (matrix) of scalar type T.
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
Eigen::Aligned>
Matrix;
typedef Eigen::TensorMap<
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
ConstMatrix;
};

} // namespace framework
} // namespace paddle
32 changes: 16 additions & 16 deletions paddle/operators/add_op.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <paddle/framework/op_registry.h>
#include <paddle/framework/tensor.h>
#include <paddle/operators/add_op.h>
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"

namespace paddle {
namespace operators {
Expand All @@ -31,8 +31,7 @@ class AddOp : public framework::OperatorWithKernel {
"Inputs/Outputs of AddOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of Add Op's dimension must be same.");
// Need set dims in Tensor
// outputs[0]->set_dims(inputs[0]->dims())
outputs[0]->set_dims(inputs[0]->dims());
}
};

Expand All @@ -54,5 +53,6 @@ The equation is: Out = X + Y
} // namespace paddle

REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker);
REGISTER_OP_CPU_KERNEL(
add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>);
typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>
AddKernel_CPU_float;
REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float);
7 changes: 4 additions & 3 deletions paddle/operators/add_op.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <paddle/operators/add_op.h>
#include <paddle/framework/op_registry.h>
#include "paddle/operators/add_op.h"
#include "paddle/framework/op_registry.h"

typedef paddle::operators::AddKernel<::paddle::platform::GPUPlace, float> AddKernel_GPU_float;
REGISTER_OP_GPU_KERNEL(add_two,
paddle::operators::AddKernel<paddle::platform::GPUPlace>);
AddKernel_GPU_float);
31 changes: 26 additions & 5 deletions paddle/operators/add_op.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <glog/logging.h>
#include <paddle/framework/operator.h>
#include "glog/logging.h"
#include "paddle/framework/operator.h"

namespace paddle {
namespace operators {

template <typename Place>
template <typename Place, typename T>
class AddKernel : public framework::OpKernel {
public:
void Compute(const framework::KernelContext &context) const override {
LOG(INFO) << "Add kernel in " << typeid(Place).name();
void Compute(const framework::KernelContext& context) const override {
auto input0 = context.Input(0)->Get<framework::Tensor>();
auto input1 = context.Input(1)->Get<framework::Tensor>();
auto* output = context.Output(0)->GetMutable<framework::Tensor>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe output->mutable_data could hide in context.GetOutputTensor<>()


output->mutable_data<T>(context.GetPlace());

output->flat<T>().device(*(context.GetEigenDevice<Place>())) =
input0.flat<T>() + input1.flat<T>();
}
};

Expand Down
Loading