Skip to content

Commit

Permalink
Refactor flatten kernel (PaddlePaddle#12)
Browse files Browse the repository at this point in the history
* refactor flatten kernel

* update infershape function

* fix compile bugs

* fix bugs when merge

* fix compiler bugs

* fix bugs when run test_flatten_api

* fix bugs when run test
  • Loading branch information
YuanRisheng authored Oct 15, 2021
1 parent 2309149 commit 6ce92e5
Show file tree
Hide file tree
Showing 36 changed files with 1,154 additions and 40 deletions.
24 changes: 22 additions & 2 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1286,9 +1286,23 @@ static bool ContainHostTensor(const proto::OpProto& op_proto,
return false;
}

// TODO(yuanrisheng): enhance rules, for get kernel that contains Intermediate
// Tensor
static bool ContainMidOutputTensor(const proto::OpProto& op_proto,
const VariableValueMap& outputs) {
for (int i = 0; i < op_proto.outputs_size(); ++i) {
auto output = op_proto.outputs()[i];
if (output.has_intermediate() && output.intermediate()) {
return IsValidVar(output.name(), outputs);
}
}
return false;
}

static pt::KernelName ConstructPtKernelName(const std::string& op_type,
const proto::OpProto& op_proto,
const VariableValueMap& inputs) {
const VariableValueMap& inputs,
const VariableValueMap& outputs) {
std::string overload_name;
// TODO(chenweihang): adapt SelectedRows by xiaowei's design
if (ContainHostTensor(op_proto, inputs)) {
Expand All @@ -1297,6 +1311,12 @@ static pt::KernelName ConstructPtKernelName(const std::string& op_type,
}
overload_name += pt::kContainHostTensorSuffix;
}
if (ContainMidOutputTensor(op_proto, outputs)) {
if (overload_name != "") {
overload_name += ".";
}
overload_name += pt::kContainMidOutputTensorSuffix;
}
return pt::KernelName(op_type, overload_name);
}

Expand All @@ -1305,7 +1325,7 @@ void OperatorWithKernel::ChoosePtKernel(
// 1. construct operation name
// TODO(chenweihang): add rules for construct op name
auto kernel_name =
ConstructPtKernelName(Type(), *(Info().proto_), ctx.inputs);
ConstructPtKernelName(Type(), *(Info().proto_), ctx.inputs, ctx.outputs);

// 2. construct op kernel key
pt_kernel_key_.reset(new pt::KernelKey(
Expand Down
5 changes: 3 additions & 2 deletions paddle/tcmpt/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# declare_module(MathCUDA)

set(TCMPT_DEPS convert_utils dense_tensor kernel_factory kernel_context)
set(TCMPT_DEPS ${TCMPT_DEPS} math_cpu linalg_cpu creation_cpu)
set(TCMPT_DEPS ${TCMPT_DEPS} math_cpu linalg_cpu creation_cpu manipulation_cpu)
set(TCMPT_DEPS ${TCMPT_DEPS} unary binary)
if(WITH_GPU OR WITH_ROCM)
set(TCMPT_DEPS ${TCMPT_DEPS} math_cuda linalg_cuda creation_cuda)
set(TCMPT_DEPS ${TCMPT_DEPS} math_cuda linalg_cuda creation_cuda manipulation_cuda)
endif()

cc_library(tcmpt SRCS all.cc DEPS ${TCMPT_DEPS})
1 change: 1 addition & 0 deletions paddle/tcmpt/api/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ limitations under the License. */
#include "paddle/tcmpt/api/include/creation.h"
#include "paddle/tcmpt/api/include/infershape.h"
#include "paddle/tcmpt/api/include/linalg.h"
#include "paddle/tcmpt/api/include/manipulation.h"
#include "paddle/tcmpt/api/include/math.h"
1 change: 1 addition & 0 deletions paddle/tcmpt/api/include/infershape.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ limitations under the License. */
#pragma once

// See Note: [ How do we organize the kernel directory ]
#include "paddle/tcmpt/infershape/binary.h"
#include "paddle/tcmpt/infershape/unary.h"
19 changes: 19 additions & 0 deletions paddle/tcmpt/api/include/manipulation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

// See Note: [ How do we organize the kernel directory ]
#include "paddle/tcmpt/cpu/manipulation.h"
#include "paddle/tcmpt/cuda/manipulation.h"
6 changes: 6 additions & 0 deletions paddle/tcmpt/core/dtype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,10 @@ std::ostream& operator<<(std::ostream& os, DataType dtype) {
return os;
}

DataType& operator++(DataType& dtype, int) {
dtype =
DataType(static_cast<std::underlying_type<DataType>::type>(dtype) + 1);
return dtype;
}

} // namespace pt
4 changes: 3 additions & 1 deletion paddle/tcmpt/core/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ enum class DataType {
kFLOAT64,
kCOMPLEX64,
kCOMPLEX128,
kNumDataTypes,
kNumDataTypes
};

std::ostream& operator<<(std::ostream& os, DataType dtype);

DataType& operator++(DataType& dtype, int);

#define PT_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::kBOOL) \
_(int8_t, DataType::kINT8) \
Expand Down
2 changes: 2 additions & 0 deletions paddle/tcmpt/core/kernel_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@ constexpr char kContainHostTensorSuffix[] = "host";
// For kernels with SelectedRowsTensor input and output
constexpr char kContainSelectedRowsSuffix[] = "sr";

// For kernels with intermediate output
constexpr char kContainMidOutputTensorSuffix[] = "mid";
} // namespace pt
73 changes: 73 additions & 0 deletions paddle/tcmpt/core/kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,58 @@ struct KernelRegistrar {
KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn,
KernelFn kernel_fn) {
ConstructKernel(kernel_name_cstr,
backend,
layout,
dtype,
args_parse_fn,
args_def_fn,
kernel_fn);
}

KernelRegistrar(const char* kernel_name_cstr,
Backend backend,
DataLayout layout,
KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn,
KernelFn kernel_fn) {
if (layout == DataLayout::kAny) {
for (DataLayout layout_iter = DataLayout::kNHWC;
layout_iter != DataLayout::kNumLayouts;
layout_iter++) {
for (DataType dtype = DataType::kBOOL; dtype != DataType::kNumDataTypes;
dtype++) {
ConstructKernel(kernel_name_cstr,
backend,
layout_iter,
dtype,
args_parse_fn,
args_def_fn,
kernel_fn);
}
}
} else {
for (DataType dtype = DataType::kBOOL; dtype != DataType::kNumDataTypes;
dtype++) {
ConstructKernel(kernel_name_cstr,
backend,
layout,
static_cast<DataType>(dtype),
args_parse_fn,
args_def_fn,
kernel_fn);
}
}
}

private:
void ConstructKernel(const char* kernel_name_cstr,
Backend backend,
DataLayout layout,
DataType dtype,
KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn,
KernelFn kernel_fn) {
KernelName kernel_name(kernel_name_cstr);
KernelKey kernel_key(backend, layout, dtype);
Kernel kernel(kernel_fn);
Expand Down Expand Up @@ -549,4 +601,25 @@ struct KernelRegistrar {
void PT_CONCATENATE(__PT_KERNEL_for_test_args_def_FN_, \
func_id)(::pt::Kernel * kernel)

#define PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, backend, layout, meta_kernel_fn) \
_PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, PT_ID, backend, layout, meta_kernel_fn)

#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, func_id, backend, layout, meta_kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
decltype(meta_kernel_fn) meta_kernel_fn; \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pt::Kernel*); \
static const ::pt::KernelRegistrar __reg_pt_op_kernel_##func_id( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pt::KernelArgsParseFunctor<decltype(&meta_kernel_fn)>::Parse, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \
PT_KERNEL(meta_kernel_fn)); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id)(::pt::Kernel * kernel)
} // namespace pt
5 changes: 5 additions & 0 deletions paddle/tcmpt/core/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,9 @@ std::ostream& operator<<(std::ostream& os, DataLayout dtype) {
return os;
}

DataLayout& operator++(DataLayout& layout, int) {
layout = DataLayout(
static_cast<std::underlying_type<DataLayout>::type>(layout) + 1);
return layout;
}
} // namespace pt
2 changes: 2 additions & 0 deletions paddle/tcmpt/core/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ enum class DataLayout {

std::ostream& operator<<(std::ostream& os, DataLayout dtype);

DataLayout& operator++(DataLayout& layout, int);

} // namespace pt
2 changes: 1 addition & 1 deletion paddle/tcmpt/core/tensor_meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace pt {
*/
// using LoD = std::vector<paddle::framework::Vector<size_t>>;
using LoD = std::vector<std::vector<size_t>>;

using DDim = paddle::framework::DDim;
/**
* The Meta data member of DenseTensor.
*
Expand Down
2 changes: 2 additions & 0 deletions paddle/tcmpt/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ endif()
cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory eigen_function)
cc_library(linalg_cpu SRCS linalg.cc DEPS dense_tensor kernel_context kernel_factory)
cc_library(creation_cpu SRCS creation.cc DEPS dense_tensor kernel_context kernel_factory eigen_function)
cc_library(utils_cpu SRCS utils.cc DEPS dense_tensor kernel_context kernel_factory memory)
cc_library(manipulation_cpu SRCS manipulation.cc DEPS dense_tensor kernel_context kernel_factory utils_cpu unary)
81 changes: 81 additions & 0 deletions paddle/tcmpt/cpu/manipulation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/tcmpt/cpu/manipulation.h"
#include "paddle/tcmpt/cpu/utils.h"
#include "paddle/tcmpt/infershape/unary.h"

namespace pt {

template <typename T>
void Flatten(const CPUContext& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out) {
auto out_meta = FlattenInferShape(x.meta(), start_axis, stop_axis);
pt::Copy(dev_ctx, x, out);
out->mutable_meta()->lod = out_meta.lod;
out->Resize(out_meta.dims);
}

// TODO(yuanrisheng): this kernel is for training and xshape is a Intermediate
// Output Tensor,
// is there a more flexible way to deal with this case?
template <typename T>
void FlattenWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out,
DenseTensor* xshape) {
Flatten<T>(dev_ctx, x, start_axis, stop_axis, out);
const auto& in_dims = x.meta().dims;
std::vector<int64_t> xshape_dims(in_dims.size() + 1);
xshape_dims[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
xshape_dims[i + 1] = in_dims[i];
}
xshape->mutable_meta()->dims = paddle::framework::make_ddim(xshape_dims);
xshape->mutable_meta()->lod = x.meta().lod;
}

} // namespace pt

// TODO(chenweihang): replace by better impl
PT_REGISTER_MODULE(ManipulationCPU);

// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL("flatten_contiguous_range",
CPU,
NCHW,
pt::Flatten,
float,
double,
uint8_t,
int8_t,
int,
int64_t) {}

PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
CPU,
NCHW,
pt::FlattenWithXShape,
float,
double,
uint8_t,
int8_t,
int,
int64_t) {}
34 changes: 34 additions & 0 deletions paddle/tcmpt/cpu/manipulation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/tcmpt/core/dense_tensor.h"
#include "paddle/tcmpt/core/kernel_registry.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device_context.h"

namespace pt {

using CPUContext = paddle::platform::CPUDeviceContext;

template <typename T>
void Flatten(const CPUContext& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out);

} // namespace pt
Loading

0 comments on commit 6ce92e5

Please sign in to comment.