diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f00d1b8682243..0b6ea0be611caf 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -599,6 +599,9 @@ if(WITH_PROFILER) endif() include_directories("${PADDLE_SOURCE_DIR}") +include_directories("${PADDLE_SOURCE_DIR}/paddle/phi/api/include/compat/") +include_directories( + "${PADDLE_SOURCE_DIR}/paddle/phi/api/include/compat/torch/csrc/api/include/") if(WITH_NV_JETSON) set(WITH_ARM diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 231d47dab14444..a4898d76fed9ee 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -80,6 +80,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/prim/utils/utils.h" +#include "paddle/fluid/pybind/torch_compat.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/common/int_array.h" @@ -4139,6 +4140,9 @@ All parameter, weight, gradient are variables in Paddle. BindVjp(&m); BindDecompRule(&m); BindDecompVjp(&m); + py::module torch_compat = m.def_submodule( + "torch_compat", "Compatibility layer for PyTorch-like APIs"); + BindTorchCompat(&torch_compat); #ifdef PADDLE_WITH_DISTRIBUTE BindDistApi(&m); #endif diff --git a/paddle/fluid/pybind/torch_compat.h b/paddle/fluid/pybind/torch_compat.h new file mode 100644 index 00000000000000..7466edf9451226 --- /dev/null +++ b/paddle/fluid/pybind/torch_compat.h @@ -0,0 +1,380 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "paddle/common/exception.h" +#include "paddle/fluid/pybind/eager_utils.h" +#include "paddle/fluid/pybind/op_function_common.h" +#include "paddle/phi/api/include/compat/utils/scalar_type_conversion.h" +#include "paddle/utils/pybind.h" + +namespace py = pybind11; + +namespace torch { + +class OperationInvoker { + public: + static py::object invoke_operator_from_python( + const std::string& qualified_name, + const py::args& args, + const py::kwargs& kwargs); + + static std::pair get_op_with_args( + const std::string& qualified_name, + const py::args& args, + const py::kwargs& kwargs); + + static py::object to_py_object(const torch::IValue& value); + + static torch::IValue to_ivalue(py::handle obj); + + static py::object create_python_callable(const std::string& qualified_name); + + static FunctionArgs convert_args_kwargs_to_function_args( + const py::args& args, const py::kwargs& kwargs); + + static py::object convert_result_to_python(const FunctionResult& result); +}; + +inline py::object OperationInvoker::invoke_operator_from_python( + const std::string& qualified_name, + const py::args& args, + const py::kwargs& kwargs) { + try { + auto [found_op, function_args] = + get_op_with_args(qualified_name, args, kwargs); + + FunctionResult result; + { + py::gil_scoped_release no_gil_guard; + result = found_op->call_with_args(function_args); + } + + return convert_result_to_python(result); + } catch (const std::exception& e) { + PADDLE_THROW(common::errors::PreconditionNotMet( + "Error in operator '%s': %s", qualified_name.c_str(), e.what())); + } +} + +inline std::pair +OperationInvoker::get_op_with_args(const std::string& qualified_name, + const py::args& args, + const py::kwargs& kwargs) { + auto* op = OperatorRegistry::instance().find_operator(qualified_name); + if (!op) { + PADDLE_THROW(common::errors::NotFound( + "Operator '%s' not found in the registry", qualified_name.c_str())); + } + + auto impl_it = op->implementations.find(DispatchKey::CPU); + if (impl_it == op->implementations.end()) { + PADDLE_THROW(common::errors::NotFound( + "No CPU implementation found for operator '%s'", + qualified_name.c_str())); + } + + FunctionArgs function_args = + convert_args_kwargs_to_function_args(args, kwargs); + + return std::make_pair(&impl_it->second, std::move(function_args)); +} + +inline py::object OperationInvoker::to_py_object(const torch::IValue& value) { + if (value.is_none()) { + return py::none(); + } else if (value.is_bool()) { + return py::cast(value.to_bool()); + } else if (value.is_int()) { + return py::cast(value.to_int()); + } else if (value.is_double()) { + return py::cast(value.to_double()); + } else if (value.is_string()) { + return py::cast(value.to_string()); + } else if (value.is_tensor()) { + return py::reinterpret_borrow( + paddle::pybind::ToPyObject(value.to_tensor()._PD_GetInner())); + } else { + PADDLE_THROW(common::errors::Unimplemented( + "Conversion of torch::IValue to Python object for this type is not " + "implemented yet.")); + } +} + +inline torch::IValue OperationInvoker::to_ivalue(py::handle obj) { + if (obj.is_none()) { + return torch::IValue(); // None + } else if (py::isinstance(obj)) { + return torch::IValue(py::cast(obj)); + } else if (py::isinstance(obj)) { + return torch::IValue(py::cast(obj)); + } else if (py::isinstance(obj)) { + return torch::IValue(py::cast(obj)); + } else if (py::isinstance(obj)) { + return torch::IValue(py::cast(obj)); + } else if (paddle::pybind::PyCheckTensor(obj.ptr())) { + return torch::IValue(paddle::pybind::CastPyArg2Tensor(obj.ptr(), 0)); + } else if (paddle::pybind::PyObject_CheckDataType(obj.ptr())) { + return torch::IValue(compat::_PD_PhiDataTypeToAtenScalarType( + paddle::pybind::CastPyArg2DataType(obj.ptr(), "to_ivalue", 0))); + } else if (py::isinstance(obj)) { + auto list = obj.cast(); + std::vector ivalue_list; + ivalue_list.reserve(list.size()); + for (auto item : list) { + ivalue_list.push_back(to_ivalue(item)); + } + return torch::IValue(ivalue_list); + } else { + try { + auto val = py::cast(obj); + return torch::IValue(val); + } catch (...) { + try { + auto val = py::cast(obj); + return torch::IValue(val); + } catch (...) { + try { + auto val = py::cast(obj); + return torch::IValue(val); + } catch (...) { + PADDLE_THROW(common::errors::Unimplemented( + "Conversion of Python object to torch::IValue for type %s is not " + "implemented yet.", + std::string(py::str(py::type::of(obj))).c_str())); + } + } + } + } +} + +inline FunctionArgs OperationInvoker::convert_args_kwargs_to_function_args( + const py::args& args, const py::kwargs& kwargs) { + FunctionArgs function_args; + + for (const auto& arg : args) { + torch::IValue value = to_ivalue(arg); + function_args.add_arg(std::move(value)); + } + + for (auto item : kwargs) { + py::str key = item.first.cast(); + py::object value_obj = item.second.cast(); + + torch::IValue value = to_ivalue(value_obj); + function_args.add_arg(std::move(value)); + } + + return function_args; +} + +inline py::object OperationInvoker::convert_result_to_python( + const FunctionResult& result) { + if (!result.has_value()) { + return py::none(); + } + + const torch::IValue& value = result.get_value(); + return to_py_object(value); +} + +inline py::object OperationInvoker::create_python_callable( + const std::string& qualified_name) { + return py::cpp_function( + [qualified_name](py::args args, py::kwargs kwargs) -> py::object { + return invoke_operator_from_python(qualified_name, args, kwargs); + }, + py::name(qualified_name.c_str()), + py::is_method(py::none())); +} + +class CustomClassProxyInstance { + public: + CustomClassProxyInstance(const std::string& qualified_name, + const IValue& instance) + : qualified_name_(qualified_name), instance_(instance) {} + + // Get instance method + py::object __getattr__(const std::string& method_name) { + if (ClassRegistry::instance().has_method(qualified_name_, method_name)) { + return py::cpp_function( + [this, method_name](py::args args, py::kwargs kwargs) -> py::object { + FunctionArgs function_args; + function_args.add_arg(instance_); // this pointer + for (auto arg : + OperationInvoker::convert_args_kwargs_to_function_args( + args, kwargs)) { + function_args.add_arg(std::move(arg)); + } + + auto result = ClassRegistry::instance().call_method_with_args( + qualified_name_, method_name, function_args); + + return OperationInvoker::convert_result_to_python(result); + }, + py::name(method_name.c_str())); + } + + PADDLE_THROW(common::errors::NotFound("Method '%s' not found in class %s", + method_name.c_str(), + qualified_name_.c_str())); + } + + const IValue& get_instance() const { return instance_; } + + private: + std::string qualified_name_; + IValue instance_; +}; + +class CustomClassProxy { + public: + CustomClassProxy(const std::string& qualified_name) // NOLINT + : qualified_name_(qualified_name) {} + + // Create a new instance of the class + py::object __call__(const py::args& args, const py::kwargs& kwargs) { + try { + FunctionArgs function_args = + OperationInvoker::convert_args_kwargs_to_function_args(args, kwargs); + + // Call the constructor + auto result = ClassRegistry::instance().call_constructor_with_args( + qualified_name_, function_args); + + // Wrap the result in a CustomClassProxyInstance + if (result.has_value()) { + const IValue& value = result.get_value(); + // Create proxy object for the custom class instance + return py::cast(CustomClassProxyInstance(qualified_name_, value)); + } else { + PADDLE_THROW(common::errors::PreconditionNotMet( + "Constructor did not return an instance")); + } + } catch (const std::exception& e) { + PADDLE_THROW(common::errors::PreconditionNotMet( + "Failed to construct %s: %s", qualified_name_.c_str(), e.what())); + } + } + + // Get static method + py::object __getattr__(const std::string& method_name) { + // Check if the method name is a dunder method + if (method_name.size() >= 2 && method_name.substr(0, 2) == "__") { + PADDLE_THROW(common::errors::InvalidArgument( + "Dunder methods are not supported: %s", method_name.c_str())); + } + + // Check if the class has the static method + if (ClassRegistry::instance().has_static_method(qualified_name_, + method_name)) { + return py::cpp_function( + [this, method_name](py::args args, py::kwargs kwargs) -> py::object { + // Convert args and kwargs to FunctionArgs + FunctionArgs function_args = + OperationInvoker::convert_args_kwargs_to_function_args(args, + kwargs); + + // Call the static method + auto result = + ClassRegistry::instance().call_static_method_with_args( + qualified_name_, method_name, function_args); + + return OperationInvoker::convert_result_to_python(result); + }, + py::name(method_name.c_str())); + } + + PADDLE_THROW( + common::errors::NotFound("Static method '%s' not found in class %s", + method_name.c_str(), + qualified_name_.c_str())); + } + + private: + std::string qualified_name_; +}; + +inline py::object get_custom_class_python_wrapper( + const std::string& namespace_name, const std::string& class_name) { + std::string qualified_name = namespace_name + "::" + class_name; + + if (!ClassRegistry::instance().has_class(qualified_name)) { + PADDLE_THROW(common::errors::NotFound( + "Class '%s' not found in the registry", qualified_name.c_str())); + } + + return py::cast(CustomClassProxy(qualified_name)); +} + +inline py::object get_operation(const std::string& qualified_name) { + return OperationInvoker::create_python_callable(qualified_name); +} +} // namespace torch + +namespace paddle::pybind { + +void BindTorchCompat(pybind11::module* m) { + py::class_(*m, "IValue") + .def(py::init<>()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def("is_none", &torch::IValue::is_none) + .def("is_int", &torch::IValue::is_int) + .def("is_double", &torch::IValue::is_double) + .def("is_bool", &torch::IValue::is_bool) + .def("is_string", &torch::IValue::is_string) + .def("to_int", &torch::IValue::to_int) + .def("to_double", &torch::IValue::to_double) + .def("to_bool", &torch::IValue::to_bool) + .def("to_string", &torch::IValue::to_string) + .def("__repr__", [](const torch::IValue& v) { + if (v.is_none()) return std::string("IValue(None)"); + if (v.is_int()) + return std::string("IValue(") + std::to_string(v.to_int()) + ")"; + if (v.is_double()) + return std::string("IValue(") + std::to_string(v.to_double()) + ")"; + if (v.is_bool()) + return std::string("IValue(") + (v.to_bool() ? "True" : "False") + + ")"; + if (v.is_string()) + return std::string("IValue(\"") + v.to_string() + "\")"; + return std::string("IValue(unknown)"); + }); + + py::class_(*m, "CustomClassProxy") + .def("__call__", &torch::CustomClassProxy::__call__) + .def("__getattr__", &torch::CustomClassProxy::__getattr__); + + py::class_(*m, "CustomClassProxyInstance") + .def("__getattr__", &torch::CustomClassProxyInstance::__getattr__); + + m->def("_get_operation", + &torch::get_operation, + "Get a callable for the specified operation", + py::arg("qualified_name")); + + m->def("_get_custom_class_python_wrapper", + &torch::get_custom_class_python_wrapper, + "Get a Python wrapper for the specified custom class", + py::arg("namespace_name"), + py::arg("class_name")); +} +} // namespace paddle::pybind diff --git a/paddle/phi/api/CMakeLists.txt b/paddle/phi/api/CMakeLists.txt index 1827dfbeb7f642..a3984ec1fc33bc 100644 --- a/paddle/phi/api/CMakeLists.txt +++ b/paddle/phi/api/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(profiler) add_subdirectory(lib) +add_subdirectory(include/compat) diff --git a/paddle/phi/api/include/compat/ATen/ATen.h b/paddle/phi/api/include/compat/ATen/ATen.h new file mode 100644 index 00000000000000..18e9d2c9d62458 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ATen.h @@ -0,0 +1,32 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include +#include +#include +#endif diff --git a/paddle/phi/api/include/compat/ATen/AccumulateType.cpp b/paddle/phi/api/include/compat/ATen/AccumulateType.cpp new file mode 100644 index 00000000000000..174eac6a8a6b6f --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/AccumulateType.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#include + +namespace at { + +c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device) { + switch (type) { +#define DEFINE_CASE(scalar_t, TypeNum) \ + case ScalarType::TypeNum: \ + switch (device) { \ + case DeviceType::CUDA: \ + return CppTypeToScalarType< \ + at::acc_type_device>::value; \ + default: \ + return CppTypeToScalarType< \ + at::acc_type_device>::value; \ + } + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type); + } +} + +c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda) { + return is_cuda ? toAccumulateType(type, c10::DeviceType::CUDA) + : toAccumulateType(type, c10::DeviceType::CPU); +} + +} // namespace at diff --git a/paddle/phi/api/include/compat/ATen/AccumulateType.h b/paddle/phi/api/include/compat/ATen/AccumulateType.h new file mode 100644 index 00000000000000..29b7bf33adcb69 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/AccumulateType.h @@ -0,0 +1,115 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include +#include +#include +#include +// #include +#include +// #include +#include + +#if defined(__CUDACC__) +#include +#include +#elif defined(__HIPCC__) +#include +#include +#endif + +namespace at { + +template +struct AccumulateTypeDevice {}; + +template +struct AccumulateType {}; + +template +struct AccumulateType { + using type = typename AccumulateTypeDevice::type; +}; + +template +struct AccumulateType { + using type = typename AccumulateTypeDevice::type; +}; + +template +using acc_type_device = typename AccumulateTypeDevice::type; + +template +using acc_type = typename AccumulateType::type; + +#define ACC_TYPE(t, acc_t, device_type) \ + template <> \ + struct AccumulateTypeDevice { \ + using type = acc_t; \ + }; + +#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA) +#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU) + +#if defined(__CUDACC__) || defined(__HIPCC__) +CUDA_ACC_TYPE(half, float) +#endif +CUDA_ACC_TYPE(BFloat16, float) +CUDA_ACC_TYPE(Half, float) +CUDA_ACC_TYPE(Float8_e5m2, float) +CUDA_ACC_TYPE(Float8_e4m3fn, float) +// CUDA_ACC_TYPE(Float8_e5m2fnuz, float) +// CUDA_ACC_TYPE(Float8_e4m3fnuz, float) +CUDA_ACC_TYPE(float, float) +CUDA_ACC_TYPE(double, double) +CUDA_ACC_TYPE(int8_t, int64_t) +CUDA_ACC_TYPE(uint8_t, int64_t) +CUDA_ACC_TYPE(char, int64_t) +CUDA_ACC_TYPE(int16_t, int64_t) +CUDA_ACC_TYPE(int32_t, int64_t) +CUDA_ACC_TYPE(int64_t, int64_t) +CUDA_ACC_TYPE(bool, bool) +CUDA_ACC_TYPE(c10::complex, c10::complex) +CUDA_ACC_TYPE(c10::complex, c10::complex) +CUDA_ACC_TYPE(c10::complex, c10::complex) + +CPU_ACC_TYPE(BFloat16, float) +CPU_ACC_TYPE(Half, float) +CPU_ACC_TYPE(Float8_e5m2, float) +CPU_ACC_TYPE(Float8_e4m3fn, float) +// CPU_ACC_TYPE(Float8_e5m2fnuz, float) +// CPU_ACC_TYPE(Float8_e4m3fnuz, float) +CPU_ACC_TYPE(float, double) +CPU_ACC_TYPE(double, double) +CPU_ACC_TYPE(int8_t, int64_t) +CPU_ACC_TYPE(uint8_t, int64_t) +CPU_ACC_TYPE(char, int64_t) +CPU_ACC_TYPE(int16_t, int64_t) +CPU_ACC_TYPE(int32_t, int64_t) +CPU_ACC_TYPE(int64_t, int64_t) +CPU_ACC_TYPE(bool, bool) +CPU_ACC_TYPE(c10::complex, c10::complex) +CPU_ACC_TYPE(c10::complex, c10::complex) +CPU_ACC_TYPE(c10::complex, c10::complex) + +c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device); +c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda); + +} // namespace at diff --git a/paddle/phi/api/include/compat/ATen/Device.h b/paddle/phi/api/include/compat/ATen/Device.h new file mode 100644 index 00000000000000..7970c1ba5f22a4 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/Device.h @@ -0,0 +1,16 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include diff --git a/paddle/phi/api/include/compat/ATen/DeviceGuard.h b/paddle/phi/api/include/compat/ATen/DeviceGuard.h new file mode 100644 index 00000000000000..78d8d1b9470250 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/DeviceGuard.h @@ -0,0 +1,35 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace at { + +inline std::optional device_of(const Tensor& t) { + if (t.defined()) { + return t.device(); + } else { + return std::nullopt; + } +} + +inline std::optional device_of(const std::optional& t) { + return t.has_value() ? device_of(t.value()) : std::nullopt; +} + +} // namespace at diff --git a/paddle/phi/api/include/compat/ATen/Functions.h b/paddle/phi/api/include/compat/ATen/Functions.h new file mode 100644 index 00000000000000..5f77150510e750 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/Functions.h @@ -0,0 +1,27 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/paddle/phi/api/include/compat/ATen/Tensor.h b/paddle/phi/api/include/compat/ATen/Tensor.h new file mode 100644 index 00000000000000..aaaa6501cd0b09 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/Tensor.h @@ -0,0 +1,17 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include diff --git a/paddle/phi/api/include/compat/ATen/Utils.h b/paddle/phi/api/include/compat/ATen/Utils.h new file mode 100644 index 00000000000000..30a417cd6f61ec --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/Utils.h @@ -0,0 +1,23 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include diff --git a/paddle/phi/api/include/compat/ATen/core/Scalar.h b/paddle/phi/api/include/compat/ATen/core/Scalar.h new file mode 100644 index 00000000000000..3136613467502e --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/core/Scalar.h @@ -0,0 +1,15 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include diff --git a/paddle/phi/api/include/compat/ATen/core/Tensor.h b/paddle/phi/api/include/compat/ATen/core/Tensor.h new file mode 100644 index 00000000000000..fc8587c08078d1 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/core/Tensor.h @@ -0,0 +1,17 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include diff --git a/paddle/phi/api/include/compat/ATen/core/TensorBase.h b/paddle/phi/api/include/compat/ATen/core/TensorBase.h new file mode 100644 index 00000000000000..18949c2909bae4 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/core/TensorBase.h @@ -0,0 +1,176 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "paddle/common/layout.h" +#include "paddle/phi/api/include/api.h" +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/place.h" + +namespace at { +using PaddleTensor = paddle::Tensor; + +class PADDLE_API TensorBase { + public: + TensorBase() = default; + TensorBase(const PaddleTensor& tensor) : tensor_(tensor){}; // NOLINT + + void* data_ptr() const { return const_cast(tensor_.data()); } + template + T* data_ptr() const { + return const_cast(tensor_.data()); + } + + const void* const_data_ptr() const { + return const_cast(tensor_.data()); + } + + template , int> = 0> + const T* const_data_ptr() const; + + template , int> = 0> + const std::remove_const_t* const_data_ptr() const; + + void* mutable_data_ptr() const { return const_cast(tensor_.data()); } + + template + T* mutable_data_ptr() const; + + int64_t stride(int64_t dim) const { + if (dim < 0) { + dim += tensor_.strides().size(); + } + return tensor_.strides()[static_cast(dim)]; + } + c10::IntArrayRef strides() const { + return compat::_PD_PhiDDimToIntArrayRef(tensor_.strides()); + } + + int64_t size(int64_t dim) const { + return tensor_.dims()[static_cast(dim)]; + } + + c10::IntArrayRef sizes() const { + return compat::_PD_PhiDDimToIntArrayRef(tensor_.dims()); + } + + int64_t numel() const { return tensor_.numel(); } + + c10::ScalarType dtype() const { // Should we use `TypeMeta` here? + return compat::_PD_PhiDataTypeToAtenScalarType(tensor_.dtype()); + } + + c10::Device device() const { return c10::Device(tensor_.place()); } + c10::DeviceIndex get_device() const { + return c10::Device(tensor_.place()).index(); + } + + int64_t dim() const { return tensor_.dims().size(); } + int64_t ndimension() const { return dim(); } + + at::TensorBase contiguous( + c10::MemoryFormat memory_format = c10::MemoryFormat::Contiguous) const { + PD_CHECK(memory_format == c10::MemoryFormat::Contiguous, + "`MemoryFormat` other than Contiguous"); + + return tensor_.contiguous(); + } + + bool is_contiguous( + at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { + PD_CHECK(memory_format == c10::MemoryFormat::Contiguous, + "`MemoryFormat` other than Contiguous"); + + return tensor_.is_contiguous(); + } + + c10::ScalarType scalar_type() const { + return compat::_PD_PhiDataTypeToAtenScalarType(tensor_.dtype()); + } + + c10::TensorOptions options() const { + // TODO(SigureMo): Implement layout + return c10::TensorOptions().dtype(dtype()).device(device()); + } + + const TensorBase& fill_(const at::Scalar& scalar) const { + paddle::experimental::fill_(const_cast(tensor_), scalar); + return *this; + } + + const TensorBase& zero_() const { + paddle::experimental::fill_(const_cast(tensor_), 0.0); + return *this; + } + + bool is_cpu() const { return phi::is_cpu_place(tensor_.place()); } + bool is_cuda() const { return phi::is_gpu_place(tensor_.place()); } + + at::TensorBase reshape(at::IntArrayRef shape) const { + return TensorBase( + paddle::experimental::reshape(tensor_, shape._PD_ToPaddleIntArray())); + } + + at::TensorBase& copy_(const at::TensorBase& src, + bool non_blocking = false) const { + const_cast(tensor_).copy_( + src._PD_GetInner(), tensor_.place(), /*blocking=*/!non_blocking); + return const_cast(*this); + } + + at::TensorBase view(at::IntArrayRef size) const { + return TensorBase(paddle::experimental::view_shape(tensor_, size.vec())); + } + + at::TensorBase view(at::ScalarType dtype) const { + return TensorBase(paddle::experimental::view_dtype( + tensor_, compat::_PD_AtenScalarTypeToPhiDataType(dtype))); + } + + inline size_t nbytes() const { + PD_CHECK( + ((tensor_.layout() != common::DataLayout::SPARSE_COO) && + (tensor_.layout() != common::DataLayout::SPARSE_CSR)), + "nbytes is not defined for sparse tensors. If you want the size of " + "the constituent " + "tensors, add the nbytes of the indices and values. If you want the " + "size of the " + "equivalent dense tensor, multiply numel() by element_size()"); + return tensor_.numel() * SizeOf(tensor_.dtype()); + } + + size_t itemsize() const { return SizeOf(tensor_.dtype()); } + + int64_t element_size() const { + return static_cast(SizeOf(tensor_.dtype())); + } + + bool defined() const { return tensor_.defined(); } + + PaddleTensor _PD_GetInner() const { return tensor_; } + PaddleTensor& _PD_GetInner() { return tensor_; } + + protected: + PaddleTensor tensor_; +}; + +} // namespace at diff --git a/paddle/phi/api/include/compat/ATen/core/TensorBody.h b/paddle/phi/api/include/compat/ATen/core/TensorBody.h new file mode 100644 index 00000000000000..9db93db832f497 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/core/TensorBody.h @@ -0,0 +1,175 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/api/include/tensor.h" + +namespace at { +using PaddleTensor = paddle::Tensor; + +class Tensor : public TensorBase { + public: + Tensor() = default; + Tensor(const PaddleTensor& tensor) : TensorBase(tensor){}; // NOLINT + + void* data_ptr() const { return const_cast(tensor_.data()); } + template + T* data_ptr() const { + return const_cast(tensor_.data()); + } + + const void* const_data_ptr() const { + return const_cast(tensor_.data()); + } + + template , int> = 0> + const T* const_data_ptr() const; + + template , int> = 0> + const std::remove_const_t* const_data_ptr() const; + + void* mutable_data_ptr() const { return const_cast(tensor_.data()); } + + template + T* mutable_data_ptr() const; + + using TensorBase::stride; + + c10::IntArrayRef strides() const { + return compat::_PD_PhiDDimToIntArrayRef(tensor_.strides()); + } + + using TensorBase::size; + // int64_t size(int64_t dim) const { + // return tensor_.dims()[static_cast(dim)]; + // } + + c10::IntArrayRef sizes() const { + return compat::_PD_PhiDDimToIntArrayRef(tensor_.dims()); + } + + Tensor toType(ScalarType t) const { + return Tensor(paddle::experimental::cast( + tensor_, compat::_PD_AtenScalarTypeToPhiDataType(t))); + } + + int64_t numel() const { return tensor_.numel(); } + + c10::ScalarType dtype() const { // Should we use `TypeMeta` here? + return compat::_PD_PhiDataTypeToAtenScalarType(tensor_.dtype()); + } + + c10::Device device() const { return c10::Device(tensor_.place()); } + c10::DeviceIndex get_device() const { + return c10::Device(tensor_.place()).index(); + } + + int64_t dim() const { return tensor_.dims().size(); } + int64_t ndimension() const { return dim(); } + + at::Tensor contiguous( + c10::MemoryFormat memory_format = c10::MemoryFormat::Contiguous) const { + PD_CHECK(memory_format == c10::MemoryFormat::Contiguous, + "`MemoryFormat` other than Contiguous"); + + return tensor_.contiguous(); + } + + bool is_contiguous( + at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const { + PD_CHECK(memory_format == c10::MemoryFormat::Contiguous, + "`MemoryFormat` other than Contiguous"); + + return tensor_.is_contiguous(); + } + + c10::ScalarType scalar_type() const { + return compat::_PD_PhiDataTypeToAtenScalarType(tensor_.dtype()); + } + + const Tensor& fill_(const at::Scalar& scalar) const { + paddle::experimental::fill_(const_cast(tensor_), scalar); + return *this; + } + + const Tensor& zero_() const { + paddle::experimental::fill_(const_cast(tensor_), 0.0); + return *this; + } + + bool is_cpu() const { return phi::is_cpu_place(tensor_.place()); } + bool is_cuda() const { return phi::is_gpu_place(tensor_.place()); } + + at::Tensor reshape(at::IntArrayRef shape) const { + return Tensor( + paddle::experimental::reshape(tensor_, shape._PD_ToPaddleIntArray())); + } + + at::Tensor transpose(int64_t dim0, int64_t dim1) const { + return Tensor(paddle::experimental::transpose( + tensor_, {static_cast(dim0), static_cast(dim1)})); + } + + at::Tensor& copy_(const at::Tensor& src, bool non_blocking = false) const { + const_cast(tensor_).copy_( + src._PD_GetInner(), tensor_.place(), /*blocking=*/!non_blocking); + return const_cast(*this); + } + + at::Tensor view(at::IntArrayRef size) const { + return Tensor(paddle::experimental::view_shape(tensor_, size.vec())); + } + + at::Tensor view(at::ScalarType dtype) const { + return Tensor(paddle::experimental::view_dtype( + tensor_, compat::_PD_AtenScalarTypeToPhiDataType(dtype))); + } + + // Paddle Tensor has no storage_offset, so we add it here, and it is always + // 0. + // int64_t storage_offset() const { return storage_offset_; } + + inline size_t nbytes() const { + PD_CHECK( + ((tensor_.layout() != common::DataLayout::SPARSE_COO) && + (tensor_.layout() != common::DataLayout::SPARSE_CSR)), + "nbytes is not defined for sparse tensors. If you want the size of " + "the constituent " + "tensors, add the nbytes of the indices and values. If you want the " + "size of the " + "equivalent dense tensor, multiply numel() by element_size()"); + return tensor_.numel() * SizeOf(tensor_.dtype()); + } + + size_t itemsize() const { return SizeOf(tensor_.dtype()); } + + int64_t element_size() const { + return static_cast(SizeOf(tensor_.dtype())); + } + + inline Tensor clone() const { + PaddleTensor cloned_tensor = paddle::experimental::assign(tensor_); + return Tensor(cloned_tensor); + } + + PaddleTensor _PD_GetInner() const { return tensor_; } + PaddleTensor& _PD_GetInner() { return tensor_; } +}; + +} // namespace at +namespace torch { +using at::Tensor; +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/core/TensorMethods.cpp b/paddle/phi/api/include/compat/ATen/core/TensorMethods.cpp new file mode 100644 index 00000000000000..b452493b22aa3d --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/core/TensorMethods.cpp @@ -0,0 +1,66 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#include +#include +#include + +namespace at { + +void check_type(const TensorBase& tensor, + ScalarType type, + std::string_view type_name) { + PD_CHECK(tensor.scalar_type() == type, + "expected scalar type ", + type_name, + " but found ", + compat::_PD_AtenScalarTypeToPhiDataType(tensor.scalar_type())); +} + +#define DEFINE_CAST(T, name) \ + template <> \ + PADDLE_API const T* TensorBase::const_data_ptr() const { \ + check_type(*this, ScalarType::name, #name); \ + return const_cast(tensor_.data()); \ + } \ + \ + template <> \ + PADDLE_API const T* TensorBase::const_data_ptr() const { \ + check_type(*this, ScalarType::name, #name); \ + return const_cast(tensor_.data>()); \ + } \ + \ + template <> \ + PADDLE_API T* TensorBase::mutable_data_ptr() const { \ + check_type(*this, ScalarType::name, #name); \ + return const_cast(tensor_).mutable_data(); \ + } \ + \ + template <> \ + PADDLE_API T* TensorBase::data_ptr() const { \ + return const_cast(tensor_.data()); \ + } + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST) // missing half and float16 +// AT_FORALL_QINT_TYPES(DEFINE_CAST) // missing qint +DEFINE_CAST(uint16_t, UInt16) +DEFINE_CAST(uint32_t, UInt32) +DEFINE_CAST(uint64_t, UInt64) +#undef DEFINE_CAST + +} // namespace at diff --git a/paddle/phi/api/include/compat/ATen/core/ivalue.h b/paddle/phi/api/include/compat/ATen/core/ivalue.h new file mode 100644 index 00000000000000..4e161cdc5060ca --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/core/ivalue.h @@ -0,0 +1,583 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { + +class CustomClassHolder { + public: + virtual ~CustomClassHolder() = default; +}; + +template +class intrusive_ptr { + public: + using element_type = T; + using pointer = T*; + + intrusive_ptr() : ptr_(nullptr) {} + intrusive_ptr(T* ptr) : ptr_(std::shared_ptr(ptr)) {} // NOLINT + intrusive_ptr(std::shared_ptr ptr) : ptr_(ptr) {} // NOLINT + + template + static intrusive_ptr make(Args&&... args) { + return intrusive_ptr(std::make_shared(std::forward(args)...)); + } + + T* get() const { return ptr_.get(); } + T& operator*() const { return *ptr_; } + T* operator->() const { return ptr_.get(); } + + // For IValue + std::shared_ptr get_shared() const { return ptr_; } + + explicit operator bool() const { return ptr_ != nullptr; } + + private: + std::shared_ptr ptr_; +}; + +template +intrusive_ptr make_intrusive(Args&&... args) { + return intrusive_ptr::make(std::forward(args)...); +} + +template +struct _fake_type {}; + +enum class TypeTag { + None = 0, + Bool, + Int, + Double, + String, + Tensor, + GenericList, + CustomClass, + Tuple +}; + +class IValue; // Forward declaration + +// Forward declaration of generic_to template function +template +T generic_to(const IValue& ivalue, _fake_type); + +using GenericList = std::vector; + +// Separate tuple wrapper to avoid ambiguity with GenericList +struct GenericTuple { + std::vector elements; + + GenericTuple() = default; + GenericTuple(std::vector elems) // NOLINT + : elements(std::move(elems)) {} + + size_t size() const { return elements.size(); } + IValue& operator[](size_t idx) { return elements[idx]; } + const IValue& operator[](size_t idx) const { return elements[idx]; } +}; + +class IValue { + private: + struct CustomClassWrapper { + std::shared_ptr ptr; + std::string class_name; + + CustomClassWrapper(std::shared_ptr p, + const std::string& name) + : ptr(std::move(p)), class_name(name) {} + }; + + public: + IValue() : tag_(TypeTag::None), value_(std::monostate{}) {} + + IValue(bool val) : tag_(TypeTag::Bool), value_(val) {} // NOLINT + IValue(int val) // NOLINT + : tag_(TypeTag::Int), value_(static_cast(val)) {} + IValue(int64_t val) : tag_(TypeTag::Int), value_(val) {} // NOLINT + IValue(double val) : tag_(TypeTag::Double), value_(val) {} // NOLINT + IValue(const std::string& val) // NOLINT + : tag_(TypeTag::String), value_(val) {} + IValue(std::string&& val) // NOLINT + : tag_(TypeTag::String), value_(std::move(val)) {} + IValue(const char* val) // NOLINT + : tag_(TypeTag::String), value_(std::string(val)) {} + IValue(at::Tensor val) : tag_(TypeTag::Tensor), value_(val) {} // NOLINT + IValue(ScalarType val) // NOLINT + : tag_(TypeTag::Int), + value_(static_cast( + static_cast>(val))) {} + template + IValue(intrusive_ptr ptr) // NOLINT + : tag_(TypeTag::CustomClass), + value_(CustomClassWrapper{ptr.get_shared(), typeid(T).name()}) {} + + template >> + IValue(const std::vector& vec) // NOLINT + : tag_(TypeTag::GenericList) { + GenericList generic_list; + generic_list.reserve(vec.size()); + for (const auto& item : vec) { + generic_list.emplace_back(IValue(item)); + } + value_ = std::move(generic_list); + } + + template >> + IValue(std::vector&& vec) // NOLINT + : tag_(TypeTag::GenericList) { + GenericList generic_list; + generic_list.reserve(vec.size()); + for (auto&& item : vec) { + generic_list.emplace_back(IValue(std::move(item))); + } + value_ = std::move(generic_list); + } + + template >> + IValue(ArrayRef arr) : IValue(arr.vec()) {} // NOLINT + + template + IValue(const std::optional& opt) { // NOLINT + if (opt.has_value()) { + *this = IValue(*opt); + } else { + tag_ = TypeTag::None; + value_ = std::monostate{}; + } + } + + template + IValue(std::optional&& opt) { // NOLINT + if (opt.has_value()) { + *this = IValue(std::move(*opt)); + } else { + tag_ = TypeTag::None; + value_ = std::monostate{}; + } + } + + // Variadic template constructor for tuple of any number of tensors or + // IValue-convertible types + template + IValue(const std::tuple& tuple_val) // NOLINT + : tag_(TypeTag::Tuple) { + static_assert(sizeof...(Args) > 0, "Tuple must have at least one element"); + std::vector elements; + elements.reserve(sizeof...(Args)); + tuple_to_ivalue_vector( + tuple_val, elements, std::index_sequence_for{}); + value_ = GenericTuple(std::move(elements)); + } + + // Helper function to convert tuple elements to IValue vector using index + // sequence + template + void tuple_to_ivalue_vector(const Tuple& tuple_val, + std::vector& elements, // NOLINT + std::index_sequence) { + (elements.emplace_back(std::get(tuple_val)), ...); + } + + IValue(const IValue& other) = default; + IValue(IValue&& other) = default; + IValue& operator=(const IValue& other) = default; + IValue& operator=(IValue&& other) = default; + + bool is_none() const { return tag_ == TypeTag::None; } + bool is_bool() const { return tag_ == TypeTag::Bool; } + bool is_int() const { return tag_ == TypeTag::Int; } + bool is_double() const { return tag_ == TypeTag::Double; } + bool is_string() const { return tag_ == TypeTag::String; } + bool is_list() const { return tag_ == TypeTag::GenericList; } + bool is_tensor() const { return tag_ == TypeTag::Tensor; } + bool is_custom_class() const { return tag_ == TypeTag::CustomClass; } + bool is_tuple() const { return tag_ == TypeTag::Tuple; } + + bool to_bool() const { + if (!is_bool()) throw std::runtime_error("Not a bool"); + return std::get(value_); + } + + int64_t to_int() const { + if (!is_int()) throw std::runtime_error("Not an int"); + return std::get(value_); + } + + double to_double() const { + if (!is_double()) throw std::runtime_error("Not a double"); + return std::get(value_); + } + + const std::string& to_string() const { + if (!is_string()) throw std::runtime_error("Not a string"); + return std::get(value_); + } + + const GenericList& to_list() const { + if (!is_list()) throw std::runtime_error("Not a list"); + return std::get(value_); + } + + GenericList& to_list() { + if (!is_list()) throw std::runtime_error("Not a list"); + return std::get(value_); + } + + at::Tensor to_tensor() const { + if (!is_tensor()) throw std::runtime_error("Not a tensor"); + return std::get(value_); + } + + const GenericTuple& to_tuple() const { + if (!is_tuple()) throw std::runtime_error("Not a tuple"); + return std::get(value_); + } + + GenericTuple& to_tuple() { + if (!is_tuple()) throw std::runtime_error("Not a tuple"); + return std::get(value_); + } + + at::ScalarType to_scalar_type() const { + if (!is_int()) throw std::runtime_error("Not an int"); + return static_cast(std::get(value_)); + } + + template + intrusive_ptr to_custom_class() const { + if (!is_custom_class()) throw std::runtime_error("Not a custom class"); + const auto& wrapper = std::get(value_); + auto casted = std::dynamic_pointer_cast(wrapper.ptr); + if (!casted) { + throw std::runtime_error("Cannot cast custom class to requested type"); + } + return intrusive_ptr(casted); + } + + private: + template + struct is_intrusive_ptr : std::false_type {}; + + template + struct is_intrusive_ptr> : std::true_type {}; + + template + static constexpr bool is_intrusive_ptr_v = is_intrusive_ptr::value; + + public: + bool try_to_bool(bool& out) const { // NOLINT + if (is_bool()) { + out = std::get(value_); + return true; + } else if (is_int()) { + out = (std::get(value_) != 0); + return true; + } else if (is_double()) { + out = (std::get(value_) != 0.0); + return true; + } + return false; + } + + bool try_to_int(int& out) const { // NOLINT + if (is_int()) { + out = static_cast(std::get(value_)); + return true; + } else if (is_double()) { + double val = std::get(value_); + if (val != static_cast(val)) { + std::cout << "Warning: Converting double(" << val + << ") to int (precision loss)" << std::endl; + } + out = static_cast(val); + return true; + } + return false; + } + + bool try_to_double(double& out) const { // NOLINT + if (is_double()) { + out = std::get(value_); + return true; + } else if (is_int()) { + out = static_cast(std::get(value_)); + return true; + } + return false; + } + + bool try_to_string(std::string& out) const { // NOLINT + if (is_string()) { + out = std::get(value_); + return true; + } + return false; + } + + bool try_to_tensor(at::Tensor& out) const { // NOLINT + if (is_tensor()) { + out = std::get(value_); + return true; + } + return false; + } + + bool try_to_scalar_type(at::ScalarType& out) const { // NOLINT + if (is_int()) { + out = static_cast(std::get(value_)); + return true; + } + return false; + } + + template + bool try_to_optional_type(std::optional& out) const { // NOLINT + if (is_none()) { + out = std::nullopt; + return true; + } else { + T value; + if (try_convert_to(value)) { + out = value; + return true; + } + } + return false; + } + + bool try_to_custom_class(std::shared_ptr& out, // NOLINT + const std::string& expected_class_name) const { + if (is_custom_class()) { + const auto& wrapper = std::get(value_); + if (wrapper.class_name == expected_class_name) { + out = wrapper.ptr; + return true; + } + } + return false; + } + + template + bool try_convert_to(T& out) const { // NOLINT + // Remove reference and cv-qualifiers from T + using BaseType = std::remove_cv_t>; + + if constexpr (std::is_same_v) { + return try_to_bool(const_cast(reinterpret_cast(out))); + } else if constexpr (std::is_same_v) { + return try_to_int(const_cast(reinterpret_cast(out))); + } else if constexpr (std::is_same_v) { + return try_to_double( + const_cast(reinterpret_cast(out))); + } else if constexpr (std::is_same_v) { + return try_to_string( + const_cast(reinterpret_cast(out))); + } else if constexpr (std::is_same_v) { + return try_to_tensor( + const_cast(reinterpret_cast(out))); + } else if constexpr (std::is_same_v) { + return try_to_scalar_type(const_cast( + reinterpret_cast(out))); + } else { + try { + // Handle const types by removing const and using const_cast + using NonConstType = std::remove_const_t; + NonConstType temp = this->to(); + const_cast(out) = std::move(temp); + return true; + } catch (const std::exception&) { + return false; + } + } + } + + std::string get_custom_class_name() const { + if (!is_custom_class()) throw std::runtime_error("Not a custom class"); + const auto& wrapper = std::get(value_); + return wrapper.class_name; + } + + template + T to() && { + return generic_to(std::move(*this), _fake_type{}); + } + + template + T to() const& { + return generic_to(*this, _fake_type{}); + } + + std::string type_string() const { + switch (tag_) { + case TypeTag::None: + return "None"; + case TypeTag::Bool: + return "Bool"; + case TypeTag::Int: + return "Int"; + case TypeTag::Double: + return "Double"; + case TypeTag::String: + return "String"; + case TypeTag::Tensor: + return "Tensor"; + case TypeTag::GenericList: + return "List"; + case TypeTag::CustomClass: + return "CustomClass(" + get_custom_class_name() + ")"; + default: + return "Unknown"; + } + } + + std::string to_repr() const { + switch (tag_) { + case TypeTag::None: + return "None"; + case TypeTag::Bool: + return std::get(value_) ? "true" : "false"; + case TypeTag::Int: + return std::to_string(std::get(value_)); + case TypeTag::Double: + return std::to_string(std::get(value_)); + case TypeTag::String: + return "\"" + std::get(value_) + "\""; + case TypeTag::Tensor: { + const auto& tensor = std::get(value_); + return "Tensor(" + std::to_string(tensor.numel()) + " elements)"; + } + case TypeTag::GenericList: { + const auto& list = std::get(value_); + std::string result = "["; + for (size_t i = 0; i < list.size(); ++i) { + if (i > 0) result += ", "; + result += list[i].to_repr(); + } + result += "]"; + return result; + } + case TypeTag::CustomClass: { + const auto& wrapper = std::get(value_); + return "CustomClass(" + wrapper.class_name + ")"; + } + default: + return "Unknown"; + } + } + + friend std::ostream& operator<<(std::ostream& os, const IValue& val) { + return os << val.to_repr(); + } + + private: + TypeTag tag_; + std::variant + value_; + template + friend T generic_to(const IValue& ivalue, _fake_type); +}; + +template <> +inline bool generic_to(const IValue& ivalue, _fake_type) { + return ivalue.to_bool(); +} + +template <> +inline int generic_to(const IValue& ivalue, _fake_type) { + return static_cast(ivalue.to_int()); +} + +template <> +inline int64_t generic_to(const IValue& ivalue, _fake_type) { + return ivalue.to_int(); +} + +template <> +inline double generic_to(const IValue& ivalue, _fake_type) { + return ivalue.to_double(); +} + +template <> +inline std::string generic_to(const IValue& ivalue, _fake_type) { + return ivalue.to_string(); +} + +template <> +inline at::Tensor generic_to(const IValue& ivalue, _fake_type) { + return ivalue.to_tensor(); +} + +template +std::vector generic_to(const IValue& ivalue, _fake_type>) { + auto list = ivalue.to_list(); + std::vector result; + result.reserve(list.size()); + for (const auto& item : list) { + result.push_back(item.to()); + } + return result; +} + +template +ArrayRef generic_to(const IValue& ivalue, _fake_type>) { + static thread_local std::vector temp_storage; + temp_storage = ivalue.to>(); + return ArrayRef(temp_storage); +} + +template +std::optional generic_to(const IValue& ivalue, + _fake_type>) { + if (ivalue.is_none()) { + return std::nullopt; + } + return std::optional(ivalue.to()); +} + +template +intrusive_ptr generic_to(const IValue& ivalue, + _fake_type>) { + return ivalue.to_custom_class(); +} + +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/cuda/CUDAContext.h b/paddle/phi/api/include/compat/ATen/cuda/CUDAContext.h new file mode 100644 index 00000000000000..27503784e71209 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/cuda/CUDAContext.h @@ -0,0 +1,18 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include diff --git a/paddle/phi/api/include/compat/ATen/cuda/EmptyTensor.cpp b/paddle/phi/api/include/compat/ATen/cuda/EmptyTensor.cpp new file mode 100644 index 00000000000000..1b78e29095fd80 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/cuda/EmptyTensor.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/api/include/api.h" +#include "paddle/phi/common/place.h" + +namespace at::detail { + +at::Tensor empty_cuda(IntArrayRef size, + ScalarType dtype, + std::optional device_opt, + std::optional memory_format_opt) { + PD_CHECK(!(memory_format_opt.has_value() && + memory_format_opt.value() != c10::MemoryFormat::Contiguous), + "`MemoryFormat` other than Contiguous is not supported now."); + return paddle::experimental::empty( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(dtype), + phi::GPUPlace()); +} + +at::Tensor empty_cuda(IntArrayRef size, const TensorOptions &options) { + return paddle::experimental::empty( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype_opt().value()), + phi::GPUPlace()); +} + +} // namespace at::detail diff --git a/paddle/phi/api/include/compat/ATen/cuda/EmptyTensor.h b/paddle/phi/api/include/compat/ATen/cuda/EmptyTensor.h new file mode 100644 index 00000000000000..080f355994c781 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/cuda/EmptyTensor.h @@ -0,0 +1,28 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace at::detail { + +using at::Tensor; +at::Tensor empty_cuda(IntArrayRef size, + ScalarType dtype, + std::optional device_opt, + std::optional memory_format_opt); + +at::Tensor empty_cuda(IntArrayRef size, const TensorOptions &options); + +} // namespace at::detail diff --git a/paddle/phi/api/include/compat/ATen/cuda/Exceptions.h b/paddle/phi/api/include/compat/ATen/cuda/Exceptions.h new file mode 100644 index 00000000000000..e8c0c76b803643 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/cuda/Exceptions.h @@ -0,0 +1,16 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include diff --git a/paddle/phi/api/include/compat/ATen/indexing.h b/paddle/phi/api/include/compat/ATen/indexing.h new file mode 100644 index 00000000000000..169e9e9f329b34 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/indexing.h @@ -0,0 +1,72 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include + +namespace at::indexing { + +constexpr int64_t INDEX_MIN = std::numeric_limits::min(); +constexpr int64_t INDEX_MAX = std::numeric_limits::max(); + +enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor }; + +constexpr std::nullopt_t None = std::nullopt; + +struct EllipsisIndexType final { + EllipsisIndexType() = default; +}; + +const EllipsisIndexType Ellipsis = EllipsisIndexType(); + +struct Slice final { + public: + Slice(std::optional start_index = std::nullopt, + std::optional stop_index = std::nullopt, + std::optional step_index = std::nullopt) { + if (!step_index.has_value()) { + step_ = c10::SymInt(1); + } else { + step_ = std::move(step_index).value(); + } + + if (!start_index.has_value()) { + start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0); + } else { + start_ = std::move(start_index).value(); + } + + if (!stop_index.has_value()) { + stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX); + } else { + stop_ = std::move(stop_index).value(); + } + } + + inline c10::SymInt start() const { return start_; } + + inline c10::SymInt stop() const { return stop_; } + + inline c10::SymInt step() const { return step_; } + + private: + c10::SymInt start_; + c10::SymInt stop_; + c10::SymInt step_; +}; + +} // namespace at::indexing diff --git a/paddle/phi/api/include/compat/ATen/native/cuda/Resize.h b/paddle/phi/api/include/compat/ATen/native/cuda/Resize.h new file mode 100644 index 00000000000000..e065c7dfc0df76 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/native/cuda/Resize.h @@ -0,0 +1,19 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include +#endif diff --git a/paddle/phi/api/include/compat/ATen/ops/abs.h b/paddle/phi/api/include/compat/ATen/ops/abs.h new file mode 100644 index 00000000000000..a0b889126d4411 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/abs.h @@ -0,0 +1,33 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/api/include/api.h" + +namespace at { + +inline at::Tensor abs(const at::Tensor& self) { + return paddle::experimental::abs(self._PD_GetInner()); +} + +} // namespace at + +namespace torch { +using at::abs; +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/ops/empty.h b/paddle/phi/api/include/compat/ATen/ops/empty.h new file mode 100644 index 00000000000000..3aee3c4dddcef9 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/empty.h @@ -0,0 +1,64 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/api/include/api.h" + +namespace at { + +inline at::Tensor empty( + at::IntArrayRef size, + at::TensorOptions options = {}, + ::std::optional memory_format = ::std::nullopt) { + PD_CHECK(!(memory_format.has_value() && + memory_format.value() != c10::MemoryFormat::Contiguous), + "`MemoryFormat` other than Contiguous is not supported now."); + return paddle::experimental::empty( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + options._PD_GetPlace()); +} + +inline at::Tensor empty(at::IntArrayRef size, + ::std::optional dtype, + ::std::optional layout, + ::std::optional device, + ::std::optional pin_memory, + ::std::optional memory_format) { + PD_CHECK(!layout.has_value(), "`layout` is not supported now."); + PD_CHECK(!(pin_memory.has_value() && pin_memory.value() != false), + "`pin_memory` other than False is not supported now."); + PD_CHECK(!(memory_format.has_value() && + memory_format.value() != c10::MemoryFormat::Contiguous), + "`MemoryFormat` other than Contiguous is not supported now."); + + return paddle::experimental::empty( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType( + dtype.value_or(c10::get_default_dtype())), + device.value_or(at::kCPU)._PD_GetInner()); +} + +#define empty_symint empty // SymIntArrayRef is same as IntArrayRef + +} // namespace at + +namespace torch { +using at::empty; +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/ops/empty_like.h b/paddle/phi/api/include/compat/ATen/ops/empty_like.h new file mode 100644 index 00000000000000..a42c3606574cb6 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/empty_like.h @@ -0,0 +1,64 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/api/include/api.h" + +namespace at { + +inline at::Tensor empty_like( + const at::Tensor& self, + at::TensorOptions options = {}, + ::std::optional memory_format = ::std::nullopt) { + PD_CHECK(!(memory_format.has_value() && + memory_format.value() != c10::MemoryFormat::Contiguous), + "`MemoryFormat` other than Contiguous is not supported now."); + + auto dtype = options.dtype_opt().value_or(self.dtype()); + auto place = options.device_opt().value_or(self.device()); + return paddle::experimental::empty_like( + self._PD_GetInner(), + compat::_PD_AtenScalarTypeToPhiDataType(dtype), + place._PD_GetInner()); +} + +inline at::Tensor empty_like(const at::Tensor& self, + ::std::optional dtype, + ::std::optional layout, + ::std::optional device, + ::std::optional pin_memory, + ::std::optional memory_format) { + PD_CHECK(!layout.has_value(), "`layout` is not supported now."); + PD_CHECK(!(pin_memory.has_value() && pin_memory.value() != false), + "`pin_memory` other than False is not supported now."); + PD_CHECK(!(memory_format.has_value() && + memory_format.value() != c10::MemoryFormat::Contiguous), + "`MemoryFormat` other than Contiguous is not supported now."); + + return paddle::experimental::empty_like( + self._PD_GetInner(), + compat::_PD_AtenScalarTypeToPhiDataType(dtype.value_or(self.dtype())), + device.value_or(self.device())._PD_GetInner()); +} + +} // namespace at + +namespace torch { +using at::empty_like; +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/ops/from_blob.h b/paddle/phi/api/include/compat/ATen/ops/from_blob.h new file mode 100644 index 00000000000000..4e3f958dd5e4b0 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/from_blob.h @@ -0,0 +1,101 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/phi/api/include/tensor_utils.h" +namespace at { + +inline Tensor from_blob( + void* data, + IntArrayRef sizes, + IntArrayRef strides, + const std::function& deleter, + const TensorOptions& options = {}, + const std::optional target_device = std::nullopt) { + return paddle::from_blob( + data, + sizes._PD_ToPaddleIntArray(), + strides._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + phi::DataLayout::NCHW, + device_or_default(target_device)._PD_GetInner(), + deleter); +} + +inline Tensor from_blob( + void* data, + IntArrayRef sizes, + IntArrayRef strides, + int64_t storage_offset, + const std::function& deleter, + const TensorOptions& options = {}, + const std::optional target_device = std::nullopt) { + PD_CHECK(storage_offset == 0, "`storage_offset` should be zero."); + + return paddle::from_blob( + data, + sizes._PD_ToPaddleIntArray(), + strides._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + phi::DataLayout::NCHW, + device_or_default(target_device)._PD_GetInner(), + deleter); +} + +inline Tensor from_blob( + void* data, + IntArrayRef sizes, + std::function deleter, + const TensorOptions& options = {}, + const std::optional target_device = std::nullopt) { + return paddle::from_blob( + data, + sizes._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + phi::DataLayout::NCHW, + device_or_default(target_device)._PD_GetInner(), + deleter); +} + +inline Tensor from_blob(void* data, + IntArrayRef sizes, + IntArrayRef strides, + const TensorOptions& options = {}) { + return paddle::from_blob( + data, + sizes._PD_ToPaddleIntArray(), + strides._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + phi::DataLayout::NCHW, + options._PD_GetPlace()); +} + +inline Tensor from_blob(void* data, + IntArrayRef sizes, + const TensorOptions& options = {}) { + return paddle::from_blob( + data, + sizes._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + phi::DataLayout::NCHW, + options._PD_GetPlace(), + nullptr); +} + +} // namespace at +namespace torch { +using at::from_blob; +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/ops/full.h b/paddle/phi/api/include/compat/ATen/ops/full.h new file mode 100644 index 00000000000000..69fd60be30ed80 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/full.h @@ -0,0 +1,83 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/api/include/api.h" + +namespace at { + +inline at::Tensor full(at::IntArrayRef size, + const at::Scalar& fill_value, + at::TensorOptions options = {}) { + return paddle::experimental::full( + size._PD_ToPaddleIntArray(), + fill_value, + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + options._PD_GetPlace()); +} + +inline at::Tensor full(at::IntArrayRef size, + const at::Scalar& fill_value, + ::std::optional dtype, + ::std::optional layout, + ::std::optional device, + ::std::optional pin_memory) { + PD_CHECK(!layout.has_value(), "`layout` is not supported now."); + PD_CHECK(!(pin_memory.has_value() && pin_memory.value() != false), + "`pin_memory` other than False is not supported now."); + return paddle::experimental::full( + size._PD_ToPaddleIntArray(), + fill_value, + compat::_PD_AtenScalarTypeToPhiDataType( + dtype.value_or(c10::get_default_dtype())), + device.value_or(at::kCPU)._PD_GetInner()); +} + +inline at::Tensor full_symint(c10::SymIntArrayRef size, + const at::Scalar& fill_value, + at::TensorOptions options = {}) { + return paddle::experimental::full( + size._PD_ToPaddleIntArray(), + fill_value, + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + options._PD_GetPlace()); +} + +inline at::Tensor full_symint(c10::SymIntArrayRef size, + const at::Scalar& fill_value, + ::std::optional dtype, + ::std::optional layout, + ::std::optional device, + ::std::optional pin_memory) { + PD_CHECK(!layout.has_value(), "`layout` is not supported now."); + PD_CHECK(!(pin_memory.has_value() && pin_memory.value() != false), + "`pin_memory` other than False is not supported now."); + return paddle::experimental::full( + size._PD_ToPaddleIntArray(), + fill_value, + compat::_PD_AtenScalarTypeToPhiDataType( + dtype.value_or(c10::get_default_dtype())), + device.value_or(at::kCPU)._PD_GetInner()); +} + +} // namespace at +namespace torch { +using at::full; +using at::full_symint; +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/ops/ones.h b/paddle/phi/api/include/compat/ATen/ops/ones.h new file mode 100644 index 00000000000000..0624faa3bf2e3e --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/ones.h @@ -0,0 +1,74 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/api/include/api.h" + +namespace at { + +inline at::Tensor ones(at::IntArrayRef size, at::TensorOptions options = {}) { + return paddle::experimental::ones( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + options._PD_GetPlace()); +} + +inline at::Tensor ones(at::IntArrayRef size, + ::std::optional dtype, + ::std::optional layout, + ::std::optional device, + ::std::optional pin_memory) { + PD_CHECK(!layout.has_value(), "`layout` is not supported now."); + PD_CHECK(!(pin_memory.has_value() && pin_memory.value() != false), + "`pin_memory` other than False is not supported now."); + return paddle::experimental::ones( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType( + dtype.value_or(c10::get_default_dtype())), + device.value_or(at::kCPU)._PD_GetInner()); +} + +inline at::Tensor ones_symint(c10::SymIntArrayRef size, + at::TensorOptions options = {}) { + return paddle::experimental::ones( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + options._PD_GetPlace()); +} + +inline at::Tensor ones_symint(c10::SymIntArrayRef size, + ::std::optional dtype, + ::std::optional layout, + ::std::optional device, + ::std::optional pin_memory) { + PD_CHECK(!layout.has_value(), "`layout` is not supported now."); + PD_CHECK(!(pin_memory.has_value() && pin_memory.value() != false), + "`pin_memory` other than False is not supported now."); + return paddle::experimental::ones( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType( + dtype.value_or(c10::get_default_dtype())), + device.value_or(at::kCPU)._PD_GetInner()); +} + +} // namespace at +namespace torch { +using at::ones; +using at::ones_symint; +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/ops/reshape.h b/paddle/phi/api/include/compat/ATen/ops/reshape.h new file mode 100644 index 00000000000000..4048109b422176 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/reshape.h @@ -0,0 +1,39 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/api/include/api.h" +namespace at { + +inline at::Tensor reshape(const at::Tensor& self, at::IntArrayRef shape) { + return paddle::experimental::reshape(self._PD_GetInner(), + shape._PD_ToPaddleIntArray()); +} + +inline at::Tensor reshape_symint(const at::Tensor& self, + c10::SymIntArrayRef shape) { + return paddle::experimental::reshape(self._PD_GetInner(), + shape._PD_ToPaddleIntArray()); +} + +} // namespace at +namespace torch { +using at::reshape; +using at::reshape_symint; +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/ops/sum.h b/paddle/phi/api/include/compat/ATen/ops/sum.h new file mode 100644 index 00000000000000..d264a2f42c7251 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/sum.h @@ -0,0 +1,75 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/api/include/api.h" + +namespace at { + +inline at::Tensor sum(const at::Tensor& self, + ::std::optional dtype = ::std::nullopt) { + return paddle::experimental::sum( + self._PD_GetInner(), + {}, + compat::_PD_AtenScalarTypeToPhiDataType( + dtype.value_or(c10::get_default_dtype())), + /*keepdim=*/false); +} + +inline at::Tensor sum(const at::Tensor& self, + at::OptionalIntArrayRef dim, + bool keepdim = false, + ::std::optional dtype = ::std::nullopt) { + return paddle::experimental::sum( + self._PD_GetInner(), + dim.has_value() ? dim.value()._PD_ToPaddleIntArray() + : paddle::experimental::IntArray(), + compat::_PD_AtenScalarTypeToPhiDataType( + dtype.value_or(c10::get_default_dtype())), + keepdim); +} + +inline at::Tensor& sum_out( + at::Tensor& + out, // NOLINT: intentional non-const reference for output parameter + const at::Tensor& self, + at::OptionalIntArrayRef dim, + bool keepdim = false, + ::std::optional dtype = ::std::nullopt) { + auto res = sum(self, dim, keepdim, dtype); + paddle::experimental::assign_out_(res._PD_GetInner(), out._PD_GetInner()); + return out; +} + +inline at::Tensor& sum_out( + at::Tensor& + out, // NOLINT: intentional non-const reference for output parameter + const at::Tensor& self, + ::std::optional dtype = ::std::nullopt) { + auto res = sum(self, dtype); + paddle::experimental::assign_out_(res._PD_GetInner(), out._PD_GetInner()); + return out; +} + +} // namespace at + +namespace torch { +using at::sum; +using at::sum_out; +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/ops/tensor.h b/paddle/phi/api/include/compat/ATen/ops/tensor.h new file mode 100644 index 00000000000000..4f95f3aa82cd2d --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/tensor.h @@ -0,0 +1,45 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once +#include +#include + +namespace at { + +#define TENSOR(T, S) \ + Tensor tensor(ArrayRef values, const TensorOptions& options); \ + inline Tensor tensor(std::initializer_list values, \ + const TensorOptions& options) { \ + return at::tensor(ArrayRef(values), options); \ + } \ + inline Tensor tensor(T value, const TensorOptions& options) { \ + return at::tensor(ArrayRef(value), options); \ + } \ + inline Tensor tensor(ArrayRef values) { \ + return at::tensor(std::move(values), at::dtype(k##S)); \ + } \ + inline Tensor tensor(std::initializer_list values) { \ + return at::tensor(ArrayRef(values)); \ + } \ + inline Tensor tensor(T value) { return at::tensor(ArrayRef(value)); } +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) +AT_FORALL_COMPLEX_TYPES(TENSOR) +#undef TENSOR + +} // namespace at diff --git a/paddle/phi/api/include/compat/ATen/ops/zeros.h b/paddle/phi/api/include/compat/ATen/ops/zeros.h new file mode 100644 index 00000000000000..04c4edbf17eac0 --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/zeros.h @@ -0,0 +1,74 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/api/include/api.h" + +namespace at { + +inline at::Tensor zeros(at::IntArrayRef size, at::TensorOptions options = {}) { + return paddle::experimental::zeros( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + options._PD_GetPlace()); +} + +inline at::Tensor zeros(at::IntArrayRef size, + ::std::optional dtype, + ::std::optional layout, + ::std::optional device, + ::std::optional pin_memory) { + PD_CHECK(!layout.has_value(), "`layout` is not supported now."); + PD_CHECK(!(pin_memory.has_value() && pin_memory.value() != false), + "`pin_memory` other than False is not supported now."); + return paddle::experimental::zeros( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType( + dtype.value_or(c10::get_default_dtype())), + device.value_or(at::kCPU)._PD_GetInner()); +} + +inline at::Tensor zeros_symint(c10::SymIntArrayRef size, + at::TensorOptions options = {}) { + return paddle::experimental::zeros( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + options._PD_GetPlace()); +} + +inline at::Tensor zeros_symint(c10::SymIntArrayRef size, + ::std::optional dtype, + ::std::optional layout, + ::std::optional device, + ::std::optional pin_memory) { + PD_CHECK(!layout.has_value(), "`layout` is not supported now."); + PD_CHECK(!(pin_memory.has_value() && pin_memory.value() != false), + "`pin_memory` other than False is not supported now."); + return paddle::experimental::zeros( + size._PD_ToPaddleIntArray(), + compat::_PD_AtenScalarTypeToPhiDataType( + dtype.value_or(c10::get_default_dtype())), + device.value_or(at::kCPU)._PD_GetInner()); +} + +} // namespace at +namespace torch { +using at::zeros; +using at::zeros_symint; +} // namespace torch diff --git a/paddle/phi/api/include/compat/ATen/ops/zeros_like.h b/paddle/phi/api/include/compat/ATen/ops/zeros_like.h new file mode 100644 index 00000000000000..e614d87543cffb --- /dev/null +++ b/paddle/phi/api/include/compat/ATen/ops/zeros_like.h @@ -0,0 +1,62 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/phi/api/include/api.h" + +namespace at { + +inline at::Tensor zeros_like( + const at::Tensor& self, + at::TensorOptions options = {}, + ::std::optional memory_format = ::std::nullopt) { + PD_CHECK(!(memory_format.has_value() && + memory_format.value() != c10::MemoryFormat::Contiguous), + "`MemoryFormat` other than Contiguous is not supported now."); + + return paddle::experimental::zeros_like( + self._PD_GetInner(), + compat::_PD_AtenScalarTypeToPhiDataType(options.dtype()), + options._PD_GetPlace()); +} + +inline at::Tensor zeros_like(const at::Tensor& self, + ::std::optional dtype, + ::std::optional layout, + ::std::optional device, + ::std::optional pin_memory, + ::std::optional memory_format) { + PD_CHECK(!layout.has_value(), "`layout` is not supported now."); + PD_CHECK(!(pin_memory.has_value() && pin_memory.value() != false), + "`pin_memory` other than False is not supported now."); + PD_CHECK(!(memory_format.has_value() && + memory_format.value() != c10::MemoryFormat::Contiguous), + "`MemoryFormat` other than Contiguous is not supported now."); + + return paddle::experimental::zeros_like( + self._PD_GetInner(), + compat::_PD_AtenScalarTypeToPhiDataType( + dtype.value_or(c10::get_default_dtype())), + device.value_or(at::kCPU)._PD_GetInner()); +} + +} // namespace at +namespace torch { +using at::zeros_like; +} // namespace torch diff --git a/paddle/phi/api/include/compat/CMakeLists.txt b/paddle/phi/api/include/compat/CMakeLists.txt new file mode 100644 index 00000000000000..8099b2cb9e78a4 --- /dev/null +++ b/paddle/phi/api/include/compat/CMakeLists.txt @@ -0,0 +1,4 @@ +collect_srcs(api_srcs SRCS ATen/cuda/EmptyTensor.cpp) +collect_srcs(api_srcs SRCS ATen/core/TensorMethods.cpp) +collect_srcs(api_srcs SRCS ATen/AccumulateType.cpp) +collect_srcs(api_srcs SRCS torch/csrc/api/include/torch/cuda.cpp) diff --git a/paddle/phi/api/include/compat/README.md b/paddle/phi/api/include/compat/README.md new file mode 100644 index 00000000000000..9a45775526e49b --- /dev/null +++ b/paddle/phi/api/include/compat/README.md @@ -0,0 +1,4 @@ +# Paddle <> PyTorch Compat API + +This folder contains an implementation of (most of) the Pytorch public API using Paddle API. +Note that this folder does not depend on Pytorch in any way. This is a standalone implementation. diff --git a/paddle/phi/api/include/compat/c10/core/DefaultDtype.h b/paddle/phi/api/include/compat/c10/core/DefaultDtype.h new file mode 100644 index 00000000000000..5ff76298cd507d --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/DefaultDtype.h @@ -0,0 +1,32 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace c10 { +static auto default_dtype = ScalarType::Float; +static auto default_complex_dtype = ScalarType::ComplexFloat; + +void inline set_default_dtype(ScalarType dtype) { default_dtype = dtype; } + +const ScalarType inline get_default_dtype() { return default_dtype; } + +ScalarType inline get_default_dtype_as_scalartype() { return default_dtype; } + +const ScalarType inline get_default_complex_dtype() { + return default_complex_dtype; +} +} // namespace c10 diff --git a/paddle/phi/api/include/compat/c10/core/Device.h b/paddle/phi/api/include/compat/c10/core/Device.h new file mode 100644 index 00000000000000..f361b598e246cd --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/Device.h @@ -0,0 +1,47 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace c10 { +using DeviceIndex = int8_t; + +struct Device final { + using Type = DeviceType; + Device(phi::Place place) : inner_(place) {} + Device(DeviceType type, DeviceIndex index = 0) + : inner_(phi::Place(type, index)) {} // NOLINT + + DeviceIndex index() const noexcept { return inner_.GetDeviceId(); } + + DeviceType type() const { return inner_.GetType(); } + + phi::Place _PD_GetInner() const { return inner_; } + + private: + phi::Place inner_; +}; + +} // namespace c10 + +namespace at { +using c10::Device; +using c10::DeviceIndex; +} // namespace at + +namespace torch { +using c10::Device; +using c10::DeviceIndex; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/core/DeviceType.h b/paddle/phi/api/include/compat/c10/core/DeviceType.h new file mode 100644 index 00000000000000..713da22d706c7c --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/DeviceType.h @@ -0,0 +1,43 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/common/place.h" + +namespace c10 { + +using DeviceType = phi::AllocationType; + +constexpr DeviceType kCUDA = DeviceType::GPU; +constexpr DeviceType kCPU = DeviceType::CPU; +constexpr DeviceType kCUSTOM = DeviceType::CUSTOM; + +} // namespace c10 + +namespace at { +using c10::DeviceType; +using c10::kCPU; +using c10::kCUDA; +using c10::kCUSTOM; +} // namespace at + +namespace torch { +using c10::DeviceType; +using c10::kCPU; +using c10::kCUDA; +using c10::kCUSTOM; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/core/Layout.h b/paddle/phi/api/include/compat/c10/core/Layout.h new file mode 100644 index 00000000000000..4916dd768be1a5 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/Layout.h @@ -0,0 +1,96 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +#include +#include + +namespace c10 { +enum class Layout : int8_t { + Strided, + Sparse, + SparseCsr, + Mkldnn, + SparseCsc, + SparseBsr, + SparseBsc, + Jagged, + NumOptions +}; + +constexpr auto kStrided = Layout::Strided; +constexpr auto kSparse = Layout::Sparse; +constexpr auto kSparseCsr = Layout::SparseCsr; +constexpr auto kMkldnn = Layout::Mkldnn; +constexpr auto kSparseCsc = Layout::SparseCsc; +constexpr auto kSparseBsr = Layout::SparseBsr; +constexpr auto kSparseBsc = Layout::SparseBsc; +constexpr auto kJagged = Layout::Jagged; + +inline std::ostream& operator<<(std::ostream& stream, c10::Layout layout) { + switch (layout) { + case c10::kStrided: + return stream << "Strided"; + case c10::kSparse: + return stream << "Sparse"; + case c10::kSparseCsr: + return stream << "SparseCsr"; + case c10::kSparseCsc: + return stream << "SparseCsc"; + case c10::kSparseBsr: + return stream << "SparseBsr"; + case c10::kSparseBsc: + return stream << "SparseBsc"; + case c10::kMkldnn: + return stream << "Mkldnn"; + case c10::kJagged: + return stream << "Jagged"; + default: + TORCH_CHECK(false, "Unknown layout"); + } +} + +} // namespace c10 + +namespace at { +using c10::kJagged; +using c10::kMkldnn; +using c10::kSparse; +using c10::kSparseBsc; +using c10::kSparseBsr; +using c10::kSparseCsc; +using c10::kSparseCsr; +using c10::kStrided; + +using c10::Layout; +} // namespace at +namespace torch { +using c10::kJagged; +using c10::kMkldnn; +using c10::kSparse; +using c10::kSparseBsc; +using c10::kSparseBsr; +using c10::kSparseCsc; +using c10::kSparseCsr; +using c10::kStrided; + +using c10::Layout; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/core/MemoryFormat.h b/paddle/phi/api/include/compat/c10/core/MemoryFormat.h new file mode 100644 index 00000000000000..d3fcfc3063a497 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/MemoryFormat.h @@ -0,0 +1,40 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +namespace c10 { +enum class PADDLE_API MemoryFormat : int8_t { + Contiguous, + Preserve, + ChannelsLast, + ChannelsLast3d, + NumOptions +}; + +} + +namespace at { +using c10::MemoryFormat; +} // namespace at + +namespace torch { +using c10::MemoryFormat; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/core/Scalar.h b/paddle/phi/api/include/compat/c10/core/Scalar.h new file mode 100644 index 00000000000000..d1f287f6341654 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/Scalar.h @@ -0,0 +1,28 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" + +namespace c10 { +using Scalar = paddle::experimental::Scalar; +} +namespace at { +using c10::Scalar; +} // namespace at + +namespace torch { +using c10::Scalar; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/core/ScalarType.h b/paddle/phi/api/include/compat/c10/core/ScalarType.h new file mode 100644 index 00000000000000..6c8867eb530511 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/ScalarType.h @@ -0,0 +1,304 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/common/macros.h" + +namespace c10 { + +// dummy struct for uint1 to uint7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_uint1_7_t {}; + +// dummy struct for int1 to int7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_int1_7_t {}; + +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ + _(uint8_t, UINT8, Byte) /* 0 */ \ + _(int8_t, INT8, Char) /* 1 */ \ + _(int16_t, INT16, Short) /* 2 */ \ + _(int, INT32, Int) /* 3 */ \ + _(int64_t, INT64, Long) /* 4 */ \ + _(at::Half, FLOAT16, Half) \ + _(float, FLOAT32, Float) /* 6 */ \ + _(double, FLOAT64, Double) /* 7 */ \ + _(c10::complex, COMPLEX64, ComplexFloat) /* 9 */ \ + _(c10::complex, COMPLEX128, ComplexDouble) /* 10 */ \ + _(bool, BOOL, Bool) /* 11 */ \ + _(at::BFloat16, BFLOAT16, BFloat16) /* 15 */ \ + _(c10::Float8_e5m2, FLOAT8_E5M2, Float8_e5m2) /* 23 */ \ + _(c10::Float8_e4m3fn, FLOAT8_E4M3FN, Float8_e4m3fn) /* 24 */ \ + _(uint16_t, UINT16, UInt16) /* 27 */ \ + _(uint32_t, UINT32, UInt32) /* 28 */ \ + _(uint64_t, UINT64, UInt64) /* 29 */ \ + _(c10::dummy_uint1_7_t<1>, UInt1, UInt1) /* 30 */ \ + _(c10::dummy_uint1_7_t<2>, UInt2, UInt2) /* 31 */ \ + _(c10::dummy_uint1_7_t<3>, UInt3, UInt3) /* 32 */ \ + _(c10::dummy_uint1_7_t<4>, UInt4, UInt4) /* 33 */ \ + _(c10::dummy_uint1_7_t<5>, UInt5, UInt5) /* 34 */ \ + _(c10::dummy_uint1_7_t<6>, UInt6, UInt6) /* 35 */ \ + _(c10::dummy_uint1_7_t<7>, UInt7, UInt7) /* 36 */ \ + _(c10::dummy_int1_7_t<1>, Int1, Int1) /* 37 */ \ + _(c10::dummy_int1_7_t<2>, Int2, Int2) /* 38 */ \ + _(c10::dummy_int1_7_t<3>, Int3, Int3) /* 39 */ \ + _(c10::dummy_int1_7_t<4>, Int4, Int4) /* 40 */ \ + _(c10::dummy_int1_7_t<5>, Int5, Int5) /* 41 */ \ + _(c10::dummy_int1_7_t<6>, Int6, Int6) /* 42 */ \ + _(c10::dummy_int1_7_t<7>, Int7, Int7) /* 43 */ + +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(at::Half, Half) \ + _(float, Float) \ + _(double, Double) \ + _(c10::complex, ComplexFloat) \ + _(c10::complex, ComplexDouble) \ + _(bool, Bool) \ + _(at::BFloat16, BFloat16) \ + _(at::Float8_e5m2, Float8_e5m2) + +#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(c10::complex, ComplexFloat) \ + _(c10::complex, ComplexDouble) \ + _(bool, Bool) \ + _(at::BFloat16, BFloat16) \ + _(c10::Float8_e5m2, Float8_e5m2) \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) + +#define AT_FORALL_QINT_TYPES(_) \ + _(c10::qint8, QInt8) \ + _(c10::quint8, QUInt8) \ + _(c10::qint32, QInt32) \ + _(c10::quint4x2, QUInt4x2) \ + _(c10::quint2x4, QUInt2x4) + +#define FOREACH_PADDLE_AND_TORCH_DTYPES(_) \ + _(uint8_t, UINT8, Byte) \ + _(int8_t, INT8, Char) \ + _(int16_t, INT16, Short) \ + _(int32_t, INT32, Int) \ + _(int64_t, INT64, Long) \ + _(at::Half, FLOAT16, Half) \ + _(float, FLOAT32, Float) \ + _(double, FLOAT64, Double) \ + _(c10::complex, COMPLEX64, ComplexFloat) \ + _(c10::complex, COMPLEX128, ComplexDouble) \ + _(bool, BOOL, Bool) \ + _(at::BFloat16, BFLOAT16, BFloat16) \ + _(c10::Float8_e5m2, FLOAT8_E5M2, Float8_e5m2) \ + _(c10::Float8_e4m3fn, FLOAT8_E4M3FN, Float8_e4m3fn) \ + _(uint16_t, UINT16, UInt16) \ + _(uint32_t, UINT32, UInt32) + +enum class PADDLE_API ScalarType : int8_t { +#define DEFINE_ST_ENUM_VAL_(_1, _2, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_) +#undef DEFINE_ENUM_ST_ENUM_VAL_ +#define DEFINE_ST_ENUM_VAL_FOR_QINTS_(_1, n) n, + AT_FORALL_QINT_TYPES(DEFINE_ST_ENUM_VAL_FOR_QINTS_) +#undef DEFINE_ST_ENUM_VAL_FOR_QINTS_ + Undefined, + NumOptions +}; +namespace impl { + +// These are used to map ScalarTypes to C++ types. + +template +struct ScalarTypeToCPPType; + +#define SPECIALIZE_ScalarTypeToCPPType(cpp_type, _2, scalar_type) \ + template <> \ + struct ScalarTypeToCPPType { \ + using type = cpp_type; \ + \ + static type t; \ + }; + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_ScalarTypeToCPPType) + +#undef SPECIALIZE_ScalarTypeToCPPType + +template +using ScalarTypeToCPPTypeT = typename ScalarTypeToCPPType::type; + +} // namespace impl + +template +struct CppTypeToScalarType; + +#define SPECIALIZE_CppTypeToScalarType(cpp_type, _2, scalar_type) \ + template <> \ + struct CppTypeToScalarType \ + : std::integral_constant {}; + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType) + +#undef SPECIALIZE_CppTypeToScalarType + +#define DEFINE_CONSTANT(_1, _2, name) \ + constexpr ScalarType k##name = ScalarType::name; + +AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CONSTANT) +#undef DEFINE_CONSTANT + +#define AT_FORALL_SCALAR_TYPES_AND(SCALARTYPE, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE>::t), \ + SCALARTYPE) + +#define AT_FORALL_SCALAR_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) + +#define AT_FORALL_SCALAR_TYPES_AND3(SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE1>::t), \ + SCALARTYPE1) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE2>::t), \ + SCALARTYPE2) \ + _(decltype(::c10::impl::ScalarTypeToCPPType< \ + ::c10::ScalarType::SCALARTYPE3>::t), \ + SCALARTYPE3) + +#define AT_FORALL_COMPLEX_TYPES(_) \ + _(c10::complex, ComplexFloat) \ + _(c10::complex, ComplexDouble) + +inline const char* toString(ScalarType t) { +#define DEFINE_CASE(_1, _2, name) \ + case ScalarType::name: \ + return #name; + + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE) + default: + return "UNKNOWN_SCALAR"; + } +#undef DEFINE_CASE +} + +inline size_t elementSize(ScalarType t) { +#define CASE_ELEMENTSIZE_CASE(ctype, _2, name) \ + case ScalarType::name: \ + return sizeof(ctype); + + switch (t) { + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CASE_ELEMENTSIZE_CASE) + default: + TORCH_CHECK(false, "Unknown ScalarType"); + } +#undef CASE_ELEMENTSIZE_CASE +} + +inline bool isIntegralType(ScalarType t, bool includeBool) { + bool isIntegral = (t == ScalarType::Byte || t == ScalarType::Char || + t == ScalarType::Int || t == ScalarType::Long || + t == ScalarType::Short || t == ScalarType::UInt16 || + t == ScalarType::UInt32 || t == ScalarType::UInt64); + + return isIntegral || (includeBool && t == ScalarType::Bool); +} + +inline bool isFloat8Type(ScalarType t) { + return t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn; + // || t == ScalarType::Float8_e5m2fnuz + // || t == ScalarType::Float8_e4m3fnuz + // || t == ScalarType::Float8_e8m0fnu +} + +inline bool isReducedFloatingType(ScalarType t) { + return t == ScalarType::Half || t == ScalarType::BFloat16 || isFloat8Type(t); + //|| t == ScalarType::Float4_e2m1fn_x2 +} + +inline bool isFloatingType(ScalarType t) { + return t == ScalarType::Double || t == ScalarType::Float || + isReducedFloatingType(t); +} + +inline bool isComplexType(ScalarType t) { + return ( + /* t == ScalarType::ComplexHalf || */ t == ScalarType::ComplexFloat || + t == ScalarType::ComplexDouble); +} + +inline std::ostream& operator<<(std::ostream& stream, ScalarType scalar_type) { + return stream << toString(scalar_type); +} + +} // namespace c10 + +namespace at { +using c10::CppTypeToScalarType; +using c10::ScalarType; +} // namespace at +namespace torch { +using c10::CppTypeToScalarType; +using c10::ScalarType; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/core/SymInt.h b/paddle/phi/api/include/compat/c10/core/SymInt.h new file mode 100644 index 00000000000000..d0e01b2d7469da --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/SymInt.h @@ -0,0 +1,22 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +namespace c10 { +using SymInt = int64_t; + +} // namespace c10 diff --git a/paddle/phi/api/include/compat/c10/core/SymIntArrayRef.h b/paddle/phi/api/include/compat/c10/core/SymIntArrayRef.h new file mode 100644 index 00000000000000..11204851ec1621 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/SymIntArrayRef.h @@ -0,0 +1,29 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace c10 { +using SymIntArrayRef = IntArrayRef; // SymIntArrayRef is same as ArrayRef +} // namespace c10 + +namespace at { +using c10::SymIntArrayRef; +} // namespace at +namespace torch { +using c10::SymIntArrayRef; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/core/Symfloat.h b/paddle/phi/api/include/compat/c10/core/Symfloat.h new file mode 100644 index 00000000000000..3fc11c6c1abd53 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/Symfloat.h @@ -0,0 +1,26 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace c10 { +using SymFloat = double; +} // namespace c10 + +namespace at { +using c10::SymFloat; +} // namespace at +namespace torch { +using c10::SymFloat; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/core/TensorOptions.h b/paddle/phi/api/include/compat/c10/core/TensorOptions.h new file mode 100644 index 00000000000000..7bae10ac338b51 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/TensorOptions.h @@ -0,0 +1,322 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/common/macros.h" +#include "paddle/phi/common/place.h" + +namespace c10 { +inline Layout layout_or_default(std::optional layout) { + return layout.value_or(kStrided); +} + +inline Device device_or_default(std::optional device) { + return device.value_or(Device(kCPU)); +} +inline ScalarType dtype_or_default(std::optional dtype) { + return dtype.value_or(get_default_dtype()); +} + +inline bool pinned_memory_or_default(std::optional pinned_memory) { + return pinned_memory.value_or(false); +} + +struct PADDLE_API TensorOptions { + TensorOptions() + : requires_grad_(false), + pinned_memory_(false), + has_device_(false), + has_dtype_(false), + has_layout_(false), + has_requires_grad_(false), + has_pinned_memory_(false), + has_memory_format_(false) {} + + /* implicit */ explicit TensorOptions(Layout layout) // NOLINT + : TensorOptions() { + this->set_layout(layout); + } + + template < + typename T, + typename = std::enable_if_t, Device>>> + /* implicit */ explicit TensorOptions(T&& device) // NOLINT + : TensorOptions() { + this->set_device(std::forward(device)); + } + + /* implicit */ TensorOptions(c10::ScalarType dtype) // NOLINT + : TensorOptions() { + this->set_dtype(dtype); + } + + /* implicit */ TensorOptions(MemoryFormat memory_format) // NOLINT + : TensorOptions() { + set_memory_format(memory_format); + } + + [[nodiscard]] TensorOptions device( + std::optional device) const noexcept { + TensorOptions r = *this; + r.set_device(device); + return r; + } + + [[nodiscard]] TensorOptions device_index( + c10::DeviceIndex device_index) const noexcept { + return device(Device(kCUDA, device_index)); + } + + [[nodiscard]] TensorOptions dtype( + std::optional dtype) const noexcept { + TensorOptions r = *this; + r.set_dtype(dtype); + return r; + } + + template + TensorOptions& dtype() { + has_dtype_ = true; + return *this; + } + + [[nodiscard]] TensorOptions layout( + std::optional layout) const noexcept { + TensorOptions r = *this; + r.set_layout(layout); + return r; + } + + [[nodiscard]] TensorOptions requires_grad( + std::optional requires_grad) const noexcept { + TensorOptions r = *this; + r.set_requires_grad(requires_grad); + return r; + } + + [[nodiscard]] TensorOptions pinned_memory( + std::optional pinned_memory) const noexcept { + TensorOptions r = *this; + r.set_pinned_memory(pinned_memory); + return r; + } + + [[nodiscard]] TensorOptions memory_format( + std::optional memory_format) const noexcept { + TensorOptions r = *this; + r.set_memory_format(memory_format); + return r; + } + + Device device() const noexcept { return device_or_default(device_opt()); } + + bool has_device() const noexcept { return has_device_; } + + std::optional device_opt() const noexcept { + return has_device_ ? std::make_optional(device_) : std::nullopt; + } + + c10::DeviceIndex device_index() const noexcept { return device().index(); } + + ScalarType dtype() const noexcept { return dtype_or_default(dtype_opt()); } + + bool has_dtype() const noexcept { return has_dtype_; } + + std::optional dtype_opt() const noexcept { + return has_dtype_ ? std::make_optional(dtype_) : std::nullopt; + } + + Layout layout() const noexcept { return layout_or_default(layout_opt()); } + + bool has_layout() const noexcept { return has_layout_; } + + std::optional layout_opt() const noexcept { + return has_layout_ ? std::make_optional(layout_) : std::nullopt; + } + + bool requires_grad() const noexcept { + return has_requires_grad_ ? requires_grad_ : false; + } + + bool has_requires_grad() const noexcept { return has_requires_grad_; } + + std::optional requires_grad_opt() const noexcept { + return has_requires_grad_ ? std::make_optional(requires_grad_) + : std::nullopt; + } + + bool pinned_memory() const noexcept { + return pinned_memory_or_default(pinned_memory_opt()); + } + + bool has_pinned_memory() const noexcept { return has_pinned_memory_; } + + bool is_sparse() const { return layout_ == c10::Layout::Sparse; } + + bool is_sparse_csr() const { return layout_ == c10::Layout::SparseCsr; } + + bool is_sparse_compressed() const { + return layout_ == c10::Layout::SparseCsr || + layout_ == c10::Layout::SparseCsc || + layout_ == c10::Layout::SparseBsr || + layout_ == c10::Layout::SparseBsc; + } + + std::optional pinned_memory_opt() const noexcept { + return has_pinned_memory_ ? std::make_optional(pinned_memory_) + : std::nullopt; + } + + bool has_memory_format() const noexcept { return has_memory_format_; } + + std::optional memory_format_opt() const noexcept { + return has_memory_format_ ? std::make_optional(memory_format_) + : std::nullopt; + } + + TensorOptions merge_memory_format( + std::optional optional_memory_format) const noexcept { + TensorOptions merged = *this; + if (optional_memory_format.has_value()) { + merged.set_memory_format(optional_memory_format); + } + return merged; + } + + ::phi::Place _PD_GetPlace() const { return device_._PD_GetInner(); } + + private: + void set_device(std::optional device) & noexcept { + if (device) { + device_ = *device; + has_device_ = true; + } else { + has_device_ = false; + } + } + + void set_dtype(std::optional dtype) & noexcept { + if (dtype) { + dtype_ = *dtype; + has_dtype_ = true; + } else { + has_dtype_ = false; + } + } + + void set_layout(std::optional layout) & noexcept { + if (layout) { + layout_ = *layout; + has_layout_ = true; + } else { + has_layout_ = false; + } + } + + void set_requires_grad(std::optional requires_grad) & noexcept { + if (requires_grad) { + requires_grad_ = *requires_grad; + has_requires_grad_ = true; + } else { + has_requires_grad_ = false; + } + } + + void set_pinned_memory(std::optional pinned_memory) & noexcept { + if (pinned_memory) { + pinned_memory_ = *pinned_memory; + has_pinned_memory_ = true; + } else { + has_pinned_memory_ = false; + } + } + + void set_memory_format(std::optional memory_format) & noexcept { + if (memory_format) { + memory_format_ = *memory_format; + has_memory_format_ = true; + } else { + has_memory_format_ = false; + } + } + + Device device_ = c10::kCPU; + c10::ScalarType dtype_ = c10::ScalarType::Float; + Layout layout_ = at::kStrided; // 8-bit + MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit + + bool requires_grad_ : 1; + bool pinned_memory_ : 1; + + bool has_device_ : 1; + bool has_dtype_ : 1; + bool has_layout_ : 1; + bool has_requires_grad_ : 1; + bool has_pinned_memory_ : 1; + bool has_memory_format_ : 1; +}; + +inline TensorOptions dtype(ScalarType dtype) { + return TensorOptions().dtype(dtype); +} + +inline TensorOptions layout(Layout layout) { + return TensorOptions().layout(layout); +} + +inline TensorOptions device(Device device) { + return TensorOptions().device(device); +} + +inline TensorOptions device_index(c10::DeviceIndex device_index) { + return TensorOptions().device_index(device_index); +} + +inline TensorOptions requires_grad(bool requires_grad = true) { + return TensorOptions().requires_grad(requires_grad); +} + +inline TensorOptions memory_format(MemoryFormat memory_format) { + return TensorOptions().memory_format(memory_format); +} + +std::ostream& operator<<(std::ostream& stream, const TensorOptions& options); + +inline std::string toString(const TensorOptions& options) { + std::ostringstream stream; + stream << options; + return stream.str(); +} + +} // namespace c10 + +namespace at { +using namespace c10; // NOLINT +} // namespace at + +namespace torch { +using namespace c10; // NOLINT +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/cuda/CUDAException.h b/paddle/phi/api/include/compat/c10/cuda/CUDAException.h new file mode 100644 index 00000000000000..e2cca2445d04ae --- /dev/null +++ b/paddle/phi/api/include/compat/c10/cuda/CUDAException.h @@ -0,0 +1,22 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#define C10_CUDA_CHECK(expr) \ + do { \ + } while (0); // TODO(SigureMo): impl this +#define C10_CUDA_KERNEL_LAUNCH_CHECK(expr) \ + do { \ + } while (0); // TODO(SigureMo): impl this diff --git a/paddle/phi/api/include/compat/c10/cuda/CUDAFunctions.h b/paddle/phi/api/include/compat/c10/cuda/CUDAFunctions.h new file mode 100644 index 00000000000000..82fce0a440af99 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/cuda/CUDAFunctions.h @@ -0,0 +1,56 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#ifdef PADDLE_WITH_CUDA +#include +using gpuStream_t = cudaStream_t; +#endif + +#ifdef PADDLE_WITH_HIP +#include +using gpuStream_t = hipStream_t; +#endif + +#include "paddle/phi/core/platform/device/gpu/gpu_info.h" +#include "paddle/phi/core/platform/device_event_base.h" + +namespace c10::cuda { + +void device_synchronize() { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + int curr_device_id = paddle::platform::GetCurrentDeviceId(); + paddle::platform::SetDeviceId(curr_device_id); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#endif +#else + PADDLE_THROW(common::errors::Unavailable( + "Paddle is not compiled with CUDA. Cannot visit device synchronize.")); +#endif +} + +void __inline__ stream_synchronize(gpuStream_t stream) { + phi::backends::gpu::GpuStreamSync(stream); +} +} // namespace c10::cuda + +namespace at::cuda { +using c10::cuda::device_synchronize; +using c10::cuda::stream_synchronize; +} // namespace at::cuda diff --git a/paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h b/paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h new file mode 100644 index 00000000000000..cdce54630aaa6f --- /dev/null +++ b/paddle/phi/api/include/compat/c10/cuda/CUDAGuard.h @@ -0,0 +1,120 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include +#include "paddle/phi/core/platform/cuda_device_guard.h" + +namespace c10::cuda { +struct CUDAGuard { + explicit CUDAGuard() = delete; // NOLINT + + explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {} + + explicit CUDAGuard(Device device) : guard_(device._PD_GetInner()) {} + + CUDAGuard(const CUDAGuard&) = delete; + CUDAGuard& operator=(const CUDAGuard&) = delete; + + CUDAGuard(CUDAGuard&& other) = delete; + CUDAGuard& operator=(CUDAGuard&& other) = delete; + ~CUDAGuard() = default; + + void set_device(Device device) { guard_.SetDevice(device._PD_GetInner()); } + + void reset_device(Device device) { set_device(device); } + + void set_index(DeviceIndex device_index) { + guard_.SetDeviceIndex(device_index); + } + + Device current_device() const { + return c10::Device(c10::kCUDA, phi::backends::gpu::GetCurrentDeviceId()); + } + + private: + paddle::platform::CUDADeviceGuard guard_; +}; + +struct OptionalCUDAGuard { + OptionalCUDAGuard() = default; + + explicit OptionalCUDAGuard(std::optional device_opt) : guard_() { + if (device_opt.has_value()) { + guard_.emplace(device_opt.value()._PD_GetInner()); + } + } + + explicit OptionalCUDAGuard(std::optional device_index_opt) + : guard_() { + if (device_index_opt.has_value()) { + guard_.emplace(device_index_opt.value()); + } + } + + // Copy is not allowed + OptionalCUDAGuard(const OptionalCUDAGuard&) = delete; + OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete; + + OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete; + + OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete; + ~OptionalCUDAGuard() = default; + + void set_device(Device device) { + if (!guard_.has_value()) { + guard_.emplace(device._PD_GetInner()); + } else { + guard_->SetDevice(device._PD_GetInner()); + } + } + + void reset_device(Device device) { + if (!guard_.has_value()) { + guard_.emplace(device._PD_GetInner()); + } else { + guard_->SetDevice(device._PD_GetInner()); + } + } + + void set_index(DeviceIndex device_index) { + if (!guard_.has_value()) { + guard_.emplace(device_index); + } else { + guard_->SetDeviceIndex(device_index); + } + } + + std::optional current_device() const { + return guard_.has_value() + ? std::make_optional(c10::Device( + c10::kCUDA, phi::backends::gpu::GetCurrentDeviceId())) + : std::nullopt; + } + + private: + std::optional guard_; +}; + +} // namespace c10::cuda + +namespace at::cuda { +using c10::cuda::CUDAGuard; +using c10::cuda::OptionalCUDAGuard; +} // namespace at::cuda diff --git a/paddle/phi/api/include/compat/c10/cuda/CUDAStream.h b/paddle/phi/api/include/compat/c10/cuda/CUDAStream.h new file mode 100644 index 00000000000000..84ae56fac4f9c4 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/cuda/CUDAStream.h @@ -0,0 +1,56 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/api/include/context_pool.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/cuda_stream.h" + +namespace at::cuda { + +using StreamId = int64_t; + +class CUDAStream { + public: + CUDAStream() = delete; + explicit CUDAStream(const gpuStream_t& stream) : raw_stream_(stream) {} + StreamId id() const { return reinterpret_cast(raw_stream_); } + + operator gpuStream_t() const { return raw_stream_; } + + // operator Stream() const { return unwrap(); } + + DeviceType device_type() const { return DeviceType::CUDA; } + + const gpuStream_t& stream() const { return raw_stream_; } + + private: + gpuStream_t raw_stream_; +}; + +inline CUDAStream getCurrentCUDAStream(c10::DeviceIndex device_index = -1) { + if (device_index == -1) { + device_index = phi::backends::gpu::GetCurrentDeviceId(); + } + + return CUDAStream( + paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream()); +} + +#define getDefaultCUDAStream getCurrentCUDAStream; + +} // namespace at::cuda diff --git a/paddle/phi/api/include/compat/c10/cuda/PhiloxCudaState.h b/paddle/phi/api/include/compat/c10/cuda/PhiloxCudaState.h new file mode 100644 index 00000000000000..c920708e536353 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/cuda/PhiloxCudaState.h @@ -0,0 +1,60 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/place.h" + +namespace at { + +struct PhiloxCudaState { + PhiloxCudaState() = default; + // Called if graph capture is not underway + PhiloxCudaState(uint64_t seed, uint64_t offset) { + seed_.val = seed; + offset_.val = offset; + } + // Called if graph capture is underway + PhiloxCudaState(int64_t* seed, + int64_t* offset_extragraph, + uint32_t offset_intragraph) { + seed_.ptr = seed; + offset_.ptr = offset_extragraph; + offset_intragraph_ = offset_intragraph; + captured_ = true; + } + + union Payload { + uint64_t val; + int64_t* ptr; + }; + + Payload seed_{}; + Payload offset_{}; + uint32_t offset_intragraph_ = 0; + bool captured_ = false; +}; + +inline PhiloxCudaState _PD_Internal_GetDefaultPhiloxCudaState(int64_t inc) { + auto dev_ctx = phi::DeviceContextPool::Instance().Get(phi::GPUPlace()); + auto cuda_ctx = static_cast(dev_ctx); + // auto gen = phi::GetRandomSeedGenerator(""); + auto* gen = cuda_ctx->GetGenerator(); + auto seed_offset_pair = gen->IncrementOffset(inc); + return PhiloxCudaState(seed_offset_pair.first, seed_offset_pair.second); +} + +} // namespace at diff --git a/paddle/phi/api/include/compat/c10/macros/Macros.h b/paddle/phi/api/include/compat/c10/macros/Macros.h new file mode 100644 index 00000000000000..7f40a0b1cf18c8 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/macros/Macros.h @@ -0,0 +1,35 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#define C10_CONCATENATE_IMPL(s1, s2) s1##s2 +#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2) + +#define C10_MACRO_EXPAND(args) args + +#define C10_STRINGIZE_IMPL(x) #x +#define C10_STRINGIZE(x) C10_STRINGIZE_IMPL(x) + +#ifdef __COUNTER__ +#define C10_UID __COUNTER__ +#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__) +#else +#define C10_UID __LINE__ +#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __LINE__) +#endif diff --git a/paddle/phi/api/include/compat/c10/util/ArrayRef.h b/paddle/phi/api/include/compat/c10/util/ArrayRef.h new file mode 100644 index 00000000000000..9cf38a4dbb1dc9 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/ArrayRef.h @@ -0,0 +1,200 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/phi/common/int_array.h" + +namespace c10 { + +#define TORCH_CHECK_CONSTEXPR(COND, MSG) \ + ((COND) ? void(0) : throw std::runtime_error(MSG)) + +template +class ArrayRef { + private: + /// The start of the array, in an external buffer. + const T* Data; + + /// The number of elements. + size_t Length; + + public: + using iterator = const T*; + using const_iterator = const T*; + using size_type = size_t; + using value_type = T; + + using reverse_iterator = std::reverse_iterator; + + /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {} + + constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} // NOLINT + + constexpr ArrayRef(const T* data, size_t length) + : Data(data), Length(length) {} + + constexpr ArrayRef(const T* begin, const T* end) + : Data(begin), Length(end - begin) {} + + template ().data()), + typename = std::enable_if_t<(std::is_same_v || + std::is_same_v)>> + /* implicit */ ArrayRef(const Container& container) // NOLINT + : Data(container.data()), Length(container.size()) {} + + template + /* implicit */ ArrayRef(const std::vector& Vec) // NOLINT + : Data(Vec.data()), Length(Vec.size()) { + static_assert(!std::is_same_v, + "ArrayRef cannot be constructed from a " + "std::vector bitfield."); + } + + template + /* implicit */ constexpr ArrayRef(const std::array& Arr) // NOLINT + : Data(Arr.data()), Length(N) {} + + template + /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) // NOLINT + : Data(Arr), Length(N) {} + + /* implicit */ constexpr ArrayRef(const std::initializer_list& Vec) + : Data(std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), + Length(Vec.size()) {} + + constexpr iterator begin() const { return Data; } + constexpr iterator end() const { return Data + Length; } + + constexpr const_iterator cbegin() const { return Data; } + constexpr const_iterator cend() const { return Data + Length; } + + constexpr reverse_iterator rbegin() const { return reverse_iterator(end()); } + constexpr reverse_iterator rend() const { return reverse_iterator(begin()); } + + constexpr bool allMatch(const std::function& pred) const { + return std::all_of(cbegin(), cend(), pred); + } + + constexpr bool empty() const { return Length == 0; } + + constexpr const T* data() const { return Data; } + + constexpr size_t size() const { return Length; } + + constexpr const T& front() const { + TORCH_CHECK_CONSTEXPR( + !empty(), "ArrayRef: attempted to access front() of empty list"); + return Data[0]; + } + + constexpr const T& back() const { + TORCH_CHECK_CONSTEXPR(!empty(), + "ArrayRef: attempted to access back() of empty list"); + return Data[Length - 1]; + } + + constexpr bool equals(ArrayRef RHS) const { + return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); + } + + /// slice(n, m) - Take M elements of the array starting at element N + constexpr ArrayRef slice(size_t N, size_t M) const { + TORCH_CHECK_CONSTEXPR(N + M <= size(), "ArrayRef: invalid slice"); + return ArrayRef(data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + constexpr ArrayRef slice(size_t N) const { + TORCH_CHECK_CONSTEXPR(N <= size(), "ArrayRef: invalid slice"); + return slice(N, size() - N); + } + + constexpr const T& operator[](size_t Index) const { return Data[Index]; } + + /// Vector compatibility + constexpr const T& at(size_t Index) const { + TORCH_CHECK_CONSTEXPR(Index < Length, "ArrayRef: invalid index"); + return Data[Index]; + } + + template + std::enable_if_t, ArrayRef>& operator=( + U&& Temporary) = delete; + + template + std::enable_if_t, ArrayRef>& operator=( + std::initializer_list) = delete; + + std::vector vec() const { return std::vector(Data, Data + Length); } + + const paddle::experimental::IntArray _PD_ToPaddleIntArray() const { + return paddle::experimental::IntArray(Data, Length); + } +}; + +template +bool operator==(c10::ArrayRef a1, c10::ArrayRef a2) { + return a1.equals(a2); +} + +template +bool operator!=(c10::ArrayRef a1, c10::ArrayRef a2) { + return !a1.equals(a2); +} + +template +bool operator==(const std::vector& a1, c10::ArrayRef a2) { + return c10::ArrayRef(a1).equals(a2); +} + +template +bool operator!=(const std::vector& a1, c10::ArrayRef a2) { + return !c10::ArrayRef(a1).equals(a2); +} + +template +bool operator==(c10::ArrayRef a1, const std::vector& a2) { + return a1.equals(c10::ArrayRef(a2)); +} + +template +bool operator!=(c10::ArrayRef a1, const std::vector& a2) { + return !a1.equals(c10::ArrayRef(a2)); +} +using IntArrayRef = ArrayRef; + +} // namespace c10 + +namespace at { +using c10::ArrayRef; +using c10::IntArrayRef; +} // namespace at + +namespace torch { +using c10::ArrayRef; +using c10::IntArrayRef; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/BFloat16.h b/paddle/phi/api/include/compat/c10/util/BFloat16.h new file mode 100644 index 00000000000000..77f8524e13a7d9 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/BFloat16.h @@ -0,0 +1,29 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/bfloat16.h" + +namespace c10 { +using BFloat16 = ::phi::dtype::bfloat16; +} // namespace c10 + +namespace at { +using c10::BFloat16; +} // namespace at + +namespace torch { +using c10::BFloat16; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/Exception.h b/paddle/phi/api/include/compat/c10/util/Exception.h new file mode 100644 index 00000000000000..fb2465a3a95c25 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/Exception.h @@ -0,0 +1,59 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/common/enforce.h" +#include "paddle/common/errors.h" +#include "paddle/common/exception.h" +#include "paddle/common/macros.h" + +namespace c10 { +#define TORCH_CHECK(COND, ...) PD_CHECK(COND, ##__VA_ARGS__); +#define TORCH_INTERNAL_ASSERT(COND, ...) PD_CHECK(COND, ##__VA_ARGS__); +} // namespace c10 + +enum class C10ErrorType { + NotImplementedError, + Error, +}; + +constexpr auto NotImplementedError = C10ErrorType::NotImplementedError; +constexpr auto Error = C10ErrorType::Error; + +inline void C10ThrowImpl(C10ErrorType err_type, const std::string& msg) { + switch (err_type) { + case C10ErrorType::NotImplementedError: + PADDLE_THROW(common::errors::Unimplemented(msg)); + break; + case C10ErrorType::Error: + PADDLE_THROW(common::errors::InvalidArgument(msg)); + break; + default: + PADDLE_THROW(common::errors::Fatal("Unknown error type: " + msg)); + } +} + +#define C10_THROW_ERROR(err_type, msg) C10ThrowImpl(err_type, msg) diff --git a/paddle/phi/api/include/compat/c10/util/Float8_e4m3fn.h b/paddle/phi/api/include/compat/c10/util/Float8_e4m3fn.h new file mode 100644 index 00000000000000..24a81fae9ae544 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/Float8_e4m3fn.h @@ -0,0 +1,27 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/float8_e4m3fn.h" + +namespace c10 { +using Float8_e4m3fn = ::phi::dtype::float8_e4m3fn; +} // namespace c10 +namespace at { +using c10::Float8_e4m3fn; +} // namespace at +namespace torch { +using c10::Float8_e4m3fn; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/Float8_e5m2.h b/paddle/phi/api/include/compat/c10/util/Float8_e5m2.h new file mode 100644 index 00000000000000..65d830a5799048 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/Float8_e5m2.h @@ -0,0 +1,28 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/float8_e5m2.h" + +namespace c10 { +using Float8_e5m2 = ::phi::dtype::float8_e5m2; +} // namespace c10 + +namespace at { +using c10::Float8_e5m2; +} // namespace at +namespace torch { +using c10::Float8_e5m2; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/Half.h b/paddle/phi/api/include/compat/c10/util/Half.h new file mode 100644 index 00000000000000..b45433a08f748a --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/Half.h @@ -0,0 +1,29 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/float16.h" + +namespace c10 { +using Half = ::phi::dtype::float16; +} // namespace c10 + +namespace at { +using c10::Half; +} // namespace at + +namespace torch { +using c10::Half; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/Optional.h b/paddle/phi/api/include/compat/c10/util/Optional.h new file mode 100644 index 00000000000000..db8da3d282e9e6 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/Optional.h @@ -0,0 +1,26 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace c10 { +// Aliases from C++17 std::optional +using std::bad_optional_access; +using std::make_optional; +using std::nullopt; +using std::nullopt_t; +using std::optional; +} // namespace c10 diff --git a/paddle/phi/api/include/compat/c10/util/OptionalArrayRef.h b/paddle/phi/api/include/compat/c10/util/OptionalArrayRef.h new file mode 100644 index 00000000000000..8a25aa359e0ccd --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/OptionalArrayRef.h @@ -0,0 +1,234 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once +#include +#include +#include +#include + +namespace c10 { +template +class OptionalArrayRef final { + public: + // Constructors + + constexpr OptionalArrayRef() noexcept = default; + + constexpr OptionalArrayRef(std::nullopt_t) noexcept {} + + OptionalArrayRef(const OptionalArrayRef& other) = default; + + OptionalArrayRef(OptionalArrayRef&& other) noexcept = default; + + constexpr OptionalArrayRef(const std::optional>& other) noexcept + : wrapped_opt_array_ref(other) {} + + constexpr OptionalArrayRef(std::optional>&& other) noexcept + : wrapped_opt_array_ref(std::move(other)) {} + + constexpr OptionalArrayRef(const T& value) noexcept + : wrapped_opt_array_ref(value) {} + + template < + typename U = ArrayRef, + std::enable_if_t, OptionalArrayRef> && + !std::is_same_v, std::in_place_t> && + std::is_constructible_v, U&&> && + std::is_convertible_v> && + !std::is_convertible_v, + bool> = false> + constexpr OptionalArrayRef(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&>) + : wrapped_opt_array_ref(std::forward(value)) {} + + template < + typename U = ArrayRef, + std::enable_if_t, OptionalArrayRef> && + !std::is_same_v, std::in_place_t> && + std::is_constructible_v, U&&> && + !std::is_convertible_v>, + bool> = false> + constexpr explicit OptionalArrayRef(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&>) + : wrapped_opt_array_ref(std::forward(value)) {} + + template + constexpr explicit OptionalArrayRef(std::in_place_t ip, + Args&&... args) noexcept + : wrapped_opt_array_ref(ip, std::forward(args)...) {} + + template + constexpr explicit OptionalArrayRef(std::in_place_t ip, + std::initializer_list il, + Args&&... args) + : wrapped_opt_array_ref(ip, il, std::forward(args)...) {} + + constexpr OptionalArrayRef(const std::initializer_list& Vec) + : wrapped_opt_array_ref(ArrayRef(Vec)) {} + + // Destructor + + ~OptionalArrayRef() = default; + + // Assignment + + constexpr OptionalArrayRef& operator=(std::nullopt_t) noexcept { + wrapped_opt_array_ref = std::nullopt; + return *this; + } + + OptionalArrayRef& operator=(const OptionalArrayRef& other) = default; + + OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default; + + constexpr OptionalArrayRef& operator=( + const std::optional>& other) noexcept { + wrapped_opt_array_ref = other; + return *this; + } + + constexpr OptionalArrayRef& operator=( + std::optional>&& other) noexcept { + wrapped_opt_array_ref = std::move(other); + return *this; + } + + template , + typename = std::enable_if_t< + !std::is_same_v, OptionalArrayRef> && + std::is_constructible_v, U&&> && + std::is_assignable_v&, U&&>>> + constexpr OptionalArrayRef& operator=(U&& value) noexcept( + std::is_nothrow_constructible_v, U&&>&& + std::is_nothrow_assignable_v&, U&&>) { + wrapped_opt_array_ref = std::forward(value); + return *this; + } + + // Observers + + constexpr ArrayRef* operator->() noexcept { + return &wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef* operator->() const noexcept { + return &wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef& operator*() & noexcept { + return wrapped_opt_array_ref.value(); + } + + constexpr const ArrayRef& operator*() const& noexcept { + return wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef&& operator*() && noexcept { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr const ArrayRef&& operator*() const&& noexcept { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr explicit operator bool() const noexcept { + return wrapped_opt_array_ref.has_value(); + } + + constexpr bool has_value() const noexcept { + return wrapped_opt_array_ref.has_value(); + } + + constexpr ArrayRef& value() & { return wrapped_opt_array_ref.value(); } + + constexpr const ArrayRef& value() const& { + return wrapped_opt_array_ref.value(); + } + + constexpr ArrayRef&& value() && { + return std::move(wrapped_opt_array_ref.value()); + } + + constexpr const ArrayRef&& value() const&& { + return std::move(wrapped_opt_array_ref.value()); + } + + template + constexpr std::enable_if_t>, + ArrayRef> + value_or(U&& default_value) const& { + return wrapped_opt_array_ref.value_or(std::forward(default_value)); + } + + template + constexpr std::enable_if_t>, + ArrayRef> + value_or(U&& default_value) && { + return wrapped_opt_array_ref.value_or(std::forward(default_value)); + } + + // Modifiers + + constexpr void swap(OptionalArrayRef& other) noexcept { + std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref); + } + + constexpr void reset() noexcept { wrapped_opt_array_ref.reset(); } + + template + constexpr std::enable_if_t, Args&&...>, + ArrayRef&> + emplace(Args&&... args) noexcept( + std::is_nothrow_constructible_v, Args&&...>) { + return wrapped_opt_array_ref.emplace(std::forward(args)...); + } + + template + constexpr ArrayRef& emplace(std::initializer_list il, + Args&&... args) noexcept { + return wrapped_opt_array_ref.emplace(il, std::forward(args)...); + } + + private: + std::optional> wrapped_opt_array_ref; +}; + +using OptionalIntArrayRef = OptionalArrayRef; + +inline bool operator==(const OptionalIntArrayRef& a1, + const IntArrayRef& other) { + if (!a1.has_value()) { + return false; + } + return a1.value() == other; +} + +inline bool operator==(const c10::IntArrayRef& a1, + const c10::OptionalIntArrayRef& a2) { + return a2 == a1; +} + +} // namespace c10 +namespace at { +using c10::OptionalIntArrayRef; +} // namespace at + +namespace torch { +using c10::OptionalIntArrayRef; +} // namespace torch diff --git a/paddle/phi/api/include/compat/c10/util/accumulate.h b/paddle/phi/api/include/compat/c10/util/accumulate.h new file mode 100644 index 00000000000000..9e9a3bc1e78f08 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/accumulate.h @@ -0,0 +1,106 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +template , int> = 0> +inline int64_t sum_integers(const C& container) { + return std::accumulate( + container.begin(), container.end(), static_cast(0)); +} + +template ::value_type>, + int> = 0> +inline int64_t sum_integers(Iter begin, Iter end) { + return std::accumulate(begin, end, static_cast(0)); +} + +template , int> = 0> +inline int64_t multiply_integers(const C& container) { + return std::accumulate(container.begin(), + container.end(), + static_cast(1), + std::multiplies<>()); +} + +template ::value_type>, + int> = 0> +inline int64_t multiply_integers(Iter begin, Iter end) { + return std::accumulate( + begin, end, static_cast(1), std::multiplies<>()); +} + +template , int> = 0> +inline int64_t numelements_from_dim(const int k, const C& dims) { + if (k > static_cast(dims.size())) { + return 1; + } else { + auto cbegin = dims.cbegin(); + std::advance(cbegin, k); + return multiply_integers(cbegin, dims.cend()); + } +} + +template , int> = 0> +inline int64_t numelements_to_dim(const int k, const C& dims) { + TORCH_INTERNAL_ASSERT(0 <= k); + TORCH_INTERNAL_ASSERT((unsigned)k <= dims.size()); + + auto cend = dims.cbegin(); + std::advance(cend, k); + return multiply_integers(dims.cbegin(), cend); +} + +template , int> = 0> +inline int64_t numelements_between_dim(int k, int l, const C& dims) { + TORCH_INTERNAL_ASSERT(0 <= k); + TORCH_INTERNAL_ASSERT(0 <= l); + + if (k > l) { + std::swap(k, l); + } + + TORCH_INTERNAL_ASSERT((unsigned)l < dims.size()); + + auto cbegin = dims.cbegin(); + auto cend = dims.cbegin(); + std::advance(cbegin, k); + std::advance(cend, l); + return multiply_integers(cbegin, cend); +} + +} // namespace c10 diff --git a/paddle/phi/api/include/compat/c10/util/complex.h b/paddle/phi/api/include/compat/c10/util/complex.h new file mode 100644 index 00000000000000..debef7b45f958a --- /dev/null +++ b/paddle/phi/api/include/compat/c10/util/complex.h @@ -0,0 +1,29 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/complex.h" + +namespace c10 { +template +using complex = ::phi::dtype::complex; +} // namespace c10 + +namespace at { +using c10::complex; +} // namespace at +namespace torch { +using c10::complex; +} // namespace torch diff --git a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/all.h b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/all.h new file mode 100644 index 00000000000000..81092387002b28 --- /dev/null +++ b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/all.h @@ -0,0 +1,20 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include diff --git a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.cpp b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.cpp new file mode 100644 index 00000000000000..e13f017e35c88a --- /dev/null +++ b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/core/platform/device/gpu/gpu_info.h" +#include "paddle/phi/core/platform/device_event_base.h" + +namespace torch::cuda { + +c10::DeviceIndex device_count() { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + return phi::backends::gpu::GetGPUDeviceCount(); +#else + PADDLE_THROW(common::errors::Unavailable( + "Paddle is not compiled with CUDA. Cannot visit device count.")); +#endif +} + +bool is_available() { return cuda::device_count() > 0; } + +void synchronize(int64_t device_index) { + TORCH_CHECK(is_available(), "No CUDA GPUs are available"); + auto num_gpus = cuda::device_count(); + TORCH_CHECK(device_index < 0 || device_index < num_gpus, + "Device index out of range: ", + device_index); +// TODO(yongqiang) need using DeviceGuard +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + paddle::platform::SetDeviceId(device_index); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#endif +#else + PADDLE_THROW(common::errors::Unavailable( + "Paddle is not compiled with CUDA. Cannot visit device synchronize.")); +#endif +} + +} // namespace torch::cuda diff --git a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h new file mode 100644 index 00000000000000..3cf18fd4f22574 --- /dev/null +++ b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/cuda.h @@ -0,0 +1,34 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include + +namespace torch::cuda { + +c10::DeviceIndex device_count(); + +bool is_available(); + +void synchronize(int64_t device_index = -1); + +} // namespace torch::cuda +namespace at::cuda { +using torch::cuda::device_count; +using torch::cuda::is_available; +using torch::cuda::synchronize; +} // namespace at::cuda diff --git a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/sparse.h b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/sparse.h new file mode 100644 index 00000000000000..ac97da4ccaad6f --- /dev/null +++ b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/sparse.h @@ -0,0 +1,17 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include diff --git a/paddle/phi/api/include/compat/torch/csrc/api/include/torch/types.h b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/types.h new file mode 100644 index 00000000000000..36faaec0920e14 --- /dev/null +++ b/paddle/phi/api/include/compat/torch/csrc/api/include/torch/types.h @@ -0,0 +1,60 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include +#include +#include +#include +#include + +namespace torch { + +using namespace at; // NOLINT + +using std::nullopt; // NOLINT +using std::optional; // NOLINT + +using Dtype = at::ScalarType; + +constexpr auto kUInt8 = at::kByte; +constexpr auto kInt8 = at::kChar; +constexpr auto kInt16 = at::kShort; +constexpr auto kInt32 = at::kInt; +constexpr auto kInt64 = at::kLong; +constexpr auto kUInt16 = at::kUInt16; +constexpr auto kUInt32 = at::kUInt32; + +constexpr auto kFloat16 = at::kHalf; +constexpr auto kFloat32 = at::kFloat; +constexpr auto kFloat64 = at::kDouble; +constexpr auto kBFloat16 = at::kBFloat16; + +constexpr auto kU8 = kUInt8; +constexpr auto kU16 = kUInt16; +constexpr auto kU32 = kUInt32; +constexpr auto kI8 = kInt8; +constexpr auto kI16 = kInt16; +constexpr auto kI32 = kInt32; +constexpr auto kI64 = kInt64; +constexpr auto kF16 = kFloat16; +constexpr auto kF32 = kFloat32; +constexpr auto kF64 = kFloat64; + +} // namespace torch diff --git a/paddle/phi/api/include/compat/torch/library.h b/paddle/phi/api/include/compat/torch/library.h new file mode 100644 index 00000000000000..4d2982ac6f0764 --- /dev/null +++ b/paddle/phi/api/include/compat/torch/library.h @@ -0,0 +1,1282 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +class Library; +class FunctionArgs; +class FunctionResult; + +struct arg { + explicit arg(std::string name) + : name_(std::move(name)), value_(std::nullopt) {} + + arg& operator=(const IValue& rhs) { + value_ = rhs; + return *this; + } + + static IValue none() { return IValue(); } + + std::string name_; + std::optional value_; +}; + +template +struct types { + using type = types; +}; + +template +struct init_types { + using type = init_types; +}; + +template +init_types init() { + return init_types{}; +} + +class FunctionArgs { + public: + FunctionArgs() = default; + + template + FunctionArgs(Args&&... args) { // NOLINT + (add_arg(std::forward(args)), ...); + } + + static FunctionArgs from_vector(const std::vector& args_vec) { + FunctionArgs args; + args.args_ = args_vec; + return args; + } + + template + void add_arg(T&& arg) { + if constexpr (std::is_same_v, const char*> || + (std::is_array_v> && + std::is_same_v>, + char>)) { + args_.emplace_back(torch::IValue(std::string(arg))); + } else if constexpr (std::is_arithmetic_v>) { + args_.emplace_back(torch::IValue(std::forward(arg))); + } else if constexpr (std::is_same_v, std::string>) { + args_.emplace_back(torch::IValue(std::forward(arg))); + } else if constexpr (std::is_same_v, torch::IValue>) { + args_.emplace_back(std::forward(arg)); + } else { + args_.emplace_back(torch::IValue(std::forward(arg))); + } + } + + template + auto get(size_t index) const -> std:: + conditional_t, std::remove_reference_t, T> { + if (index >= args_.size()) { + throw std::out_of_range("Argument index out of range"); + } + + const torch::IValue& arg = args_[index]; + + using ReturnType = std:: + conditional_t, std::remove_reference_t, T>; + + // Handle const references by creating a temporary object + if constexpr (std::is_const_v> && + std::is_reference_v) { + using NonConstType = std::remove_const_t>; + NonConstType temp_result; + if (arg.template try_convert_to(temp_result)) { + return temp_result; + } + } else if constexpr (std::is_const_v>) { + // Handle const types by using underlying non-const type for conversion + using NonConstType = std::remove_const_t; + NonConstType temp_result; + if (arg.template try_convert_to(temp_result)) { + return static_cast(temp_result); + } + } else { + ReturnType result; + if (arg.template try_convert_to(result)) { + return result; + } + } + + std::ostringstream oss; + oss << "Cannot convert argument " << index << " from " << arg.type_string() + << " to " << typeid(T).name(); + throw std::runtime_error(oss.str()); + } + + // Convert to a tuple of specified types + template + std::tuple to_tuple() const { + if (sizeof...(Types) != args_.size()) { + throw std::runtime_error("Argument count mismatch: expected " + + std::to_string(sizeof...(Types)) + ", got " + + std::to_string(args_.size())); + } + return to_tuple_impl( + std::make_index_sequence{}); + } + + size_t size() const { return args_.size(); } + + bool empty() const { return args_.empty(); } + + const IValue& operator[](size_t index) const { return args_[index]; } + IValue& operator[](size_t index) { return args_[index]; } + + const torch::IValue& get_value(size_t index) const { + if (index >= args_.size()) { + throw std::out_of_range("Argument index out of range"); + } + return args_[index]; + } + + auto begin() const { return args_.begin(); } + auto end() const { return args_.end(); } + + std::string to_string() const { + std::ostringstream oss; + oss << "FunctionArgs["; + for (size_t i = 0; i < args_.size(); ++i) { + if (i > 0) oss << ", "; + oss << args_[i]; + } + oss << "]"; + return oss.str(); + } + + private: + template + std::tuple to_tuple_impl(std::index_sequence) const { + return std::make_tuple(get(I)...); + } + std::vector args_; +}; + +class FunctionResult { + public: + FunctionResult() : value_(torch::IValue()) {} + + template + FunctionResult(T&& value) // NOLINT + : value_(torch::IValue(std::forward(value))) {} + + FunctionResult(const torch::IValue& value) : value_(value) {} // NOLINT + FunctionResult(torch::IValue&& value) : value_(std::move(value)) {} // NOLINT + + template + T get() const { + if (value_.is_none()) { + throw std::runtime_error("No return value (void function)"); + } + + T result; + if (value_.try_convert_to(result)) { + return result; + } + + throw std::runtime_error("Cannot convert result from " + + value_.type_string() + " to " + typeid(T).name()); + } + + bool has_value() const { return !value_.is_none(); } + + const torch::IValue& get_value() const { return value_; } + + static FunctionResult void_result() { return FunctionResult(); } + + std::string to_string() const { + return "FunctionResult(" + value_.to_repr() + ")"; + } + + private: + torch::IValue value_; +}; + +template +struct function_traits; + +// Basic function type +template +struct function_traits { + using return_type = R; + static constexpr size_t arity = sizeof...(Args); + using ArgsTuple = std::tuple; + + template + struct arg { + using type = typename std::tuple_element>::type; + }; + + // Generic function call interface + template + static IValue call_function(F&& func, const FunctionArgs& args) { + if (args.size() != sizeof...(Args)) { + throw std::runtime_error( + "Function expects " + std::to_string(sizeof...(Args)) + + " arguments, got " + std::to_string(args.size())); + } + return call_function_impl(std::forward(func), + args, + std::make_index_sequence{}); + } + + private: + template + static IValue call_function_impl(F&& func, + const FunctionArgs& args, + std::index_sequence) { + auto args_without_ref = + std::make_tuple(args.template get>(I)...); + if constexpr (std::is_void_v) { + func(std::get(args_without_ref)...); + return IValue(); + } else { + auto result = func(std::get(args_without_ref)...); + return IValue(result); + } + } +}; + +// Function pointer specialization +template +struct function_traits : public function_traits {}; + +// Reference to function type specialization +template +struct function_traits : public function_traits {}; + +// Const function type specialization +template +struct function_traits : public function_traits { +}; + +// Const function pointer specialization +template +struct function_traits + : public function_traits {}; + +// Common Reference and Pointer types +template +struct function_traits + : public function_traits> {}; + +template +struct function_traits : public function_traits {}; + +// Member function pointer specialization +template +struct function_traits + : public function_traits { + using class_type = C; + + static IValue call_method(R (C::*func)(Args...), + C* instance, + const FunctionArgs& args) { + if (args.size() != sizeof...(Args) + 1) { // +1 for this pointer + throw std::runtime_error( + "Method expects " + std::to_string(sizeof...(Args)) + + " arguments (plus this), got " + std::to_string(args.size() - 1)); + } + return call_method_impl( + func, instance, args, std::make_index_sequence{}); + } + + private: + template + static IValue call_method_impl(R (C::*func)(Args...), + C* instance, + const FunctionArgs& args, + std::index_sequence) { + // Skip args[0] which is 'this' + auto args_without_ref = std::make_tuple( + args.template get>(I + 1)...); + if constexpr (std::is_void_v) { + (instance->*func)(std::get(args_without_ref)...); + return IValue(); + } else { + auto result = (instance->*func)(std::get(args_without_ref)...); + return IValue(result); + } + } +}; + +// Const member function pointer specialization +template +struct function_traits + : public function_traits { + using class_type = C; + + static IValue call_method(R (C::*func)(Args...) const, + C* instance, + const FunctionArgs& args) { + if (args.size() != sizeof...(Args) + 1) { // +1 for this pointer + throw std::runtime_error( + "Method expects " + std::to_string(sizeof...(Args)) + + " arguments (plus this), got " + std::to_string(args.size() - 1)); + } + return call_method_impl( + func, instance, args, std::make_index_sequence{}); + } + + private: + template + static IValue call_method_impl(R (C::*func)(Args...) const, + C* instance, + const FunctionArgs& args, + std::index_sequence) { + if constexpr (std::is_void_v) { + (instance->*func)( + args.get(I + 1)...); // Skip args[0] which is 'this' + return IValue(); + } else { + auto result = (instance->*func)(args.get(I + 1)...); + return IValue(result); + } + } +}; + +template +IValue invoke_function(Func&& func, const FunctionArgs& args) { + using traits = + function_traits>>; + return traits::call_function(std::forward(func), args); +} + +template +IValue invoke_member_function(Func&& func, + Class* instance, + const FunctionArgs& args) { + using traits = + function_traits>>; + return traits::call_method(func, instance, args); +} + +class CppFunction { + public: + using CallableFunction = std::function; + + CppFunction() : func_(nullptr) {} + + // Constructor for lambda or function object + explicit CppFunction(std::function func) + : func_([func](const FunctionArgs& args) -> FunctionResult { + try { + auto result = func(args); + return FunctionResult(result); + } catch (const std::exception& e) { + throw std::runtime_error("Constructor failed: " + + std::string(e.what())); + } + }) {} + + // Common function pointer or member function pointer constructor + template + explicit CppFunction( + Func&& f, + typename std::enable_if_t< + std::is_function_v>> || + (std::is_pointer_v> && + std::is_function_v>>)>* = + nullptr) + : func_([f = std::forward(f)]( + const FunctionArgs& args) -> FunctionResult { + try { + auto result = invoke_function(f, args); + return FunctionResult(result); + } catch (const std::exception& e) { + throw std::runtime_error("Function call failed: " + + std::string(e.what())); + } + }) {} + + // Common member function pointer constructor + template + explicit CppFunction( + Func&& f, + typename std::enable_if_t< + !std::is_function_v>> && + !std::is_pointer_v> && + std::is_invocable_v>* = nullptr) + : func_([f = std::forward(f)]( + const FunctionArgs& args) -> FunctionResult { + try { + auto result = f(args); + return FunctionResult(result); + } catch (const std::exception& e) { + throw std::runtime_error("Lambda execution failed: " + + std::string(e.what())); + } + }) {} + + CppFunction(CppFunction&& other) noexcept : func_(std::move(other.func_)) {} + + CppFunction& operator=(CppFunction&& other) noexcept { + if (this != &other) { + func_ = std::move(other.func_); + } + return *this; + } + + CppFunction(const CppFunction&) = delete; + CppFunction& operator=(const CppFunction&) = delete; + + FunctionResult call() const { + if (!func_) { + throw std::runtime_error("CppFunction is not initialized"); + } + return func_(FunctionArgs{}); + } + + template + FunctionResult call(Args&&... args) const { + if (!func_) { + throw std::runtime_error("CppFunction is not initialized"); + } + return func_(FunctionArgs{std::forward(args)...}); + } + + FunctionResult call_with_args(const FunctionArgs& args) const { + if (!func_) { + throw std::runtime_error("CppFunction is not initialized"); + } + return func_(args); + } + + bool valid() const { return func_ != nullptr; } + + private: + CallableFunction func_; +}; + +struct ClassRegistration { + std::string namespace_name; + std::string class_name; + std::string qualified_name; + std::vector> constructors; + std::unordered_map> methods; + std::unordered_map> static_methods; + + ClassRegistration() = default; + ClassRegistration(const std::string& ns, const std::string& name) + : namespace_name(ns), + class_name(name), + qualified_name(ns + "::" + name) {} +}; + +// Global class registry +class ClassRegistry { + public: + static ClassRegistry& instance() { + static ClassRegistry registry; + return registry; + } + + void register_class(const std::string& namespace_name, + const std::string& class_name) { + std::string qualified_name = namespace_name + "::" + class_name; + classes_[qualified_name] = + std::make_unique(namespace_name, class_name); + // TODO(SigureMo): Use vlog for debug logging + // std::cout << "Registered class: " << qualified_name << std::endl; + } + + void register_constructor(const std::string& qualified_name, + CppFunction&& func) { + auto it = classes_.find(qualified_name); + if (it == classes_.end()) { + throw std::runtime_error("Class " + qualified_name + " not found"); + } + it->second->constructors.push_back( + std::make_shared(std::move(func))); + // std::cout << "Registered constructor for: " << qualified_name + // << " (total: " << it->second->constructors.size() << ")" + // << std::endl; + } + + void register_method(const std::string& qualified_name, + const std::string& method_name, + CppFunction&& func) { + auto it = classes_.find(qualified_name); + if (it == classes_.end()) { + throw std::runtime_error("Class " + qualified_name + " not found"); + } + it->second->methods[method_name] = + std::make_shared(std::move(func)); + // std::cout << "Registered method: " << qualified_name << "::" << + // method_name + // << std::endl; + } + + void register_static_method(const std::string& qualified_name, + const std::string& method_name, + CppFunction&& func) { + auto it = classes_.find(qualified_name); + if (it == classes_.end()) { + throw std::runtime_error("Class " + qualified_name + " not found"); + } + it->second->static_methods[method_name] = + std::make_shared(std::move(func)); + // std::cout << "Registered static method: " << qualified_name + // << "::" << method_name << std::endl; + } + + bool has_class(const std::string& qualified_name) const { + return classes_.find(qualified_name) != classes_.end(); + } + + bool has_method(const std::string& qualified_name, + const std::string& method_name) const { + auto it = classes_.find(qualified_name); + if (it == classes_.end()) return false; + return it->second->methods.find(method_name) != it->second->methods.end(); + } + + bool has_static_method(const std::string& qualified_name, + const std::string& method_name) const { + auto it = classes_.find(qualified_name); + if (it == classes_.end()) return false; + return it->second->static_methods.find(method_name) != + it->second->static_methods.end(); + } + + FunctionResult call_method_with_args(const std::string& qualified_name, + const std::string& method_name, + const FunctionArgs& args) { + auto it = classes_.find(qualified_name); + if (it == classes_.end()) { + throw std::runtime_error("Class " + qualified_name + " not found!"); + } + + auto& class_reg = it->second; + auto method_it = class_reg->methods.find(method_name); + if (method_it == class_reg->methods.end()) { + throw std::runtime_error("Method " + method_name + " not found in " + + qualified_name + "!"); + } + + try { + // std::cout << "Executing " << qualified_name << "::" << method_name + // << " (instance) with " << args.size() << " args" << + // std::endl; + auto result = method_it->second->call_with_args(args); + + if (result.has_value()) { + // std::cout << "Instance method executed successfully with return + // value" + // << std::endl; + } else { + // std::cout << "Instance method executed successfully (void)" + // << std::endl; + } + return result; + } catch (const std::exception& e) { + // std::cout << "Instance method execution failed: " << e.what() + // << std::endl; + throw; + } + } + + FunctionResult call_constructor_with_args(const std::string& qualified_name, + const FunctionArgs& args) const { + auto it = classes_.find(qualified_name); + if (it == classes_.end()) { + throw std::runtime_error("Class " + qualified_name + " not found!"); + } + + auto& class_reg = it->second; + if (class_reg->constructors.empty()) { + throw std::runtime_error("No constructor registered for " + + qualified_name); + } + + // std::cout << "Creating instance of " << qualified_name << " with " + // << args.size() << " args" << std::endl; + // std::cout << "Available constructors: " << class_reg->constructors.size() + // << std::endl; + + for (size_t i = 0; i < class_reg->constructors.size(); ++i) { + try { + // std::cout << "Trying constructor " << (i + 1) << "..." << std::endl; + auto result = class_reg->constructors[i]->call_with_args(args); + // std::cout << "Constructor " << (i + 1) << " executed successfully" + // << std::endl; + return result; + } catch (const std::exception& e) { + // std::cout << "Constructor " << (i + 1) << " failed: " << e.what() + // << std::endl; + } + } + + throw std::runtime_error("No suitable constructor found for " + + qualified_name); + } + + FunctionResult call_static_method_with_args(const std::string& qualified_name, + const std::string& method_name, + const FunctionArgs& args) const { + auto it = classes_.find(qualified_name); + if (it == classes_.end()) { + throw std::runtime_error("Class " + qualified_name + " not found!"); + } + + auto& class_reg = it->second; + auto method_it = class_reg->static_methods.find(method_name); + if (method_it == class_reg->static_methods.end()) { + throw std::runtime_error("Static method " + method_name + + " not found in " + qualified_name + "!"); + } + + try { + // std::cout << "Executing " << qualified_name << "::" << method_name + // << " (static) with " << args.size() << " args" << std::endl; + auto result = method_it->second->call_with_args(args); + + if (result.has_value()) { + // std::cout << "Static method executed successfully with return value" + // << std::endl; + } else { + // std::cout << "Static method executed successfully (void return)" + // << std::endl; + } + return result; + } catch (const std::exception& e) { + // std::cout << "Error executing static method: " << e.what() << + // std::endl; + throw; + } + } + + FunctionResult call_method_with_args(const std::string& qualified_name, + const std::string& method_name, + const IValue& instance, + const FunctionArgs& args) const { + auto it = classes_.find(qualified_name); + if (it == classes_.end()) { + throw std::runtime_error("Class " + qualified_name + " not found!"); + } + + auto& class_reg = it->second; + auto method_it = class_reg->methods.find(method_name); + if (method_it == class_reg->methods.end()) { + throw std::runtime_error("Instance method " + method_name + + " not found in " + qualified_name + "!"); + } + + try { + // std::cout << "Executing " << qualified_name << "::" << method_name + // << " (instance) with " << args.size() << " args" << + // std::endl; + + // Create a FunctionArgs object with the instance as the first argument + FunctionArgs method_args; + method_args.add_arg(instance); // Add the instance as the first arg + for (size_t i = 0; i < args.size(); ++i) { + method_args.add_arg(args.get_value(i)); + } + + auto result = method_it->second->call_with_args(method_args); + + if (result.has_value()) { + // std::cout << "Instance method executed successfully with return + // value" + // << std::endl; + } else { + // std::cout << "Instance method executed successfully (void return)" + // << std::endl; + } + return result; + } catch (const std::exception& e) { + // std::cout << "Error executing instance method: " << e.what() << + // std::endl; + throw; + } + } + + void print_all_classes() const { + std::cout << "\n=== Registered Classes ===" << std::endl; + for (const auto& [qualified_name, registration] : classes_) { + std::cout << "Class: " << qualified_name << std::endl; + + if (!registration->constructors.empty()) { + std::cout << " Constructors: " << registration->constructors.size() + << " available" << std::endl; + } + + if (!registration->methods.empty()) { + std::cout << " Methods: "; + for (const auto& [method_name, _] : registration->methods) { + std::cout << method_name << " "; + } + std::cout << std::endl; + } + + if (!registration->static_methods.empty()) { + std::cout << " Static Methods: "; + for (const auto& [method_name, _] : registration->static_methods) { + std::cout << method_name << " "; + } + std::cout << std::endl; + } + } + std::cout << "==========================" << std::endl << std::endl; + } + + private: + std::unordered_map> classes_; +}; + +// Class registration API +template +class class_ { + static_assert( + std::is_base_of_v, + "torch::class_ requires T to inherit from CustomClassHolder"); + + public: + class_(const std::string& namespaceName, const std::string& className) + : namespace_name_(namespaceName), + class_name_(className), + qualified_name_(namespaceName + "::" + className) { + ClassRegistry::instance().register_class(namespaceName, className); + } + + // Register constructor + template + class_& def(torch::init_types) { + // std::cout << "def() called with " << sizeof...(Types) + // << " template parameters" << std::endl; + + // Create a lambda for the constructor + auto constructor_func = [](const FunctionArgs& args) -> torch::IValue { + // std::cout << "Constructor lambda called with " << args.size() + // << " arguments" << std::endl; + // std::cout << "Expected parameter count: " << sizeof...(Types) + // << std::endl; + + if constexpr (sizeof...(Types) == 0) { + // Default constructor + if (args.size() != 0) { + throw std::runtime_error( + "Default constructor expects 0 arguments, got " + + std::to_string(args.size())); + } + auto instance = torch::make_intrusive(); + return torch::IValue(instance); + } else { + // Parameterized constructor + if (args.size() != sizeof...(Types)) { + throw std::runtime_error( + "Constructor argument count mismatch: expected " + + std::to_string(sizeof...(Types)) + ", got " + + std::to_string(args.size())); + } + // Use std::apply to unpack the arguments + auto tuple_args = args.to_tuple(); + auto instance = std::apply( + [](Types... args) { + return torch::make_intrusive( + std::forward(args)...); + }, + tuple_args); + return torch::IValue(instance); + } + }; + + ClassRegistry::instance().register_constructor( + qualified_name_, CppFunction(constructor_func)); + return *this; + } + + // Register instance method + template + class_& def(const std::string& name, Func&& f) { + // Check if Func is a member function pointer + if constexpr (std::is_member_function_pointer_v>) { + // Use function_traits to extract class type and method signature + auto method_func = [f](const FunctionArgs& args) -> torch::IValue { + if (args.size() < 1) { + throw std::runtime_error( + "Instance method requires at least 1 argument (this pointer)"); + } + + // Get the instance (first argument) + auto instance = args.get>(0); + + // Invoke the member function + return invoke_member_function(f, instance.get(), args); + }; + + ClassRegistry::instance().register_method( + qualified_name_, name, CppFunction(method_func)); + // std::cout << "Instance method " << name << " registered successfully" + // << std::endl; + } else { + // Handle generic callable (e.g., lambda, std::function) + // std::cout << "Method registration for " << name + // << " (generic callable not yet implemented)" << std::endl; + } + + return *this; + } + + // Register static method + template + class_& def_static(const std::string& name, Func&& f) { + ClassRegistry::instance().register_static_method( + qualified_name_, name, CppFunction(std::forward(f))); + return *this; + } + + private: + std::string namespace_name_; + std::string class_name_; + std::string qualified_name_; +}; + +enum class DispatchKey { + Undefined = 0, + CPU, + CUDA, +}; + +inline std::string dispatch_key_to_string(DispatchKey key) { + switch (key) { + case DispatchKey::CPU: + return "CPU"; + case DispatchKey::CUDA: + return "CUDA"; + default: + return "Undefined"; + } +} + +// Operator Registration +struct OperatorRegistration { + std::string qualified_name; // namespace::op_name + std::string schema; + std::unordered_map implementations; + + OperatorRegistration(const std::string& name, + const std::string& schema_str = "") + : qualified_name(name), schema(schema_str) {} +}; + +class OperatorRegistry { + public: + static OperatorRegistry& instance() { + static OperatorRegistry registry; + return registry; + } + + void register_schema(const std::string& qualified_name, + const std::string& schema) { + auto& op = get_or_create_operator(qualified_name); + op.schema = schema; + // std::cout << "Registered schema: " << qualified_name << " -> " << schema + // << std::endl; + } + + void register_implementation(const std::string& qualified_name, + DispatchKey key, + CppFunction&& func) { + auto& op = get_or_create_operator(qualified_name); + op.implementations[key] = std::move(func); + // std::cout << "Registered implementation: " << qualified_name << " for " + // << dispatch_key_to_string(key) << std::endl; + } + + OperatorRegistration* find_operator(const std::string& qualified_name) { + auto it = operators_.find(qualified_name); + return (it != operators_.end()) ? &it->second : nullptr; + } + + std::vector list_all_operators() const { + std::vector ops; + for (const auto& pair : operators_) { + ops.push_back(pair.first); + } + return ops; + } + + bool execute_operator(const std::string& qualified_name, + DispatchKey key = DispatchKey::CPU) { + auto* op = find_operator(qualified_name); + if (!op) { + // std::cout << "Error: Operator " << qualified_name << " not found!" + // << std::endl; + return false; + } + + auto impl_it = op->implementations.find(key); + if (impl_it != op->implementations.end()) { + try { + // std::cout << "Executing " << qualified_name << " with " + // << dispatch_key_to_string(key) << std::endl; + auto result = impl_it->second.call(); + if (result.has_value()) { + // std::cout << "Operator executed successfully with return value" + // << std::endl; + } else { + // std::cout << "Operator executed successfully (void return)" + // << std::endl; + } + return true; + } catch (const std::exception& e) { + // std::cout << "Error executing operator: " << e.what() << std::endl; + return false; + } + } + + // try fallback to CPU + if (key != DispatchKey::CPU) { + auto cpu_it = op->implementations.find(DispatchKey::CPU); + if (cpu_it != op->implementations.end()) { + // std::cout << "Fallback to CPU for " << qualified_name << std::endl; + try { + auto result = cpu_it->second.call(); + if (result.has_value()) { + // std::cout << "Operator executed successfully with return value " + // "(CPU fallback)" + // << std::endl; + } else { + // std::cout + // << "Operator executed successfully (void return, CPU + // fallback)" + // << std::endl; + } + return true; + } catch (const std::exception& e) { + // std::cout << "Error executing operator (CPU fallback): " << + // e.what() + // << std::endl; + return false; + } + } + } + + // std::cout << "Error: No implementation found for " << qualified_name + // << " with " << dispatch_key_to_string(key) << std::endl; + return false; + } + + template + FunctionResult execute_operator_with_args(const std::string& qualified_name, + DispatchKey key, + Args&&... args) { + auto* op = find_operator(qualified_name); + if (!op) { + throw std::runtime_error("Operator " + qualified_name + " not found!"); + } + + auto impl_it = op->implementations.find(key); + if (impl_it != op->implementations.end()) { + try { + // std::cout << "Executing " << qualified_name << " with " + // << dispatch_key_to_string(key) << std::endl; + auto result = impl_it->second.call(std::forward(args)...); + if (result.has_value()) { + // std::cout << "Operator executed successfully with return value" + // << std::endl; + } else { + // std::cout << "Operator executed successfully (void return)" + // << std::endl; + } + return result; + } catch (const std::exception& e) { + throw std::runtime_error("Error executing operator: " + + std::string(e.what())); + } + } + + // try fallback to CPU + if (key != DispatchKey::CPU) { + auto cpu_it = op->implementations.find(DispatchKey::CPU); + if (cpu_it != op->implementations.end()) { + // std::cout << "Fallback to CPU for " << qualified_name << std::endl; + try { + auto result = cpu_it->second.call(std::forward(args)...); + if (result.has_value()) { + // std::cout << "Operator executed successfully with return value " + // "(CPU fallback)" + // << std::endl; + } else { + // std::cout + // << "Operator executed successfully (void return, CPU + // fallback)" + // << std::endl; + } + return result; + } catch (const std::exception& e) { + throw std::runtime_error("Error executing operator (CPU fallback): " + + std::string(e.what())); + } + } + } + + throw std::runtime_error("No implementation found for " + qualified_name + + " with " + dispatch_key_to_string(key)); + } + + const std::unordered_map& get_operators() + const { + return operators_; + } + + void print_all_operators() const { + std::cout << "\n=== Registered Operators ===" << std::endl; + for (const auto& [name, op] : operators_) { + std::cout << "Operator: " << name << std::endl; + if (!op.schema.empty()) { + std::cout << " Schema: " << op.schema << std::endl; + } + std::cout << " Implementations: "; + for (const auto& [key, impl] : op.implementations) { + std::cout << dispatch_key_to_string(key) << " "; + } + std::cout << std::endl; + } + std::cout << "=========================" << std::endl; + } + + private: + std::unordered_map operators_; + + OperatorRegistration& get_or_create_operator( + const std::string& qualified_name) { + auto it = operators_.find(qualified_name); + if (it == operators_.end()) { + auto [new_it, inserted] = operators_.emplace( + qualified_name, OperatorRegistration(qualified_name)); + return new_it->second; + } + return it->second; + } +}; + +class Library { + public: + enum Kind { + DEF, // TORCH_LIBRARY + IMPL, // TORCH_LIBRARY_IMPL + FRAGMENT // TORCH_LIBRARY_FRAGMENT + }; + + Library(Kind kind, + const std::string& ns, + std::optional dispatch_key = std::nullopt, + const char* file = nullptr, + uint32_t line = 0) + : kind_(kind), + ns_(ns), + dispatch_key_(dispatch_key), + file_(file), + line_(line) { + // std::cout << "Created Library: kind=" << kind_to_string(kind) + // << ", namespace=" << ns; + if (dispatch_key) { + // std::cout << ", dispatch_key=" << + // dispatch_key_to_string(*dispatch_key); + } + // std::cout << std::endl; + } + + Library(const std::string& ns) // NOLINT + : kind_(DEF), ns_(ns), file_(nullptr), line_(0) { + // std::cout << "Created Library: namespace=" << ns << std::endl; + } + + // Define an operator schema (for TORCH_LIBRARY and TORCH_LIBRARY_FRAGMENT) + Library& def(const std::string& schema) & { + if (kind_ == IMPL) { + // std::cout + // << "Warning: def() should not be called in TORCH_LIBRARY_IMPL + // block" + // << std::endl; + return *this; + } + + // Simple schema extraction: if it contains '(', extract the part before '(' + auto op_name = extract_op_name(schema); + auto qualified_name = ns_ + "::" + op_name; + + OperatorRegistry::instance().register_schema(qualified_name, schema); + return *this; + } + + // Define an operator implementation + template + Library& def(const std::string& name_or_schema, Func&& f) & { + auto op_name = extract_op_name(name_or_schema); + auto qualified_name = ns_ + "::" + op_name; + + // If name_or_schema contains '(', treat it as a schema + if (name_or_schema.find('(') != std::string::npos) { + OperatorRegistry::instance().register_schema(qualified_name, + name_or_schema); + } + + // Register implementation + auto dispatch_key = dispatch_key_.value_or(DispatchKey::CPU); + OperatorRegistry::instance().register_implementation( + qualified_name, dispatch_key, CppFunction(std::forward(f))); + + return *this; + } + + // Implementation of an operator + template + Library& impl(const std::string& op_name, Func&& f) & { + auto qualified_name = ns_ + "::" + op_name; + auto dispatch_key = dispatch_key_.value_or(DispatchKey::CPU); + + OperatorRegistry::instance().register_implementation( + qualified_name, dispatch_key, CppFunction(std::forward(f))); + + return *this; + } + + template + ::torch::class_ class_(const std::string& className) { + return ::torch::class_(ns_, className); + } + + // Print current library info + void print_info() const { + // std::cout << "Library Info: " << kind_to_string(kind_) + // << ", namespace=" << ns_; + if (dispatch_key_) { + // std::cout << ", dispatch_key=" << + // dispatch_key_to_string(*dispatch_key_); + } + // std::cout << std::endl; + } + + private: + Kind kind_; + std::string ns_; + std::optional dispatch_key_; + const char* file_; + uint32_t line_; + + std::string extract_op_name(const std::string& name_or_schema) const { + // Extract the operator name from the schema string + auto pos = name_or_schema.find('('); + if (pos != std::string::npos) { + return name_or_schema.substr(0, pos); + } + return name_or_schema; + } + + std::string kind_to_string(Kind kind) const { + switch (kind) { + case DEF: + return "DEF"; + case IMPL: + return "IMPL"; + case FRAGMENT: + return "FRAGMENT"; + default: + return "UNKNOWN"; + } + } +}; + +namespace detail { + +class TorchLibraryInit { + public: + using InitFn = void(Library&); + + TorchLibraryInit(Library::Kind kind, + InitFn* fn, + const char* ns, + std::optional dispatch_key, + const char* file, + uint32_t line) { + Library lib(kind, ns, dispatch_key, file, line); + fn(lib); + } +}; + +} // namespace detail + +// TORCH_LIBRARY +#define TORCH_LIBRARY(ns, m) \ + static void TORCH_LIBRARY_init_##ns(torch::Library&); \ + static const torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_##ns( \ + torch::Library::DEF, \ + &TORCH_LIBRARY_init_##ns, \ + #ns, \ + std::nullopt, \ + __FILE__, \ + __LINE__); \ + void TORCH_LIBRARY_init_##ns(torch::Library& m) // NOLINT + +// TORCH_LIBRARY_FRAGMENT +#define TORCH_LIBRARY_FRAGMENT(ns, m) _TORCH_LIBRARY_FRAGMENT(ns, m, C10_UID) +#define _TORCH_LIBRARY_FRAGMENT(ns, m, uid) \ + static void C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, \ + uid)(torch::Library&); \ + static const torch::detail::TorchLibraryInit C10_CONCATENATE( \ + TORCH_LIBRARY_FRAGMENT_static_init_##ns##_, uid)( \ + torch::Library::FRAGMENT, \ + &C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, uid), \ + #ns, \ + std::nullopt, \ + __FILE__, \ + __LINE__); \ + void C10_CONCATENATE(TORCH_LIBRARY_FRAGMENT_init_##ns##_, \ + uid)(torch::Library & m) // NOLINT + +// TORCH_LIBRARY_IMPL +#define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID) +#define _TORCH_LIBRARY_IMPL(ns, k, m, uid) \ + static void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, \ + uid)(torch::Library&); \ + static const torch::detail::TorchLibraryInit C10_CONCATENATE( \ + TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)( \ + torch::Library::IMPL, \ + &C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid), \ + #ns, \ + std::make_optional(torch::DispatchKey::k), \ + __FILE__, \ + __LINE__); \ + void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, \ + uid)(torch::Library & m) // NOLINT + +} // namespace torch diff --git a/paddle/phi/api/include/compat/utils/int_array_ref_conversion.h b/paddle/phi/api/include/compat/utils/int_array_ref_conversion.h new file mode 100644 index 00000000000000..83afd90fb1b615 --- /dev/null +++ b/paddle/phi/api/include/compat/utils/int_array_ref_conversion.h @@ -0,0 +1,24 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/ddim.h" + +namespace compat { +inline c10::IntArrayRef _PD_PhiDDimToIntArrayRef(const phi::DDim& ddim) { + return c10::IntArrayRef(ddim.Get(), ddim.size()); +} +} // namespace compat diff --git a/paddle/phi/api/include/compat/utils/macros.h b/paddle/phi/api/include/compat/utils/macros.h new file mode 100644 index 00000000000000..c88949220e142f --- /dev/null +++ b/paddle/phi/api/include/compat/utils/macros.h @@ -0,0 +1,25 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace compat { +#ifndef TORCH_EXTENSION_NAME +#define _EXPAND(x) x +#define TORCH_EXTENSION_NAME _EXPAND(PADDLE_EXTENSION_NAME) +#undef _EXPAND +#endif +#define UNSUPPORTED_FEATURE_IN_PADDLE(feature) \ + std::cerr << "Unsupported feature in Paddle: " << feature << std::endl; +} // namespace compat diff --git a/paddle/phi/api/include/compat/utils/scalar_type_conversion.h b/paddle/phi/api/include/compat/utils/scalar_type_conversion.h new file mode 100644 index 00000000000000..09a55b28686443 --- /dev/null +++ b/paddle/phi/api/include/compat/utils/scalar_type_conversion.h @@ -0,0 +1,52 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/phi/common/data_type.h" + +namespace compat { +inline phi::DataType _PD_AtenScalarTypeToPhiDataType(c10::ScalarType dtype) { + switch (dtype) { +#define DEFINE_ST_TO_DT_CASE_(_1, _dt, _st) \ + case c10::ScalarType::_st: \ + return phi::DataType::_dt; + FOREACH_PADDLE_AND_TORCH_DTYPES(DEFINE_ST_TO_DT_CASE_) +#undef DEFINE_ST_TO_DT_CASE_ + case c10::ScalarType::Undefined: + return phi::DataType::UNDEFINED; + default: + UNSUPPORTED_FEATURE_IN_PADDLE("Unsupported ScalarType") + return phi::DataType::UNDEFINED; // to avoid compile warning + } +} + +inline c10::ScalarType _PD_PhiDataTypeToAtenScalarType(phi::DataType dtype) { + switch (dtype) { +#define DEFINE_DT_TO_ST_CASE_(_1, _dt, _st) \ + case phi::DataType::_dt: \ + return c10::ScalarType::_st; + FOREACH_PADDLE_AND_TORCH_DTYPES(DEFINE_DT_TO_ST_CASE_) +#undef DEFINE_DT_TO_ST_CASE_ + case phi::DataType::UNDEFINED: + return c10::ScalarType::Undefined; + default: + UNSUPPORTED_FEATURE_IN_PADDLE("Unsupported DataType") + return c10::ScalarType::Undefined; // to avoid compile warning + } +} + +} // namespace compat diff --git a/paddle/phi/common/place.h b/paddle/phi/common/place.h index 7de1b33b90b4b3..117b46b5aa6c73 100644 --- a/paddle/phi/common/place.h +++ b/paddle/phi/common/place.h @@ -32,6 +32,7 @@ enum class AllocationType : int8_t { UNDEFINED = 0, CPU = 1, GPU = 2, + CUDA = GPU, GPUPINNED = 3, XPU = 4, XPUPINNED = 5, diff --git a/paddle/utils/pybind.h b/paddle/utils/pybind.h index 07ad8462f968ac..16318d84464de2 100644 --- a/paddle/utils/pybind.h +++ b/paddle/utils/pybind.h @@ -14,6 +14,9 @@ #pragma once +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include +#endif #include "paddle/phi/api/include/tensor.h" #ifdef PADDLE_WITH_DISTRIBUTE #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" @@ -128,6 +131,40 @@ struct optional_caster> { const_name("Optional[paddle::Tensor]")); }; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor")); + + bool load(handle src, bool) { + paddle::pybind::EnableTensorOperantsToPhiMode(); + PyObject* obj = src.ptr(); + if (paddle::pybind::PyCheckTensor(obj)) { + value = paddle::pybind::CastPyArg2Tensor(obj, 0); + return true; + } + return false; + } + + static handle cast(const at::Tensor& src, + return_value_policy /* policy */, + handle /* parent */) { + const auto& src_pd_tensor = src._PD_GetInner(); + +#ifdef PADDLE_WITH_DISTRIBUTE + bool return_none = + phi::distributed::DistTensor::classof(src_pd_tensor.impl().get()) + ? false + : true; +#else + bool return_none = true; +#endif + return handle(paddle::pybind::ToPyObject( + src_pd_tensor, return_none /* return_py_none_if_not_initialize */)); + } +}; +#endif // Pybind11 bindings for optional types. // http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers template diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index e14bf8dc3c58de..6467a6880c43ce 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# Compatibility Note: The design of certain PaddlePaddle public APIs -# incorporates principles from PyTorch and NumPy, maintaining compatibility -# with PyTorch's API conventions in terms of function signatures and -# parameter semantics. It is important to clarify that these APIs are +# Compatibility Note: The design of certain PaddlePaddle public APIs +# incorporates principles from PyTorch and NumPy, maintaining compatibility +# with PyTorch's API conventions in terms of function signatures and +# parameter semantics. It is important to clarify that these APIs are # implemented as independent modules with no runtime dependency on PyTorch. import math @@ -200,6 +200,8 @@ def new_init(self, *args, **kwargs): tensor as tensor, utils as utils, ) +from ._classes import classes as classes +from ._ops import ops as ops from .amp import ( get_autocast_cpu_dtype, get_autocast_dtype, diff --git a/python/paddle/_classes.py b/python/paddle/_classes.py new file mode 100644 index 00000000000000..6d7bd5d9db13e9 --- /dev/null +++ b/python/paddle/_classes.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import types +from typing import Any + +import paddle + +from ._ops import import_module, load_library + +PADDLE_CLASSES_MODULE_NAME = "paddle.classes" + + +class ClassesNameSpace(types.ModuleType): + def __init__(self, name: str): + super().__init__(f"{PADDLE_CLASSES_MODULE_NAME}.{name}") + self.name = name + + def __getattr__(self, name: str) -> Any: + if name == "__file__": + return PADDLE_CLASSES_MODULE_NAME # type: ignore + return paddle.base.core.torch_compat._get_custom_class_python_wrapper( + self.name, name + ) + + +class PaddleClassesModule(types.ModuleType): + __file__ = "_classes.py" + + def __init__(self): + super().__init__(PADDLE_CLASSES_MODULE_NAME) + + def __getattr__(self, name: str): + namespace = ClassesNameSpace(name) + # Insert to __dict__ to avoid repeatedly __getattr__ overhead + setattr(self, name, namespace) + return namespace + + def import_module(self, module): + return import_module(module) + + def load_library(self, path): + return load_library(path) + + +classes = PaddleClassesModule() diff --git a/python/paddle/_ops.py b/python/paddle/_ops.py new file mode 100644 index 00000000000000..5e31689d0dd8f3 --- /dev/null +++ b/python/paddle/_ops.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import ctypes +import importlib +import os +import sys +import types +from functools import cached_property +from typing import Any, Callable, Generic, TypeVar + +from typing_extensions import ParamSpec + +import paddle + +_InputT = ParamSpec("_InputT") +_RetT = TypeVar("_RetT") + +PADDLE_OPS_MODULE_NAME = "paddle.ops" + +# Query `hasattr` only once. +_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr( + sys, "setdlopenflags" +) + + +@contextlib.contextmanager +def dl_open_guard(): + """ + Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a + shared library to load custom operators. + """ + if not _SET_GLOBAL_FLAGS: + yield + return + old_flags = sys.getdlopenflags() + sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) + try: + yield + finally: + sys.setdlopenflags(old_flags) + + +def import_module(module: str): + return importlib.import_module(module) + + +def load_library(path: str): + """ + Load a shared library at the specified path. + """ + path = os.path.realpath(path) + with dl_open_guard(): + ctypes.CDLL(path) + + +class OverloadedOpFunction(Generic[_InputT, _RetT]): + def __init__(self, namespace: str, name: str): + self.namespace = namespace + self.name = name + + @cached_property + def callable_fn(self) -> Callable[_InputT, _RetT]: + return paddle.base.core.torch_compat._get_operation( + f"{self.namespace}::{self.name}" + ) + + def __getattr__(self, name: str) -> Callable[_InputT, _RetT]: + if name == "default": + return self.callable_fn + raise AttributeError( + f"'{self.namespace}.{self.name}' has no attribute '{name}'" + ) + + def __call__(self, *args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT: + return self.callable_fn(*args, **kwargs) + + +class OpNameSpace(types.ModuleType): + def __init__(self, name): + super().__init__(f"{PADDLE_OPS_MODULE_NAME}.{name}") + self.name = name + + def __getattr__(self, name: str) -> OverloadedOpFunction[..., Any]: + if name == "__file__": + return PADDLE_OPS_MODULE_NAME # type: ignore + return OverloadedOpFunction(self.name, name) + + +class PaddleOpsModule(types.ModuleType): + __file__ = "_ops.py" + + def __init__(self): + super().__init__(PADDLE_OPS_MODULE_NAME) + + def __getattr__(self, name: str): + namespace = OpNameSpace(name) + # Insert to __dict__ to avoid repeatedly __getattr__ overhead + setattr(self, name, namespace) + return namespace + + def import_module(self, module): + return import_module(module) + + def load_library(self, path): + return load_library(path) + + +ops = PaddleOpsModule() diff --git a/python/paddle/compat.py b/python/paddle/compat.py index 4b981a4f45cd0b..389f1a81cea7c9 100644 --- a/python/paddle/compat.py +++ b/python/paddle/compat.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# This file implements most of the public API compatible with PyTorch. -# Note that this file does not depend on PyTorch in any way. +# This file implements most of the public API compatible with PyTorch. +# Note that this file does not depend on PyTorch in any way. # This is a standalone implementation. from .tensor.compat import ( diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 785016143dbf6e..1a13fad34b1db3 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -829,6 +829,15 @@ def find_paddle_includes(use_cuda=False): paddle_include_dir = get_include() third_party_dir = os.path.join(paddle_include_dir, 'third_party') include_dirs = [paddle_include_dir, third_party_dir] + if not IS_WINDOWS: + compat_dir_root = os.path.join( + paddle_include_dir, 'paddle/phi/api/include/compat' + ) + compat_dir_api_include = os.path.join( + paddle_include_dir, + 'paddle/phi/api/include/compat/torch/csrc/api/include', + ) + include_dirs.extend([compat_dir_root, compat_dir_api_include]) if use_cuda: if core.is_compiled_with_rocm(): diff --git a/python/setup.py.in b/python/setup.py.in index 736ed7e9301964..dcea2ff5fe0b06 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -1345,6 +1345,8 @@ headers = ( list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/api/ext')) + # custom op api list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/api/include')) + # phi api list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/common')) + # phi common headers + # torch compatible apis + list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/api/include/compat', recursive=True)) + # phi level api headers (low level api, for training only) list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi')) + # phi extension header list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/include', recursive=True)) + # phi include headers diff --git a/setup.py b/setup.py index 80026993e178ab..108e4f6052216d 100644 --- a/setup.py +++ b/setup.py @@ -1871,6 +1871,14 @@ def get_headers(): + list( # common api find_files('*.h', paddle_source_dir + '/paddle/common') ) + # torch compatible apis + + list( + find_files( + '*.h', + paddle_source_dir + '/paddle/phi/api/include/compat', + recursive=True, + ) + ) # phi level api headers (low level api, for training only) + list( # phi extension header find_files('*.h', paddle_source_dir + '/paddle/phi') diff --git a/test/auto_parallel/custom_op/utils.py b/test/auto_parallel/custom_op/utils.py index e6bc403e512a74..05047c168fc29b 100644 --- a/test/auto_parallel/custom_op/utils.py +++ b/test/auto_parallel/custom_op/utils.py @@ -13,8 +13,11 @@ # limitations under the License. import os +from pathlib import Path from site import getsitepackages +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS + # Test for extra compile args extra_cc_args = ['-w', '-g'] extra_nvcc_args = ['-O3'] @@ -34,12 +37,19 @@ def get_paddle_includes(): paddle_includes.append(f"{env_dict.get('PYBIND_INCLUDE_DIR')}") for site_packages_path in getsitepackages(): - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include') - ) - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include', 'third_party') - ) + paddle_include_dir = Path(site_packages_path) / "paddle/include" + paddle_includes.append(str(paddle_include_dir)) + paddle_includes.append(str(paddle_include_dir / 'third_party')) + if not IS_WINDOWS: + paddle_includes.append( + str(paddle_include_dir / 'paddle/phi/api/include/compat') + ) + paddle_includes.append( + str( + paddle_include_dir + / 'paddle/phi/api/include/compat/torch/csrc/api/include' + ) + ) return paddle_includes diff --git a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py index 56c5f593fe594f..78845789f713ee 100644 --- a/test/auto_parallel/semi_auto_parallel_for_custom_relu.py +++ b/test/auto_parallel/semi_auto_parallel_for_custom_relu.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from pathlib import Path from site import getsitepackages import numpy as np @@ -28,12 +29,19 @@ # PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. paddle_includes = [] for site_packages_path in getsitepackages(): - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include') - ) - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include', 'third_party') - ) + paddle_include_dir = Path(site_packages_path) / "paddle/include" + paddle_includes.append(str(paddle_include_dir)) + paddle_includes.append(str(paddle_include_dir / 'third_party')) + if not IS_WINDOWS: + paddle_includes.append( + str(paddle_include_dir / 'paddle/phi/api/include/compat') + ) + paddle_includes.append( + str( + paddle_include_dir + / 'paddle/phi/api/include/compat/torch/csrc/api/include' + ) + ) # Test for extra compile args extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] diff --git a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py index df63a91e6f0bf0..54b2452bced96c 100644 --- a/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py +++ b/test/auto_parallel/semi_auto_parallel_simple_net_custom_relu.py @@ -13,6 +13,7 @@ # limitations under the License. import os +from pathlib import Path from site import getsitepackages from semi_auto_parallel_simple_net import TestSimpleNetForSemiAutoParallel @@ -30,12 +31,19 @@ # PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. paddle_includes = [] for site_packages_path in getsitepackages(): - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include') - ) - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include', 'third_party') - ) + paddle_include_dir = Path(site_packages_path) / "paddle/include" + paddle_includes.append(str(paddle_include_dir)) + paddle_includes.append(str(paddle_include_dir / 'third_party')) + if not IS_WINDOWS: + paddle_includes.append( + str(paddle_include_dir / 'paddle/phi/api/include/compat') + ) + paddle_includes.append( + str( + paddle_include_dir + / 'paddle/phi/api/include/compat/torch/csrc/api/include' + ) + ) # Test for extra compile args extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index 267de9cda59a77..736364f9cb0415 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -12,6 +12,7 @@ add_subdirectory(inference) add_subdirectory(eager) add_subdirectory(fluid) add_subdirectory(utils) +add_subdirectory(compat) if(WITH_CINN) add_subdirectory(cinn) endif() diff --git a/test/cpp/compat/CMakeLists.txt b/test/cpp/compat/CMakeLists.txt new file mode 100644 index 00000000000000..34d8147ca30dc6 --- /dev/null +++ b/test/cpp/compat/CMakeLists.txt @@ -0,0 +1,8 @@ +if(NOT WIN32) + if(WITH_GPU) + paddle_test(compat_basic_test SRCS compat_basic_test.cc) + paddle_test(torch_library_test SRCS torch_library_test.cc) + target_link_libraries(compat_basic_test ${CUDA_LIBRARIES} + ${CUDA_CUDART_LIBRARY}) + endif() +endif() diff --git a/test/cpp/compat/compat_basic_test.cc b/test/cpp/compat/compat_basic_test.cc new file mode 100644 index 00000000000000..601ac5b540f518 --- /dev/null +++ b/test/cpp/compat/compat_basic_test.cc @@ -0,0 +1,260 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include +#include +#endif +#include "ATen/ATen.h" +#include "gtest/gtest.h" +#include "paddle/phi/common/float16.h" +#include "torch/all.h" + +TEST(TensorBaseTest, DataPtrAPIs) { + // Test data_ptr() and const_data_ptr() APIs + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + + // Test void* data_ptr() + void* void_ptr = tensor.data_ptr(); + ASSERT_NE(void_ptr, nullptr); + + // Test typed data_ptr() + float* float_ptr = tensor.data_ptr(); + ASSERT_NE(float_ptr, nullptr); + ASSERT_EQ(float_ptr, void_ptr); + + // Test const_data_ptr() + const float* const_float_ptr = tensor.const_data_ptr(); + ASSERT_NE(const_float_ptr, nullptr); + ASSERT_EQ(const_float_ptr, float_ptr); + + // Test mutable_data_ptr() + void* mutable_ptr = tensor.mutable_data_ptr(); + ASSERT_NE(mutable_ptr, nullptr); + ASSERT_EQ(mutable_ptr, void_ptr); +} +TEST(TensorBaseTest, DimensionAPIs) { + // Test dimension related APIs + at::TensorBase tensor = at::ones({2, 3, 4}, at::kFloat); + + // Test sizes() + auto sizes = tensor.sizes(); + ASSERT_EQ(sizes.size(), 3); + ASSERT_EQ(sizes[0], 2); + ASSERT_EQ(sizes[1], 3); + ASSERT_EQ(sizes[2], 4); + + // Test size(dim) + ASSERT_EQ(tensor.size(0), 2); + ASSERT_EQ(tensor.size(1), 3); + ASSERT_EQ(tensor.size(2), 4); + + // Test strides() + auto strides = tensor.strides(); + ASSERT_EQ(strides.size(), 3); + ASSERT_EQ(strides[0], 12); // 3*4 + ASSERT_EQ(strides[1], 4); // 4 + ASSERT_EQ(strides[2], 1); // contiguous + + // Test stride(dim) + ASSERT_EQ(tensor.stride(0), 12); + ASSERT_EQ(tensor.stride(1), 4); + ASSERT_EQ(tensor.stride(2), 1); + + // Test numel() + ASSERT_EQ(tensor.numel(), 24); // 2*3*4 + + // Test dim()/ndimension() + ASSERT_EQ(tensor.dim(), 3); + ASSERT_EQ(tensor.ndimension(), 3); +} +TEST(TensorBaseTest, TypeDeviceAPIs) { + // Test type and device related APIs + at::TensorBase cpu_tensor = at::ones({2, 3}, at::kFloat); + + // Test dtype()/scalar_type() + ASSERT_EQ(cpu_tensor.dtype(), at::kFloat); + ASSERT_EQ(cpu_tensor.scalar_type(), at::kFloat); + + // Test device() + ASSERT_EQ(cpu_tensor.device().type(), at::DeviceType::CPU); + + // Test get_device() + ASSERT_EQ(cpu_tensor.get_device(), 0); // CPU device index is -1 + + // Test is_cpu()/is_cuda() + ASSERT_TRUE(cpu_tensor.is_cpu()); + ASSERT_FALSE(cpu_tensor.is_cuda()); + + // Test options() + auto options = cpu_tensor.options(); + ASSERT_EQ(options.device().type(), at::DeviceType::CPU); +} + +TEST(TensorBaseTest, ModifyOperationAPIs) { + // Test modify operation related APIs + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + + // Test is_contiguous() + ASSERT_TRUE(tensor.is_contiguous()); + + // Test fill_() + tensor.fill_(2.0); + float* data = tensor.data_ptr(); + for (int i = 0; i < tensor.numel(); i++) { + ASSERT_EQ(data[i], 2.0f); + } + + // Test zero_() + tensor.zero_(); + for (int i = 0; i < tensor.numel(); i++) { + ASSERT_EQ(data[i], 0.0f); + } + + // Test copy_() + at::TensorBase src = at::ones({2, 3}, at::kFloat); + tensor.copy_(src); + for (int i = 0; i < tensor.numel(); i++) { + ASSERT_EQ(data[i], 1.0f); + } + + // Test view() + at::TensorBase viewed = tensor.view({6}); + ASSERT_EQ(viewed.sizes(), std::vector{6}); + ASSERT_EQ(viewed.strides(), std::vector{1}); +} + +TEST(tensor_clone_test, BasicClone) { + at::Tensor a = at::ones({2, 3}, at::kFloat); + + at::Tensor b = a.clone(); + + ASSERT_EQ(a.sizes(), b.sizes()); + ASSERT_EQ(a.dtype(), b.dtype()); + ASSERT_EQ(a.device().type(), b.device().type()); +} + +TEST(compat_basic_test, BasicCase) { + at::Tensor a = + at::ones({2, 3}, at::TensorOptions().dtype(at::kFloat).device(at::kCPU)); + at::Tensor b = at::full({2, 3}, 2, at::kFloat); + double c = 10; + + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = at::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = result.data_ptr(); + for (int64_t i = 0; i < a_contig.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i] + c; + } + // Show result + for (int64_t i = 0; i < a_contig.numel(); i++) { + std::cout << "Result[" << i << "] = " << a_ptr[i] * b_ptr[i] + c + << std::endl; + ASSERT_EQ(result_ptr[i], 12); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + + { + // for test empty_cuda: + at::Tensor bb = + at::detail::empty_cuda(12, at::kFloat, at::kCUDA, std::nullopt); + + // for test sizoof(at::Half): + std::cout << sizeof(at::Half) << std::endl; + at::Tensor num_non_exiting_ctas = at::empty( + {}, at::TensorOptions().device(a.device()).dtype(at::ScalarType::Int)); + } + { + std::vector shape = {2, 3, 4, 5}; + size_t size_ = + c10::elementSize(at::ScalarType::Float) * c10::multiply_integers(shape); + std::cout << "multiply_integers out: " << size_ << std::endl; + } + { + std::vector shape = {2, 3, 4, 5}; + size_t size_ = + c10::elementSize(at::ScalarType::Float) * c10::sum_integers(shape); + std::cout << "sum_integers out: " << size_ << std::endl; + } + { + auto stream = at::cuda::getCurrentCUDAStream(); + std::cout << "stream num: " << stream.stream() << std::endl; + at::cuda::stream_synchronize(stream); + at::Tensor bb = + at::detail::empty_cuda(12, at::kFloat, at::kCUDA, std::nullopt); + } + { + at::Tensor a = at::ones( + {2, 3}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); + std::cout << "a.device() is at::kCUDA: " << (a.device().type() == at::kCUDA) + << std::endl; + const c10::cuda::CUDAGuard device_guard(a.device()); + std::cout << "device_guard is at::kCUDA: " + << (device_guard.current_device().type() == at::kCUDA) + << std::endl; + const c10::cuda::OptionalCUDAGuard device_guard_opt(a.device()); + std::cout << "device_guard is at::kCUDA: " + << (device_guard_opt.current_device().value().type() == at::kCUDA) + << std::endl; + } + + { + std::cout << "num_tokens_per_rank.device() is at::kCUDA: " << std::endl; + // for test empty: + auto num_tokens_per_rank = + torch::empty({3}, + dtype(torch::kInt32).device(torch::kCUDA), + c10::MemoryFormat::Contiguous); + std::cout << "num_tokens_per_rank.device() is at::kCUDA: " + << (num_tokens_per_rank.device().type() == at::kCUDA) + << std::endl; + } + { + auto num_tokens_per_rank = torch::empty( + {3}, dtype(torch::kInt32).device(torch::kCUDA), std::nullopt); + std::cout << "num_tokens_per_rank.device() is at::kCUDA: " + << (num_tokens_per_rank.device().type() == at::kCUDA) + << std::endl; + } +#endif + { + int a = 10, b = 20, c = 30; + int* p[] = {&a, &b, &c}; // int* array[3] + int** pp = p; + + torch::Tensor t = + torch::from_blob(pp, {3}, torch::TensorOptions().dtype(torch::kInt64)); + + // Get original int** + int** restored = reinterpret_cast(t.data_ptr()); + std::cout << *restored[0] << ", " << *restored[1] << ", " << *restored[2] + << std::endl; + } +} diff --git a/test/cpp/compat/torch_library_test.cc b/test/cpp/compat/torch_library_test.cc new file mode 100644 index 00000000000000..945e9433d1207c --- /dev/null +++ b/test/cpp/compat/torch_library_test.cc @@ -0,0 +1,585 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" + +at::Tensor mymuladd_cpu(at::Tensor a, const at::Tensor& b, double c) { + TORCH_CHECK(a.sizes() == b.sizes()); + TORCH_CHECK(a.dtype() == at::kFloat); + TORCH_CHECK(b.dtype() == at::kFloat); + TORCH_INTERNAL_ASSERT(a.device().type() == at::DeviceType::CPU); + TORCH_INTERNAL_ASSERT(b.device().type() == at::DeviceType::CPU); + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = torch::empty(a_contig.sizes(), a_contig.options()); + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* result_ptr = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); i++) { + result_ptr[i] = a_ptr[i] * b_ptr[i] + c; + } + return result; +} + +template +T generic_add(T a, T b) { + return a + b; +} + +class TestClass : public torch::CustomClassHolder { + public: + int value; + std::string name; + + TestClass() : value(0), name("default") { + std::cout << "TestClass::TestClass() - Default constructor" << std::endl; + } + + TestClass(int v) : value(v), name("single_param") { // NOLINT + std::cout << "TestClass::TestClass(int) - Single parameter constructor" + << std::endl; + } + + TestClass(int v, const std::string& n) : value(v), name(n) { + std::cout + << "TestClass::TestClass(int, string) - Double parameters constructor" + << std::endl; + } + + int getValue() const { + std::cout << "TestClass::getValue() - getter" << std::endl; + return value; + } + + const std::string& getName() const { + std::cout << "TestClass::getName() - getter" << std::endl; + return name; + } + + void setValue(int v) { + std::cout << "TestClass::setValue(int) - setter (int)" << std::endl; + value = v; + } + + void setName(const std::string& n) { + std::cout << "TestClass::setName(string) - setter (string)" << std::endl; + name = n; + } + + static int getDefaultValue() { + std::cout << "TestClass::getDefaultValue() - static method" << std::endl; + return 42; + } + + static int addValues(int a, int b) { + std::cout << "TestClass::addValues(int, int) - static method" << std::endl; + return a + b; + } +}; + +TORCH_LIBRARY(example_library, m) { + // Note that "float" in the schema corresponds to the C++ double type + // and the Python float type. + m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); + m.class_("TestClass") + .def(torch::init<>()) + .def(torch::init()) + .def(torch::init()) + .def("getValue", &TestClass::getValue) + .def("getName", &TestClass::getName) + .def("setValue", &TestClass::setValue) + .def("setName", &TestClass::setName) + .def_static("getDefaultValue", &TestClass::getDefaultValue) + .def_static("addValues", &TestClass::addValues); +} + +TEST(test_torch_library, TestLibraryOperators) { + auto qualified_name = "example_library::mymuladd"; + auto* op = torch::OperatorRegistry::instance().find_operator(qualified_name); + ASSERT_NE(op, nullptr); + auto impl_it = op->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it, op->implementations.end()); + torch::FunctionArgs function_args; + function_args.add_arg(torch::IValue(at::ones({2, 2}, at::kFloat))); + function_args.add_arg(torch::IValue(at::ones({2, 2}, at::kFloat))); + function_args.add_arg(torch::IValue(2.0)); + auto result = impl_it->second.call_with_args(function_args); + ASSERT_TRUE(result.get_value().is_tensor()); + auto result_tensor = result.get_value().to_tensor(); +} + +TEST(test_torch_library, TestLibraryClasses) { + auto qualified_name = "example_library::TestClass"; + const auto& class_registry = torch::ClassRegistry::instance(); + bool has_class = class_registry.has_class(qualified_name); + ASSERT_TRUE(has_class); + torch::FunctionArgs constructor_args; + constructor_args.add_arg(torch::IValue(10)); + constructor_args.add_arg(torch::IValue("example")); + + // Call constructor + auto instance = class_registry.call_constructor_with_args(qualified_name, + constructor_args); + ASSERT_TRUE(instance.get_value().is_custom_class()); + + // Call getValue + auto get_value_result = class_registry.call_method_with_args( + qualified_name, "getValue", instance.get_value(), torch::FunctionArgs()); + ASSERT_TRUE(get_value_result.get_value().is_int()); + int value = get_value_result.get_value().to_int(); + ASSERT_EQ(value, 10); + + // Call setValue + torch::FunctionArgs set_value_args; + set_value_args.add_arg(torch::IValue(20)); + class_registry.call_method_with_args( + qualified_name, "setValue", instance.get_value(), set_value_args); + ASSERT_EQ(instance.get_value().to_custom_class()->value, 20); + auto get_value_after_set = class_registry.call_method_with_args( + qualified_name, "getValue", instance.get_value(), torch::FunctionArgs()); + ASSERT_EQ(get_value_after_set.get_value().to_int(), 20); + + // Call getName + auto get_name_result = class_registry.call_method_with_args( + qualified_name, "getName", instance.get_value(), torch::FunctionArgs()); + ASSERT_TRUE(get_name_result.get_value().is_string()); + std::string name = get_name_result.get_value().to_string(); + ASSERT_EQ(name, "example"); + + // Call setName + torch::FunctionArgs set_name_args; + set_name_args.add_arg(torch::IValue("new_example")); + class_registry.call_method_with_args( + qualified_name, "setName", instance.get_value(), set_name_args); + ASSERT_EQ(instance.get_value().to_custom_class()->name, + "new_example"); + auto get_name_after_set = class_registry.call_method_with_args( + qualified_name, "getName", instance.get_value(), torch::FunctionArgs()); + ASSERT_EQ(get_name_after_set.get_value().to_string(), "new_example"); + + // Call static method getDefaultValue + auto get_default_value_result = class_registry.call_static_method_with_args( + qualified_name, "getDefaultValue", torch::FunctionArgs()); + ASSERT_TRUE(get_default_value_result.get_value().is_int()); + int default_value = get_default_value_result.get_value().to_int(); + ASSERT_EQ(default_value, 42); + + // Call static method addValues + torch::FunctionArgs add_values_args; + add_values_args.add_arg(torch::IValue(5)); + add_values_args.add_arg(torch::IValue(7)); + auto add_values_result = class_registry.call_static_method_with_args( + qualified_name, "addValues", add_values_args); + ASSERT_TRUE(add_values_result.get_value().is_int()); + int sum = add_values_result.get_value().to_int(); + ASSERT_EQ(sum, 12); +} + +TORCH_LIBRARY_IMPL(example_library, CPU, m) { + m.impl("mymuladd", &mymuladd_cpu); +} + +TORCH_LIBRARY_FRAGMENT(example_library_fragment, m) { + m.def("int_add", &generic_add); +} + +TORCH_LIBRARY_FRAGMENT(example_library_fragment, m) { + m.def("string_concat", &generic_add); +} + +TEST(test_torch_library, TestFragmentOperators) { + auto qualified_name_int_add = "example_library_fragment::int_add"; + auto* op_int_add = + torch::OperatorRegistry::instance().find_operator(qualified_name_int_add); + ASSERT_NE(op_int_add, nullptr); + auto impl_it_int_add = + op_int_add->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it_int_add, op_int_add->implementations.end()); + torch::FunctionArgs function_args; + function_args.add_arg(torch::IValue(3)); + function_args.add_arg(torch::IValue(4)); + auto result = impl_it_int_add->second.call_with_args(function_args); + ASSERT_TRUE(result.get_value().is_int()); + int sum = result.get_value().to_int(); + ASSERT_EQ(sum, 7); + + auto qualified_name_string_concat = "example_library_fragment::string_concat"; + auto* op_string_concat = torch::OperatorRegistry::instance().find_operator( + qualified_name_string_concat); + ASSERT_NE(op_string_concat, nullptr); + auto impl_it_string_concat = + op_string_concat->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it_string_concat, op_string_concat->implementations.end()); + torch::FunctionArgs string_args; + string_args.add_arg(torch::IValue(std::string("Hello, "))); + string_args.add_arg(torch::IValue(std::string("World!"))); + auto string_result = + impl_it_string_concat->second.call_with_args(string_args); + ASSERT_TRUE(string_result.get_value().is_string()); + std::string concatenated_string = string_result.get_value().to_string(); + ASSERT_EQ(concatenated_string, "Hello, World!"); +} + +at::Tensor cast_with_scalar_type(at::Tensor input, c10::ScalarType dtype) { + return input.toType(dtype); +} + +TORCH_LIBRARY(example_library_with_scalar_type_input, m) { + m.def("cast_with_scalar_type", &cast_with_scalar_type); +} + +TEST(test_torch_library, TestScalarTypeInput) { + auto qualified_name = + "example_library_with_scalar_type_input::cast_with_scalar_type"; + auto* op = torch::OperatorRegistry::instance().find_operator(qualified_name); + ASSERT_NE(op, nullptr); + auto impl_it = op->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it, op->implementations.end()); + torch::FunctionArgs function_args; + function_args.add_arg(torch::IValue(at::ones({2, 2}, at::kFloat))); + function_args.add_arg(torch::IValue(at::kDouble)); + auto result = impl_it->second.call_with_args(function_args); + ASSERT_TRUE(result.get_value().is_tensor()); + auto result_tensor = result.get_value().to_tensor(); + ASSERT_EQ(result_tensor.dtype(), at::kDouble); +} + +int fn_with_int_const(int const x) { return x + 1; } + +TORCH_LIBRARY(example_library_with_int_const, m) { + m.def("fn_with_int_const", &fn_with_int_const); +} + +TEST(test_torch_library, TestIntConst) { + auto qualified_name = "example_library_with_int_const::fn_with_int_const"; + auto* op = torch::OperatorRegistry::instance().find_operator(qualified_name); + ASSERT_NE(op, nullptr); + auto impl_it = op->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it, op->implementations.end()); + torch::FunctionArgs function_args; + function_args.add_arg(torch::IValue(3)); + auto result = impl_it->second.call_with_args(function_args); + ASSERT_TRUE(result.get_value().is_int()); + int value = result.get_value().to_int(); + ASSERT_EQ(value, 4); +} + +int fn_with_optional_input(torch::optional x) { + if (x.has_value()) { + return x.value() + 1; + } else { + return -1; + } +} + +TORCH_LIBRARY(example_library_with_optional_input, m) { + m.def("fn_with_optional_input", &fn_with_optional_input); +} + +TEST(test_torch_library, TestOptionalInput) { + auto qualified_name = + "example_library_with_optional_input::fn_with_optional_input"; + auto* op = torch::OperatorRegistry::instance().find_operator(qualified_name); + ASSERT_NE(op, nullptr); + auto impl_it = op->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it, op->implementations.end()); + + // Test with value + torch::FunctionArgs function_args_with_value; + function_args_with_value.add_arg(torch::IValue(int64_t(5))); + auto result_with_value = + impl_it->second.call_with_args(function_args_with_value); + ASSERT_TRUE(result_with_value.get_value().is_int()); + int value_with_value = result_with_value.get_value().to_int(); + ASSERT_EQ(value_with_value, 6); + + // Test without value (nullopt) + torch::FunctionArgs function_args_without_value; + function_args_without_value.add_arg(torch::IValue()); + auto result_without_value = + impl_it->second.call_with_args(function_args_without_value); + ASSERT_TRUE(result_without_value.get_value().is_int()); + int value_without_value = result_without_value.get_value().to_int(); + ASSERT_EQ(value_without_value, -1); +} + +int fn_with_arrayref_input(c10::ArrayRef x) { + int sum = 0; + for (const auto& val : x) { + sum += val; + } + return sum; +} + +TORCH_LIBRARY(example_library_with_arrayref_input, m) { + m.def("fn_with_arrayref_input", &fn_with_arrayref_input); +} + +TEST(test_torch_library, TestArrayRefInput) { + auto qualified_name = + "example_library_with_arrayref_input::fn_with_arrayref_input"; + auto* op = torch::OperatorRegistry::instance().find_operator(qualified_name); + ASSERT_NE(op, nullptr); + auto impl_it = op->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it, op->implementations.end()); + + torch::FunctionArgs function_args; + function_args.add_arg(torch::IValue(std::vector({1, 2, 3, 4}))); + auto result = impl_it->second.call_with_args(function_args); + ASSERT_TRUE(result.get_value().is_int()); + int value = result.get_value().to_int(); + ASSERT_EQ(value, 10); +} + +int fn_with_mix_optional_arrayref_input( + c10::optional> x) { + if (x.has_value()) { + int sum = 0; + for (const auto& val : x.value()) { + sum += val; + } + return sum; + } else { + return -1; + } +} + +TORCH_LIBRARY(example_library_with_mix_optional_arrayref_input, m) { + m.def("fn_with_mix_optional_arrayref_input", + &fn_with_mix_optional_arrayref_input); +} + +TEST(test_torch_library, TestMixOptionalArrayRefInput) { + auto qualified_name = + "example_library_with_mix_optional_arrayref_input::" + "fn_with_mix_optional_arrayref_input"; + auto* op = torch::OperatorRegistry::instance().find_operator(qualified_name); + ASSERT_NE(op, nullptr); + auto impl_it = op->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it, op->implementations.end()); + + // Test with value + torch::FunctionArgs function_args_with_value; + function_args_with_value.add_arg( + torch::IValue(std::vector({1, 2, 3, 4}))); + auto result_with_value = + impl_it->second.call_with_args(function_args_with_value); + ASSERT_TRUE(result_with_value.get_value().is_int()); + int value_with_value = result_with_value.get_value().to_int(); + ASSERT_EQ(value_with_value, 10); + + // Test without value (nullopt) + torch::FunctionArgs function_args_without_value; + function_args_without_value.add_arg(torch::IValue()); + auto result_without_value = + impl_it->second.call_with_args(function_args_without_value); + ASSERT_TRUE(result_without_value.get_value().is_int()); + int value_without_value = result_without_value.get_value().to_int(); + ASSERT_EQ(value_without_value, -1); +} + +void fn_with_optional_tensor_const_ref_input( + torch::optional const& x) {} + +TORCH_LIBRARY(example_library_with_optional_tensor_const_ref_input, m) { + m.def("fn_with_optional_tensor_const_ref_input", + &fn_with_optional_tensor_const_ref_input); +} + +TEST(test_torch_library, TestOptionalTensorConstRefInput) { + auto qualified_name = + "example_library_with_optional_tensor_const_ref_input::" + "fn_with_optional_tensor_const_ref_input"; + auto* op = torch::OperatorRegistry::instance().find_operator(qualified_name); + ASSERT_NE(op, nullptr); + auto impl_it = op->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it, op->implementations.end()); + + // Test with value + torch::FunctionArgs function_args_with_value; + function_args_with_value.add_arg(torch::IValue(at::ones({2, 2}, at::kFloat))); + impl_it->second.call_with_args(function_args_with_value); + + // Test without value (nullopt) + torch::FunctionArgs function_args_without_value; + function_args_without_value.add_arg(torch::IValue()); + impl_it->second.call_with_args(function_args_without_value); +} + +// Function that returns a list of two tensors (instead of tuple) +std::vector return_tensor_list(const at::Tensor& input, int dim) { + // Simply create two tensors of different sizes as demonstration + auto first_part = at::ones({2}, input.options()); + auto second_part = at::ones({2}, input.options()); + + return {first_part, second_part}; +} + +// Function that actually returns std::tuple +std::tuple return_tensor_tuple(const at::Tensor& input, + int dim) { + // Create two tensors and return as tuple + auto first_part = at::ones({2}, input.options()); + auto second_part = + at::ones({3}, input.options()); // Different size to verify + + return std::make_tuple(first_part, second_part); +} + +// Function that actually returns std::tuple +std::tuple return_tensor_tuple_3( + const at::Tensor& input, int dim) { + // Create two tensors and return as tuple + auto first_part = at::ones({2}, input.options()); + auto second_part = + at::ones({3}, input.options()); // Different size to verify + auto third_part = at::ones({4}, input.options()); + + return std::make_tuple(first_part, second_part, third_part); +} + +TORCH_LIBRARY(example_library_with_tuple_return, m) { + m.def("split_tensor_list", &return_tensor_list); + m.def("split_tensor_tuple", &return_tensor_tuple); + m.def("split_tensor_tuple_3", &return_tensor_tuple_3); +} + +TEST(test_torch_library, TestTupleReturn) { + // Test vector return (list) + auto qualified_name_list = + "example_library_with_tuple_return::split_tensor_list"; + auto* op_list = + torch::OperatorRegistry::instance().find_operator(qualified_name_list); + ASSERT_NE(op_list, nullptr); + auto impl_it_list = op_list->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it_list, op_list->implementations.end()); + + // Create a test tensor [0, 1, 2, 3] with shape [4] + std::vector data = {0.0f, 1.0f, 2.0f, 3.0f}; + auto input_tensor = at::from_blob(data.data(), {4}, at::kFloat).clone(); + + torch::FunctionArgs function_args_list; + function_args_list.add_arg(torch::IValue(input_tensor)); + function_args_list.add_arg(torch::IValue(0)); // split along dimension 0 + + auto result_list = impl_it_list->second.call_with_args(function_args_list); + + // Verify the result is a GenericList (vector of tensors) + ASSERT_TRUE(result_list.get_value().is_list()); + + auto list_val = result_list.get_value().to_list(); + ASSERT_EQ(list_val.size(), 2); + + // Check first tensor should have size [2] + auto first_tensor_list = list_val[0].to_tensor(); + ASSERT_EQ(first_tensor_list.size(0), 2); + + // Check second tensor should have size [2] + auto second_tensor_list = list_val[1].to_tensor(); + ASSERT_EQ(second_tensor_list.size(0), 2); + + // Test std::tuple return (tuple) + auto qualified_name_tuple = + "example_library_with_tuple_return::split_tensor_tuple"; + auto* op_tuple = + torch::OperatorRegistry::instance().find_operator(qualified_name_tuple); + ASSERT_NE(op_tuple, nullptr); + auto impl_it_tuple = op_tuple->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it_tuple, op_tuple->implementations.end()); + + torch::FunctionArgs function_args_tuple; + function_args_tuple.add_arg(torch::IValue(input_tensor)); + function_args_tuple.add_arg(torch::IValue(0)); // split along dimension 0 + + auto result_tuple = impl_it_tuple->second.call_with_args(function_args_tuple); + + // Verify the result is a tuple + ASSERT_TRUE(result_tuple.get_value().is_tuple()); + + auto tuple_val = result_tuple.get_value().to_tuple(); + ASSERT_EQ(tuple_val.size(), 2); + + // Check first tensor should have size [2] + auto first_tensor_tuple = tuple_val[0].to_tensor(); + ASSERT_EQ(first_tensor_tuple.size(0), 2); + + // Check second tensor should have size [3] (different from first) + auto second_tensor_tuple = tuple_val[1].to_tensor(); + ASSERT_EQ(second_tensor_tuple.size(0), 3); + + // Test std::tuple return (tuple) + auto qualified_name_tuple_3 = + "example_library_with_tuple_return::split_tensor_tuple_3"; + auto* op_tuple_3 = + torch::OperatorRegistry::instance().find_operator(qualified_name_tuple_3); + ASSERT_NE(op_tuple_3, nullptr); + auto impl_it_tuple_3 = + op_tuple_3->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it_tuple_3, op_tuple_3->implementations.end()); + + torch::FunctionArgs function_args_tuple_3; + function_args_tuple_3.add_arg(torch::IValue(input_tensor)); + function_args_tuple_3.add_arg(torch::IValue(0)); // split along dimension 0 + + auto result_tuple_3 = + impl_it_tuple_3->second.call_with_args(function_args_tuple_3); + + // Verify the result is a tuple + ASSERT_TRUE(result_tuple_3.get_value().is_tuple()); + + auto tuple_val_3 = result_tuple_3.get_value().to_tuple(); + ASSERT_EQ(tuple_val_3.size(), 3); + + // Check first tensor should have size [2] + auto first_tensor_tuple_3 = tuple_val_3[0].to_tensor(); + ASSERT_EQ(first_tensor_tuple_3.size(0), 2); + + // Check second tensor should have size [3] (different from first) + auto second_tensor_tuple_3 = tuple_val_3[1].to_tensor(); + ASSERT_EQ(second_tensor_tuple_3.size(0), 3); + + // Check third tensor should have size [4] (different from first and second) + auto third_tensor_tuple_3 = tuple_val_3[2].to_tensor(); + ASSERT_EQ(third_tensor_tuple_3.size(0), 4); +} + +// Test for const reference parameters fix +void fn_with_const_ref_param(const int& x, const std::string& str) { + // Simple function to test const reference parameter handling +} + +TORCH_LIBRARY(example_library_const_ref_fix, m) { + m.def("fn_with_const_ref_param", &fn_with_const_ref_param); +} + +TEST(test_torch_library, TestConstRefParameterFix) { + auto qualified_name = + "example_library_const_ref_fix::fn_with_const_ref_param"; + auto* op = torch::OperatorRegistry::instance().find_operator(qualified_name); + ASSERT_NE(op, nullptr); + auto impl_it = op->implementations.find(torch::DispatchKey::CPU); + ASSERT_NE(impl_it, op->implementations.end()); + + // Test with const reference parameters + torch::FunctionArgs function_args; + function_args.add_arg(torch::IValue(42)); + function_args.add_arg(torch::IValue(std::string("test"))); + + // This should not throw compilation errors + auto result = impl_it->second.call_with_args(function_args); + ASSERT_TRUE(result.get_value().is_none()); // void function returns None +} diff --git a/test/cpp_extension/cpp_extension_setup.py b/test/cpp_extension/cpp_extension_setup.py index ebede6aa5a6ab9..f9d168f7a346a4 100644 --- a/test/cpp_extension/cpp_extension_setup.py +++ b/test/cpp_extension/cpp_extension_setup.py @@ -13,21 +13,30 @@ # limitations under the License. import os +from pathlib import Path from site import getsitepackages from utils import extra_compile_args import paddle from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS paddle_includes = [] for site_packages_path in getsitepackages(): - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include') - ) - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include', 'third_party') - ) + paddle_include_dir = Path(site_packages_path) / "paddle/include" + paddle_includes.append(str(paddle_include_dir)) + paddle_includes.append(str(paddle_include_dir / 'third_party')) + if not IS_WINDOWS: + paddle_includes.append( + str(paddle_include_dir / 'paddle/phi/api/include/compat') + ) + paddle_includes.append( + str( + paddle_include_dir + / 'paddle/phi/api/include/compat/torch/csrc/api/include' + ) + ) # Add current dir, search custom_power.h paddle_includes.append(os.path.dirname(os.path.abspath(__file__))) diff --git a/test/cpp_extension/test_cpp_extension_jit.py b/test/cpp_extension/test_cpp_extension_jit.py index 3a32acdf81f5de..dfedce266354a9 100644 --- a/test/cpp_extension/test_cpp_extension_jit.py +++ b/test/cpp_extension/test_cpp_extension_jit.py @@ -15,6 +15,7 @@ import os import sys import unittest +from pathlib import Path from site import getsitepackages import numpy as np @@ -22,6 +23,7 @@ import paddle from paddle.utils.cpp_extension import load +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS if os.name == 'nt' or sys.platform.startswith('darwin'): # only support Linux now @@ -34,12 +36,19 @@ paddle_includes = [] for site_packages_path in getsitepackages(): - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include') - ) - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include', 'third_party') - ) + paddle_include_dir = Path(site_packages_path) / "paddle/include" + paddle_includes.append(str(paddle_include_dir)) + paddle_includes.append(str(paddle_include_dir / 'third_party')) + if not IS_WINDOWS: + paddle_includes.append( + str(paddle_include_dir / 'paddle/phi/api/include/compat') + ) + paddle_includes.append( + str( + paddle_include_dir + / 'paddle/phi/api/include/compat/torch/csrc/api/include' + ) + ) # include "custom_power.h" paddle_includes.append(os.path.dirname(os.path.abspath(__file__))) diff --git a/test/cpp_extension/utils.py b/test/cpp_extension/utils.py index 76502792f3f25b..eb1aab0d0f5205 100644 --- a/test/cpp_extension/utils.py +++ b/test/cpp_extension/utils.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import sys +from pathlib import Path from site import getsitepackages import numpy as np @@ -28,12 +28,19 @@ # PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. paddle_includes = [] for site_packages_path in getsitepackages(): - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include') - ) - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include', 'third_party') - ) + paddle_include_dir = Path(site_packages_path) / "paddle/include" + paddle_includes.append(str(paddle_include_dir)) + paddle_includes.append(str(paddle_include_dir / 'third_party')) + if not IS_WINDOWS: + paddle_includes.append( + str(paddle_include_dir / 'paddle/phi/api/include/compat') + ) + paddle_includes.append( + str( + paddle_include_dir + / 'paddle/phi/api/include/compat/torch/csrc/api/include' + ) + ) # Test for extra compile args extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] diff --git a/test/custom_op/utils.py b/test/custom_op/utils.py index 9b36887455b1ff..06f81768d10c98 100644 --- a/test/custom_op/utils.py +++ b/test/custom_op/utils.py @@ -14,6 +14,7 @@ import os import sys +from pathlib import Path from site import getsitepackages import numpy as np @@ -29,12 +30,19 @@ paddle_includes = [] paddle_libraries = [] for site_packages_path in getsitepackages(): - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include') - ) - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include', 'third_party') - ) + paddle_include_dir = Path(site_packages_path) / "paddle/include" + paddle_includes.append(str(paddle_include_dir)) + paddle_includes.append(str(paddle_include_dir / 'third_party')) + if not IS_WINDOWS: + paddle_includes.append( + str(paddle_include_dir / 'paddle/phi/api/include/compat') + ) + paddle_includes.append( + str( + paddle_include_dir + / 'paddle/phi/api/include/compat/torch/csrc/api/include' + ) + ) paddle_libraries.append(os.path.join(site_packages_path, 'paddle', 'libs')) # Test for extra compile args diff --git a/test/custom_runtime/test_custom_op_setup.py b/test/custom_runtime/test_custom_op_setup.py index 25965d7963265e..51834e114654f7 100644 --- a/test/custom_runtime/test_custom_op_setup.py +++ b/test/custom_runtime/test_custom_op_setup.py @@ -16,10 +16,13 @@ import sys import tempfile import unittest +from pathlib import Path from site import getsitepackages import numpy as np +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS + def custom_relu_dynamic(func, device, dtype, np_x, use_func=True): import paddle @@ -136,14 +139,19 @@ def setUp(self): # please refer to the comments in `paddle/tests/custom_op/utils.py`` paddle_includes = [] for site_packages_path in getsitepackages(): - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include') - ) - paddle_includes.append( - os.path.join( - site_packages_path, 'paddle', 'include', 'third_party' + paddle_include_dir = Path(site_packages_path) / "paddle/include" + paddle_includes.append(str(paddle_include_dir)) + paddle_includes.append(str(paddle_include_dir / 'third_party')) + if not IS_WINDOWS: + paddle_includes.append( + str(paddle_include_dir / 'paddle/phi/api/include/compat') + ) + paddle_includes.append( + str( + paddle_include_dir + / 'paddle/phi/api/include/compat/torch/csrc/api/include' + ) ) - ) custom_module = paddle.utils.cpp_extension.load( name='custom_device', diff --git a/test/deprecated/custom_op/utils.py b/test/deprecated/custom_op/utils.py index 9b36887455b1ff..06f81768d10c98 100644 --- a/test/deprecated/custom_op/utils.py +++ b/test/deprecated/custom_op/utils.py @@ -14,6 +14,7 @@ import os import sys +from pathlib import Path from site import getsitepackages import numpy as np @@ -29,12 +30,19 @@ paddle_includes = [] paddle_libraries = [] for site_packages_path in getsitepackages(): - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include') - ) - paddle_includes.append( - os.path.join(site_packages_path, 'paddle', 'include', 'third_party') - ) + paddle_include_dir = Path(site_packages_path) / "paddle/include" + paddle_includes.append(str(paddle_include_dir)) + paddle_includes.append(str(paddle_include_dir / 'third_party')) + if not IS_WINDOWS: + paddle_includes.append( + str(paddle_include_dir / 'paddle/phi/api/include/compat') + ) + paddle_includes.append( + str( + paddle_include_dir + / 'paddle/phi/api/include/compat/torch/csrc/api/include' + ) + ) paddle_libraries.append(os.path.join(site_packages_path, 'paddle', 'libs')) # Test for extra compile args diff --git a/tools/print_signatures.py b/tools/print_signatures.py index 0a15a390f54a4d..d74519dee01f88 100644 --- a/tools/print_signatures.py +++ b/tools/print_signatures.py @@ -78,7 +78,7 @@ def md5(doc): ErrorSet = set() IdSet = set() -skiplist = [] +skiplist = ["paddle.ops", "paddle.classes"] def visit_all_module(mod):