From 42a0947ef3912c28272fe945944cd80c670b54c3 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Wed, 26 Jan 2022 16:58:55 +0800 Subject: [PATCH] [Move selected_rows PR #5] VisitDataType use Pten::DataType (#39236) * Added selected_rows and rw_lock to pten * Renamed the unit test target to fix CI * Removed Class SelectedRows in Fluid, changed include/cmake relationship, use pten::SelectedRows in Fluid * Remove rw_lock.h,rw_lock_test.cc in fluid * Use pten::RWLock and pten::AutoRDLock, fix CI * Use pten::SelectedRows * Use pten::SelectedRows * Fix to pass NPU CI * Selected_Rows inherits from TensorBase * Use pten::SelectedRows, to pass NPU CI * To fix NPU CI * To fix NPU CI again * Use paddle/pten/core/enforce and polish code * Use pten::DataType instead of using proto_type * Move part of data_type to pten * Polish Code --- paddle/pten/core/selected_rows.cc | 20 +++++----- paddle/pten/core/utils/data_type.h | 63 ++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 11 deletions(-) create mode 100644 paddle/pten/core/utils/data_type.h diff --git a/paddle/pten/core/selected_rows.cc b/paddle/pten/core/selected_rows.cc index 6f64602bdcf4d..1dfcfa49347b5 100644 --- a/paddle/pten/core/selected_rows.cc +++ b/paddle/pten/core/selected_rows.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/pten/core/selected_rows.h" - -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/framework/data_type.h" +#include "paddle/pten/core/utils/data_type.h" namespace pten { @@ -191,16 +189,16 @@ void SelectedRows::Get(const pten::DenseTensor& ids, int64_t index = AutoGrownIndex(id, auto_grown, is_test); if (index < 0) { VLOG(5) << "id " << id << " not in the table, return 0"; - paddle::framework::VisitDataType( - value_->type(), + pten::VisitDataType( + value_->dtype(), TensorFillVisitor(value, i * value_width, value_width, 0.0)); } else { - paddle::framework::VisitDataType(value_->type(), - TensorCopyVisitor(value, - i * value_width, - *value_.get(), - index * value_width, - value_width)); + pten::VisitDataType(value_->dtype(), + TensorCopyVisitor(value, + i * value_width, + *value_.get(), + index * value_width, + value_width)); } } } diff --git a/paddle/pten/core/utils/data_type.h b/paddle/pten/core/utils/data_type.h new file mode 100644 index 0000000000000..ee223afb3b03c --- /dev/null +++ b/paddle/pten/core/utils/data_type.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/pten/common/data_type.h" +#include "paddle/pten/core/enforce.h" +#include "paddle/pten/kernels/funcs/eigen/extensions.h" + +namespace pten { + +#define _PtenForEachDataTypeHelper_(callback, cpp_type, data_type) \ + callback(cpp_type, data_type); + +#define _PtenForEachDataType_(callback) \ + _PtenForEachDataTypeHelper_(callback, float, DataType::FLOAT32); \ + _PtenForEachDataTypeHelper_( \ + callback, ::paddle::platform::float16, DataType::FLOAT16); \ + _PtenForEachDataTypeHelper_( \ + callback, ::paddle::platform::bfloat16, DataType::BFLOAT16); \ + _PtenForEachDataTypeHelper_(callback, double, DataType::FLOAT64); \ + _PtenForEachDataTypeHelper_(callback, int, DataType::INT32); \ + _PtenForEachDataTypeHelper_(callback, int64_t, DataType::INT64); \ + _PtenForEachDataTypeHelper_(callback, bool, DataType::BOOL); \ + _PtenForEachDataTypeHelper_(callback, uint8_t, DataType::UINT8); \ + _PtenForEachDataTypeHelper_(callback, int16_t, DataType::INT16); \ + _PtenForEachDataTypeHelper_(callback, int8_t, DataType::INT8); \ + _PtenForEachDataTypeHelper_( \ + callback, ::paddle::platform::complex, DataType::COMPLEX64); \ + _PtenForEachDataTypeHelper_( \ + callback, ::paddle::platform::complex, DataType::COMPLEX128); + +template +inline void VisitDataType(pten::DataType type, Visitor visitor) { +#define PtenVisitDataTypeCallback(cpp_type, data_type) \ + do { \ + if (type == data_type) { \ + visitor.template apply(); \ + return; \ + } \ + } while (0) + + _PtenForEachDataType_(PtenVisitDataTypeCallback); +#undef PtenVisitDataTypeCallback + PADDLE_THROW(pten::errors::Unimplemented( + "Not supported proto::VarType::Type(%d) as data type.", + static_cast(type))); +} +} // namespace pten