-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity][Transform] Implement relax.transform.AdjustMatmulOrder #16314
Conversation
Whether an optimizations should be performed may depend on when the variables in an expression are known. For example, consider a LoRA-adjusted model, with base weights `W` of shape `[m,n]`, LoRA components `A` and `B` with shapes `[r,n]` and `[m,r]` respectively, and activations `x` with shape `[n,1]`. The LoRA-adjusted matmul could be computed either as `(W + B*A)*x` or as `(W*x + B*(A*x))`. If `A` and `B` are provided at run-time, then computing `(W + B*(A*x))` requires significantly fewer computations. * `(W + B*A)*x`: `m*n*(2*r + 3)` operations 1. `B*A`: `2*m*n*r` operations using a naive matmul 2. Adding `W` to (1): `m*n` operations 3. Multiplying `x` by (2): `2*m*n` operations * `(W*x + B*(A*x))`: (2*m*n + r*(2*n + 2*m + 1)) 1. `W*x`: `2*m*n` operations 2. `A*x`: `2*r*n` operations 3. Multiplying `B` by (2): `2*m*r` operations 4. Adding (1) and (3)`: `m` operations However, if `A` and `B` are known at compile-time, then computing `(W + B*A)*x` groups all compile-time values together, allowing them to be computed earlier (i.e. using `LiftTransformParams`) * `(W + B*A)*x`: `2*m*n` operations 1. `B*A`: 0 operations, computed at compile-time 2. Adding `W` to (1): 0 operations, computed at compile-time 3. Multiplying `x` by (2): `2*m*n` operations Since the choice of optimized expression depends on which parameters can be computed at compile-time, it is useful to have a utility that identifies values that can be computed at compile-time.
- Update the zero-parameter `WildcardPattern` constructor to produce a valid instance. Previously, the zero-parameter constructor produced a null instance of `WildcardPattern`, which resulted in an error when used. The `WildcardPattern` was expected to be constructed through the `Wildcard` function instead. Since all other `DFPattern` child classes could be constructed explicitly, this could lead to unexpected outcomes. - Check for `pattern.defined()` when performing a pattern-match. If a null instance of a pattern is provided, this gives an error message with more context than the one raised by `DFPatternFunctor`. - Expose `RewriteCall` for use in C++. Previously, it had only been exposed through the FFI registry, and had no declaration in a header file.
Reorder `x*(A*B)` to `(x*A)*B`. Intended for optimization of LoRA models, for which `(x*A)*B` has a much smaller memory footprint.
This PR is currently a draft, as it depends on functionality introduced in #16312. |
The prerequisite #16312 has now landed, and this PR is ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, which would help on the specific case. However, I think it's necessary to analyze the memory footprint before rewriting it.
it would be good to rewrite x * (A * B) -> (x * A) * B
, but not reasonable to rewrite `A * (B * x) -> (A * B) * x
Good point. I've updated the pattern and the rewrite rule with two changes:
I also added several additional unit tests to validate the behavior, for cases where the initial matmul is left-to-right, and where the initial matmul is right-to-left. |
@Hzfengsy All changes based on your earlier comments have been made, and the PR is ready for re-review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the great work!
Reorder
x*(A*B)
to(x*A)*B
. Intended for optimization of LoRA models, for which(x*A)*B
has a much smaller memory footprint.