diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 0cd21c942e03a..e626669d2a73a 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -26,6 +26,35 @@ function(find_register FILENAME PATTERN OUTPUT) PARENT_SCOPE) endfunction() +function(find_phi_register FILENAME ADD_PATH) + # set op_name to OUTPUT + set(options "") + set(oneValueArgs "") + set(multiValueArgs "") + file(READ ${FILENAME} CONTENT) + + string( + REGEX + MATCH + "PD_REGISTER_KERNEL\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z]*,[ \\\t\r\n]*[A-Z_]*" + register + "${CONTENT}") + if(NOT register STREQUAL "") + string(REPLACE "PD_REGISTER_KERNEL(" "" register "${register}") + string(REPLACE "," ";" register "${register}") + string(REGEX REPLACE "[ \\\t\r\n]+" "" register "${register}") + string(REGEX REPLACE "//cuda_only" "" register "${register}") + list(GET register 0 kernel_name) + list(GET register 1 kernel_backend) + list(GET register 2 kernel_layout) + + file( + APPEND ${ADD_PATH} + "PD_DECLARE_KERNEL(${kernel_name}, ${kernel_backend}, ${kernel_layout});\n" + ) + endif() +endfunction() + function(op_library TARGET) # op_library is a function to create op library. The interface is same as # cc_library. But it handle split GPU/CPU code and link some common library @@ -371,6 +400,8 @@ function(op_library TARGET) foreach(cc_src ${cc_srcs}) # pybind USE_OP_ITSELF set(op_name "") + # Add PHI Kernel Registry Message + find_phi_register(${cc_src} ${pybind_file}) find_register(${cc_src} "REGISTER_OPERATOR" op_name) if(NOT ${op_name} EQUAL "") file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n") @@ -408,6 +439,8 @@ function(op_library TARGET) # message("cu_srcs ${cu_srcs}") foreach(cu_src ${cu_srcs}) set(op_name "") + # Add PHI Kernel Registry Message + find_phi_register(${cu_src} ${pybind_file}) find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name) if(NOT ${op_name} EQUAL "") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n") diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index e48ec3ad0385b..71288a44c0969 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -115,7 +115,7 @@ proto_library(trainer_desc_proto SRCS trainer_desc.proto DEPS framework_proto cc_library( string_array SRCS string_array.cc - DEPS utf8proc) + DEPS utf8proc phi_enforce) cc_library( data_type @@ -233,7 +233,8 @@ cc_test( cc_library( var_type_traits SRCS var_type_traits.cc - DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor) + DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor + extended_tensor) if(WITH_GPU) target_link_libraries(var_type_traits dynload_cuda) endif() diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 32ae7f064e9cd..5cb1cca052902 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/op_call_stack.h" #include "paddle/fluid/framework/phi_utils.h" +#include "paddle/fluid/framework/raw_tensor.h" #include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/unused_var_check.h" @@ -3008,6 +3009,9 @@ void OperatorWithKernel::BuildPhiKernelContext( need_prepare_phi_data_ = true; tensor_in = &(var->Get()); phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); + } else if (var->IsType()) { + tensor_in = &(var->Get()); + phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported input `%s` type when call pt kernel.", @@ -3057,6 +3061,13 @@ void OperatorWithKernel::BuildPhiKernelContext( // Note: If the input LoDTensorArray size is 0, the output // LoDTensorArray is also 0 phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); + } else if (var->template IsType()) { + tensor_out = var->template GetMutable(); + phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); + } else if (!var->IsInitialized()) { + // The following is for RAW type of var + tensor_out = var->template GetMutable(); + phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported output `%s` type when call pt kernel.", @@ -3156,6 +3167,7 @@ void OperatorWithKernel::BuildPhiKernelContext( } } break; + case phi::AttributeType::SCALARS: { PADDLE_ENFORCE_NE( attr_iter, diff --git a/paddle/fluid/framework/raw_tensor.h b/paddle/fluid/framework/raw_tensor.h new file mode 100644 index 0000000000000..dfee5acd14af0 --- /dev/null +++ b/paddle/fluid/framework/raw_tensor.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/extended_tensor.h" +#include "paddle/utils/any.h" + +namespace paddle { +namespace framework { + +/// \brief Fluid Kernel and PHI Kernel will be unified in the future. +/// So, we need a class in PHI that can represent the RAW type in Fluid. +/// The RawTensor is for PHI Kernel that has RAW type arguments. +class RawTensor : public phi::ExtendedTensor, + public phi::TypeInfoTraits { + public: + RawTensor() = default; + + RawTensor(RawTensor&& other) = default; + + RawTensor(const RawTensor& other) = default; + + RawTensor& operator=(RawTensor&& other) = default; + + /// \brief Destroy the RawTensor and release exclusive resources. + virtual ~RawTensor() = default; + + public: + /// \brief Returns the name of the class for type traits. + /// \return The name of the class. + static const char* name() { return "RawTensor"; } + + template + T* GetMutable() { + if (!data_.empty()) { + try { + return paddle::any_cast(data_); + } catch (paddle::bad_any_cast&) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Invalid data type error, expected %s, actual %s.", + typeid(T).name(), + data_type_.name())); + } + } + T* created_data = new T(); + data_ = created_data; + data_deleter_ = [created_data]() { delete created_data; }; + data_type_ = std::type_index(typeid(T)); + return created_data; + } + + template + bool IsType() const { + return std::type_index(typeid(T)) == data_type_; + } + + private: + paddle::any data_; + std::function data_deleter_; + std::type_index data_type_ = std::type_index(typeid(void)); +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/string_array.h b/paddle/fluid/framework/string_array.h old mode 100755 new mode 100644 index b874fbac4c9e7..4ac8d89981bee --- a/paddle/fluid/framework/string_array.h +++ b/paddle/fluid/framework/string_array.h @@ -20,13 +20,82 @@ limitations under the License. */ #include #include #include +#include "paddle/phi/core/extended_tensor.h" namespace paddle { namespace framework { +class Vocab : public phi::ExtendedTensor, + public phi::TypeInfoTraits { + public: + Vocab() = default; + + Vocab(Vocab&& other) = default; + + Vocab(const Vocab& other) = default; + + Vocab& operator=(const Vocab& other) = default; + + Vocab& operator=(Vocab&& other) = default; + + Vocab& operator=( + const std::unordered_map& other) { + this->data_ = other; + return *this; + } + + /// \brief Destroy the Vocab and release exclusive resources. + virtual ~Vocab() = default; + + public: + /// \brief Returns the name of the class for type traits. + /// \return The name of the class. + static const char* name() { return "Vocab"; } + + size_t size() const { return data_.size(); } + + void clear() { data_.clear(); } + + void emplace(const std::wstring& key, std::int32_t value) { + data_.emplace(key, value); + } + + std::int32_t at(const std::wstring& key) { return data_.at(key); } + + std::int32_t at(const std::wstring& key) const { return data_.at(key); } + + std::unordered_map::iterator find( + const std::wstring& key) { + return data_.find(key); + } + + std::unordered_map::const_iterator find( + const std::wstring& key) const { + return data_.find(key); + } + + std::unordered_map::iterator begin() { + return data_.begin(); + } + + std::unordered_map::const_iterator begin() const { + return data_.begin(); + } + + std::unordered_map::iterator end() { + return data_.end(); + } + + std::unordered_map::const_iterator end() const { + return data_.end(); + } + + private: + std::unordered_map data_; +}; + using String = std::string; using Strings = std::vector; -using Vocab = std::unordered_map; // Convert the std::string type to the std::string type. bool ConvertStrToWstr(const std::string& src, std::wstring* res); diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc index 2a53e2f88523b..d73c9b7d95957 100644 --- a/paddle/fluid/framework/var_type_traits.cc +++ b/paddle/fluid/framework/var_type_traits.cc @@ -41,6 +41,7 @@ #include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif +#include "paddle/fluid/framework/raw_tensor.h" #include "paddle/fluid/operators/cuda_graph_with_in_out.h" namespace paddle { diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index eb3f90006a788..1e6c110e86a30 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/raw_tensor.h" #include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_CUDA @@ -219,7 +220,8 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< float, Vocab, std::vector, - std::vector>; + std::vector, + RawTensor>; template struct VarTypeTrait { static_assert(VarTypeRegistry::IsRegistered(), "Must be registered type"); diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index 48e2a9a9c6278..54a9fa60d9578 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -38,6 +38,7 @@ #if defined(PADDLE_WITH_XPU_BKCL) #include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #endif +#include "paddle/fluid/framework/raw_tensor.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/variable_test.cc b/paddle/fluid/framework/variable_test.cc index 22af9ae934e0c..358ba3436d01e 100644 --- a/paddle/fluid/framework/variable_test.cc +++ b/paddle/fluid/framework/variable_test.cc @@ -22,10 +22,10 @@ namespace framework { TEST(Variable, GetMutable) { std::unique_ptr v(new Variable()); - auto* t = v->GetMutable(); + auto* t = v->GetMutable(); *t = "1234"; - const auto& tt = v->Get(); + const auto& tt = v->Get(); EXPECT_EQ("1234", tt); try { diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 70f2830d12067..cc3ed77c391d8 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -5,7 +5,7 @@ cc_library( cc_library( var_helper SRCS var_helper.cc - DEPS tensor selected_rows) + DEPS tensor selected_rows extended_tensor) if(WITH_XPU) cc_library( prepared_operator diff --git a/paddle/fluid/jit/layer.cc b/paddle/fluid/jit/layer.cc index 9055120e4bbb7..75a7e282e6be8 100644 --- a/paddle/fluid/jit/layer.cc +++ b/paddle/fluid/jit/layer.cc @@ -89,7 +89,7 @@ std::vector Layer::FunctionNames() const { PD_SPECIALZE_ATTRIBUTE_TYPE(int) PD_SPECIALZE_ATTRIBUTE_TYPE(float) -PD_SPECIALZE_ATTRIBUTE_TYPE(std::string) +PD_SPECIALZE_ATTRIBUTE_TYPE(framework::String) PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector) PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector) PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector) diff --git a/paddle/fluid/jit/layer_test.cc b/paddle/fluid/jit/layer_test.cc index 7007693aa83e0..4e367d8cc1b51 100644 --- a/paddle/fluid/jit/layer_test.cc +++ b/paddle/fluid/jit/layer_test.cc @@ -86,7 +86,7 @@ TEST(CpuLayerTest, Construct) { int ds = layer.Attribute("down_sampling"); EXPECT_EQ(ds, 4); - std::string fstr = layer.Attribute("fstr"); + std::string fstr = layer.Attribute("fstr"); EXPECT_STREQ(fstr.c_str(), "save str property"); std::vector ints = layer.Attribute>("ints"); diff --git a/paddle/fluid/jit/property.cc b/paddle/fluid/jit/property.cc index 1cf303239baed..b0c943b24a6d7 100644 --- a/paddle/fluid/jit/property.cc +++ b/paddle/fluid/jit/property.cc @@ -97,7 +97,7 @@ std::unordered_map> Property::Values() { *var->GetMutable() = static_cast(GetInt64(n)); break; case ValueProto::STRING: - *var->GetMutable() = GetString(n); + *var->GetMutable() = GetString(n); break; case ValueProto::FLOATS: *var->GetMutable>() = GetFloats(n); diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 67070bb4cb6fa..882706dc8dd62 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -12,7 +12,7 @@ unset(OP_LIBRARY CACHE) set(pybind_file ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h.tmp CACHE INTERNAL "pybind.h file") set(pybind_file_prune ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h.prune CACHE INTERNAL "pybind.h file") set(pybind_file_final ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h) -file(WRITE ${pybind_file} "// Generated by the paddle/fluid/operators/CMakeLists.txt. DO NOT EDIT!\n\n") +file(WRITE ${pybind_file} "#include \"paddle/phi/core/kernel_registry.h\" // Generated by the paddle/fluid/operators/CMakeLists.txt. DO NOT EDIT!\n\n") add_subdirectory(math) add_subdirectory(controlflow) @@ -109,7 +109,7 @@ register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combin op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS}) target_link_libraries(run_program_op cuda_graph_with_memory_pool) op_library(quantize_linear_op DEPS phi) -op_library(save_combine_op DEPS string_array) +op_library(save_combine_op DEPS string_array phi) op_library(load_combine_op DEPS string_array) if (WITH_GPU OR WITH_ROCM) diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index 71f78911456f7..6d4e844d03ed8 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -16,6 +16,10 @@ limitations under the License. */ #include +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/core/kernel_registry.h" + namespace paddle { namespace operators { @@ -100,10 +104,22 @@ REGISTER_OPERATOR(save_combine, ops::SaveCombineOpProtoMaker, ops::SaveCombineOpInferVarType); -REGISTER_OP_CPU_KERNEL( - save_combine, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel); +PD_REGISTER_KERNEL(save_combine_tensor, + CPU, + ALL_LAYOUT, + paddle::operators::SaveCombineTensorKernel, + int, + int64_t, + float, + double, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(save_combine_vocab, + CPU, + ALL_LAYOUT, + paddle::operators::SaveCombineVocabKernel, + int, + int64_t, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/fluid/operators/save_combine_op.cu b/paddle/fluid/operators/save_combine_op.cu index e96aafa382978..94fb8f396a5fa 100644 --- a/paddle/fluid/operators/save_combine_op.cu +++ b/paddle/fluid/operators/save_combine_op.cu @@ -1,23 +1,35 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/fluid/operators/save_combine_op.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" -namespace ops = paddle::operators; +PD_REGISTER_KERNEL(save_combine_tensor, + GPU, + ALL_LAYOUT, + paddle::operators::SaveCombineTensorKernel, + int, + int64_t, + float, + double) {} -REGISTER_OP_CUDA_KERNEL(save_combine, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel, - ops::SaveCombineOpKernel); +PD_REGISTER_KERNEL(save_combine_vocab, + GPU, + ALL_LAYOUT, + paddle::operators::SaveCombineVocabKernel, + int, + int64_t, + float, + double) {} diff --git a/paddle/fluid/operators/save_combine_op.h b/paddle/fluid/operators/save_combine_op.h index fd54202a75d3f..bf5e2a5e4d90f 100644 --- a/paddle/fluid/operators/save_combine_op.h +++ b/paddle/fluid/operators/save_combine_op.h @@ -27,35 +27,161 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/raw_tensor.h" #include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/serialization.h" namespace paddle { namespace operators { + +inline void SaveToMemory(const std::string& file_path, + const std::ostringstream& ss, + bool save_to_memory, + std::string* output) { + if (save_to_memory) { + PADDLE_ENFORCE_NE(output, + nullptr, + phi::errors::InvalidArgument( + "Cannot find variable Y for save_combine_op")); + *output = ss.str(); + } else { + MkDirRecursively(DirName(file_path).c_str()); + std::ofstream fout(file_path, std::ios::binary); + PADDLE_ENFORCE_EQ(static_cast(fout), + true, + phi::errors::Unavailable( + "Cannot open %s to save variables.", file_path)); + fout << ss.str(); + fout.close(); + } +} + +template +void SaveCombineTensorKernel(const Context& dev_ctx, + const std::vector& x, + const std::string& file_path, + bool overwrite, + bool save_as_fp16, + bool save_to_memory, + phi::ExtendedTensor* out) { + std::string* y = nullptr; + if (out != nullptr) { + auto raw_out = static_cast(out); + y = raw_out->GetMutable(); + } + + bool is_present = FileExists(file_path); + if (is_present && !overwrite) { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "%s exists! Cannot save_combine to it when overwrite is set to " + "false.", + file_path, + overwrite)); + } + + std::ostringstream ss; + PADDLE_ENFORCE_GT(x.size(), + 0UL, + phi::errors::InvalidArgument( + "The number of variables to be saved is %d, expect " + "it to be greater than 0.", + x.size())); + + for (size_t i = 0; i < x.size(); i++) { + auto& tensor = *(x[i]); + PADDLE_ENFORCE_EQ( + tensor.IsInitialized(), + true, + phi::errors::InvalidArgument( + "The Tensor with Index (%d) to be saved is not initialized.", i)); + // Serialize tensors one by one + // Check types to see if a fp16 transformation is required + auto in_dtype = framework::TransToProtoVarType(tensor.dtype()); + auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; + if (in_dtype != out_dtype) { + auto place = dev_ctx.GetPlace(); + auto in_kernel_type = framework::OpKernelType(in_dtype, place); + auto out_kernel_type = framework::OpKernelType(out_dtype, place); + phi::DenseTensor out; + framework::TransDataType(in_kernel_type, out_kernel_type, tensor, &out); + // copy LoD info to the new tensor + out.set_lod(tensor.lod()); + phi::SerializeToStream(ss, out, dev_ctx); + } else { + phi::SerializeToStream(ss, tensor, dev_ctx); + } + } + + SaveToMemory(file_path, ss, save_to_memory, y); +} + +template +void SaveCombineVocabKernel( + const Context& dev_ctx, + const std::vector& inputs, + const std::string& file_path, + bool overwrite, + bool save_as_fp16, + bool save_to_memory, + phi::ExtendedTensor* out) { + std::string* y = nullptr; + if (out != nullptr) { + auto raw_out = static_cast(out); + y = raw_out->GetMutable(); + } + + std::vector x; + x.reserve(inputs.size()); + for (auto input : inputs) { + x.push_back(static_cast(input)); + } + bool is_present = FileExists(file_path); + if (is_present && !overwrite) { + PADDLE_THROW(phi::errors::PreconditionNotMet( + "%s exists! Cannot save_combine to it when overwrite is set to " + "false.", + file_path, + overwrite)); + } + + std::ostringstream ss; + PADDLE_ENFORCE_GT(x.size(), + 0UL, + phi::errors::InvalidArgument( + "The number of variables to be saved is %d, expect " + "it to be greater than 0.", + x.size())); + + for (size_t i = 0; i < x.size(); i++) { + auto& tensor = *(x[i]); + std::unordered_map data; + for (auto it = tensor.begin(); it != tensor.end(); ++it) { + std::string t; + paddle::framework::ConvertWstrToStr(it->first, &t); + data.emplace(t, it->second); + } + paddle::framework::StringMapToStream(ss, data); + } + + SaveToMemory(file_path, ss, save_to_memory, y); +} + template class SaveCombineOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { auto place = ctx.GetPlace(); auto filename = ctx.Attr("file_path"); auto overwrite = ctx.Attr("overwrite"); auto save_as_fp16 = ctx.Attr("save_as_fp16"); auto save_to_memory = ctx.Attr("save_to_memory"); - auto output = ctx.Output("Y"); - - bool is_present = FileExists(filename); - if (is_present && !overwrite) { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "%s exists! Cannot save_combine to it when overwrite is set to " - "false.", - filename, - overwrite)); - } - - std::ostringstream ss; + auto output = ctx.Output("Y"); auto inp_var_names = ctx.InputNames("X"); - auto &inp_vars = ctx.MultiInputVar("X"); + auto& inp_vars = ctx.MultiInputVar("X"); + PADDLE_ENFORCE_GT(inp_var_names.size(), 0UL, platform::errors::InvalidArgument( @@ -64,8 +190,8 @@ class SaveCombineOpKernel : public framework::OpKernel { inp_var_names.size())); // get device context from pool - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(place); for (size_t i = 0; i < inp_var_names.size(); i++) { PADDLE_ENFORCE_NOT_NULL( @@ -81,59 +207,32 @@ class SaveCombineOpKernel : public framework::OpKernel { "phi::DenseTensor or Vocab variable, %s has wrong type.", inp_var_names[i])); - if (inp_vars[i]->IsType()) { - auto &tensor = inp_vars[i]->Get(); - PADDLE_ENFORCE_EQ( - tensor.IsInitialized(), - true, - platform::errors::InvalidArgument( - "The Tensor of Variable(%s) to be saved is not initialized.", - inp_var_names[i])); - // Serialize tensors one by one - // Check types to see if a fp16 transformation is required - auto in_dtype = framework::TransToProtoVarType(tensor.dtype()); - auto out_dtype = - save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; - - if (in_dtype != out_dtype) { - auto in_kernel_type = framework::OpKernelType(in_dtype, place); - auto out_kernel_type = framework::OpKernelType(out_dtype, place); - phi::DenseTensor out; - // copy LoD info to the new tensor - out.set_lod(tensor.lod()); - framework::TransDataType( - in_kernel_type, out_kernel_type, tensor, &out); - framework::SerializeToStream(ss, out, dev_ctx); - } else { - framework::SerializeToStream(ss, tensor, dev_ctx); + if (inp_vars.size() > 0 && inp_vars[0]->IsType()) { + std::vector x(inp_vars.size()); + for (auto inp_var : inp_vars) { + x.push_back(&(inp_var->Get())); } + SaveCombineTensorKernel(dev_ctx, + x, + filename, + overwrite, + save_as_fp16, + save_to_memory, + output); } else { - auto &tensor = inp_vars[i]->Get(); - std::unordered_map data; - for (auto it = tensor.begin(); it != tensor.end(); ++it) { - std::string t; - framework::ConvertWstrToStr(it->first, &t); - data.emplace(t, it->second); + std::vector x(inp_vars.size()); + for (auto inp_var : inp_vars) { + x.push_back(&(inp_var->Get())); } - framework::StringMapToStream(ss, data); + SaveCombineVocabKernel(dev_ctx, + x, + filename, + overwrite, + save_as_fp16, + save_to_memory, + output); } } - if (save_to_memory) { - PADDLE_ENFORCE_NE(output, - nullptr, - platform::errors::InvalidArgument( - "Cannot find variable Y for save_combine_op")); - *output = ss.str(); - } else { - MkDirRecursively(DirName(filename).c_str()); - std::ofstream fout(filename, std::ios::binary); - PADDLE_ENFORCE_EQ(static_cast(fout), - true, - platform::errors::Unavailable( - "Cannot open %s to save variables.", filename)); - fout << ss.str(); - fout.close(); - } } }; diff --git a/paddle/fluid/operators/save_load_combine_op_test.cc b/paddle/fluid/operators/save_load_combine_op_test.cc index 5f305d2d6c0c5..14ff038376b43 100644 --- a/paddle/fluid/operators/save_load_combine_op_test.cc +++ b/paddle/fluid/operators/save_load_combine_op_test.cc @@ -20,8 +20,10 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/core/kernel_registry.h" -USE_CPU_ONLY_OP(save_combine); +USE_OP_ITSELF(save_combine); +PD_DECLARE_KERNEL(save_combine_tensor, CPU, ALL_LAYOUT); USE_CPU_ONLY_OP(load_combine); template diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 8c7b6296eb46e..25d2195c15cf2 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -29,6 +29,7 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/eager/hooks.h" #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/enforce.h" @@ -1493,7 +1494,7 @@ static PyObject* tensor_method_set_vocab(TensorObject* self, PyObject* args, PyObject* kwargs) { EAGER_TRY - using Vocab = std::unordered_map; + using Vocab = paddle::framework::Vocab; auto vocab = CastPyArg2Vocab(PyTuple_GET_ITEM(args, 0), 0); auto var_tensor = std::make_shared(); *var_tensor->GetMutable() = vocab; @@ -1524,7 +1525,7 @@ static PyObject* tensor_method_get_map_tensor(TensorObject* self, true, paddle::platform::errors::Fatal( "this method is only effective for VariableCompatTensor")); - using Vocab = std::unordered_map; + using Vocab = paddle::framework::Vocab; auto* var_tensor = static_cast(self->tensor.impl().get()); return ToPyObject(var_tensor->Get()); diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 7c9faf2fd593e..1eca3dd1d91ce 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -590,11 +590,12 @@ paddle::framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, return dtype; } -std::unordered_map CastPyArg2Vocab(PyObject* obj, - ssize_t arg_pos) { +paddle::framework::Vocab CastPyArg2Vocab(PyObject* obj, ssize_t arg_pos) { if (PyDict_Check(obj)) { - return ::pybind11::handle(obj) - .cast>(); + paddle::framework::Vocab vocab; + vocab = ::pybind11::handle(obj) + .cast>(); + return vocab; } else { PADDLE_THROW(platform::errors::InvalidArgument( "argument (position %d) must be dict, but got %s", @@ -887,7 +888,7 @@ PyObject* ToPyObject( return dict; } -PyObject* ToPyObject(const std::unordered_map& value) { +PyObject* ToPyObject(const paddle::framework::Vocab& value) { PyObject* dict = PyDict_New(); for (const auto& map_iter : value) { // Convert Key diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 900b2538ead50..02a8e10dace3d 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -20,6 +20,7 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/eager/hooks.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/jit/function.h" #include "paddle/fluid/platform/place.h" @@ -74,8 +75,7 @@ std::vector> CastPyArg2VectorOfVectorOfSize_t( PyObject* obj, size_t arg_pos); framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj, ssize_t arg_pos); -std::unordered_map CastPyArg2Vocab(PyObject* obj, - ssize_t arg_pos); +paddle::framework::Vocab CastPyArg2Vocab(PyObject* obj, ssize_t arg_pos); std::vector CastPyArg2VectorOfString(PyObject* obj, ssize_t arg_pos); std::shared_ptr CastPyArg2JitFunction(PyObject* obj, @@ -116,7 +116,7 @@ PyObject* ToPyObject(const paddle::framework::proto::VarType& type); PyObject* ToPyObject(const void* value); PyObject* ToPyObject( const std::unordered_map>& value); -PyObject* ToPyObject(const std::unordered_map& value); +PyObject* ToPyObject(const paddle::framework::Vocab& value); class PyTensorHook : public egr::TensorHook { public: diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index be8355d023d25..4bbf5a33bd137 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -55,6 +55,7 @@ limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/framework/prune.h" +#include "paddle/fluid/framework/raw_tensor.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/scope_pool.h" #include "paddle/fluid/framework/selected_rows_utils.h" @@ -942,14 +943,20 @@ All parameter, weight, gradient are variables in Paddle. py::return_value_policy::reference) .def("get_bytes", [](Variable &self) { - return py::bytes(*self.GetMutable()); + if (self.IsType()) { + return py::bytes(*(self.GetMutable())); + } else { + return py::bytes( + *(self.GetMutable()->GetMutable())); + } }) .def("set_string_list", [](Variable &self, Strings str_list) { *self.GetMutable() = str_list; }) .def("set_vocab", - [](Variable &self, Vocab vocab) { + [](Variable &self, + const std::unordered_map &vocab) { *self.GetMutable() = vocab; }) .def( diff --git a/paddle/phi/CMakeLists.txt b/paddle/phi/CMakeLists.txt index 9e13962bd6ac6..aec5c7632a866 100644 --- a/paddle/phi/CMakeLists.txt +++ b/paddle/phi/CMakeLists.txt @@ -38,7 +38,8 @@ set(PHI_DEPS sparse_coo_tensor string_tensor api_scalar - api_int_array) + api_int_array + extended_tensor) get_property(phi_kernels GLOBAL PROPERTY PHI_KERNELS) set(PHI_DEPS ${PHI_DEPS} ${phi_kernels}) diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index 92d87cf79bd70..08004673edfce 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -82,6 +82,11 @@ cc_library( SRCS tensor_array.cc DEPS dense_tensor tensor_base) +cc_library( + extended_tensor + SRCS extended_tensor.cc + DEPS tensor_base) + cc_library( meta_tensor SRCS meta_tensor.cc diff --git a/paddle/phi/core/extended_tensor.cc b/paddle/phi/core/extended_tensor.cc new file mode 100644 index 0000000000000..6ffbcf401224a --- /dev/null +++ b/paddle/phi/core/extended_tensor.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/extended_tensor.h" + +namespace phi { + +int64_t ExtendedTensor::numel() const { + PADDLE_THROW(phi::errors::Unavailable( + "ExtendedTensor does not support `numel` method.")); +} + +const DDim& ExtendedTensor::dims() const { + PADDLE_THROW(phi::errors::Unavailable( + "ExtendedTensor does not support `dims` method.")); +} + +const Place& ExtendedTensor::place() const { + PADDLE_THROW(phi::errors::Unavailable( + "ExtendedTensor does not support `place` method.")); +} + +DataType ExtendedTensor::dtype() const { + PADDLE_THROW(phi::errors::Unavailable( + "ExtendedTensor does not support `dtype` method.")); +} + +DataLayout ExtendedTensor::layout() const { + PADDLE_THROW(phi::errors::Unavailable( + "ExtendedTensor does not support `dtype` method.")); +} + +bool ExtendedTensor::valid() const { + PADDLE_THROW(phi::errors::Unavailable( + "ExtendedTensor does not support `valid` method.")); +} + +bool ExtendedTensor::initialized() const { + PADDLE_THROW(phi::errors::Unavailable( + "ExtendedTensor does not support `initialized` method.")); +} + +void* ExtendedTensor::AllocateFrom(Allocator* allocator, + DataType dtype, + size_t requested_size) { + PADDLE_THROW(phi::errors::Unavailable( + "ExtendedTensor does not support `AllocateFrom` method.")); +} + +} // namespace phi diff --git a/paddle/phi/core/extended_tensor.h b/paddle/phi/core/extended_tensor.h new file mode 100644 index 0000000000000..404e1014bb328 --- /dev/null +++ b/paddle/phi/core/extended_tensor.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/allocator.h" +#include "paddle/phi/core/tensor_base.h" +#include "paddle/phi/core/tensor_meta.h" + +namespace phi { + +/// \brief The ExtendedTensor is a interface for custom designed class. +/// If you want to pass some self-designed data as input/output to kernels, +/// you can inherit from this class to store your self-designed data. +class ExtendedTensor : public TensorBase { + public: + ExtendedTensor() = default; + virtual ~ExtendedTensor() = default; + + public: + /// \brief Returns the name of the class for type traits. + /// \return The name of the class. + static const char* name() { return "ExtendedTensor"; } + + int64_t numel() const override; + + const DDim& dims() const override; + + const Place& place() const override; + + DataType dtype() const override; + + DataLayout layout() const override; + + bool valid() const override; + + bool initialized() const override; + + void* AllocateFrom(Allocator* allocator, + DataType dtype, + size_t requested_size = 0) override; +}; + +} // namespace phi diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 69baf243e68ea..89ae772f30be4 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -141,7 +141,7 @@ enum class AttributeType { INT_ARRAY, DATA_TYPE, DATA_LAYOUT, - PLACE, + PLACE }; struct AttributeArgDef { diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index 396b17dd401d5..a195f2ad60cbc 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -23,6 +23,7 @@ #include "paddle/phi/core/custom_kernel.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/extended_tensor.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_utils.h" #include "paddle/phi/core/macros.h" @@ -100,6 +101,12 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid( + const std::vector&))) { + args_def->AppendInput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid( const std::vector&))) { args_def->AppendInput(default_key.backend(), @@ -191,6 +198,11 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); + } else if (arg_type == std::type_index(typeid(ExtendedTensor*))) { + args_def->AppendOutput(default_key.backend(), + default_tensor_layout, + default_key.dtype(), + arg_type); } else if (arg_type == std::type_index(typeid(bool))) { args_def->AppendAttribute(AttributeType::BOOL); } else if (arg_type == std::type_index(typeid(int))) { diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 05d8e259cff10..1eb3a52aebad1 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -19,6 +19,7 @@ #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/extended_tensor.h" #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" @@ -264,6 +265,7 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); + PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(ExtendedTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorBase); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); @@ -323,6 +325,7 @@ struct KernelImpl { PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(TensorArray); + PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(ExtendedTensor); /* End case */ template diff --git a/paddle/phi/ops/compat/save_combine_sig.cc b/paddle/phi/ops/compat/save_combine_sig.cc new file mode 100644 index 0000000000000..8c9760410b35a --- /dev/null +++ b/paddle/phi/ops/compat/save_combine_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SaveCombineOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInputs("X")) { + return KernelSignature( + "save_combine_tensor", + {"X"}, + {"file_path", "overwrite", "save_as_fp16", "save_to_memory"}, + {"Y"}); + } else { + return KernelSignature( + "save_combine_vocab", + {"X"}, + {"file_path", "overwrite", "save_as_fp16", "save_to_memory"}, + {"Y"}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(save_combine, phi::SaveCombineOpArgumentMapping);