-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoScheduler] Register workload when deserializing tasks #6927
Conversation
59b7525
to
2aabc3b
Compare
func_name = json.loads(self.workload_key)[0] | ||
except Exception: # pylint: disable=broad-except | ||
raise RuntimeError("Invalid workload key %s" % self.workload_key) | ||
register_workload_tensors(func_name, self.dag.tensors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should do more checks here.
There are two kinds of workloads.
You can find the note here
https://github.com/apache/incubator-tvm/blob/51d81fb71781eb606f4b563ec28290dbd8d04faf/python/tvm/auto_scheduler/workload_registry.py#L42-L52
We have to handle two cases correctly without incorrect overwriting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the logic accordingly. Now we only register workloads extracted from Relay programs.
efe9ef6
to
ca1945e
Compare
* [AutoScheduler] Register workload when deserializing tasks * fix name * format * merge * fix test * more checks
* [AutoScheduler] Register workload when deserializing tasks * fix name * format * merge * fix test * more checks
* [AutoScheduler] Register workload when deserializing tasks * fix name * format * merge * fix test * more checks
One issue of the current auto_scheduler task serialization is that the workload won't be registered when deserializing a task so users have to manually call
register_workload_tensors
after deserialization. However,register_workload_tensors
always uses compute DAG hash as the workload key. This could cause workload key mismatching if the task was from a TE function that uses function name as the workload key. This PR makes the following changes to make sureWORKLOAD_FUNC_REGISTRY
is consist after deserializing a task.register_workload_tensors
by moving the ComputeDAG out.cc @merrymercy @jcf94