-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TVM Vertical Integration with PyTorch #11911
Conversation
Co-authored-by: Lite Ye <liteye859@gmail.com>
I suggest moving tutorials to a separate PR. Ideally, tutorials should demonstrate more realistic examples than vector add or matmul, i.e. something PyTorch users would reach for "custom op" authoring. For example, I think demonstrating equivalent of Triton fused softmax tutorial https://triton-lang.org/master/getting-started/tutorials/02-fused-softmax.html in this workflow would be very interesting. |
After discussing with @yelite , we will drop the how-to guides and will resubmit a separate PR afterward. |
* The basic forward function calling TVM's runtime is provided. | ||
* The TVM module can be serialized/deserialized as a Torch module. | ||
*/ | ||
class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks similar to TvmGraphModulePack
:
class TvmGraphModulePack { |
Why do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one reason is that we don't want to use temp files to transmit data, as bytedance's approach, but use TVM's FFI. @yelite
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @masahi , there are several reasons we don't plan to use codes from tvm_class.cc
:
tvm_class.cc
is complex while our code is more natural. For example, they maintain a torch's tensor to DLpack by themselves, while we use torch's built-in library.- Our code is more readable. We have less functions but could cover
tvm_class.cc
's functionality. For example, we don't need to have an extra initialization functioninit
orloadTVMmodule
. tvm_class.cc
uses tempfile and absolute path to transmit TVM module while we use TVM's FFI, which is a better practice I believe- The most significant difference is
save/load
functions. I tested that if we save a torch model viatvm_class.cc
and then restart the python kernel, we cannot load the model back successfully because of (3). Our code can arbitrarilysave/load
models anywhere anytime because we serialize/deserialize the whole runtime module.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If GraphExecutorFactoryWrapper
is strictly better than existing one, I want to see the existing one removed or reimplemented in terms of GraphExecutorFactoryWrapper
. But this can be done in a follow up.
* The basic forward function calling TVM's runtime is provided. | ||
* The TVM module can be serialized/deserialized as a Torch module. | ||
*/ | ||
class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If GraphExecutorFactoryWrapper
is strictly better than existing one, I want to see the existing one removed or reimplemented in terms of GraphExecutorFactoryWrapper
. But this can be done in a follow up.
save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod") | ||
save_runtime_mod(executor_factory.module) | ||
|
||
return GraphExecutorFactoryWrapper(torch.classes.tvm_tuning.GraphExecutorFactoryWrapper()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks strange... Why torch.classes.tvm_tuning.GraphExecutorFactoryWrapper()
doesn't take any argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class GraphExecutorFactoryWrapper
is the subclass of Torch's module, and Torch's FFI cannot recognize TVM's datastructure, thus we transmit the runtime module by TVM's FFI.
Concretely, in line 185, we store the module in memory.
When the constructor of GraphExecutorFactoryWrapper
is called, it will get the TVM's runtime module in the memory.
The Python class GraphExecutorFactoryWrapper
is just a wrapper of the output because C++ doesn't support tuple unpacking but we do need this function in Python.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see now that the compiled module is passed between python and C++ PyTorch by a thread local storage (stored by save_runtime_mod
).
return self.rt_module.forward(torch_inputs) | ||
|
||
|
||
def as_torch(func: Union[tvm.ir.module.IRModule, tvm.tir.function.PrimFunc, Callable]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So as_torch
doesn't provide tuning facilities? I noticed that all tuning tests in this PR is done via optimize_torch
which involves Relay. If a user wants to tune a TVMScript-written op and use @as_torch
decorator, how tuning can be done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that as_torch
is just used to convert TVMscript to Torch.
Need to confirm with @yelite to see if we need to do more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still possible for PT users to write TVMScript and use tune_tir
to tune, and use as_torch
to convert the tuned prim func to PT. We are offering optimize_pytorch
to wrap tune_relay
, so it would be nice if as_torch
also wraps tune_tir
and automatically does tuning.
Current examples only show the usage of as_torch
as an decorator on top of a manually written TVMScript without tuning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added the tune_tir
in as_torch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this: We add tune(config)
(with optional config param like optimize_torch
) method on OperatorModuleWrapper
, which does tuning and rebuild the mod. And remove tune_tir
from build(...)
. So by default tuning won't happen, but the user can explicitly ask to tune.
python/tvm/contrib/torch/as_torch.py
Outdated
return sch | ||
|
||
def build(self, target=None): | ||
tuned_module = self.tune_tir_auto(self.ir_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the input TVMScript module doesn't have any tunable knobs, does this tune_tir_auto
finish instantly? Tuning should be an opt-in feature.
python/tvm/contrib/torch/as_torch.py
Outdated
mod = default_config.mod(mod) | ||
target = default_config.target(target) | ||
|
||
extracted_task = ExtractedTask( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is always only one task, since it is tuning a single op
"For optimal performance, it is recommended to provide", | ||
"the `tuning_config` argument with a bigger number of trials.", | ||
) | ||
warnings.warn(" ".join(warning_msg), stacklevel=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the default tuning config is dropped? @juda
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is moved to line 111 because we need to get extracted_tasks
in advance
* optimize_torch & as_torch * split files * code formatting * optimizing optimized_torch * scrap your boilerplate * as_torch polished * configuration fixed * Apply suggestions from code review Co-authored-by: Lite Ye <liteye859@gmail.com> * more document * file deleter * optimize deleter * drop how-to guides * clang-format-10 * formatter changes * reformat * reformat * reformat * reformatting * fixed * auto setting * fixed * split long string * tune_tir * upgrade as_torch * optimize as_torch * as_torch * fixed typo Co-authored-by: juda <yzhou@octoml.ai> Co-authored-by: Lite Ye <liteye859@gmail.com>
* optimize_torch & as_torch * split files * code formatting * optimizing optimized_torch * scrap your boilerplate * as_torch polished * configuration fixed * Apply suggestions from code review Co-authored-by: Lite Ye <liteye859@gmail.com> * more document * file deleter * optimize deleter * drop how-to guides * clang-format-10 * formatter changes * reformat * reformat * reformat * reformatting * fixed * auto setting * fixed * split long string * tune_tir * upgrade as_torch * optimize as_torch * as_torch * fixed typo Co-authored-by: juda <yzhou@octoml.ai> Co-authored-by: Lite Ye <liteye859@gmail.com>
The pull request contains two functions:
optimize_torch
as a function similar totorch.jit.trace
, which is used to optimize thetorch.nn.module
by TVM metaSchedule, and returns a custom TorchScript operatoras_torch
as a decorator, which is used to wrap the TVMscript code totorch.nn.module
.The files consist of:
@yelite @junrushao1994 @masahi