Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Support PrimFunc-to-PrimFunc calls with primitive arguments #14862

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

Lunderberg
Copy link
Contributor

This PR allows TIR PrimFuncs to call subroutines located within the same IRModule. The immediate goal is to make it easier to hand-write TIR that represents output of possible optimizations, to determine whether those optimizations would be worth implementing. A longer-term goal is to enable multi-device scheduling, such as tensor-parallel compute, in which a single TIR function may be scheduled to delegate portions of the compute onto multiple devices.

The use cases enabled by this PR are shown in the unit tests in tests/python/unittest/test_tir_subroutine_call.py, which is the best place to start for reviewing. (And both thanks and apologies in advance to reviewers for the larger PR!) The changes enabling these use cases are summarized below:

  • Cherry-pick TVMScript changes from unity branch, to parse module.func_name as a GlobalVar

  • Update TIR building to only use IRModule to IRModule transforms, without using the intermediate Map<Target, IRModule>. This ensures that each pass can identify the callee within the same IRModule. External functions that accept Map<Target, IRModule> as input are normalized before lowering.

  • Refactor SplitHostDevice to output PrimFunc-to-PrimFunc calls, with a subsequent LowerDeviceKernelLaunch pass to generate compute kernel launches, as well as handling any user-written host-to-device.

  • Update several passes to accept input PrimFuncs without a kGlobalSymbol attribute

  • Update LLVM codegen to handle GlobalVar as the CallNode::op, producing calls to the subroutine.

@tvm-bot
Copy link
Collaborator

tvm-bot commented May 16, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@Lunderberg
Copy link
Contributor Author

The majority of this PR is now split out, and waiting for the components to pass CI. Of the initial ~2.5k lines of changes, about 1k have landed so far, 1.6k are currently making their way through CI, leaving about 500 lines in the current PR. Increased line count when splitting out the changes was largely due to increased test coverage, and was well worth the effort.

Commits 721da73 and b4025fd probably deserve to be their own PRs as well, but they depend on the current batch so they can't easily be split out yet.

Prior to this commit, a build that used multiple targets needed to
provide `tvm::build` with a `Map<Target, IRModule>` specifying which
target should be used to compile each `IRModule`.  As a result,
lowering passes could not introduce new targets based on a PrimFunc's
content (e.g. a `with T.target()` frame to delegate out to another
device), nor simplify based on cross-device subroutines (e.g. simplify
a host-side conditional based on the known output of a device-side
internal subroutine).

This commit makes the `tvm::attr::kTarget` attribute (`"target"`) be
the single source of truth for where a `PrimFunc` will be executed.
Other existing methods for specifying the target (the `target`
parameter for `tvm.build`, the keys in a `Map<Target,IRModule>`, the
parameter to the pass `tir::transform::BindTarget`) are still accepted
as inputs, and may provide a default value for `tvm::attr::kTarget` if
the attribute is missing, but may not overwrite the target attribute.

This is part of a series of commits to simplify the handling of
multi-target builds.
Otherwise, in cases of a custom codegen, the device specification may
be dropped entirely.
Currently, the Target does two independent tasks: (1) defining which
device owns the buffers that are passed as input to a PrimFunc,
and (2) defining which codegen will be used for a PrimFunc.

Prior to this commit, the "ext_dev" target was required to define the
device ownership, but did not provide the `"target.build.ext_dev"`
function that is required for codegen.  This worked, because
`SplitHostDevice` would remove the `"ext_dev"` target without making a
device-side function.  With the single-module lowering flow, the
separate device-side function is required to support UMA codegen.

To resolve this issue, `"ext_dev"` now provides a codegen function,
which is identical to the LLVM codegen.  This may be improved in the
future by allowing the buffer device and the codegen to be specified
independently.
Previously, if no device-specific attribute is found, assume that the
entire function should be executed on the device.  Now, identify
host-specific Call (e.g. `builtin::call_packed()`) and ensure these
remain on the host.
Since the ext_dev target may be compiled separately.
Because VTA's codegen is allowed to call host-side, it uses the
`"cpu"` tag.  Therefore, the allocations that are already handled with
`VTABufferCPUPtr` should opt-out of using the device API from
`LowerTVMBuiltin`.
May be used by kernels to call device-specific intrinsics (e.g. for
cmsis-nn)
The functionality tested in this commit was added across several
recent PRs, each of which tested their features in isolation.  This PR
adds unit tests to validate the end-to-end behavior of TIR subroutine
calls.

PRs building up to this point:

- TVMScript
  - apache#14889
  - apache#14915
  - apache#14919
  - apache#14941

- Functionality improvements of existing TIR passes
  - apache#14913
  - apache#14914
  - apache#14918
  - apache#14951

- Changes to the TIR lowering flow
  - apache#14942
  - apache#14985

- Codegen updates
  - apache#14958
  - apache#14901

- Compatibility updates/fixes
  - apache#14892
  - apache#14950
  - apache#14943
  - apache#14944
  - apache#14945
  - apache#14952
  - apache#14982
  - apache#14949
Now that the function return type is handled by `CodeGenC`, updating
the docstring to a usage other than the return type.
Calling a function annotated with `__global__` can be done from the
GPU (see https://stackoverflow.com/a/39448797), but requires a
different calling convention.
Externally exposed function is lowered into a PackedFunc call, and
calling a PackedFunc requires the caller to be on the host.  In the
future, this can be improved by having a pass that identifies internal
callers of an externally-exposed callee, rewriting to extract an
internal method that is called by both externally-exposed functions.
Previously, all functions required the `tvm::attr::kCallingConv`
attribute to be set to `CallingConv::kDeviceKernelLaunch` (2). Now,
this is only required for externally-exposed functions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants