Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
yelite committed Nov 1, 2022
1 parent ac8f72b commit 1e54702
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
12 changes: 7 additions & 5 deletions python/tvm/meta_schedule/testing/torchbench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
import warnings
from collections import defaultdict
from enum import Enum
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Dict

import numpy as np # type: ignore
import torch # type: ignore
Expand Down Expand Up @@ -377,7 +377,7 @@ def get_graph_executor_forward(
# It has to lazily import this package, loading the C++ PyTorch integration
# after the transformers package is imported when loading model. Otherwise
# there will be segfault caused by the protobuf library.
import tvm.contrib.torch # pylint: disable=import-outside-toplevel, unused-import
import tvm.contrib.torch # pylint: disable=import-outside-toplevel, unused-import, redefined-outer-name

save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod", allow_missing=True)
if save_runtime_mod is None:
Expand Down Expand Up @@ -442,7 +442,7 @@ def create_tvm_task_collection_backend() -> Tuple[Callable, List[ms.ExtractedTas
os.makedirs(subgraphs_dir, exist_ok=True)

collected_tasks = []
task_index = defaultdict(list)
task_index: Dict[int, List[ms.ExtractedTask]] = defaultdict(list)

def collect_task(task):
task_hash = tvm.ir.structural_hash(task.dispatched[0])
Expand Down Expand Up @@ -547,7 +547,8 @@ def inspect_output_error(output, expected):
"""
if not isinstance(output, torch.Tensor):
logger.info(
f"Unsupported type for error inspection: {type(output).__name__}. Please manually check output.pt"
f"Unsupported type for error inspection: {type(output).__name__}."
f"Please manually check output.pt"
)
return
output = output.cpu().float()
Expand Down Expand Up @@ -596,7 +597,8 @@ def format_error_table(error, bins) -> str:
logger.error(f"Absolute Error\n{format_error_table(abs_error, abs_error_bins)}")
logger.error(f"Relative Error\n{format_error_table(rel_error, rel_error_bins)}")
logger.error(
f"Max absolute error for position with large relative error (> 1): {abs_error_with_large_rel_error.max()}"
f"Max absolute error for position with large relative error (> 1):"
f"{abs_error_with_large_rel_error.max()}"
)


Expand Down
3 changes: 1 addition & 2 deletions python/tvm/meta_schedule/testing/torchbench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def find_torchdynamo() -> str:


def load_torchdynamo_benchmark_runner(
is_cuda: bool, cosine_similarity: bool = False,
float32: bool = False
is_cuda: bool, cosine_similarity: bool = False, float32: bool = False
) -> TorchBenchmarkRunner:
"""
Load the benchmark runner from TorchDynamo.
Expand Down

0 comments on commit 1e54702

Please sign in to comment.