Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converting MaxPool2D with dynamic spatial dimensions crashes #270

Open
sc-aharri opened this issue Sep 30, 2024 · 1 comment
Open

Converting MaxPool2D with dynamic spatial dimensions crashes #270

sc-aharri opened this issue Sep 30, 2024 · 1 comment

Comments

@sc-aharri
Copy link

Description of the bug:

When I convert a pytorch model containing a MaxPool2D module, ai_edge_torch.convert crashes. This can be reproduced on my setup using the following minimal repro:

import ai_edge_torch
import torch

if __name__ == "__main__":
    print("creating dummy image and model which is just a maxpool op")
    model = torch.nn.MaxPool2d(2)
    sample_image = torch.zeros((1, 1, 320, 320))

    print("define dynamic sizing, specifying that the sizes must be divisible by 2")
    height = 2 * torch.export.Dim("height", min=1, max=320)
    width = 2 * torch.export.Dim("width", min=1, max=320)
    dynamic_shapes = ({2: height, 3: width},)

    print("converting model")
    exported_model = ai_edge_torch.convert(
        model,
        (sample_image,),
        dynamic_shapes=dynamic_shapes,
    )

The relevant stack trace can be found below.

Actual vs expected behavior:

The script crashes with the following callstack:

2024-09-30 13:05:37.114853: E external/xla/xla/status_macros.cc:56] INTERNAL: RET_CHECK failure (external/xla/xla/service/shape_inference.cc:3508) !inferred_shape.is_unbounded_dynamic() Reshaping with unbounded result shape is not supported.
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        xla::status_macros::MakeErrorStream::Impl::GetStatus()
        xla::ShapeInference::InferReshapeShape(xla::Shape const&, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, long)


        xla::XlaBuilder::ReportErrorOrReturn(absl::lts_20230802::FunctionRef<absl::lts_20230802::StatusOr<xla::XlaOp> ()>)
        xla::XlaBuilder::Reshape(xla::XlaOp, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, long)

        xla::XlaBuilder::ReportErrorOrReturn(absl::lts_20230802::FunctionRef<absl::lts_20230802::StatusOr<xla::XlaOp> ()>)
        xla::XlaBuilder::Reshape(xla::XlaOp, absl::lts_20230802::Span<long const>, long)


        torch_xla::BuildMaxPoolNd(xla::XlaOp, long, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, bool)
        torch_xla::MaxPoolNd::Lower(torch_xla::LoweringContext*) const
        torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
        torch_xla::LoweringContext::GetOutputOp(torch::lazy::Output const&)
        torch_xla::LoweringContext::AddResult(torch::lazy::Output const&)
        torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
        torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)




        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        PyEval_EvalCode

        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

Traceback (most recent call last):
  File "basic-repro.py", line 13, in <module>
    exported_model = ai_edge_torch.convert(
                     ^^^^^^^^^^^^^^^^^^^^^^
  File ".../ai_edge_torch/convert/converter.py", line 195, in convert
    return Converter().convert(
           ^^^^^^^^^^^^^^^^^^^^
  File ".../ai_edge_torch/convert/converter.py", line 134, in convert
    return conversion.convert_signatures(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../ai_edge_torch/convert/conversion.py", line 97, in convert_signatures
    shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
                                                         ^
  File ".../ai_edge_torch/convert/conversion.py", line 98, in <listcomp>
    cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
  File ".../ai_edge_torch/convert/conversion_utils.py", line 133, in exported_program_to_stablehlo_bundle
    return stablehlo.exported_program_to_stablehlo(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../torch_xla/stablehlo.py", line 618, in exported_program_to_stablehlo
    bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../torch_xla/stablehlo.py", line 397, in _exported_program_to_stablehlo_bundle
    stablehlo_content = xm.get_stablehlo_bytecode(res)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../torch_xla/core/xla_model.py", line 1112, in get_stablehlo_bytecode
    return torch_xla._XLAC._get_stablehlo(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Error while lowering: [] aten::max_pool2d, num_outputs=2, xla_shape=(f32[1,1,160,160]{3,2,1,0}, u32[1,1,160,160]{3,2,1,0}), dynamic_dims: (), spatial_dim_count=2, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=0
XLA builder error: INTERNAL: RET_CHECK failure (external/xla/xla/service/shape_inference.cc:3508) !inferred_shape.is_unbounded_dynamic() Reshaping with unbounded result shape is not supported.: 

It seems thatis_unbounded_dynamic is returning false. There are cetrtainly dynamic dimensions, but I'm suspicious of the word unbounded, since I am defining bounds. Should I expect bounds to be propagated to torch_xla?

Any other information you'd like to share?

My setup is an Ubuntu docker container with cpu-only versions of torch.

@sc-aharri sc-aharri added the type:bug Bug label Sep 30, 2024
@pkgoogle pkgoogle self-assigned this Sep 30, 2024
@pkgoogle
Copy link
Contributor

pkgoogle commented Sep 30, 2024

Hi @sc-aharri, if you are able to staticize as much as you can, that'll probably be the easiest work around for now. I was able to reproduce on nightly with the exact same code... it might be at root a torch_xla issue, assuming we are passing in the right values here: torch_xla_utils.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants