Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Transform] Implement relax.transform.AdjustMatmulOrder #16314

Merged
merged 7 commits into from
Jan 12, 2024

Conversation

Lunderberg
Copy link
Contributor

Reorder x*(A*B) to (x*A)*B. Intended for optimization of LoRA models, for which (x*A)*B has a much smaller memory footprint.

Whether an optimizations should be performed may depend on when the
variables in an expression are known.

For example, consider a LoRA-adjusted model, with base weights `W` of
shape `[m,n]`, LoRA components `A` and `B` with shapes `[r,n]` and
`[m,r]` respectively, and activations `x` with shape `[n,1]`.  The
LoRA-adjusted matmul could be computed either as `(W + B*A)*x` or as
`(W*x + B*(A*x))`.

If `A` and `B` are provided at run-time, then computing `(W +
B*(A*x))` requires significantly fewer computations.

* `(W + B*A)*x`: `m*n*(2*r + 3)` operations
  1. `B*A`: `2*m*n*r` operations using a naive matmul
  2. Adding `W` to (1): `m*n` operations
  3. Multiplying `x` by (2): `2*m*n` operations

* `(W*x + B*(A*x))`: (2*m*n + r*(2*n + 2*m + 1))
  1. `W*x`: `2*m*n` operations
  2. `A*x`: `2*r*n` operations
  3. Multiplying `B` by (2): `2*m*r` operations
  4. Adding (1) and (3)`: `m` operations

However, if `A` and `B` are known at compile-time, then computing `(W
+ B*A)*x` groups all compile-time values together, allowing them to be
computed earlier (i.e. using `LiftTransformParams`)

* `(W + B*A)*x`: `2*m*n` operations
  1. `B*A`: 0 operations, computed at compile-time
  2. Adding `W` to (1): 0 operations, computed at compile-time
  3. Multiplying `x` by (2): `2*m*n` operations

Since the choice of optimized expression depends on which parameters
can be computed at compile-time, it is useful to have a utility that
identifies values that can be computed at compile-time.
- Update the zero-parameter `WildcardPattern` constructor to produce a
  valid instance.  Previously, the zero-parameter constructor produced
  a null instance of `WildcardPattern`, which resulted in an error
  when used.  The `WildcardPattern` was expected to be constructed
  through the `Wildcard` function instead.  Since all other
  `DFPattern` child classes could be constructed explicitly, this
  could lead to unexpected outcomes.

- Check for `pattern.defined()` when performing a pattern-match.  If
  a null instance of a pattern is provided, this gives an error
  message with more context than the one raised by `DFPatternFunctor`.

- Expose `RewriteCall` for use in C++.  Previously, it had only been
  exposed through the FFI registry, and had no declaration in a header
  file.
Reorder `x*(A*B)` to `(x*A)*B`.  Intended for optimization of LoRA
models, for which `(x*A)*B` has a much smaller memory footprint.
@Lunderberg
Copy link
Contributor Author

This PR is currently a draft, as it depends on functionality introduced in #16312.

@Lunderberg Lunderberg marked this pull request as ready for review January 4, 2024 18:22
@Lunderberg
Copy link
Contributor Author

The prerequisite #16312 has now landed, and this PR is ready for review.

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, which would help on the specific case. However, I think it's necessary to analyze the memory footprint before rewriting it.

it would be good to rewrite x * (A * B) -> (x * A) * B, but not reasonable to rewrite `A * (B * x) -> (A * B) * x

@Lunderberg
Copy link
Contributor Author

Good point. I've updated the pattern and the rewrite rule with two changes:

  1. Check if (A*B)*C should be re-ordered into A*(B*C). Previously, only the reverse was checked.
  2. Only reorder the matmuls in either case when there's a provable benefit to doing so, based on the shapes of the three matrices.

I also added several additional unit tests to validate the behavior, for cases where the initial matmul is left-to-right, and where the initial matmul is right-to-left.

@Lunderberg
Copy link
Contributor Author

@Hzfengsy All changes based on your earlier comments have been made, and the PR is ready for re-review.

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the great work!

@Hzfengsy Hzfengsy merged commit 4c7c010 into apache:unity Jan 12, 2024
2 checks passed
@Lunderberg Lunderberg deleted the unity_adjust_matmul_order branch January 12, 2024 14:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants