Skip to content

Commit

Permalink
[pten] remove in_type arg in cast kernel (#38486)
Browse files Browse the repository at this point in the history
* remove intype arg in cast kernel

* modify conj config in api.yaml by dictionary order

* rm unused code in cast_kernel.cu
  • Loading branch information
MingMingShangTian authored Dec 28, 2021
1 parent 78836bb commit 0637b9a
Show file tree
Hide file tree
Showing 11 changed files with 24 additions and 30 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/operators/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ class CastOp : public framework::OperatorWithKernel {
#endif
return framework::OpKernelType(tensor->type(), tensor_place);
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
return framework::KernelSignature("cast", {"X"}, {"out_dtype"}, {"Out"});
}
};

} // namespace operators
Expand Down
7 changes: 1 addition & 6 deletions paddle/fluid/operators/cast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ class CastOpKernel : public framework::OpKernel<InT> {
auto* out = context.Output<framework::Tensor>("Out");

auto out_dtype = context.Attr<int>("out_dtype");
// todo: not used in_dtype
auto in_dtype = context.Attr<int>("in_dtype");

auto& dev_ctx = context.device_context<DeviceContext>();
out->mutable_data(dev_ctx.GetPlace(),
Expand All @@ -71,12 +69,9 @@ class CastOpKernel : public framework::OpKernel<InT> {

auto pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
auto pt_in_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(in_dtype));

// call new kernel
pten::Cast<InT>(dev_ctx, *pt_x.get(), pt_out_dtype, pt_in_dtype,
pt_out.get());
pten::Cast<InT>(dev_ctx, *pt_x.get(), pt_out_dtype, pt_out.get());
}
};

Expand Down
6 changes: 4 additions & 2 deletions paddle/pten/api/include/kernel_signature.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ using add_kernel = void (*)(const DeviceContext&,
int,
DenseTensor*);

using cast_kernel = void (*)(
const DeviceContext&, const DenseTensor&, DataType, DataType, DenseTensor*);
using cast_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
DataType,
DenseTensor*);

using divide_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
Expand Down
5 changes: 2 additions & 3 deletions paddle/pten/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,13 @@ DenseTensor Flatten(const ContextT& dev_ctx,
template <typename T, typename ContextT>
DenseTensor Cast(const ContextT& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DataType in_dtype) {
DataType out_dtype) {
auto out_meta = CastInferMeta(x.meta(), out_dtype);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Cast<T, ContextT>(dev_ctx, x, out_dtype, in_dtype, &dense_out);
Cast<T, ContextT>(dev_ctx, x, out_dtype, &dense_out);
return dense_out;
}

Expand Down
1 change: 0 additions & 1 deletion paddle/pten/kernels/cast_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ template <typename T, typename ContextT>
void Cast(const ContextT& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DataType in_dtype,
DenseTensor* out);

} // namespace pten
1 change: 0 additions & 1 deletion paddle/pten/kernels/cpu/cast_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ template <typename T, typename ContextT>
void Cast(const ContextT& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DataType in_dtype,
DenseTensor* out) {
PD_VISIT_ALL_TYPES(out_dtype, "CastKernelImpl", ([&] {
CastKernelImpl<T, data_t>(dev_ctx, x, out);
Expand Down
3 changes: 0 additions & 3 deletions paddle/pten/kernels/gpu/cast_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pten/kernels/cast_kernel.h"

#include "paddle/pten/api/ext/dispatch.h"
Expand Down Expand Up @@ -84,7 +82,6 @@ template <typename T, typename ContextT>
void Cast(const ContextT& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DataType in_dtype,
DenseTensor* out) {
PD_VISIT_ALL_TYPES(out_dtype, "CastCUDAKernelImpl", ([&] {
CastCUDAKernelImpl<T, data_t>(dev_ctx, x, out);
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/hybird/cuda/reduce/reduce_cuda_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x,
AsyncCopy(x, y);
y->Resize(out_dims);
} else {
pten::Cast<Tx>(*dev_ctx, x, y->dtype(), x.dtype(), y);
pten::Cast<Tx>(*dev_ctx, x, y->dtype(), y);
}
return;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/hybird/general/reduce_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void Reduce(const DeviceContext& dev_ctx,
pten::DenseTensorMeta(out_dtype, x.dims(), x.layout()));

// cast x tensor to out_dtype
pten::Cast<T, DeviceContext>(dev_ctx, x, out_dtype, x.dtype(), &tmp_tensor);
pten::Cast<T, DeviceContext>(dev_ctx, x, out_dtype, &tmp_tensor);

// do reduce sum
PD_VISIT_ALL_TYPES(
Expand Down
4 changes: 1 addition & 3 deletions paddle/pten/tests/kernels/test_cast_dev_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,11 @@ TEST(DEV_API, cast) {
auto* dev_ctx = pool.Get(paddle::platform::CPUPlace());

pten::DataType out_dtype = pten::DataType::FLOAT64;
pten::DataType in_dtype = pten::DataType::FLOAT32;
// 2. test API
auto out = pten::Cast<float>(
*(static_cast<paddle::platform::CPUDeviceContext*>(dev_ctx)),
dense_x,
out_dtype,
in_dtype);
out_dtype);

// 3. check result
ASSERT_EQ(out.dims().size(), 2);
Expand Down
18 changes: 9 additions & 9 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@
func : CastInferMeta
kernel :
func : cast
param : [x, out_dtype, x.dtype()]
param : [x, out_dtype]
data_type : x

- api : conj
args : (const Tensor& x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : conj

- api : divide
args : (const Tensor& x, const Tensor& y)
output : Tensor
Expand Down Expand Up @@ -171,11 +179,3 @@
args : (const Tensor& x, DataType dtype=DataType::UNDEFINED, Backend place=Backend::UNDEFINED, DataLayout layout=DataLayout::UNDEFINED)
output : Tensor
invoke : full_like(x, 0, dtype, place, layout)

- api : conj
args : (const Tensor& x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
kernel :
func : conj

0 comments on commit 0637b9a

Please sign in to comment.