Skip to content

Commit

Permalink
[PERF] Increase accuracy of pick up the best candidate (#269)
Browse files Browse the repository at this point in the history
In this review, I redo the pick-up of the best candidate.

Now statistical t-test is used to estimate what schedule is better. Moved this code to python/hidet/utils/benchmark/bench.py` together with another similar code.

**Performance improvement for bs=1, A10G**

`python tests/benchmarks/bench_vision.py resnet50 --params 1x3x224x224 --dtype float16`

_**Before**_
0.7848
0.7803
0.7808
0.7839
0.7821
0.7887
0.7785
0.7843
0.7857
0.7939
median = 0.7841
stddev = 0.45%

_**After**_
0.7717
0.7708
0.7679
0.7662
0.7717
0.7715
0.7698
0.7692
0.7706
0.7720
median = 0.7707 (**improvement 1.7%**)
stddev = 0.19% (**improvement 2.37x**)


**Compilation time improvement**
g5.x16large instance. 64 threads/32 cores
`time python tests/benchmarks/bench_op.py batch_matmul --params
1x4096x4096,1x4096x4096 --dtype float16`

**_Before_**
real    5m9s

_**After**_
real    2m27s
  • Loading branch information
vadiklyutiy committed Jul 22, 2024
1 parent d0877d5 commit 8add3b7
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 35 deletions.
1 change: 1 addition & 0 deletions .github/scripts/start_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def run_command(cmd):

# e.g., ' 1, 2, ,3,,' -> ['1', '2', '3']
hw_config_ids = os.environ.get('HW_CONFIG').replace(' ', '')
hw_config_ids = '2'
repo_org = os.environ.get('REPO_NAME').split('/')[0]
if hw_config_ids == 'all':
query = (
Expand Down
7 changes: 6 additions & 1 deletion python/hidet/drivers/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,13 @@ def check_function_singular(module_list: Union[Sequence[IRModule], Sequence[Sequ
num_workers = min(len(ir_modules), 128)
else:
num_workers = get_parallel_num_workers(max_num_worker, mem_for_worker)
# shuffle the candidates to avoid grouping long-compilation time candidates together
# Shuffle the candidates to avoid grouping long-compilation time candidates together
# Make compilation deterministic
random.seed(42)
# Shuffle
random.shuffle(ir_modules)
# Make random number a random again
random.seed()
if num_workers > 1 and len(ir_modules) > 1:
lazy_initialize_cuda()
per_worker_jobs = 1 if len(ir_modules) < num_workers else len(ir_modules) // num_workers
Expand Down
21 changes: 3 additions & 18 deletions python/hidet/runtime/compiled_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from dataclasses import dataclass
import os
import json
import time
from collections import namedtuple
import tabulate
from hidet.runtime.compiled_module import CompiledModule, CompiledFunction, load_compiled_module
Expand Down Expand Up @@ -165,27 +164,13 @@ def create_outputs(self, inputs):
return outputs

def pick_best_candidate(self, inputs, outputs) -> int:
import hidet
from hidet.utils.benchmark.bench import find_best_candidate

key = self._get_symbol_values()
if key not in self.dispatch_table:
if len(self.candidates) > 1:
warmup, number, repeat = hidet.option.get_bench_config()
latencies = []
for idx, candidate in enumerate(self.candidates):
for _ in range(warmup):
candidate(*inputs, *outputs)
candidate_latency = 0.0
for _ in range(repeat):
hidet.cuda.synchronize()
t1 = time.time()
for _ in range(number):
candidate(*inputs, *outputs)
hidet.cuda.synchronize()
t2 = time.time()
candidate_latency += (t2 - t1) * 1000 / number
latencies.append(candidate_latency / repeat)
self.dispatch_table[key] = latencies.index(min(latencies))
best_idx, latencies = find_best_candidate(self.candidates, *inputs, *outputs)
self.dispatch_table[key] = best_idx

# write a benchmark report
report_name = '_'.join('{}_{}'.format(a, b) for a, b in zip(self.meta_data.symbols, key))
Expand Down
6 changes: 4 additions & 2 deletions python/hidet/testing/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def check_module(model: torch.nn.Module, args: Sequence[torch.Tensor], atol=1e-4

# Class to initialise backend, run compilation
class Backend:
def __init__(self, backend, dtype, search_space=2) -> None:
def __init__(self, backend, dtype, cache='', search_space=2) -> None:
assert backend in [
'hidet',
'max-autotune',
Expand All @@ -64,6 +64,7 @@ def __init__(self, backend, dtype, search_space=2) -> None:
self.backend = backend
self.dtype = dtype
self.search_space = search_space
self.cache = cache
if self.backend == 'hidet':
self.init_hidet()

Expand All @@ -75,6 +76,7 @@ def init_hidet(self):
hidet.torch.dynamo_config.use_tensor_core(True)
hidet.torch.dynamo_config.use_cuda_graph(True)
hidet.option.search_space(self.search_space)
hidet.option.cache_dir(hidet.option.get_cache_dir() + self.cache)

# hidet.option.cache_dir(hidet.option.get_cache_dir() + '/regression')
# hidet.option.parallel_tune(max_parallel_jobs=1)
Expand All @@ -93,7 +95,7 @@ def init_hidet(self):

def compile(self, model):
if self.backend == 'hidet':
model = torch.compile(model, backend=self.backend)
model = torch.compile(model, backend='hidet', mode='max-autotune')
elif self.backend == 'eager':
pass
else:
Expand Down
90 changes: 80 additions & 10 deletions python/hidet/utils/benchmark/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from typing import List, Optional, Callable, Tuple, Any, Dict, Union
import time
from dataclasses import dataclass

from scipy import stats
import numpy as np
import nvtx
import hidet
import hidet.cuda


# copied from: https://github.com/openai/triton/blob/main/python/triton/testing.py
Expand All @@ -33,7 +36,6 @@ def do_bench(fn, warmup=25, rep=100, percentiles=(0.2, 0.5, 0.8)):
"""

# Estimate the runtime of the function
import hidet

fn()
hidet.cuda.synchronize()
Expand Down Expand Up @@ -69,7 +71,7 @@ def do_bench(fn, warmup=25, rep=100, percentiles=(0.2, 0.5, 0.8)):
return np.mean(times).item()


def benchmark_func(run_func, warmup=1, number=5, repeat=5, median=True) -> Union[List[float], float]:
def benchmark_func(run_func, *args, warmup=1, number=5, repeat=5, median=True) -> Union[List[float], float]:
"""Benchmark given function.
The given function ``run_func`` will be executed :math:`warmup + repeat * number` times. Each :math:`number` times
Expand Down Expand Up @@ -98,29 +100,97 @@ def benchmark_func(run_func, warmup=1, number=5, repeat=5, median=True) -> Union
- When median == True, a single latency number is returned.
- When median == False, the latency of each repeat is returned, as a list of floats.
"""
import nvtx
import hidet.cuda

results = []
with nvtx.annotate('warmup'):
for _ in range(warmup):
run_func()
run_func(*args)
hidet.cuda.synchronize()

for i in range(repeat):
with nvtx.annotate(f'repeat {i}'):
hidet.cuda.synchronize()
start_time = time.time()
start_time = time.time_ns()
for _ in range(number):
run_func()
run_func(*args)
hidet.cuda.synchronize()
end_time = time.time()
results.append((end_time - start_time) * 1000 / number)
end_time = time.time_ns()
results.append((end_time - start_time) / 10**6 / number)
if median:
return float(np.median(results))
else:
return results


@dataclass
class CandidateData:
idx: int
latencies: List[float] = None
median: float = 0.0
in_game: bool = True


def find_best_candidate(candidates: List[Callable[..., None]], *args):
P_VALUE_THRESHOLD = 0.01
num_candidates = len(candidates)
candidates_data = [CandidateData(idx=idx) for idx, _ in enumerate(candidates)]
repeats = (7, 31)
for cur_repeat in repeats:
for idx, cand in enumerate(candidates):
if candidates_data[idx].in_game:
lats = benchmark_func(cand, *args, warmup=5, number=1, repeat=cur_repeat, median=False)
candidates_data[idx].latencies = lats

for cand in candidates_data:
if cand.in_game:
cand.median = np.median(cand.latencies)

# We have samples for every cansidate.
# Start with candidate with minimum median. Likely it drop a lot of slower candidates.
# Just optimisation. The next loop is enough for functionality
min_lat_cand = min((cand for cand in candidates_data if cand.in_game), key=lambda cand: cand.median)
min_idx = min_lat_cand.idx
for i in range(num_candidates):
if i == min_idx or not candidates_data[i].in_game:
continue
_, p_value = stats.ttest_ind(
candidates_data[min_idx].latencies, candidates_data[i].latencies, alternative='less'
)
if p_value < P_VALUE_THRESHOLD:
candidates_data[i].in_game = False
# If left only one candidate - good we found the best
left_candidates = [cand for cand in candidates_data if cand.in_game]

if len(left_candidates) == 1:
return (left_candidates[0].idx, [cand.median for cand in candidates_data])

# Compare all candidates betwee each other. Comparison use T-test
for i in range(num_candidates):
if not candidates_data[i].in_game:
continue
for j in range(num_candidates):
if not candidates_data[j].in_game or i == j:
continue
_, p_value = stats.ttest_ind(
candidates_data[i].latencies, candidates_data[j].latencies, alternative='less'
)
if p_value < P_VALUE_THRESHOLD:
candidates[j].in_game = False

# If left only one candidate - good we found the best
left_candidates = [cand for cand in candidates_data if cand.in_game]
if len(left_candidates) == 1:
return (left_candidates[0].idx, [cand.median for cand in candidates_data])

# Can not prove that one candidate statistically significant than all other.
# There are several but we can not order them using above method.
# Should choose some candidate. Choose one with minimal median
best = min((cand for cand in candidates_data if cand.in_game), key=lambda cand: cand.median)
best_idx = best.idx
latensies = [cand.median for cand in candidates_data]
return (best_idx, latensies)


@dataclass
class BenchData:
x_vals: List[Any]
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,7 @@ tomlkit

# for parser
lark

# for performance measurements
scipy

10 changes: 6 additions & 4 deletions tests/benchmarks/bench_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from hidet.testing.torch_utils import bench_torch_model, Backend


def bench_torchvision(model_name, shape, dtype, backend):
comp_backend = Backend(backend, dtype)
def bench_torchvision(model_name, shape, dtype, backend, cache):
comp_backend = Backend(backend, dtype, cache)

dtype = getattr(torch, dtype)
if any(name in model_name for name in ['deeplab', 'fcn', 'lraspp']):
Expand Down Expand Up @@ -37,9 +37,11 @@ def bench_torchvision(model_name, shape, dtype, backend):
parser.add_argument('--params', type=str, default='1x3x224x224', help='Specify Input Size. E.g., 1x3x224x224')
parser.add_argument('--dtype', type=str, default='float16', help='Specify precision. E.g., float32')
parser.add_argument('--backend', type=str, default='hidet', help='torch.compile backend: hidet or max-autotune')
parser.add_argument('--cache', type=str, default='', help='')

args = parser.parse_args()

model, dtype, backend = args.model, args.dtype, args.backend
model, dtype, backend, cache = args.model, args.dtype, args.backend, args.cache
shape = [int(d) for d in args.params.split('x')]
latency = bench_torchvision(model, shape, dtype, backend)
latency = bench_torchvision(model, shape, dtype, backend, cache)
print(latency)

0 comments on commit 8add3b7

Please sign in to comment.