Skip to content

Conversation

@tengyifei
Copy link
Collaborator

@tengyifei tengyifei commented Aug 23, 2024

Add the lowering of scan to HLO While op.

Introduce scan_layers which can sequentially apply a bunch of layers
using scan underneath.

Beef up unit tests including linear layers and decoders.

@JackCaoG
Copy link
Collaborator

======================================================================
ERROR: test_decoder_model (__main__.ApplyLayersTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/__w/xla/xla/pytorch/xla/test/test_apply_layers.py", line 77, in test_decoder_model
    from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel  # type:ignore
ModuleNotFoundError: No module named 'decoder_only_model'

you can't just import it, you need to setup import dir correctly. Take a look at https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo_dynamic_shape.py#L1-L6

@tengyifei
Copy link
Collaborator Author

@JackCaoG ty. i followed your example and got it working.

at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype);
results[param_ids[i]] = input;
std::optional param_id = lowering_ctx.GetParameterId(device_data[i]);
XLA_CHECK(param_id.has_value());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would it not has value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When GetParameterId receives a BackendData that is not a parameter in this lowering context, it will return std::nullopt. However, this loop is only iterating over parameters (line 1071, const std::vector<torch::lazy::BackendDataPtr>& device_data = lowering_ctx.GetParametersData();), so we will expect all BackendData there to have an ID. Seems good to enforce this invariant.

example_layer = deepcopy(next(iter(layers)))

# Hollow out the weights and biases in the example layer.
example_layer = example_layer.to_empty(device=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not going to impact the cloned arg?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify this question -- I thought to_empty is going to destroy the value inside example_layer, so I deepcopy it before to backup.


def step_fn(grad_carry, pytree: Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]):
grad_y, carry, x = pytree
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a typo?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so -- pytree is a tuple of the output grad at current step (grad_y), carry at the current step (carry), and input at current step (x)

@miladm
Copy link
Collaborator

miladm commented Sep 20, 2024

@tengyifei is this PR a 2.5 candidate?

@tengyifei
Copy link
Collaborator Author

@miladm yes, I'd like to backport this to 2.5 after addressing the comments etc.

with torch.enable_grad():
fw_compiler, get_fwd = _make_get_graph_compiler()
bw_compiler, get_bwd = _make_get_graph_compiler()
fn_compiled = aot_function(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this fail if there are tensors within fn that were not provided as parameters, since they are not 'fake' tensors? Say, for model parameters if we have a module inside fn that we wish to also trace fwd/bwd on. Is this targeted as a follow-up?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will. That's why I added scan_layers (previously named apply_layers) in this PR to extract module parameters and functionalize the module.

@tengyifei tengyifei force-pushed the yifeit/scan branch 5 times, most recently from 7f457b9 to 13b40f5 Compare November 21, 2024 01:22
@tengyifei tengyifei force-pushed the yifeit/scan branch 2 times, most recently from 41acb50 to 4020bb5 Compare November 21, 2024 01:33
@tengyifei tengyifei requested a review from JackCaoG November 21, 2024 06:23
@tengyifei tengyifei changed the title scan and apply_layers scan and scan_layers Nov 21, 2024
@tengyifei tengyifei force-pushed the yifeit/scan branch 3 times, most recently from edf8c12 to 98222b2 Compare November 22, 2024 18:39
@tengyifei
Copy link
Collaborator Author

This is ready for another look. For ease of review I added two additional commits for stuff we talked about offline:

  • 1 Test that the XLA compiler can propagate SPMD sharding annotations just fine through the While op and the Body computations
  • 2 Change the name of the PyLoweringContext to FnComputation when it appears in the HLO

@JackCaoG
Copy link
Collaborator

This is ready for another look. For ease of review I added two additional commits for stuff we talked about offline:

  • 1 Test that the XLA compiler can propagate SPMD sharding annotations just fine through the While op and the Body computations
  • 2 Change the name of the PyLoweringContext to FnComputation when it appears in the HLO

please move the SPMD stuff to a different pr, I will try to finish reviewing this pr today( github needs a better way to stack prs....).

@tengyifei tengyifei force-pushed the yifeit/scan branch 2 times, most recently from 60ef219 to 3006c8e Compare November 22, 2024 19:06
@tengyifei
Copy link
Collaborator Author

tengyifei commented Nov 22, 2024

please move the SPMD stuff to a different pr

sure, done

This commit adds the lowering of scan to HLO While op. It also
introduce apply_layers which can sequentially apply a bunch of layers
using scan underneath.

In this milestone we use AOTAutograd to obtain the backward of the
function being scanned. Users can either save the activations in
fn or recompute them by passing different graph partitioners to
AOTAutograd.

ALso give the lowered fn computation a more meaningful name
@tengyifei
Copy link
Collaborator Author

Thanks for the review. I'll merge this first. I'm looking at caching the fn computation properly but that'll take some time. I might be able to send that PR over (and the SPMD one) for next Monday.

@tengyifei tengyifei merged commit 51575db into master Nov 23, 2024
12 checks passed
@ydwu4
Copy link
Contributor

ydwu4 commented Dec 16, 2024

Hi! PyTorch also has a scan operator https://github.com/pytorch/pytorch/blob/main/torch/_higher_order_ops/scan.py. Wondering if we want to consolidate the efforts and what's the plan going forward?

@tengyifei
Copy link
Collaborator Author

@ydwu4 does the PyTorch scan operator support autograd? I remember there's an issue tracking autograd support. It would be great to use upstream op but without autograd support we can't use that in training.

@ydwu4
Copy link
Contributor

ydwu4 commented Dec 16, 2024

Sounds good. On PyTorch side, we've been prioritizing to get inference working e2e and will get to autograd next half. It will be great if we could reduce fragmentation with a single front end op.

@tengyifei
Copy link
Collaborator Author

That's a great idea.

If you looked at the current scan impl in PyTorch/XLA, it uses AOTAutograd to derive a backward graph to implement the backward pass of scan. That API has a lot of limitations and IIUC dynamo is the well supported frontend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants