Skip to content

Commit

Permalink
add the tensor base class, test=develop (PaddlePaddle#17)
Browse files Browse the repository at this point in the history
* update the tensor base class, test=develop

* remove two funcs, test=develop

* update the error msg, test=develop

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
  • Loading branch information
Shixiaowei02 and chenwhql authored Oct 19, 2021
1 parent 1dd0145 commit b77d1ee
Show file tree
Hide file tree
Showing 38 changed files with 1,242 additions and 571 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1826,7 +1826,7 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext(
<< in_def.layout;

auto ins_vector = ctx.inputs.at(input_names[i]);
std::vector<std::shared_ptr<pt::TensorInterface>> tmp_inputs;
std::vector<std::shared_ptr<tcmpt::TensorBase>> tmp_inputs;

for (auto var : ins_vector) {
auto pt_in = framework::InputVariableToPtTensor(*var, in_def);
Expand All @@ -1839,7 +1839,7 @@ pt::KernelContext OperatorWithKernel::BuildPtKernelContext(
auto out_def = output_defs.at(i);
auto outs_vector = ctx.outputs.at(output_names[i]);

std::vector<std::shared_ptr<pt::TensorInterface>> tmp_outputs;
std::vector<std::shared_ptr<tcmpt::TensorBase>> tmp_outputs;
for (auto var : outs_vector) {
auto pt_out = framework::OutputVariableToPtTensor(var, out_def);
tmp_outputs.emplace_back(pt_out);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/tcmpt_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ std::shared_ptr<pt::DenseTensor> MakeTensorImpl<pt::DenseTensor>(
pt::TransToPtDataLayout(tensor.layout()));
}

std::shared_ptr<pt::TensorInterface> InputVariableToPtTensor(
std::shared_ptr<tcmpt::TensorBase> InputVariableToPtTensor(
const framework::Variable& variable, const pt::TensorArgDef& arg_def) {
auto expected_place = pt::TransToFluidPlace(arg_def.backend);

Expand Down Expand Up @@ -122,7 +122,7 @@ std::shared_ptr<pt::TensorInterface> InputVariableToPtTensor(
return nullptr;
}

std::shared_ptr<pt::TensorInterface> OutputVariableToPtTensor(
std::shared_ptr<tcmpt::TensorBase> OutputVariableToPtTensor(
framework::Variable* variable, const pt::TensorArgDef& arg_def) {
// mutable_data before run kernel, to avoid share output form
// KernelContext to original tensor
Expand Down
10 changes: 8 additions & 2 deletions paddle/fluid/framework/tcmpt_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,15 @@ std::shared_ptr<PtTensorImplT> MakeTensorImpl(const Tensor& tensor,
const platform::Place& place,
proto::VarType::Type type);

std::shared_ptr<pt::TensorInterface> InputVariableToPtTensor(
template <typename PtTensorImplT>
void ShareTensorImpl(PtTensorImplT* tensor_impl, LoDTensor* out);

template <typename PtTensorImplT>
void ShareTensorImpl(PtTensorImplT* tensor_impl, Tensor* out);

std::shared_ptr<tcmpt::TensorBase> InputVariableToPtTensor(
const framework::Variable& variable, const pt::TensorArgDef& arg_def);
std::shared_ptr<pt::TensorInterface> OutputVariableToPtTensor(
std::shared_ptr<tcmpt::TensorBase> OutputVariableToPtTensor(
framework::Variable* variable, const pt::TensorArgDef& arg_def);

/* Kernel Key translate */
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/tcmpt_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ TEST(TcmptUtils, MakeTensor) {
ASSERT_EQ(dense_x->data<float>()[0], expect_value[0]);
ASSERT_EQ(dense_x->data<float>()[1], expect_value[1]);
ASSERT_EQ(dense_x->backend(), pt::Backend::kCPU);
ASSERT_EQ(dense_x->type(), pt::DataType::kFLOAT32);
ASSERT_EQ(dense_x->data_type(), pt::DataType::kFLOAT32);
}

TEST(TcmptUtils, VarToPtTensor) {
Expand All @@ -60,7 +60,7 @@ TEST(TcmptUtils, VarToPtTensor) {
auto tensor_x = InputVariableToPtTensor(v, tensor_def);
// 3. check result
ASSERT_EQ(tensor_x->backend(), expect_backend);
ASSERT_EQ(tensor_x->type(), pt::DataType::kINT32);
ASSERT_EQ(tensor_x->data_type(), pt::DataType::kINT32);
}

} // namespace framework
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ static pt::KernelContext BuildDygraphPtKernelContext(
auto& in_def = input_defs.at(i);
auto& ins_vector = ins.at(input_names[i]);

std::vector<std::shared_ptr<pt::TensorInterface>> tmp_inputs;
std::vector<std::shared_ptr<tcmpt::TensorBase>> tmp_inputs;
for (auto var : ins_vector) {
const auto& variable = var->Var();

Expand All @@ -302,7 +302,7 @@ static pt::KernelContext BuildDygraphPtKernelContext(
auto& out_def = output_defs.at(i);
auto& outs_vector = outs.at(output_names[i]);

std::vector<std::shared_ptr<pt::TensorInterface>> tmp_outputs;
std::vector<std::shared_ptr<tcmpt::TensorBase>> tmp_outputs;
for (auto var : outs_vector) {
auto* variable = var->MutableVar();

Expand Down
181 changes: 181 additions & 0 deletions paddle/tcmpt/common/data_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace experimental {

using complex64 = ::paddle::platform::complex<float>;
using complex128 = ::paddle::platform::complex<double>;
using float16 = ::paddle::platform::float16;
using bfloat16 = ::paddle::platform::bfloat16;

enum class DataType {
kUndef = 0,
kBOOL,
kINT8, // Char
kUINT8, // BYte
kINT16,
kINT32,
kUINT32,
kINT64,
kUINT64,
kBFLOAT16,
kFLOAT16,
kUINT16,
kFLOAT32,
kFLOAT64,
kCOMPLEX64,
kCOMPLEX128,
kNumDataTypes
};

inline size_t SizeOf(DataType data_type) {
switch (data_type) {
case DataType::kBOOL:
case DataType::kUINT8:
case DataType::kINT8:
return 1;
case DataType::kFLOAT16:
case DataType::kINT16:
case DataType::kUINT16:
return 2;
case DataType::kFLOAT32:
case DataType::kINT32:
case DataType::kUINT32:
return 4;
case DataType::kFLOAT64:
case DataType::kINT64:
case DataType::kUINT64:
return 8;
case DataType::kUndef:
case DataType::kBFLOAT16:
case DataType::kCOMPLEX64:
case DataType::kCOMPLEX128:
case DataType::kNumDataTypes:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type %d is not supported by tensor.",
static_cast<int>(data_type)));
return 0;
}
}

#define PT_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::kBOOL) \
_(int8_t, DataType::kINT8) \
_(uint8_t, DataType::kUINT8) \
_(int16_t, DataType::kINT16) \
_(int, DataType::kINT32) \
_(int64_t, DataType::kINT64) \
_(bfloat16, DataType::kBFLOAT16) \
_(float16, DataType::kFLOAT16) \
_(float, DataType::kFLOAT32) \
_(double, DataType::kFLOAT64) \
_(complex64, DataType::kCOMPLEX64) \
_(complex128, DataType::kCOMPLEX128)

template <DataType T>
struct DataTypeToCppType;

template <typename T>
struct CppTypeToDataType;

#define PT_SPECIALIZE_DataTypeToCppType(cpp_type, data_type) \
template <> \
struct DataTypeToCppType<data_type> { \
using type = cpp_type; \
};

PT_FOR_EACH_DATA_TYPE(PT_SPECIALIZE_DataTypeToCppType)

#undef PT_SPECIALIZE_DataTypeToCppType

#define PT_SPECIALIZE_CppTypeToDataType(cpp_type, data_type) \
template <> \
struct CppTypeToDataType<cpp_type> { \
constexpr static DataType Type() { return data_type; } \
};

PT_FOR_EACH_DATA_TYPE(PT_SPECIALIZE_CppTypeToDataType)

#undef PT_SPECIALIZE_CppTypeToDataType

inline std::ostream& operator<<(std::ostream& os, DataType dtype) {
switch (dtype) {
case DataType::kUndef:
os << "Undefined";
break;
case DataType::kBOOL:
os << "bool";
break;
case DataType::kINT8:
os << "int8";
break;
case DataType::kUINT8:
os << "uint8";
break;
case DataType::kINT16:
os << "int16";
break;
case DataType::kINT32:
os << "int32";
break;
case DataType::kINT64:
os << "int64";
break;
case DataType::kBFLOAT16:
os << "bfloat16";
break;
case DataType::kFLOAT16:
os << "float16";
break;
case DataType::kFLOAT32:
os << "float32";
break;
case DataType::kFLOAT64:
os << "float64";
break;
case DataType::kCOMPLEX64:
os << "complex64";
break;
case DataType::kCOMPLEX128:
os << "complex128";
break;
default:
// TODO(chenweihang): change to enforce later
throw std::runtime_error("Invalid DataType type.");
}
return os;
}

inline DataType& operator++(DataType& dtype, int) {
dtype =
DataType(static_cast<std::underlying_type<DataType>::type>(dtype) + 1);
return dtype;
}

} // namespace experimental
} // namespace paddle

namespace pt {
using DataType = paddle::experimental::DataType;
}
26 changes: 21 additions & 5 deletions paddle/tcmpt/core/layout.cc → paddle/tcmpt/common/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/tcmpt/core/layout.h"
#pragma once

namespace pt {
namespace paddle {
namespace experimental {

enum class DataLayout {
kUndef = 0,
kAny,
kNHWC,
kNCHW,
kMKLDNN,
kNumLayouts,
};

std::ostream& operator<<(std::ostream& os, DataLayout dtype) {
inline std::ostream& operator<<(std::ostream& os, DataLayout dtype) {
switch (dtype) {
case DataLayout::kUndef:
os << "Undefined";
Expand All @@ -40,9 +50,15 @@ std::ostream& operator<<(std::ostream& os, DataLayout dtype) {
return os;
}

DataLayout& operator++(DataLayout& layout, int) {
inline DataLayout& operator++(DataLayout& layout, int) {
layout = DataLayout(
static_cast<std::underlying_type<DataLayout>::type>(layout) + 1);
return layout;
}
} // namespace pt

} // namespace experimental
} // namespace paddle

namespace pt {
using DataLayout = paddle::experimental::DataLayout;
}
10 changes: 4 additions & 6 deletions paddle/tcmpt/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@ ELSE()
ENDIF()

cc_library(backend SRCS backend.cc)
cc_library(dtype SRCS dtype.cc)
cc_library(layout SRCS layout.cc)

if(WITH_GPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend dtype layout gpu_info)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend gpu_info)
elseif(WITH_ROCM)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend dtype layout gpu_info)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend gpu_info)
else()
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend dtype layout)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place backend)
endif()
cc_library(dense_tensor SRCS dense_tensor.cc DEPS enforce data_type ddim allocator place convert_utils ${MKLDNN_CTX_DEPS})

cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce backend dtype layout)
cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce backend)
cc_library(kernel_context SRCS kernel_context.cc DEPS enforce device_context)
19 changes: 19 additions & 0 deletions paddle/tcmpt/core/allocator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/tcmpt/core/allocator.h"

namespace paddle {
namespace tcmpt {} // namespace tcmpt
} // namespace paddle
Loading

0 comments on commit b77d1ee

Please sign in to comment.