Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX][TVMC] Fix the mixed precision conversion pipeline #17520

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/driver/tvmc/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def autotvm_get_tuning_tasks(
"""
target, target_host = Target.canon_target_and_host(target, target_host)

mod = apply_graph_transforms(mod, transform_args)
mod = apply_graph_transforms(mod, transform_args, params)

tasks = autotvm.task.extract_from_program(
mod["main"],
Expand Down Expand Up @@ -718,7 +718,7 @@ def autoscheduler_get_tuning_tasks(
"""
target, target_host = Target.canon_target_and_host(target, target_host)

mod = apply_graph_transforms(mod, transform_args)
mod = apply_graph_transforms(mod, transform_args, params)

# Extract the tasks
tasks, task_weights = auto_scheduler.extract_tasks(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def compile_model(
instruments=instruments,
):
transform_args = parse_graph_transform_args(locals())
mod = apply_graph_transforms(mod, transform_args)
mod = apply_graph_transforms(mod, transform_args, params)

for partition_function, opts in zip(partition_functions, partition_opts):
mod = partition_function(mod, params, mod_name=mod_name, **opts)
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/driver/tvmc/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def layout_helper(layout):
relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout(desired_layouts),
relay.transform.FoldConstant(),
relay.transform.FoldScaleAxis(),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why this change was made?

]
)

Expand All @@ -162,7 +163,7 @@ def layout_helper(layout):
raise TVMCException("Error converting layouts: {}".format(str(err)))


def apply_graph_transforms(mod, args):
def apply_graph_transforms(mod, args, params=None):
"""Alter the layout of the input graph.

Parameters
Expand All @@ -171,6 +172,8 @@ def apply_graph_transforms(mod, args):
The relay module to convert.
args : dict
The transform arguments.
params: dict
Module params

Returns
-------
Expand All @@ -188,6 +191,7 @@ def apply_graph_transforms(mod, args):

# ToMixedPrecision
if args.get("mixed_precision", False):
mod = relay.quantize.prerequisite_optimize(mod, params)
mod = convert_to_mixed_precision(
mod,
args.get("mixed_precision_ops"),
Expand Down
2 changes: 2 additions & 0 deletions tests/python/driver/tvmc/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def check(self, func):
"mixed_precision_calculation_type": "float16",
"mixed_precision_acc_type": "float16",
},
params,
)
ret = CheckOpMutator("float16", "float16", "nn.conv2d").check(mod["main"])
assert ret
Expand All @@ -240,6 +241,7 @@ def check(self, func):
"mixed_precision_calculation_type": "float16",
"mixed_precision_acc_type": "float32",
},
params,
)
ret = CheckOpMutator("float16", "float32", "nn.conv2d").check(mod["main"])
assert ret
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/opencl_texture/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _test_mobilenet_v1(remote, target, calc_dtype, executor_type, acc_dtype):
"mixed_precision_calculation_type": calc_dtype,
"mixed_precision_acc_type": acc_dtype,
},
params,
)

if executor_type == "ge":
Expand Down
Loading