Skip to content

Conversation

@guan404ming
Copy link
Member

@guan404ming guan404ming commented Nov 10, 2025

Related Issue

related to #17818

Why

  • Add support for PyTorch range constraints in Relax frontend
  • Enables proper handling of dynamic shapes with min/max bounds

How

  • Extract range constraints from PyTorch ExportedProgram and store in shape_var_constraints function attribute
  • Add test test_dynamic_shape_with_range_constraints to verify constraint extraction works correctly

@guan404ming guan404ming force-pushed the fix-relax-pytorch-constraints branch from 05fd731 to 77685df Compare November 10, 2025 13:52
@guan404ming guan404ming marked this pull request as ready for review November 10, 2025 16:30
@guan404ming
Copy link
Member Author

cc @mshr-h

@mshr-h mshr-h self-requested a review November 11, 2025 01:07
Copy link
Contributor

@mshr-h mshr-h left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Few changes are needed.

else s
for s in torch_shape
]
# Create TIR variables for symbolic dimensions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's no functional changes. Any reason for the change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about this, it seems like legacy change in #17898. Let me revert this.

if range_constraints:
if func_attrs is None:
func_attrs = {}
func_attrs["shape_var_constraints"] = range_constraints
Copy link
Contributor

@mshr-h mshr-h Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use tir_var_upper_bound to annotate upper bound.
I grepped the tvm code base and I found that there's no lower bound annotation. So I don't think we need to keep it at the moment. If we have a real use case for it, it's fine to keep it in the Relax module.
https://github.com/apache/tvm/blob/main/src/relax/transform/static_plan_block_memory.cc#L62-L66

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for letting me know, I've updated to use tir_var_upper_bound to annotate upper bound.

mod = from_exported_program(exported_program)

main_func = mod["main"]
assert hasattr(main_func, "attrs"), "Function should have attributes"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use structual equality instead of manually checking the attributes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I've updated the test to follow the convention.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for PyTorch range constraints in the Relax frontend. The changes correctly extract range constraints from the ExportedProgram and attach them as function attributes. A new test case is included to verify this functionality. My review includes a couple of suggestions to improve code conciseness and test robustness.

dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program = export(DynamicModel(), args=example_args, dynamic_shapes=dynamic_shapes)

mod = from_exported_program(exported_program)
Copy link
Contributor

@mshr-h mshr-h Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be from_exported_program(exported_program, run_ep_decomposition=True)
See #18399

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for providing useful info!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course you can help update it!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I would open another PR to handle this, thanks for your reply.

guan404ming and others added 2 commits November 11, 2025 23:39
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@guan404ming guan404ming force-pushed the fix-relax-pytorch-constraints branch from 93b699f to 34476a4 Compare November 11, 2025 16:32
@guan404ming guan404ming requested a review from mshr-h November 12, 2025 04:27
Copy link
Contributor

@mshr-h mshr-h left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LTGM. Thanks!

@mshr-h mshr-h merged commit 394f668 into apache:main Nov 12, 2025
14 checks passed
@guan404ming
Copy link
Member Author

Thanks for your detailed comments and reviews!

@guan404ming guan404ming deleted the fix-relax-pytorch-constraints branch November 12, 2025 08:54
@guan404ming guan404ming changed the title [Relax][Pytorch] Support basic range constraints [Relax][PyTorch] Support basic range constraints Nov 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants