-
Couldn't load subscription status.
- Fork 560
scan and scan_layers #7901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scan and scan_layers #7901
Conversation
you can't just import it, you need to setup import dir correctly. Take a look at https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo_dynamic_shape.py#L1-L6 |
|
@JackCaoG ty. i followed your example and got it working. |
| at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype); | ||
| results[param_ids[i]] = input; | ||
| std::optional param_id = lowering_ctx.GetParameterId(device_data[i]); | ||
| XLA_CHECK(param_id.has_value()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when would it not has value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When GetParameterId receives a BackendData that is not a parameter in this lowering context, it will return std::nullopt. However, this loop is only iterating over parameters (line 1071, const std::vector<torch::lazy::BackendDataPtr>& device_data = lowering_ctx.GetParametersData();), so we will expect all BackendData there to have an ID. Seems good to enforce this invariant.
| example_layer = deepcopy(next(iter(layers))) | ||
|
|
||
| # Hollow out the weights and biases in the example layer. | ||
| example_layer = example_layer.to_empty(device=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this not going to impact the cloned arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify this question -- I thought to_empty is going to destroy the value inside example_layer, so I deepcopy it before to backup.
torch_xla/experimental/scan.py
Outdated
|
|
||
| def step_fn(grad_carry, pytree: Tuple[torch.Tensor, torch.Tensor, | ||
| torch.Tensor]): | ||
| grad_y, carry, x = pytree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a typo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so -- pytree is a tuple of the output grad at current step (grad_y), carry at the current step (carry), and input at current step (x)
4c4d127 to
67cff9b
Compare
|
@tengyifei is this PR a 2.5 candidate? |
|
@miladm yes, I'd like to backport this to 2.5 after addressing the comments etc. |
| with torch.enable_grad(): | ||
| fw_compiler, get_fwd = _make_get_graph_compiler() | ||
| bw_compiler, get_bwd = _make_get_graph_compiler() | ||
| fn_compiled = aot_function( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this fail if there are tensors within fn that were not provided as parameters, since they are not 'fake' tensors? Say, for model parameters if we have a module inside fn that we wish to also trace fwd/bwd on. Is this targeted as a follow-up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will. That's why I added scan_layers (previously named apply_layers) in this PR to extract module parameters and functionalize the module.
7f457b9 to
13b40f5
Compare
41acb50 to
4020bb5
Compare
edf8c12 to
98222b2
Compare
|
This is ready for another look. For ease of review I added two additional commits for stuff we talked about offline: |
please move the SPMD stuff to a different pr, I will try to finish reviewing this pr today( github needs a better way to stack prs....). |
60ef219 to
3006c8e
Compare
sure, done |
This commit adds the lowering of scan to HLO While op. It also introduce apply_layers which can sequentially apply a bunch of layers using scan underneath. In this milestone we use AOTAutograd to obtain the backward of the function being scanned. Users can either save the activations in fn or recompute them by passing different graph partitioners to AOTAutograd. ALso give the lowered fn computation a more meaningful name
3006c8e to
0082c1b
Compare
|
Thanks for the review. I'll merge this first. I'm looking at caching the fn computation properly but that'll take some time. I might be able to send that PR over (and the SPMD one) for next Monday. |
|
Hi! PyTorch also has a scan operator https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/scan.py. Wondering if we want to consolidate the efforts and what's the plan going forward? |
|
@ydwu4 does the PyTorch scan operator support autograd? I remember there's an issue tracking autograd support. It would be great to use upstream op but without autograd support we can't use that in training. |
|
Sounds good. On PyTorch side, we've been prioritizing to get inference working e2e and will get to autograd next half. It will be great if we could reduce fragmentation with a single front end op. |
|
That's a great idea. If you looked at the current scan impl in PyTorch/XLA, it uses |
Add the lowering of scan to HLO While op.
Introduce scan_layers which can sequentially apply a bunch of layers
using scan underneath.
Beef up unit tests including linear layers and decoders.