diff --git a/paddle/fluid/operators/einsum_op.cc b/paddle/fluid/operators/einsum_op.cc index 0e33efab90a85..5f169e20e3dc3 100644 --- a/paddle/fluid/operators/einsum_op.cc +++ b/paddle/fluid/operators/einsum_op.cc @@ -106,7 +106,7 @@ namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(einsum, EinsumInferShapeFunctor, - PD_INFER_META(phi::EinsumInferMeta)); + PD_INFER_META(phi::EinsumRawInferMeta)); REGISTER_OPERATOR(einsum, ops::EinsumOp, diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index f532a429b49e2..b3c70e2fe9988 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -347,7 +347,7 @@ namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(squeeze2, SqueezeInferShapeFunctor, - PD_INFER_META(phi::SqueezeInferMeta)); + PD_INFER_META(phi::SqueezeWithXShapeInferMeta)); REGISTER_OPERATOR(squeeze, ops::SqueezeOp, diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc index 53de6440f1f61..f01ae5f142d28 100644 --- a/paddle/fluid/operators/unsqueeze_op.cc +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -347,7 +347,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(UnsqueezeGradOpNoNeedBufferVarInferer, "X"); DECLARE_INFER_SHAPE_FUNCTOR(unsqueeze2, Unsqueeze2InferShapeFunctor, - PD_INFER_META(phi::UnsqueezeInferMeta)); + PD_INFER_META(phi::UnsqueezeWithXShapeInferMeta)); namespace ops = paddle::operators; REGISTER_OPERATOR(unsqueeze, diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index 05d27571b8795..cb7f439690619 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -325,8 +325,8 @@ add_custom_command( ${dygraph_api_header_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp} ${dygraph_api_source_file} - DEPENDS ${api_yaml_file} ${sparse_api_yaml_file} ${im_api_gen_file} - ${api_gen_base} ${api_gen_file} + DEPENDS ${api_yaml_file} ${legacy_api_yaml_file} ${sparse_api_yaml_file} + ${im_api_gen_file} ${api_gen_base} ${api_gen_file} VERBATIM) # generate wrapped infermeta diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index a562db94745c9..0d0fd74c17aa7 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -582,10 +582,10 @@ args : (Tensor[] x, str equation) output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} infer_meta : - func : EinsumInferMeta + func : EinsumRawInferMeta param : [x, equation] kernel : - func : einsum + func : einsum_raw backward : einsum_grad - api : elementwise_pow @@ -2047,9 +2047,9 @@ args : (Tensor x, int[] axes) output : Tensor(out), Tensor(xshape) infer_meta : - func : SqueezeInferMeta + func : SqueezeWithXShapeInferMeta kernel : - func : squeeze + func : squeeze_with_xshape view: (x -> out) intermediate : xshape backward : squeeze_grad @@ -2290,9 +2290,9 @@ args : (Tensor x, IntArray axis) output : Tensor(out), Tensor(xshape) infer_meta : - func : UnsqueezeInferMeta + func : UnsqueezeWithXShapeInferMeta kernel : - func : unsqueeze + func : unsqueeze_with_xshape view: (x -> out) intermediate : xshape backward : unsqueeze_grad diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 35cada2c325e5..c7699c34cc546 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -570,9 +570,7 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) { void EinsumInferMeta(const std::vector& inputs, const std::string& equation, - MetaTensor* out, - std::vector inner_cache, - std::vector xshape) { + MetaTensor* out) { // collect the following informations to prepare einsum. LabelMap labelshape(0); LabelMap labeltype(LabelType::Reduction); @@ -609,6 +607,14 @@ void EinsumInferMeta(const std::vector& inputs, VLOG(3) << "Label Shape is : " << label_to_string(all_labels, labelshape); out->set_dims(make_ddim(output_dims)); out->set_dtype(inputs[0]->dtype()); +} + +void EinsumRawInferMeta(const std::vector& inputs, + const std::string& equation, + MetaTensor* out, + std::vector inner_cache, + std::vector xshape) { + EinsumInferMeta(inputs, equation, out); for (size_t i = 0; i < xshape.size(); ++i) { if (xshape[i] != nullptr) { xshape[i]->set_dims(inputs[i]->dims()); @@ -2448,8 +2454,7 @@ void SplitInferMeta(const MetaTensor& x, void SqueezeInferMeta(const MetaTensor& x, const std::vector& axes, - MetaTensor* out, - MetaTensor* xshape) { + MetaTensor* out) { const auto& x_dims = x.dims(); // Check input tensor dims (<6) Eigen limit. PADDLE_ENFORCE_LE(x_dims.size(), @@ -2469,15 +2474,25 @@ void SqueezeInferMeta(const MetaTensor& x, out->share_lod(x); } + out->set_dtype(x.dtype()); +} + +void SqueezeWithXShapeInferMeta(const MetaTensor& x, + const std::vector& axes, + MetaTensor* out, + MetaTensor* xshape) { + SqueezeInferMeta(x, axes, out); + const auto& x_dims = x.dims(); std::vector xshape_dims(x_dims.size() + 1); xshape_dims[0] = 0; for (int i = 0; i < x_dims.size(); ++i) { xshape_dims[i + 1] = x_dims[i]; } - xshape->set_dims(phi::make_ddim(xshape_dims)); - xshape->share_lod(x); - xshape->set_dtype(x.dtype()); - out->set_dtype(x.dtype()); + if (xshape) { + xshape->set_dims(phi::make_ddim(xshape_dims)); + xshape->share_lod(x); + xshape->set_dtype(x.dtype()); + } } void StridedSliceRawInferMeta(const MetaTensor& x, @@ -3310,7 +3325,6 @@ void UniqueRawInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x, const IntArray& axes, MetaTensor* out, - MetaTensor* xshape, MetaConfig config) { const auto& x_dims = x.dims(); // Validity Check: input tensor dims (<6). @@ -3339,14 +3353,22 @@ void UnsqueezeInferMeta(const MetaTensor& x, } out->set_dtype(x.dtype()); } - if (xshape) { - // set xshape dims. - std::vector xshape_dims(x_dims.size() + 1); - xshape_dims[0] = 0; - for (int i = 0; i < x_dims.size(); ++i) { - xshape_dims[i + 1] = x_dims[i]; - } +} +void UnsqueezeWithXShapeInferMeta(const MetaTensor& x, + const IntArray& axes, + MetaTensor* out, + MetaTensor* xshape, + MetaConfig config) { + const auto& x_dims = x.dims(); + UnsqueezeInferMeta(x, axes, out, config); + // set xshape dims. + std::vector xshape_dims(x_dims.size() + 1); + xshape_dims[0] = 0; + for (int i = 0; i < x_dims.size(); ++i) { + xshape_dims[i + 1] = x_dims[i]; + } + if (xshape) { xshape->set_dims(phi::make_ddim(xshape_dims)); xshape->share_lod(x); xshape->set_dtype(x.dtype()); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 1a0da23600339..ea7364e643960 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -97,9 +97,13 @@ void EigvalsInferMeta(const MetaTensor& x, void EinsumInferMeta(const std::vector& inputs, const std::string& equation, - MetaTensor* out, - std::vector inner_cache, - std::vector xshape); + MetaTensor* out); + +void EinsumRawInferMeta(const std::vector& inputs, + const std::string& equation, + MetaTensor* out, + std::vector inner_cache, + std::vector xshape); void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, @@ -341,8 +345,12 @@ void SplitInferMeta(const MetaTensor& x_meta, void SqueezeInferMeta(const MetaTensor& x, const std::vector& axes, - MetaTensor* out, - MetaTensor* xshape); + MetaTensor* out); + +void SqueezeWithXShapeInferMeta(const MetaTensor& x, + const std::vector& axes, + MetaTensor* out, + MetaTensor* xshape); void StridedSliceRawInferMeta(const MetaTensor& x, const std::vector& axes, @@ -470,9 +478,14 @@ void UniqueRawInferMeta(const MetaTensor& x, void UnsqueezeInferMeta(const MetaTensor& x, const IntArray& axes, MetaTensor* out, - MetaTensor* xshape, MetaConfig config = MetaConfig()); +void UnsqueezeWithXShapeInferMeta(const MetaTensor& x, + const IntArray& axes, + MetaTensor* out, + MetaTensor* xshape, + MetaConfig config = MetaConfig()); + void UnStackInferMeta(const MetaTensor& x, int axis, int num, diff --git a/paddle/phi/kernels/cpu/einsum_kernel.cc b/paddle/phi/kernels/cpu/einsum_kernel.cc index 901c1fed628d3..7ef85a942e435 100644 --- a/paddle/phi/kernels/cpu/einsum_kernel.cc +++ b/paddle/phi/kernels/cpu/einsum_kernel.cc @@ -18,7 +18,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL(einsum, +PD_REGISTER_KERNEL(einsum_raw, CPU, ALL_LAYOUT, phi::EinsumKernelRaw, @@ -26,3 +26,12 @@ PD_REGISTER_KERNEL(einsum, double, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(einsum, + CPU, + ALL_LAYOUT, + phi::EinsumKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/squeeze_kernel.cc b/paddle/phi/kernels/cpu/squeeze_kernel.cc index 7d5a6ca4e884e..d22efdf969440 100644 --- a/paddle/phi/kernels/cpu/squeeze_kernel.cc +++ b/paddle/phi/kernels/cpu/squeeze_kernel.cc @@ -32,3 +32,18 @@ PD_REGISTER_KERNEL(squeeze, int64_t, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(squeeze_with_xshape, + CPU, + ALL_LAYOUT, + phi::SqueezeWithXShapeKernel, + float, + double, + phi::dtype::bfloat16, + bool, + int, + uint8_t, + int8_t, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/unsqueeze_kernel.cc b/paddle/phi/kernels/cpu/unsqueeze_kernel.cc index 0152a31f80ba8..612e1a78cc5bb 100644 --- a/paddle/phi/kernels/cpu/unsqueeze_kernel.cc +++ b/paddle/phi/kernels/cpu/unsqueeze_kernel.cc @@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(unsqueeze, int64_t, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(unsqueeze_with_xshape, + CPU, + ALL_LAYOUT, + phi::UnsqueezeWithXShapeKernel, + float, + double, + phi::dtype::bfloat16, + bool, + int, + int16_t, + uint8_t, + int8_t, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/einsum_kernel.cu b/paddle/phi/kernels/gpu/einsum_kernel.cu index b3706710c40e3..99a9c58995c1f 100644 --- a/paddle/phi/kernels/gpu/einsum_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_kernel.cu @@ -18,7 +18,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL(einsum, +PD_REGISTER_KERNEL(einsum_raw, GPU, ALL_LAYOUT, phi::EinsumKernelRaw, @@ -28,3 +28,14 @@ PD_REGISTER_KERNEL(einsum, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(einsum, + GPU, + ALL_LAYOUT, + phi::EinsumKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/squeeze_kernel.cu b/paddle/phi/kernels/gpu/squeeze_kernel.cu index ae15e210a02e7..06ddba2ef1c2b 100644 --- a/paddle/phi/kernels/gpu/squeeze_kernel.cu +++ b/paddle/phi/kernels/gpu/squeeze_kernel.cu @@ -33,3 +33,19 @@ PD_REGISTER_KERNEL(squeeze, int64_t, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(squeeze_with_xshape, + GPU, + ALL_LAYOUT, + phi::SqueezeWithXShapeKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16, + bool, + int, + uint8_t, + int8_t, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/unsqueeze_kernel.cu b/paddle/phi/kernels/gpu/unsqueeze_kernel.cu index 86b4462254637..2e7bae8666d24 100644 --- a/paddle/phi/kernels/gpu/unsqueeze_kernel.cu +++ b/paddle/phi/kernels/gpu/unsqueeze_kernel.cu @@ -34,3 +34,20 @@ PD_REGISTER_KERNEL(unsqueeze, int64_t, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(unsqueeze_with_xshape, + GPU, + ALL_LAYOUT, + phi::UnsqueezeWithXShapeKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16, + bool, + int, + int16_t, + uint8_t, + int8_t, + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/solve_kernel_impl.h b/paddle/phi/kernels/impl/solve_kernel_impl.h index 09c9e74dd207a..4120823a9d2e9 100644 --- a/paddle/phi/kernels/impl/solve_kernel_impl.h +++ b/paddle/phi/kernels/impl/solve_kernel_impl.h @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once + #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/expand_as_kernel.h" #include "paddle/phi/kernels/funcs/matrix_solve.h" @@ -77,7 +79,7 @@ static std::vector get_broadcast_batch_portion( static inline std::vector convert_to_int_vec(std::vector a) { std::vector ret; for (size_t i = 0; i < a.size(); i++) { - ret.emplace_back(int(a[i])); + ret.emplace_back(static_cast(a[i])); } return ret; @@ -167,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx, out_tmp.Resize(out->dims()); out_tmp = *out; - phi::SqueezeKernel(dev_ctx, out_tmp, {-1}, out, nullptr); + phi::SqueezeKernel(dev_ctx, out_tmp, {-1}, out); } else { PADDLE_ENFORCE_EQ( x_dim[x_dim_size - 1], diff --git a/paddle/phi/kernels/impl/squeeze_kernel_impl.h b/paddle/phi/kernels/impl/squeeze_kernel_impl.h index b4c94d619cc2a..156a71973a794 100644 --- a/paddle/phi/kernels/impl/squeeze_kernel_impl.h +++ b/paddle/phi/kernels/impl/squeeze_kernel_impl.h @@ -22,8 +22,7 @@ template void SqueezeKernel(const Context& dev_ctx, const DenseTensor& x, const std::vector& axes, - DenseTensor* out, - DenseTensor* xshape) { + DenseTensor* out) { auto x_dims = x.dims(); auto out_dims = funcs::GetOutputSqueezeShape(axes, x_dims, true); @@ -31,4 +30,14 @@ void SqueezeKernel(const Context& dev_ctx, phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); out->Resize(out_dims); } + +template +void SqueezeWithXShapeKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axes, + DenseTensor* out, + DenseTensor* xshape) { + SqueezeKernel(dev_ctx, x, axes, out); +} + } // namespace phi diff --git a/paddle/phi/kernels/impl/unsqueeze_kernel_impl.h b/paddle/phi/kernels/impl/unsqueeze_kernel_impl.h index 4f81fa6c42341..5bef856d19b72 100644 --- a/paddle/phi/kernels/impl/unsqueeze_kernel_impl.h +++ b/paddle/phi/kernels/impl/unsqueeze_kernel_impl.h @@ -22,8 +22,7 @@ template void UnsqueezeKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& axes, - DenseTensor* out, - DenseTensor* xshape) { + DenseTensor* out) { auto x_dims = x.dims(); auto out_dims = out->dims(); if (axes.FromTensor()) { @@ -39,4 +38,13 @@ void UnsqueezeKernel(const Context& dev_ctx, phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); out->Resize(out_dims); // copy will reset the dims. } + +template +void UnsqueezeWithXShapeKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out, + DenseTensor* xshape) { + UnsqueezeKernel(dev_ctx, x, axes, out); +} } // namespace phi diff --git a/paddle/phi/kernels/squeeze_kernel.h b/paddle/phi/kernels/squeeze_kernel.h index bd8f508cbb1db..1c6aeedbe5161 100644 --- a/paddle/phi/kernels/squeeze_kernel.h +++ b/paddle/phi/kernels/squeeze_kernel.h @@ -23,6 +23,13 @@ template void SqueezeKernel(const Context& dev_ctx, const DenseTensor& x, const std::vector& axes, - DenseTensor* out, - DenseTensor* xshape); + DenseTensor* out); + +template +void SqueezeWithXShapeKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axes, + DenseTensor* out, + DenseTensor* xshape); + } // namespace phi diff --git a/paddle/phi/kernels/unsqueeze_kernel.h b/paddle/phi/kernels/unsqueeze_kernel.h index 62ba878c056cb..35a0515c92da3 100644 --- a/paddle/phi/kernels/unsqueeze_kernel.h +++ b/paddle/phi/kernels/unsqueeze_kernel.h @@ -25,8 +25,14 @@ template void UnsqueezeKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& axes, - DenseTensor* out, - DenseTensor* xshape); + DenseTensor* out); + +template +void UnsqueezeWithXShapeKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out, + DenseTensor* xshape); template void Unsqueeze(const Context& dev_ctx, @@ -35,8 +41,8 @@ void Unsqueeze(const Context& dev_ctx, DenseTensor* out, DenseTensor* xshape) { MetaTensor meta_out(out); - UnsqueezeInferMeta(x, axes, &meta_out, nullptr, MetaConfig()); - UnsqueezeKernel(dev_ctx, x, axes, out, nullptr); + UnsqueezeInferMeta(x, axes, &meta_out); + UnsqueezeKernel(dev_ctx, x, axes, out); } } // namespace phi diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc index 4fd31c1a2d842..e5aa570985596 100644 --- a/paddle/phi/ops/compat/einsum_sig.cc +++ b/paddle/phi/ops/compat/einsum_sig.cc @@ -17,8 +17,14 @@ limitations under the License. */ namespace phi { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"}); + if (ctx.OutputSize("XShape") > 0 && ctx.OutputSize("InnerCache") > 0) { + return KernelSignature("einsum_raw", + {"Operands"}, + {"equation"}, + {"Out", "InnerCache", "XShape"}); + } else { + return KernelSignature("einsum", {"Operands"}, {"equation"}, {"Out"}); + } } KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { diff --git a/paddle/phi/ops/compat/squeeze_sig.cc b/paddle/phi/ops/compat/squeeze_sig.cc index cd6d5fc7253df..a251b9f537ccf 100644 --- a/paddle/phi/ops/compat/squeeze_sig.cc +++ b/paddle/phi/ops/compat/squeeze_sig.cc @@ -18,7 +18,12 @@ namespace phi { KernelSignature SqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out", "XShape"}); + if (ctx.HasOutput("XShape")) { + return KernelSignature( + "squeeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); + } else { + return KernelSignature("squeeze", {"X"}, {"axes"}, {"Out"}); + } } KernelSignature SqueezeGradOpArgumentMapping( diff --git a/paddle/phi/ops/compat/unsqueeze_sig.cc b/paddle/phi/ops/compat/unsqueeze_sig.cc index aee83933e5b97..a2f184e7150b8 100644 --- a/paddle/phi/ops/compat/unsqueeze_sig.cc +++ b/paddle/phi/ops/compat/unsqueeze_sig.cc @@ -18,17 +18,33 @@ namespace phi { KernelSignature UnsqueezeOpArgumentMapping(const ArgumentMappingContext& ctx) { - if (ctx.InputSize("AxesTensorList") > 0) { - VLOG(2) << "unsqueeze2 in AxesTensorList"; - return KernelSignature( - "unsqueeze", {"X"}, {"AxesTensorList"}, {"Out", "XShape"}); - } else if (ctx.InputSize("AxesTensor") > 0) { - VLOG(2) << "unsqueeze2 in AxesTensor"; - return KernelSignature( - "unsqueeze", {"X"}, {"AxesTensor"}, {"Out", "XShape"}); + if (ctx.HasOutput("XShape")) { + if (ctx.InputSize("AxesTensorList") > 0) { + VLOG(2) << "unsqueeze2 in AxesTensorList"; + return KernelSignature("unsqueeze_with_xshape", + {"X"}, + {"AxesTensorList"}, + {"Out", "XShape"}); + } else if (ctx.InputSize("AxesTensor") > 0) { + VLOG(2) << "unsqueeze2 in AxesTensor"; + return KernelSignature( + "unsqueeze_with_xshape", {"X"}, {"AxesTensor"}, {"Out", "XShape"}); + } else { + VLOG(2) << "unsqueeze2 in axes"; + return KernelSignature( + "unsqueeze_with_xshape", {"X"}, {"axes"}, {"Out", "XShape"}); + } } else { - VLOG(2) << "unsqueeze2 in axes"; - return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out", "XShape"}); + if (ctx.InputSize("AxesTensorList") > 0) { + VLOG(2) << "unsqueeze2 in AxesTensorList"; + return KernelSignature("unsqueeze", {"X"}, {"AxesTensorList"}, {"Out"}); + } else if (ctx.InputSize("AxesTensor") > 0) { + VLOG(2) << "unsqueeze2 in AxesTensor"; + return KernelSignature("unsqueeze", {"X"}, {"AxesTensor"}, {"Out"}); + } else { + VLOG(2) << "unsqueeze2 in axes"; + return KernelSignature("unsqueeze", {"X"}, {"axes"}, {"Out"}); + } } }