-
Notifications
You must be signed in to change notification settings - Fork 24.6k
Bisect FX node asserts on ValidationException
.
#107493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107493
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 007d7e8 with merge base 8ff0036 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
…n`." This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ ghstack-source-id: ff4647e Pull Request resolved: #107493
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ ghstack-source-id: f65ba0d Pull Request resolved: #107493
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ ghstack-source-id: c961203 Pull Request resolved: #107493
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ ghstack-source-id: 3a63ef2 Pull Request resolved: #107493
test/dynamo/test_exc.py
Outdated
ValidationException, | ||
lambda: fn(torch.randn(20), (5, 10, 5)), | ||
"""\ | ||
translation validation failed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How exactly am I supposed to interpret the message here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem that TV tries to solve is: "Given Assertions, is there any Model that satisfies Target Expressions, but doesn't satisfy the source expressions?". The Faiiled Source Expressions represents an affirmative response, showing which source expressions that failed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I mean is, I was expecting bisect to tell me something like "the problem is specifically when you add an equality guard at X point of shape env internal state" or something like that. Otherwise I wouldn't need bisect, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. I thought the smaller problem instance would be a nice win already, e.g. easier to read and reason about. But that makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another thing I thought was that it would be nice to know under what program FX node (PyTorch operation) we failed. But, is there any way to get this information at this point (inside ShapeEnv
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the stack trace on the bad guard would be pretty useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of this?
BisectValidationException
File "/home/ysiraichi/work/pytorch/2/torch/_dynamo/convert_frame.py", line 439, in transform
bisect(tracer.output.shape_env, tracer.output.tracked_fakes)
File "/home/ysiraichi/work/pytorch/2/torch/fx/experimental/symbolic_shapes.py", line 4266, in bisect
raise BisectValidationException(exception[left], shape_env.events[number], running_node)
torch.fx.experimental.validator.BisectValidationException: translation validation failed when evaluating: Eq(s1 + s2 + s3, s0)
Failure ocurred while running node:
%split : [num_users=1] = call_method[target=split](args = (%l_x_, (%l_shape_0_, %l_shape_1_, %l_shape_2_)), kwargs = {})
Model:
==> s3: -9223372036854775807
==> s1: -9223372036854775807
==> s0: 3
==> s2: -9223372036854775807
==> L['x'].size()[0]: 3
==> L['shape'][1]: -9223372036854775807
==> L['shape'][2]: -9223372036854775807
==> L['x'].stride()[0]: 1
==> L['shape'][0]: -9223372036854775807
==> L['x'].storage_offset(): 0
Assertions:
==> (== L['x'].size()[0] s0)
==> (> s0 1)
==> (== L['shape'][2] s3)
==> (== 0 L['x'].storage_offset())
==> (== L['shape'][0] s1)
==> (== L['shape'][1] s2)
==> (== 1 L['x'].stride()[0])
Target Expressions:
==> (== L['x'].size()[0] s0)
==> (!= (+ s1 s2 s3) s0)
==> (<= -9223372036854775808 s1)
==> (>= 9223372036854775807 s1)
==> (== L['shape'][2] s3)
==> (== 0 L['x'].storage_offset())
==> (>= 9223372036854775807 s2)
==> (== L['shape'][0] s1)
==> (> s0 0)
==> (== 1 L['x'].stride()[0])
==> (== L['shape'][1] s2)
==> (<= 2 s0)
==> (>= 9223372036854775807 s3)
==> (>= 9223372036854775806 s0)
==> (<= -9223372036854775808 s3)
==> (<= -9223372036854775808 s2)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is better.
torch/_dynamo/output_graph.py
Outdated
@@ -263,6 +269,7 @@ def __init__( | |||
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, | |||
frame_id=frame_state["_id"], | |||
co_fields=self.co_fields, | |||
tracked_fakes=self.tracked_fakes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mutable list sharing here is very delicate, call it out (and make sure it is tested!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One alternative I can think of is:
- New
ShapeEnv.tracked_fakes_length
field - New
ShapeEnv.inc_tracked_fakes_length
function- Call it whenever
tracked_fakes
is modified
- Call it whenever
In order to make it less error-prone, what do you think of creating a class for automating that?
class TrackedFakeList(list):
def __init__(self, shape_env):
self.shape_env = shape_env
super().__init__()
def append(self, obj):
self.shape_env.inc_tracked_fakes_length()
super().append(obj)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a good alternative, what you have right now is OK, it's just delicate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you rather have this circular reference? I kind of like this alternative idea...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternate idea is fine too.
replacearg(index=3, key="fx_node", fn=convert_node) | ||
|
||
# Actually call the method with the converted arguments. | ||
return self.f(shape_env, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems pretty delicate; in particular, it seems difficult to guarantee that the replay mechanism is going to keep working as people keep monkeying around with methods on ShapeEnv. First off, is this well tested? Second, one way to make it structurally harder to "mess it up" is to force all substantive method calls through this API, so that even regular, non bisecting ShapeEnv calls go through the logic here.
# remember_fakes_length: flags whether we should add the length of the | ||
# ShapeEnv.tracked_fakes list to the event. This is used for calling | ||
# ShapeEnv.produce_guards at an arbitrary time (assuming the fakes list only | ||
# increases monotonically). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I add a new method which mutates the state of ShapeEnv, but I didn't read this code carefully enough, how will I know to apply this decorator? What test will tell me I did it wrong?
# remember_fakes_length: flags whether we should add the length of the | ||
# ShapeEnv.tracked_fakes list to the event. This is used for calling | ||
# ShapeEnv.produce_guards at an arbitrary time (assuming the fakes list only | ||
# increases monotonically). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What arguments are OK for decorated functions to take? When do I have to add special cases to ShapeEnvEvent.__call__
?
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ Edit: moved `ShapeEnv` replay-recording to #107989 cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ ghstack-source-id: 5c47757 Pull Request resolved: #107493
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ Edit: moved `ShapeEnv` replay-recording to #107989 cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ ghstack-source-id: cbf2527 Pull Request resolved: #107493
…tionException`." This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ Edit: moved `ShapeEnv` replay-recording to #107989 cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ ghstack-source-id: c45128d Pull Request resolved: #107493
torch/fx/experimental/recording.py
Outdated
@@ -163,6 +163,13 @@ def is_evaluate_expr(self) -> bool: | |||
def is_defer_runtime_assert(self) -> bool: | |||
return self.name == "defer_runtime_assert" | |||
|
|||
def getarg(self, *, index: int, key: str) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do? Can we get some docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is so we can get an argument from an event. It gets an argument by its positional index or its name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can get rid of it, though. It's not a good solution for getting arguments from events.
…ption`." This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ Edit: moved `ShapeEnv` replay-recording to #107989 cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ ghstack-source-id: 0344e34 Pull Request resolved: #107493
@@ -2412,6 +2412,11 @@ def remove_fx_node(self, node: Optional[torch.fx.Node]) -> None: | |||
self.name_to_node.pop(node.name) | |||
self.graph.erase_node(node) | |||
|
|||
def add_fx_node_metadata(self, node: torch.fx.Node) -> None: | |||
from torch._dynamo.utils import get_current_node | |||
node.meta["event"] = self.last_event_index() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too generic. Meta dict is used by everyone. Give a more detailed string key
|
||
if not last_exception: | ||
# We don't actually fail due to a produce_guards call. | ||
# Stop and don't bisect. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe worth logging here?
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ Edit: moved `ShapeEnv` replay-recording to #107989 cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy chenyang78 aakhundov [ghstack-poisoned]
This PR introduces binary search for finding smaller validation errors, when they occur. We do that by bisecting the sequence of `torch._assert` FX nodes recorded as the source expression of the translation validator (TV) by `ShapeEnv.evaluate_expr` calls. Then, we raise the error caused by the earliest node. In summary, the changes are: - Call `bisect` on `ValidationError` @ _torch/_dynamo/convert_frame.py_ - Store a reference of the list of `TrackedFake` to `ShapeEnv` @ _torch/_dynamo/output_graph.py_ - Implement the event recording logic @ _torch/fx/experimental/symbolic_shapes.py_ - Create the `ShapeEnvEvent` class - Decorate `ShapeEnv` methods with the `record_shapeenv_event` decorator - Implement the `ShapeEnv` reconstruction function @ _torch/fx/experimental/symbolic_shapes.py_ - Implement the binary search @ _torch/fx/experimental/symbolic_shapes.py_ ghstack-source-id: 1d75e6b Pull Request resolved: #107493
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor, 1, 1, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor, 1, 1, linux.g5.4xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
ValidationException
. #107493This PR introduces binary search for finding smaller validation errors, when they occur.
We do that by bisecting the sequence of
torch._assert
FX nodes recorded as the sourceexpression of the translation validator (TV) by
ShapeEnv.evaluate_expr
calls. Then, weraise the error caused by the earliest node.
In summary, the changes are:
bisect
onValidationError
@ _torch/dynamo/convert_frame.pyEdit: moved
ShapeEnv
replay-recording to #107989cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @anijain2305 @ipiszy