Skip to content

Commit

Permalink
[AutoScheduler] Fix task scheduler after 8478 (#8984)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Sep 11, 2021
1 parent 8b59f99 commit 02f885a
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions python/tvm/auto_scheduler/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,9 @@ def _tune_task(self, task_idx):

def _compute_score(self, costs):
"""compute the objective function"""
return self.objective_func(costs)
# Make sure to return float.
score = self.objective_func(costs)
return score.value if hasattr(score, "value") else score

def _adjust_similarity_group(self, task_idx):
"""adjust the similarity group for the selected task"""
Expand Down Expand Up @@ -598,7 +600,7 @@ def pre_tune(self, task_scheduler, task_id):

# overall info
if all(cost < 1e9 for cost in task_scheduler.best_costs):
total_latency_str = "%.3f" % (task_scheduler.cur_score.value * 1e3)
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
else:
total_latency_str = "-"
print(
Expand Down Expand Up @@ -629,7 +631,7 @@ def __init__(self, log_file):

def post_tune(self, task_scheduler, task_id):
if all(cost < 1e9 for cost in task_scheduler.best_costs):
total_latency_str = "%.3f" % (task_scheduler.cur_score.value * 1e3)
total_latency_str = "%.3f" % (task_scheduler.cur_score * 1e3)
else:
total_latency_str = "N/A"

Expand Down

0 comments on commit 02f885a

Please sign in to comment.