Skip to content

[RFC] scan operator and scan_layers #8620

@tengyifei

Description

@tengyifei

RFC: PyTorch/XLA scan operator and scan_layers

Problem statement

Many LLMs have a few dozens of decoder layers that are applied in a for loop. When we trace the forward function of the model, the decoder layers will be unrolled and inlined into the large computation. This may lead to compilation time scaling linearly or super-linearly with the number of decoder layers, compromising the user experience.

We hope to introduce a mechanism to just compile a single layer and use that in an XLA While op. The hypothesis is that compilation time will stay constant as the number of decoder layers increases. The proposed implementation will also reduce the number of times the decoder layer is traced, helping to reduce tracing overhead and improve the performance of PyTorch/XLA on TPUs with small per device batch sizes. There's a prior art in JAX called scan: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html.

We also need to support backwards/gradient propagation through the scan operator. This is needed so people can use scan during training.

Example usage

We have prototyped two functions, scan and scan_layers, to be elaborated in the next section. Users would generally use scan_layers, which wraps scan under the hood.

Given a sequence of layers self.layers, such as a torch.nn.ModuleList, rather than doing

hidden_states = inputs
for decoder_layer in self.layers:
  hidden_states = decoder_layer(hidden_states)
outputs = hidden_states

Users would write

# self.layers: Iterable[torch.nn.Module] (compatible with `torch.nn.ModuleList` and lists of modules etc.)
# inputs: torch.Tensor
outputs = scan_layers(self.layers, inputs)

where the scan_layers function will trace a single layer, then apply that computation sequentially over the layers, filling in different weights and biases, passing the output from the previous layer as input into the next layer. The scan_layers function will check that the layers have identical structure in terms of weights and biases:

  • They all have the same set of dictionary keys (parameter names)
  • The parameters have the same shapes.

If we want to further ensure the layers perform identically structured computations, we can trace each layer and compare their HLO. Tracing should be faster than end-to-end compilation. That would mean scan would trace i times given input sequence of length i.

High level design

We factor this problem down into two operations: scan and scan_layers.

scan_layers

We have a sequence of layers with identical structure (such as a bunch of nn.Linear(128, 128) or decoder layers). We would like to use the XLA While op to loop over and apply the layers. Specifically, in every iteration we: \

  • Index into the layers and obtain layer i.
  • Pass the hidden state from the previous layer as input to run the layer.
  • Obtain the output and use that as in the hidden state/input in the next iteration.

The challenge is figuring out how to express this in terms of XLA ops. The type system of XLA as specified in its operation semantics supports tensors and tuples. One cannot for example define a C-style structure with named fields and define HLO operations on those. We need to find an appropriate representation of a sequence of torch.nn.Modules in terms of these data types so that the body computation inside the While op can obtain a specific layer using a scalar index.

We propose to stack the weights and biases of these layers into larger tensors where the index of the layer is given by the first dimension. The module.named_parameters() method lets us obtain all the parameters of a module in the form of a dict. Let's say we have 3 linear layers and their state is:

[
  { "weight": torch.tensor((64, 64)), "bias": torch.tensor(64) },
  { "weight": torch.tensor((64, 64)), "bias": torch.tensor(64) },
  { "weight": torch.tensor((64, 64)), "bias": torch.tensor(64) },
]

We'll stack them into:

{
  "weight": torch.tensor((3, 64, 64)), "bias": torch.tensor((3, 64)),
}

If you're familiar with the JAX scan function, you may notice that this data structure matches the format expected by jax.lax.scan, which takes a PyTree where leaves are tensors, and calls a user supplied function with a slice of the PyTree, indexing into the leading dimension of each leaf. So e.g. if the dict above is param, it would call the user function fn with

fn({
  "weight": param["weight"][i],
  "bias": param["bias"][i],
})

etc. We'll implement scan to similarly support generalizing to arbitrary PyTrees. Because the XLA type system only consists of tensors and tuples, scan needs to flatten the dictionary into a list of tensors and supply those as parameters to the XLA computation, and similarly for the output. This way scan_layers can pass this stacked tensor dictionary as input. The fn supplied by scan_layers will rebuild a Linear, plugging in the parameters at layer i, and invoke the layer on the inputs. Fortunately, PyTorch already has built enough utilities around module parameter handling so we can implement this with minimal hassle. See the scan_layers implementation.

scan

The design of scan follows the JAX version very closely. I'll describe the noteworthy things when interfacing with HLO/XLA. At a high level this function will:

  • Trace the user supplied fn using fake inputs (torch.empty tensors) to obtain an XlaComputation object.

  • Inspect the computation to get all the referenced tensors (xla::device_data nodes) and their ordering when supplied as parameters. There may be more parameters than in the method arguments of fn, because fn internally may capture more tensors in the function closure, or it may create more tensors during the execution. Example:

      def fn(carry, x):
        foo = torch.zeros(8)
        return carry, x + foo
    

Gets lowered into an HLO like this

  HloModule FnComputation, entry_computation_layout={
    (f32[8], f32[8], f32[8]) -> (f32[8], f32[8])
  }
One of the `f32[8]` parameters corresponds to the `foo` tensor created within the body of `fn`. We need to identify that parameter and supply the correct tensor value when building the computation.
  • Builds a While op with a cond_fn and a body_fn. The cond_fn determines if the current iteration (represented as another parameter to the computation) is zero, where we should exit. The body_fn calls the fn computation with the correct parameter ordering and supplies additional tensor parameters as necessary.

The backwards of the scan is also implemented using scan. At a high level, we scan the backwards version of fn in the reverse order from the last input to the first. There are several ways to extract the backward of fn. We'll start with AOTAutgrad and explore using Dynamo to extract the backward.

Interface design

def scan(
    fn: Callable[[Carry, X], tuple[Carry, Y]],
    init: Carry,
    xs: X,
) -> tuple[Carry, Y]:
  """Apply a function over leading dimension of tensors while carrying along state.
  
  This is similar to the JAX `jax.lax.scan` function found in [1].
  
  You may use it to loop over the leading dimension of tensors efficiently. If `xs`
  is a single tensor, this function is roughly equal to the following Python code:

    def scan(fn, init, xs):
      ys = []
      carry = init
      for i in len(range(xs.size(0))):
        carry, y = fn(carry, xs[i])
        ys.append(y)
      return carry, torch.stack(ys, dim=0)
  
  In the general case, `Carry`, `X`, and `Y` can be arbitrary PyTrees. This function
  will iterate through the leading dimension of every leaf element of `xs` simultaneously,
  and pass a slice of those elements to `fn` as another PyTree. This means you may
  scan over multiple tensors and produce multiple output tensors at once.
  
  Args:

    fn: a Python callable that accepts two PyTrees of tensors: the carry object and the
        slices of `xs` along its leading dimension. It should return two PyTrees: the carry
        object and the slices of the output. The returned carry object will be passed to
        the next invocation of `fn`.

    init: the initial carry object passed to the first invocation of `fn`.
    
    xs: the input PyTree to scan over. If `xs` is a tensor, then `fn` will get slices along
        the leading dimension (`xs[i]`). If `xs` is some other PyTree (e.g. tuple of
        tensor), `fn` will get PyTrees of slices. In that case the leading dimension size
        of the leaves in the PyTree must be the same.

  Returns:

    (carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and
    `ys` is a PyTree with the same structure as `xs`, but where the leaves are formed
    by stacking the leaf outputs of `fn` respectively. This means if your `fn` returns
    `(carry, (y1, y2))` then this function will return
    `(carry, (torch.stack(all_y1), torch.stack(all_y2)))`.

  [1]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html
  """
  ...


init = torch.tensor(0)
xs = torch.tensor([1, 2, 3])
fn = lambda carry, x: carry + 1, x + carry
scan(fn, init, xs)
# Returns:
# (3, [1, 3, 5])
def scan_layers(layers: Iterable[torch.nn.Module], input_data: torch.Tensor):
  """Applies each layer in `layers` to `input_data` sequentially.

  `input_data` is provided as input to the first layer in `layers`. The output of one
  layer is provided as input to next layer. This function is equivalent to

    sequential = torch.nn.Sequential(layers)
    sequential(input_data)

  This function can be faster to compile since it reuses the XLA computation of the
  first layer to perform the computation of all other layers.
  """
  ...

Alternatives

An alternative implementation of scan_layers is to combine the weights and biases of different layers into an XLA Tuple, as opposed to stacking them into a Tensor. The body computation of the While op will index into the Tuple using get_tuple_element as opposed to indexing into the Tensor using dynamic_slice. This implies that scan_layers won't use scan, which is designed around indexing into tensors.

I have not prototyped this approach as I'm not sure if XLA supports nested tuples of distinctly shaped tensors. One may suspect the Tuple has better performance than Tensor due to the need to stack tensors into a larger tensor, but given that XLA backends aggressively optimizes memory layouts, I'm not sure how much this will pay off.

One advantage of the main proposal is that we'll expose a familiar scan operator that has near-identical semantics to the scan operator found in JAX, lowering the learning barrier.

In any case, these are things we can optimize in future versions of PyTorch/XLA without changing the signature of scan_layers.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions