diff --git a/tuning/compile_candidate.sh b/tuning/compile_candidate.sh index 9b9b74c..9ec696d 100755 --- a/tuning/compile_candidate.sh +++ b/tuning/compile_candidate.sh @@ -8,8 +8,6 @@ readonly DIR="$(dirname "$INPUT")" readonly BASENAME="$(basename "$INPUT" .mlir)" readonly OUT="${DIR}/compiled/${BASENAME}.vmfb" -mkdir -p "${DIR}/compiled" "${DIR}/failed" "${DIR}/specs" - timeout 4s ./punet.sh "$INPUT" -o "$OUT" --compile-from=executable-sources 2>/dev/null || (mv "$INPUT" "$DIR/failed" && exit 1) tools/iree-dump-module "$OUT" | grep -q 'rocm-hsaco-fb' || (mv "$INPUT" "$DIR/failed" && rm -f "$OUT" && exit 1) if [ -f "${DIR}/${BASENAME}_config.mlir" ]; then diff --git a/tuning/libtuner.py b/tuning/libtuner.py index 8661354..7fb86e3 100755 --- a/tuning/libtuner.py +++ b/tuning/libtuner.py @@ -42,6 +42,9 @@ device_id = None """Do not need to change""" +# Declare special symbols for libtuner to search and locate +DEVICE_ID_PLACEHOLDER = "!DEVICE_ID!" + @dataclass class CandidateTracker: @@ -53,12 +56,12 @@ class CandidateTracker: compiled_dispatch_path: Optional[Path] = None compiled_dispatch_hash: Optional[str] = None first_benchmark_time: Optional[float] = None - first_benchmark_device_id: Optional[int] = None + first_benchmark_device_id: Optional[str] = None spec_path: Optional[Path] = None - model_path: Optional[Path] = None + compiled_model_path: Optional[Path] = None compiled_model_hash: Optional[str] = None model_benchmark_time: Optional[float] = None - model_benchmark_device_id: Optional[int] = None + model_benchmark_device_id: Optional[str] = None baseline_benchmark_time: Optional[float] = None calibrated_benchmark_diff: Optional[float] = None @@ -79,7 +82,7 @@ class PathConfig: candidate_configs_pkl: Path = field(init=False) compiled_dir: Path = field(init=False) compile_failed_dir: Path = field(init=False) - spec_dir: Path = field(init=False) + specs_dir: Path = field(init=False) output_unilog: Path = field(init=False) result_summary_log: Path = field(init=False) @@ -103,7 +106,7 @@ def __post_init__(self): ) object.__setattr__(self, "compiled_dir", self.candidates_dir / "compiled") object.__setattr__(self, "compile_failed_dir", self.candidates_dir / "failed") - object.__setattr__(self, "spec_dir", self.candidates_dir / "specs") + object.__setattr__(self, "specs_dir", self.candidates_dir / "specs") object.__setattr__(self, "output_unilog", self.base_dir / "output.log") object.__setattr__( self, "result_summary_log", self.base_dir / "result_summary.log" @@ -147,32 +150,39 @@ def get_dispatch_compile_command( pass @abstractmethod - def get_dispatch_benchmark_command(self, candidate_tracker) -> list[str]: + def get_dispatch_benchmark_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: pass @abstractmethod - def get_model_compile_command(self, candidate_tracker) -> list[str]: + def get_model_compile_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: pass @abstractmethod - def get_model_benchmark_command(self, candidate_tracker) -> list[str]: + def get_model_benchmark_command( + self, candidate_tracker: CandidateTracker + ) -> list[str]: pass @dataclass -class TaskTuple: +class TaskPack: args: argparse.Namespace + candidate_id: int command: list[str] check: bool = True command_need_device_id: bool = False cooling_time: int = 0 - result_need_device_id: bool = False @dataclass class TaskResult: result: subprocess.CompletedProcess - device_id: Optional[int] = None + candidate_id: int + device_id: str @dataclass @@ -184,113 +194,47 @@ class ParsedDisptachBenchmarkResult: @dataclass -class DispatchBenchmarkResult: - result_str: Optional[str] = None - - def get_tokens(self) -> list[str]: - # e.g. ['0', 'Mean', 'Time:', '694.0'] - if self.result_str is None: - return [] - try: - return self.result_str.split() - except: - return [] - - def get_candidate_id(self) -> Optional[int]: - if len(self.get_tokens()) < 1: - return None - try: - return int(self.get_tokens()[0]) - except ValueError: - return None - - def get_benchmark_time(self) -> Optional[float]: - if len(self.get_tokens()) < 4: - return None - try: - return float(self.get_tokens()[3]) - except ValueError: - return None - - def generate_sample_result( - self, candidate_id: int = 0, mean_time: float = random.uniform(100.0, 500.0) - ) -> str: - # time unit is implicit and dependent on the output of iree-benchmark-module - return f"{candidate_id}\tMean Time: {mean_time:.1f}\n" - - -@dataclass -class ModelBenchmarkResult: - result_str: Optional[str] = None - - def get_tokens(self) -> list[str]: - # e.g. ['Benchmarking:', '/sdxl-scripts/tuning/tuning_2024_07_19_08_55/unet_candidate_12.vmfb', 'on', 'device', '4', 'BM_main/process_time/real_time_median', '65.3', 'ms', '66.7', 'ms', '5', 'items_per_second=15.3201/s'] - if self.result_str is None: - return [] - try: - return self.result_str.split() - except: - return [] +class IREEBenchmarkResult: + # Default format follows output of iree-benchmark-module + candidate_id: int + result_str: str - def get_model_candidate_path(self) -> Optional[str]: - if len(self.get_tokens()) < 2: + def get_mean_time(self) -> Optional[float]: + if not self.result_str: return None - return self.get_tokens()[1] - - def get_candidate_id(self) -> Optional[int]: - if self.get_model_candidate_path(): - try: - path_str = self.get_model_candidate_path() - return int(path_str.split("_")[-1].split(".")[0]) if path_str else None - except ValueError: - return None - return None - - def get_device_id(self) -> Optional[int]: - if len(self.get_tokens()) < 5: + pattern = r"process_time/real_time_mean\s+([\d.]+)\s\w{2}" + match = re.search(pattern, self.result_str) + if not match: return None try: - return int(self.get_tokens()[4]) + return float(match.group(1)) except ValueError: return None - def get_benchmark_time(self) -> Optional[int | float]: - if len(self.get_tokens()) < 7: - return None - try: - return float(self.get_tokens()[6]) - except ValueError: - return None - def get_calibrated_result_str(self, change: float) -> Optional[str]: - if self.result_str is None: - return self.result_str +def generate_display_DBR( + candidate_id: int = 0, mean_time: float = random.uniform(100.0, 500.0) +) -> str: + """Generate dispatch_benchmark_result string for displaying""" + return f"{candidate_id}\tMean Time: {mean_time:.1f}" - benchmark_time = self.get_benchmark_time() - if benchmark_time is None: - return self.result_str - # Format the change to be added to the string - percentage_change = change * 100 +def generate_display_MBR( + candidate_vmfb_path_str: str = "baseline.vmfb", + device_id: str = "0", + t1: float = random.uniform(100.0, 500.0), + calibrated_diff: Optional[float] = None, +) -> str: + """Generate model_benchmark_result string for displaying""" + if calibrated_diff: + percentage_change = calibrated_diff * 100 change_str = f"({percentage_change:+.3f}%)" - - # Use regex to find and replace the old benchmark time with the new one - new_result_str = re.sub( - r"(\d+(\.\d+)?)\s*ms", - lambda m: f"{self.get_benchmark_time()} ms {change_str}", - self.result_str, - count=1, + res_str = f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}: {t1:.3g} {change_str}" + else: + res_str = ( + f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}: {t1:.3g}" ) - - return new_result_str - - def generate_sample_result( - self, - candidate_vmfb_path_str: str = "unet_baseline.vmfb", - device_id: int = 0, - t1: float = random.uniform(100.0, 500.0), # time in ms - ) -> str: - return f"Benchmarking: {candidate_vmfb_path_str} on device {device_id}\nBM_run_forward/process_time/real_time_median\t {t1:.3g} ms\t {(t1+1):.3g} ms\t 5 items_per_second={t1/200:5f}/s\n\n" + return res_str def extract_driver_names(user_devices: list[str]) -> set[str]: @@ -555,11 +499,9 @@ def run_command( if result.stdout: logging.info(f"stdout: {result.stdout}") if result.stderr: - logging.error(f"stderr: {result.stderr}") + logging.info(f"stderr: {result.stderr}") except subprocess.CalledProcessError as e: - print(f"Command '{e.cmd}' returned non-zero exit status {e.returncode}.") print(e.output) - logging.error( f"Command '{command_str}' returned non-zero exit status {e.returncode}." ) @@ -572,18 +514,24 @@ def run_command( return result -def run_command_wrapper(task_tuple: TaskTuple) -> TaskResult: +def run_command_wrapper(task_tuple: TaskPack) -> TaskResult: """pool.imap_unordered can't iterate an iterable of iterables input, this function helps dividing arguments""" if task_tuple.command_need_device_id: - # worker add its device_id to the end of command list - task_tuple.command.append(str(device_id)) + # Worker searches for the special symbol and substitutes it with the actual device_id + pattern = re.compile(re.escape(DEVICE_ID_PLACEHOLDER)) + task_tuple.command = [ + pattern.sub(str(device_id), s) for s in task_tuple.command + ] res = run_command(task_tuple.args, task_tuple.command, task_tuple.check) if res is None: raise - task_result = TaskResult(res) - task_result.device_id = device_id if task_tuple.result_need_device_id else None + task_result = TaskResult( + res, task_tuple.candidate_id, device_id=str(-1) + ) # Main process + if device_id: + task_result = TaskResult(res, task_tuple.candidate_id, device_id) # Subprocess time.sleep(task_tuple.cooling_time) @@ -702,7 +650,6 @@ def generate_candidates( args: argparse.Namespace, path_config: PathConfig, candidate_trackers: list[CandidateTracker], - tuning_client: TuningClient, ) -> list[int]: """Generate candidate files for tuning. Returns the list of candidate indexes""" logging.info("generate_candidates()") @@ -803,17 +750,21 @@ def compile_dispatches( candidate_trackers: list[CandidateTracker], tuning_client: TuningClient, ) -> list[int]: - """Compile candidate files for tuning and record in candidate_vmfbs.txt. Returns the list of compiled candidate indexes.""" - logging.info("compile_candidates()") + logging.info("compile_dispatches()") if not candidates: logging.info("No candidates to compile.") return [] + path_config.compiled_dir.mkdir(parents=True, exist_ok=True) + path_config.compile_failed_dir.mkdir(parents=True, exist_ok=True) + path_config.specs_dir.mkdir(parents=True, exist_ok=True) + task_list = [ - TaskTuple( + TaskPack( args, - tuning_client.get_dispatch_compile_command(candidate_trackers[i]), + candidate_id=i, + command=tuning_client.get_dispatch_compile_command(candidate_trackers[i]), check=False, ) for i in candidates @@ -843,7 +794,7 @@ def compile_dispatches( compiled_candidates = [] compiled_candidates_hash_list = [] for compiled_file in compiled_files: - index = path_config.get_compiled_dispatch_index(failed_file) + index = path_config.get_compiled_dispatch_index(compiled_file) compiled_candidates.append(index) candidate_trackers[index].compilation_successful = True candidate_trackers[index].compiled_dispatch_path = compiled_file @@ -875,27 +826,29 @@ def parse_dispatch_benchmark_results( path_config: PathConfig, benchmark_results: list[TaskResult], candidate_trackers: list[CandidateTracker], - tuning_client: TuningClient, ) -> tuple[list[ParsedDisptachBenchmarkResult], list[str]]: benchmark_result_configs = [] dump_list = [] + incomplete_list = [] for benchmark_result in benchmark_results: res_str = benchmark_result.result.stdout - if res_str is None: + candidate_id = benchmark_result.candidate_id + res = IREEBenchmarkResult(candidate_id, res_str) + benchmark_time = res.get_mean_time() + if benchmark_time is None: + incomplete_list.append(candidate_id) continue - res = DispatchBenchmarkResult(res_str) - candidate_id = res.get_candidate_id() - benchmark_time = res.get_benchmark_time() - assert candidate_id is not None and benchmark_time is not None + assert benchmark_time is not None candidate_trackers[candidate_id].first_benchmark_time = benchmark_time candidate_trackers[candidate_id].spec_path = ( - path_config.spec_dir / path_config.get_candidate_spec_filename(candidate_id) + path_config.specs_dir + / path_config.get_candidate_spec_filename(candidate_id) ) mlir_path = candidate_trackers[candidate_id].dispatch_mlir_path spec_path = candidate_trackers[candidate_id].spec_path assert mlir_path is not None and spec_path is not None - dump_list.append(res_str) + dump_list.append(generate_display_DBR(candidate_id, benchmark_time) + "\n") benchmark_result_configs.append( ( @@ -907,26 +860,61 @@ def parse_dispatch_benchmark_results( ) ) ) + + if incomplete_list: + dump_list += [f"Candidate {i} not incompleted" for i in incomplete_list] + return benchmark_result_configs, dump_list +def generate_sample_task_result( + stdout: str, candidate_id: int, device_id: str +) -> TaskResult: + res = subprocess.CompletedProcess( + args=[""], + stdout=stdout, + returncode=0, + ) + return TaskResult(result=res, candidate_id=candidate_id, device_id=device_id) + + def generate_dryrun_dispatch_benchmark_results( compiled_candidates: list[int], ) -> list[TaskResult]: - task_results = [] - for candidate_id in compiled_candidates: - task_result = subprocess.CompletedProcess( - args=[""], - returncode=0, - stdout=DispatchBenchmarkResult().generate_sample_result( - candidate_id, mean_time=random.uniform(100.0, 500.0) - ), - stderr="", + logging.info("generate_dryrun_dispatch_benchmark_results") + + task_results = [ + generate_sample_task_result( + f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms", + i, + str(0), ) - task_results.append(TaskResult(task_result)) + for i in compiled_candidates + ] + return task_results +def generate_dryrun_model_benchmark_results( + model_candidates: list[int], +) -> tuple[list[TaskResult], list[TaskResult]]: + candidate_results = [] + for i, j in enumerate(model_candidates): + stdout = f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms" + candidate_results.append(generate_sample_task_result(stdout, j, str(i % 3))) + + baseline_results = [ + generate_sample_task_result( + f"process_time/real_time_mean {random.uniform(100.0, 500.0):.3g} ms", + 0, + str(i), + ) + for i in range(3) + ] + + return candidate_results, baseline_results + + def benchmark_dispatches( args: argparse.Namespace, path_config: PathConfig, @@ -934,20 +922,21 @@ def benchmark_dispatches( candidate_trackers: list[CandidateTracker], tuning_client: TuningClient, ): - """Benchmark the candidate files and store the topN results in file (best.log).""" - logging.info("benchmark_top_candidates()") + logging.info("benchmark_dispatches()") if args.dry_run: - logging.info("generate_dryrun_dispatch_benchmark_results") benchmark_results = generate_dryrun_dispatch_benchmark_results( compiled_candidates ) else: # Benchmarking dispatch candidates task_list = [ - TaskTuple( + TaskPack( args, - tuning_client.get_dispatch_benchmark_command(candidate_trackers[i]), + candidate_id=i, + command=tuning_client.get_dispatch_benchmark_command( + candidate_trackers[i] + ), check=False, command_need_device_id=True, ) @@ -966,7 +955,7 @@ def benchmark_dispatches( parsed_benchmark_results, dispatch_benchmark_dump_list, ) = parse_dispatch_benchmark_results( - path_config, benchmark_results, candidate_trackers, tuning_client + path_config, benchmark_results, candidate_trackers ) append_to_file( dispatch_benchmark_dump_list, @@ -1008,10 +997,13 @@ def compile_models( candidate_trackers: list[CandidateTracker], tuning_client: TuningClient, ) -> list[int]: - """Compile U-Net candidates stored in best.log. Return the list of U-Net candidate files.""" logging.info("compile_models()") + candidate_trackers[0].compiled_model_path = path_config.model_baseline_vmfb + if args.dry_run: + for i in candidates: + candidate_trackers[i].compiled_model_path = Path(f"model_{i}.vmfb") return candidates if not candidates: @@ -1019,7 +1011,12 @@ def compile_models( return [] task_list = [ - TaskTuple(args, tuning_client.get_model_compile_command(candidate_trackers[i])) + TaskPack( + args, + candidate_id=i, + command=tuning_client.get_model_compile_command(candidate_trackers[i]), + check=False, + ) for i in candidates if i != 0 ] @@ -1037,7 +1034,7 @@ def compile_models( for model_candidate in model_candidates_files: assert model_candidate is not None index = path_config.get_compiled_model_index(model_candidate) - candidate_trackers[index].model_path = model_candidate + candidate_trackers[index].compiled_model_path = model_candidate hash_val = calculate_md5(model_candidate) candidate_trackers[index].compiled_model_hash = hash_val model_candidates_hash_list.append((index, hash_val)) @@ -1060,22 +1057,6 @@ def compile_models( ) -def sort_candidates_by_first_benchmark_times( - candidate_indexes: list[int], candidate_trackers: list[CandidateTracker] -) -> list[int]: - """Sorts candidate indexes based on their first benchmark times in ascending order""" - # Get the first benchmark times, defaulting to a large number if None - first_benchmark_times = [ - candidate_trackers[index].first_benchmark_time or float("inf") - for index in candidate_indexes - ] - combined = list(zip(candidate_indexes, first_benchmark_times)) - combined_sorted = sorted(combined, key=lambda x: x[1]) - sorted_indexes, _ = zip(*combined_sorted) - - return list(sorted_indexes) - - def group_benchmark_results_by_device_id( benchmark_results: list[TaskResult], ) -> list[list[TaskResult]]: @@ -1087,7 +1068,7 @@ def group_benchmark_results_by_device_id( -----> [ [TaskResult(res1, device_1), TaskResult(res3, device_1)], [TaskResult(res2, device_2)] ] """ - grouped_results: dict[int, list[TaskResult]] = {} + grouped_results: dict[str, list[TaskResult]] = {} for result in benchmark_results: assert result.device_id is not None if result.device_id not in grouped_results: @@ -1101,63 +1082,95 @@ def group_benchmark_results_by_device_id( return grouped_benchmark_results -def parse_grouped_benchmark_results( - path_config: PathConfig, - grouped_benchmark_results: list[list[TaskResult]], +def parse_model_benchmark_results( candidate_trackers: list[CandidateTracker], -) -> list[str]: - """Update candidate_trackers and collect strings""" + candidate_results: list[TaskResult], + baseline_results: list[TaskResult], +): + """Update candidate_tracker and format a list of result strings to be saved later.""" + candidate_results = sorted(candidate_results, key=lambda br: br.device_id) + baseline_results = sorted(baseline_results, key=lambda tr: tr.device_id) + + # Assign candidates to the same groups by device_id + grouped_candidate_results = group_benchmark_results_by_device_id(candidate_results) + + # Insert baseline results to the head of each list + grouped_benchmark_results = [ + [x] + y for x, y in zip(baseline_results, grouped_candidate_results) + ] + dump_list = [] - incomplete_list: list[tuple[int, Optional[int]]] = ( + incomplete_list: list[tuple[int, Optional[str]]] = ( [] - ) # format: [(candidate_id, device_id)], baseline will have candidate_id=0 + ) # format: [(candidate_id, device_id)] + baseline_time = None for same_device_results in grouped_benchmark_results: dump_unsort_list: list[tuple[float, str]] = [] - for model_candidate_result in same_device_results: - # Skip if benchmark failed. - result_str = model_candidate_result.result.stdout + for task_result in same_device_results: + result_str = task_result.result.stdout + candidate_id = task_result.candidate_id + device_id = task_result.device_id + + # Check if benchmarking has completed if result_str is None: + incomplete_list.append((candidate_id, device_id)) + if candidate_id == 0: + baseline_time = None continue - res = ModelBenchmarkResult(result_str) - device_id = res.get_device_id() - - # Record baseline benchmarking result. - model_candidate_path = res.get_model_candidate_path() - if ( - model_candidate_path is not None - and str(path_config.model_baseline_vmfb) in model_candidate_path - ): - baseline_time = res.get_benchmark_time() - if baseline_time is None: - incomplete_list.append((0, device_id)) - continue - dump_list.append(result_str) + res = IREEBenchmarkResult(candidate_id, result_str) + benchmark_time = res.get_mean_time() + assert benchmark_time is not None + + # Record baseline benchmarking result and skip rest processes + if candidate_id == 0: + baseline_time = benchmark_time + baseline_vmfb_path = candidate_trackers[ + candidate_id + ].compiled_model_path + assert baseline_vmfb_path is not None + dump_str = ( + generate_display_MBR( + candidate_vmfb_path_str=baseline_vmfb_path.as_posix(), + device_id=device_id, + t1=benchmark_time, + ) + + "\n\n" + ) + dump_list.append(dump_str) continue - # Record candidate benchmarking result. - c_id = res.get_candidate_id() - assert c_id is not None - candidate_time = res.get_benchmark_time() - if candidate_time is None: - incomplete_list.append((c_id, device_id)) - continue - candidate_trackers[c_id].model_benchmark_time = candidate_time - candidate_trackers[c_id].model_benchmark_device_id = device_id - # Skip improvement calculation if no baseline data. - if baseline_time is None: - dump_unsort_list.append((candidate_time, result_str)) - continue - # Calculate candidate improvement based baseline. - candidate_trackers[c_id].baseline_benchmark_time = baseline_time - calibrated_benchmark_diff = (candidate_time - baseline_time) / baseline_time - candidate_trackers[c_id].calibrated_benchmark_diff = ( - calibrated_benchmark_diff + # Update candidate_tracker + candidate_trackers[candidate_id].model_benchmark_time = benchmark_time + candidate_trackers[candidate_id].model_benchmark_device_id = device_id + + # Calculate candidate improvement based on baseline. + if baseline_time: + candidate_trackers[candidate_id].baseline_benchmark_time = baseline_time + calibrated_benchmark_diff = ( + benchmark_time - baseline_time + ) / baseline_time + candidate_trackers[candidate_id].calibrated_benchmark_diff = ( + calibrated_benchmark_diff + ) + else: + calibrated_benchmark_diff = None + + # Collect candidate dump str + candidate_vmfb_path = candidate_trackers[candidate_id].compiled_model_path + assert candidate_vmfb_path is not None + dump_str = ( + generate_display_MBR( + candidate_vmfb_path_str=candidate_vmfb_path.as_posix(), + device_id=device_id, + t1=benchmark_time, + calibrated_diff=calibrated_benchmark_diff, + ) + + "\n\n" ) - dump_str = res.get_calibrated_result_str(calibrated_benchmark_diff) - assert dump_str is not None - dump_unsort_list.append((candidate_time, dump_str)) + + dump_unsort_list.append((benchmark_time, dump_str)) # Sort model candidate benchmarking result str in ascending time order. dump_list = dump_list + [ @@ -1165,63 +1178,16 @@ def parse_grouped_benchmark_results( ] # Store incomplete .vmfb file at the end of dump_list. - for index, device_id in incomplete_list: - index_to_path = lambda index: ( - f"{path_config.model_baseline_vmfb.as_posix()}" - if index == 0 - else f"{candidate_trackers[index].model_path}" - ) - error_msg = f"Benchmarking result of {index_to_path(index)} on deivce {device_id} is incomplete" + for index, device in incomplete_list: + file_path = candidate_trackers[index].compiled_model_path + assert file_path is not None + error_msg = f"Benchmarking result of {file_path.as_posix()} on device {device} is incomplete" handle_error(condition=True, msg=error_msg, level=logging.WARNING) dump_list.append(error_msg + "\n") return dump_list -def generate_dryrun_unet_benchmark_results( - unet_vmfb_paths: list[Path], -) -> list[TaskResult]: - logging.info("generate_dryrun_unet_benchmark_results") - task_results = [] - start = random.uniform(100.0, 500.0) - device_id = 0 - for candidate_vmfb_path in unet_vmfb_paths: - task_result = subprocess.CompletedProcess( - args=[""], - returncode=0, - stdout=ModelBenchmarkResult().generate_sample_result( - candidate_vmfb_path_str=candidate_vmfb_path.as_posix(), - device_id=device_id, - t1=start, - ), - stderr="", - ) - start += random.uniform(-5.0, 8.0) - task_results.append(TaskResult(task_result, device_id)) - return task_results - - -def dryrun_benchmark_unet( - path_config: PathConfig, - unet_candidates: list[int], - candidate_trackers: list[CandidateTracker], -): - - unet_vmfb_paths = [path_config.model_baseline_vmfb] + [ - Path(f"unet_candidate_{index}.vmfb") for index in unet_candidates - ] - benchmark_results = generate_dryrun_unet_benchmark_results(unet_vmfb_paths) - grouped_benchmark_results = group_benchmark_results_by_device_id(benchmark_results) - - # Update candidate_tracker and extract strings which will be stored in output.log. - dump_list = parse_grouped_benchmark_results( - path_config, grouped_benchmark_results, candidate_trackers - ) - append_to_file( - dump_list, filepath=path_config.output_unilog, title="Unet Benchmark Results" - ) - - def benchmark_models( args: argparse.Namespace, path_config: PathConfig, @@ -1233,61 +1199,57 @@ def benchmark_models( logging.info("benchmark_models()") if args.dry_run: - dryrun_benchmark_unet(path_config, model_candidates, candidate_trackers) - return - - # Benchmarking model candidates - worker_context_queue = create_worker_context_queue(args.devices) - benchmark_task_list = [ - TaskTuple( - args, - tuning_client.get_model_benchmark_command(candidate_trackers[i]), - check=False, - command_need_device_id=True, - cooling_time=10, - result_need_device_id=True, + candidate_results, baseline_results = generate_dryrun_model_benchmark_results( + model_candidates ) - for i in model_candidates - ] - benchmark_results = multiprocess_progress_wrapper( - num_worker=len(args.devices), - task_list=benchmark_task_list, - function=run_command_wrapper, - initializer=init_worker_context, - initializer_inputs=(worker_context_queue,), - ) - benchmark_results = sorted(benchmark_results, key=lambda br: br.device_id) - grouped_benchmark_results = group_benchmark_results_by_device_id(benchmark_results) - - # Benchmarking baselines on each involved device - candidate_trackers[0].model_path = path_config.model_baseline_vmfb - worker_context_queue = create_worker_context_queue(args.devices) - baseline_task_list = [ - TaskTuple( - args, - tuning_client.get_model_benchmark_command(candidate_trackers[0]), - check=False, - command_need_device_id=True, - result_need_device_id=True, + else: + # Benchmarking model candidates + worker_context_queue = create_worker_context_queue(args.devices) + benchmark_task_list = [ + TaskPack( + args, + candidate_id=i, + command=tuning_client.get_model_benchmark_command( + candidate_trackers[i] + ), + check=False, + command_need_device_id=True, + cooling_time=10, + ) + for i in model_candidates + ] + candidate_results = multiprocess_progress_wrapper( + num_worker=len(args.devices), + task_list=benchmark_task_list, + function=run_command_wrapper, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), ) - ] * len(grouped_benchmark_results) - baseline_results = multiprocess_progress_wrapper( - num_worker=len(args.devices), - task_list=baseline_task_list, - function=run_command_wrapper, - initializer=init_worker_context, - initializer_inputs=(worker_context_queue,), - ) - baseline_results = sorted(baseline_results, key=lambda tr: tr.device_id) - # Insert baseline results to the head of each list - grouped_benchmark_results = [ - [x] + y for x, y in zip(baseline_results, grouped_benchmark_results) - ] + # Benchmarking baselines on each involved device + candidate_trackers[0].compiled_model_path = path_config.model_baseline_vmfb + worker_context_queue = create_worker_context_queue(args.devices) + baseline_task_list = [ + TaskPack( + args, + candidate_id=0, + command=tuning_client.get_model_benchmark_command( + candidate_trackers[0] + ), + check=False, + command_need_device_id=True, + ) + ] * len(group_benchmark_results_by_device_id(candidate_results)) + baseline_results = multiprocess_progress_wrapper( + num_worker=len(args.devices), + task_list=baseline_task_list, + function=run_command_wrapper, + initializer=init_worker_context, + initializer_inputs=(worker_context_queue,), + ) - # Update candidate_tracker and extract strings which will be stored later - dump_list = parse_grouped_benchmark_results( - path_config, grouped_benchmark_results, candidate_trackers + dump_list = parse_model_benchmark_results( + candidate_trackers, candidate_results, baseline_results ) append_to_file( @@ -1322,3 +1284,9 @@ def summerize_top_candidates( with open(path_config.result_summary_log, "w") as file: file.writelines(dump_list) + + +def sanitize_filename(filename: str) -> str: + # Replace invalid characters by an underscore + sanitized = re.sub(r"[^\w\.-]", "_", filename) + return sanitized diff --git a/tuning/punet_autotune.py b/tuning/punet_autotune.py index 014c76d..c281bbf 100644 --- a/tuning/punet_autotune.py +++ b/tuning/punet_autotune.py @@ -41,14 +41,24 @@ def get_dispatch_compile_command( return command def get_dispatch_benchmark_command( - self, candidate_tracker: libtuner.CandidateTracker + self, + candidate_tracker: libtuner.CandidateTracker, ) -> list[str]: compiled_vmfb_path = candidate_tracker.compiled_dispatch_path assert compiled_vmfb_path is not None + command = [ - "./benchmark_dispatch.sh", - compiled_vmfb_path.as_posix(), + "timeout", + "16s", + "./tools/iree-benchmark-module", + f"--device={libtuner.DEVICE_ID_PLACEHOLDER}", + f"--module={compiled_vmfb_path.resolve()}", + "--hip_use_streams=true", + "--hip_allow_inline_execution=true", + "--batch_size=1000", + "--benchmark_repetitions=3", ] + return command def get_model_compile_command( @@ -56,21 +66,46 @@ def get_model_compile_command( ) -> list[str]: mlir_spec_path = candidate_tracker.spec_path assert mlir_spec_path is not None + script_dir = Path(__file__).resolve().parent + target_dir = mlir_spec_path.resolve().parent.parent.parent + output_name = f"unet_candidate_{candidate_tracker.candidate_id}.vmfb" command = [ - "./compile_unet_candidate.sh", - "winograd", - mlir_spec_path.as_posix(), + "timeout", + "300s", + (script_dir / "../int8-model/compile-punet-base.sh").as_posix(), + "./tools/iree-compile", + "gfx942", + f"{mlir_spec_path.resolve()}", + "./punet.mlir", + "-o", + (target_dir / output_name).as_posix(), ] return command def get_model_benchmark_command( self, candidate_tracker: libtuner.CandidateTracker ) -> list[str]: - unet_candidate_path = candidate_tracker.model_path + unet_candidate_path = candidate_tracker.compiled_model_path assert unet_candidate_path is not None + command = [ - "./benchmark_unet_candidate.sh", - unet_candidate_path.as_posix(), + "timeout", + "180s", + "tools/iree-benchmark-module", + f"--device={libtuner.DEVICE_ID_PLACEHOLDER}", + "--hip_use_streams=true", + "--hip_allow_inline_execution=true", + "--device_allocator=caching", + f"--module={unet_candidate_path.resolve()}", + "--parameters=model=punet.irpa", + "--function=main", + "--input=1x4x128x128xf16", + "--input=1xsi32", + "--input=2x64x2048xf16", + "--input=2x1280xf16", + "--input=2x6xf16", + "--input=1xf16", + "--benchmark_repetitions=5", ] return command @@ -93,10 +128,8 @@ def main(): print("Validation successful!\n") print("Generating candidates...") - candidates = libtuner.generate_candidates( - args, path_config, candidate_trackers, punet_client - ) - print(f"Generated [{len(candidates)}] candidates in {path_config.candidates_dir}\n") + candidates = libtuner.generate_candidates(args, path_config, candidate_trackers) + print(f"Stored candidates in {path_config.candidates_dir}\n") if stop_after_phase == libtuner.ExecutionPhases.generate_candidates: return diff --git a/tuning/test_libtuner.py b/tuning/test_libtuner.py index 394208c..7b68533 100644 --- a/tuning/test_libtuner.py +++ b/tuning/test_libtuner.py @@ -15,53 +15,28 @@ def test_group_benchmark_results_by_device_id(): - def generate_res(res_arg: str, device_id: int) -> libtuner.TaskResult: - result: libtuner.subprocess.CompletedProcess = ( - libtuner.subprocess.CompletedProcess( - args=[res_arg], - returncode=0, - ) - ) - return libtuner.TaskResult(result=result, device_id=device_id) - - test_input = [ - generate_res("str1", 3), - generate_res("str7", 4), - generate_res("str2", 1), - generate_res("str5", 3), - generate_res("str5", 7), - generate_res("str3", 4), - ] - expect_output = [ - [generate_res("str2", 1)], - [generate_res("str1", 3), generate_res("str5", 3)], - [generate_res("str7", 4), generate_res("str3", 4)], - [generate_res("str5", 7)], + # Create mock TaskResult objects with device_id attributes + task_result_1 = MagicMock() + task_result_1.device_id = "device_1" + + task_result_2 = MagicMock() + task_result_2.device_id = "device_2" + + task_result_3 = MagicMock() + task_result_3.device_id = "device_1" + + benchmark_results = [task_result_1, task_result_2, task_result_3] + + expected_grouped_results = [ + [task_result_1, task_result_3], # Grouped by device_1 + [task_result_2], # Grouped by device_2 ] - actual_output = libtuner.group_benchmark_results_by_device_id(test_input) - - for a, e in zip(actual_output, expect_output): - for res1, res2 in zip(a, e): - assert res1.result.args == res2.result.args - assert res1.device_id == res2.device_id - - -def test_sort_candidates_by_first_benchmark_times(): - candidate_trackers = [libtuner.CandidateTracker(i) for i in range(5)] - candidate_trackers[0].first_benchmark_time = 35 - candidate_trackers[1].first_benchmark_time = 2141 - candidate_trackers[2].first_benchmark_time = 231 - candidate_trackers[3].first_benchmark_time = 231.23 - candidate_trackers[4].first_benchmark_time = 58 - test_input = [i for i in range(5)] - expect_output = [0, 4, 2, 3, 1] - assert ( - libtuner.sort_candidates_by_first_benchmark_times( - test_input, candidate_trackers - ) - == expect_output - ) + grouped_results = libtuner.group_benchmark_results_by_device_id(benchmark_results) + + assert grouped_results == expected_grouped_results + assert grouped_results[0][0].device_id == "device_1" + assert grouped_results[1][0].device_id == "device_2" def test_find_collisions(): @@ -81,320 +56,219 @@ def test_collision_handler(): assert libtuner.collision_handler(input) == (False, []) -def test_DispatchBenchmarkResult_get(): - normal_str = "2 Mean Time: 586.0" - res = libtuner.DispatchBenchmarkResult(normal_str) - assert res.result_str == normal_str - assert res.get_tokens() == ["2", "Mean", "Time:", "586.0"] - assert res.get_candidate_id() == 2 - assert res.get_benchmark_time() == 586.0 - - incomplete_str = "2 Mean Time:" - res = libtuner.DispatchBenchmarkResult(incomplete_str) - assert res.get_tokens() == ["2", "Mean", "Time:"] - assert res.get_candidate_id() == 2 - assert res.get_benchmark_time() == None - incomplete_str = "" - res = libtuner.DispatchBenchmarkResult(incomplete_str) - assert res.get_tokens() == [] - assert res.get_candidate_id() == None - assert res.get_benchmark_time() == None - - bad_str = 12345 - res = libtuner.DispatchBenchmarkResult(bad_str) - assert res.get_tokens() == [] - assert res.get_candidate_id() == None - assert res.get_benchmark_time() == None - - -def test_ModelBenchmarkResult_get(): - normal_str = "Benchmarking: unet_candidate_12.vmfb on device 24\nBM_main/process_time/real_time_median 182 ms 183 ms 5 items_per_second=5.50302/s" - res = libtuner.ModelBenchmarkResult(normal_str) - assert res.result_str == normal_str - assert res.get_tokens() == [ - "Benchmarking:", - "unet_candidate_12.vmfb", - "on", - "device", - "24", - "BM_main/process_time/real_time_median", - "182", - "ms", - "183", - "ms", - "5", - "items_per_second=5.50302/s", - ] - assert res.get_model_candidate_path() == "unet_candidate_12.vmfb" - assert res.get_candidate_id() == 12 - assert res.get_device_id() == 24 - assert res.get_benchmark_time() == 182.0 - - incomplete_str = "Benchmarking: baseline.vmfb on device 24\n" - res = libtuner.ModelBenchmarkResult(incomplete_str) - assert res.get_tokens() == [ - "Benchmarking:", - "baseline.vmfb", - "on", - "device", - "24", - ] - assert res.get_model_candidate_path() == "baseline.vmfb" - assert res.get_candidate_id() == None - assert res.get_device_id() == 24 - assert res.get_benchmark_time() == None - incomplete_str = "" - res = libtuner.ModelBenchmarkResult(incomplete_str) - assert res.get_tokens() == [] - assert res.get_model_candidate_path() == None - assert res.get_candidate_id() == None - assert res.get_device_id() == None - assert res.get_benchmark_time() == None - - bad_str = 12345 - res = libtuner.ModelBenchmarkResult(bad_str) - assert res.get_tokens() == [] - assert res.get_model_candidate_path() == None - assert res.get_candidate_id() == None - assert res.get_device_id() == None - assert res.get_benchmark_time() == None - - -def test_generate_sample_result(): - res = libtuner.DispatchBenchmarkResult() - output = res.generate_sample_result(1, 3.14) - expected = f"1\tMean Time: 3.1\n" - assert output == expected, "DispatchBenchmarkResult generates invalid sample string" - - res = libtuner.ModelBenchmarkResult() - output = res.generate_sample_result( - 1, "some_dir/tuning_2024_07_24_20_06/unet_candidate_60.vmfb.vmfb", 576.89 +def test_IREEBenchmarkResult_get(): + # Time is int + normal_str = r""" + ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ + Benchmark Time CPU Iterations UserCounters... + ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 271 us 275 us 3000 items_per_second=3.65611k/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 274 us 275 us 3000 items_per_second=3.65481k/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time 273 us 275 us 3000 items_per_second=3.65671k/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_mean 274 us 275 us 3 items_per_second=3.65587k/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_mean 275 us 275 us 3 items_per_second=3.65611k/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_stddev 0.073 us 0.179 us 3 items_per_second=0.971769/s + BM_main$async_dispatch_311_rocm_hsaco_fb_main$async_dispatch_311_matmul_like_2x1024x1280x5120_i8xi8xi32/process_time/real_time_cv 0.03 % 0.07 % 3 items_per_second=0.03% + """ + res = libtuner.IREEBenchmarkResult(candidate_id=1, result_str=normal_str) + assert res.get_mean_time() == float(274) + + # Time is float + res = libtuner.IREEBenchmarkResult( + candidate_id=2, + result_str="process_time/real_time_mean 123.45 us, process_time/real_time_mean 246.78 us", ) - expected = f"Benchmarking: 1 on device some_dir/tuning_2024_07_24_20_06/unet_candidate_60.vmfb.vmfb\nBM_run_forward/process_time/real_time_median\t 577 ms\t 578 ms\t 5 items_per_second=2.884450/s\n\n" - assert output == expected, "UnetBenchmarkResult generates invalid sample string" + assert res.get_mean_time() == 123.45 - -def test_ModelBenchmarkResult_get_calibrated_result_str(): - baseline_time = 423 - res_time = 304 - result_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median {float(res_time)} ms 305 ms 5 items_per_second=1.520000/s" - change = (res_time - baseline_time) / baseline_time - output_str = libtuner.ModelBenchmarkResult(result_str).get_calibrated_result_str( - change - ) - expect_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median\t {float(res_time)} ms (-28.132%)\t 305 ms\t 5 items_per_second=1.520000/s" - assert output_str == expect_str - - baseline_time = 218 - res_time = 218 - result_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median {float(res_time)} ms 305 ms 5 items_per_second=1.520000/s" - change = (res_time - baseline_time) / baseline_time - output_str = libtuner.ModelBenchmarkResult(result_str).get_calibrated_result_str( - change - ) - expect_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median\t {float(res_time)} ms (+0.000%)\t 305 ms\t 5 items_per_second=1.520000/s" - assert output_str == expect_str - - baseline_time = 123 - res_time = 345 - result_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median {float(res_time)} ms 305 ms 5 items_per_second=1.520000/s" - change = (res_time - baseline_time) / baseline_time - output_str = libtuner.ModelBenchmarkResult(result_str).get_calibrated_result_str( - change - ) - expect_str = f"Benchmarking: tuning_2024_07_22_16_29/unet_candidate_16.vmfb on device 0\nBM_run_forward/process_time/real_time_median\t {float(res_time)} ms (+180.488%)\t 305 ms\t 5 items_per_second=1.520000/s" - assert output_str == expect_str + # Invalid str + res = libtuner.IREEBenchmarkResult(candidate_id=3, result_str="hello world") + assert res.get_mean_time() == None + res = libtuner.IREEBenchmarkResult(candidate_id=4, result_str="") + assert res.get_mean_time() == None -def test_parse_dispatch_benchmark_results(): - def generate_res(stdout: str) -> libtuner.TaskResult: - result = libtuner.subprocess.CompletedProcess( - args=[""], - stdout=stdout, - returncode=0, - ) - return libtuner.TaskResult(result) - - def generate_parsed_disptach_benchmark_result( - time: float, i: int - ) -> libtuner.ParsedDisptachBenchmarkResult: - return libtuner.ParsedDisptachBenchmarkResult( - i, - time, - path_config.get_candidate_mlir_path(i), - path_config.get_candidate_spec_mlir_path(i), - ) +def test_generate_display_BR(): + output = libtuner.generate_display_DBR(1, 3.14) + expected = f"1\tMean Time: 3.1" + assert output == expected, "DispatchBenchmarkResult generates invalid sample string" - test_list = [(0, 369.0), (1, 301.0), (2, 457.0), (3, 322.0), (4, 479.0)] - random_order = [2, 0, 3, 1, 4] - total = 5 + output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89) + expected = "Benchmarking: baseline.vmfb on device 1: 568" + assert output == expected, "ModelBenchmarkResult generates invalid sample string" + output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89, 0.0314) + expected = "Benchmarking: baseline.vmfb on device 1: 568 (+3.140%)" + assert output == expected, "ModelBenchmarkResult generates invalid sample string" + output = libtuner.generate_display_MBR("baseline.vmfb", str(1), 567.89, -3.14) + expected = "Benchmarking: baseline.vmfb on device 1: 568 (-314.000%)" + assert output == expected, "ModelBenchmarkResult generates invalid sample string" - benchmark_results = [ - generate_res(f"{test_list[i][0]} Mean Time: {test_list[i][1]}") - for i in random_order - ] +def test_parse_dispatch_benchmark_results(): + base_path = libtuner.Path("/mock/base/dir") + spec_dir = base_path / "specs" path_config = libtuner.PathConfig() - + object.__setattr__(path_config, "specs_dir", spec_dir) + + mock_result_1 = MagicMock() + mock_result_1.result.stdout = "process_time/real_time_mean 100.0 us" + mock_result_1.candidate_id = 1 + mock_result_2 = MagicMock() + mock_result_2.result.stdout = "process_time/real_time_mean 200.0 us" + mock_result_2.candidate_id = 2 + mock_result_3 = MagicMock() + mock_result_3.result.stdout = "" # Incomplete result + mock_result_3.candidate_id = 3 + benchmark_results = [mock_result_1, mock_result_2, mock_result_3] + + candidate_tracker_0 = libtuner.CandidateTracker(candidate_id=0) + candidate_tracker_0.dispatch_mlir_path = libtuner.Path("/mock/mlir/path/0.mlir") + candidate_tracker_1 = libtuner.CandidateTracker(candidate_id=1) + candidate_tracker_1.dispatch_mlir_path = libtuner.Path("/mock/mlir/path/1.mlir") + candidate_tracker_2 = libtuner.CandidateTracker(candidate_id=2) + candidate_tracker_2.dispatch_mlir_path = libtuner.Path("/mock/mlir/path/2.mlir") + candidate_tracker_3 = libtuner.CandidateTracker(candidate_id=3) + candidate_tracker_3.dispatch_mlir_path = libtuner.Path("/mock/mlir/path/3.mlir") candidate_trackers = [ - libtuner.CandidateTracker( - i, dispatch_mlir_path=path_config.get_candidate_mlir_path(i) - ) - for i in range(total) - ] - candidate_trackers_before = [ - libtuner.CandidateTracker( - i, dispatch_mlir_path=path_config.get_candidate_mlir_path(i) - ) - for i in range(total) + candidate_tracker_0, + candidate_tracker_1, + candidate_tracker_2, + candidate_tracker_3, ] - expect_candidate_trackers = [ - libtuner.CandidateTracker( - i, - dispatch_mlir_path=path_config.get_candidate_mlir_path(i), - spec_path=path_config.get_candidate_spec_mlir_path(i), - ) - for i in range(total) + expected_parsed_results = [ + libtuner.ParsedDisptachBenchmarkResult( + candidate_id=1, + benchmark_time_in_seconds=100.0, + candidate_mlir=libtuner.Path("/mock/mlir/path/1.mlir"), + candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/1_spec.mlir"), + ), + libtuner.ParsedDisptachBenchmarkResult( + candidate_id=2, + benchmark_time_in_seconds=200.0, + candidate_mlir=libtuner.Path("/mock/mlir/path/2.mlir"), + candidate_spec_mlir=libtuner.Path("/mock/base/dir/specs/2_spec.mlir"), + ), ] - for i in range(total): - expect_candidate_trackers[test_list[i][0]].first_benchmark_time = test_list[i][ - 1 - ] - - tmp = [generate_parsed_disptach_benchmark_result(t, i) for i, t in test_list] - expect_parsed_results = [tmp[i] for i in random_order] - expect_dump_list = [ - f"{test_list[i][0]} Mean Time: {test_list[i][1]}" for i in random_order + expected_dump_list = [ + "1\tMean Time: 100.0\n", + "2\tMean Time: 200.0\n", + "Candidate 3 not incompleted", ] - mock_tuning_client = MagicMock() - mock_tuning_client.get_candidate_spec_filename.side_effect = ( - lambda i: f"{i}_spec.mlir" - ) parsed_results, dump_list = libtuner.parse_dispatch_benchmark_results( - path_config, benchmark_results, candidate_trackers, mock_tuning_client + path_config, benchmark_results, candidate_trackers + ) + + assert parsed_results == expected_parsed_results + assert dump_list == expected_dump_list + assert candidate_trackers[1].first_benchmark_time == 100.0 + assert candidate_trackers[1].spec_path == libtuner.Path( + "/mock/base/dir/specs/1_spec.mlir" + ) + assert candidate_trackers[2].first_benchmark_time == 200.0 + assert candidate_trackers[2].spec_path == libtuner.Path( + "/mock/base/dir/specs/2_spec.mlir" ) - assert parsed_results == expect_parsed_results - assert dump_list == expect_dump_list - assert candidate_trackers != candidate_trackers_before - assert candidate_trackers == expect_candidate_trackers +def test_parse_model_benchmark_results(): + # Setup mock data for candidate_trackers + tracker0 = libtuner.CandidateTracker(0) + tracker0.compiled_model_path = libtuner.Path("/path/to/baseline.vmfb") -def test_parse_grouped_benchmark_results(): - def generate_res(stdout: str, device_id: int) -> libtuner.TaskResult: - result = libtuner.subprocess.CompletedProcess( - args=[""], - stdout=stdout, - returncode=0, - ) - return libtuner.TaskResult(result=result, device_id=device_id) - - def set_tracker( - tracker: libtuner.CandidateTracker, - model_benchmark_time: float, - model_benchmark_device_id: int, - baseline_benchmark_time: float, - calibrated_benchmark_diff=float, - ): - tracker.model_benchmark_time = model_benchmark_time - tracker.model_benchmark_device_id = model_benchmark_device_id - tracker.baseline_benchmark_time = baseline_benchmark_time - tracker.calibrated_benchmark_diff = calibrated_benchmark_diff - - b1 = "Benchmarking: some_dir/baseline.vmfb on device 0 BM_main/process_time/real_time_median 60.7 ms 13.5 ms 5 items_per_second=16.4733/s" - b2 = "Benchmarking: baseline.vmfb on device 1 BM_main/process_time/real_time_median 59.8 ms 15.1 ms 5 items_per_second=16.7114/s" - s1 = "Benchmarking: unet_candidate_1.vmfb on device 0 BM_main/process_time/real_time_median 62.4 ms 15.4 ms 5 items_per_second=16.0223/s" - s2 = "Benchmarking: some_dir/unet_candidate_2.vmfb on device 1 BM_main/process_time/real_time_median 61.4 ms 11.0 ms 5 items_per_second=16.2958/s" - s3 = "Benchmarking: unet_candidate_4.vmfb on device 1 BM_main/process_time/real_time_median 57.4 ms 11.0 ms 5 items_per_second=16.2958/s" - - grouped_benchmark_results = [ - [generate_res(b1, 0), generate_res(s1, 0)], - [ - generate_res(b2, 1), - generate_res(None, 1), - generate_res(s2, 1), - generate_res(s3, 1), - ], - ] + tracker1 = libtuner.CandidateTracker(1) + tracker1.compiled_model_path = libtuner.Path("/path/to/model_1.vmfb") - path_config = libtuner.PathConfig() + tracker2 = libtuner.CandidateTracker(2) + tracker2.compiled_model_path = libtuner.Path("/path/to/model_2.vmfb") - candidate_trackers = [libtuner.CandidateTracker(i) for i in range(5)] - - candidate_trackers_before = [libtuner.CandidateTracker(i) for i in range(5)] - expect_candidate_trackers = [libtuner.CandidateTracker(i) for i in range(5)] - set_tracker(expect_candidate_trackers[1], 62.4, 0, 60.7, 0.028006589785831888) - set_tracker(expect_candidate_trackers[2], 61.4, 1, 59.8, 0.02675585284280939) - set_tracker(expect_candidate_trackers[4], 57.4, 1, 59.8, -0.04013377926421403) - - expect_dump_list = [ - "Benchmarking: some_dir/baseline.vmfb on device 0 " - "BM_main/process_time/real_time_median 60.7 ms 13.5 ms 5 items_per_second=16.4733/s", - "Benchmarking: unet_candidate_1.vmfb on device 0 " - "BM_main/process_time/real_time_median 62.4 ms (+2.801%) 15.4 ms 5 items_per_second=16.0223/s", - "Benchmarking: baseline.vmfb on device 1 " - "BM_main/process_time/real_time_median 59.8 ms 15.1 ms 5 items_per_second=16.7114/s", - "Benchmarking: unet_candidate_4.vmfb on device 1 " - "BM_main/process_time/real_time_median 57.4 ms (-4.013%) 11.0 ms 5 items_per_second=16.2958/s", - "Benchmarking: some_dir/unet_candidate_2.vmfb on device 1 " - "BM_main/process_time/real_time_median 61.4 ms (+2.676%) 11.0 ms 5 items_per_second=16.2958/s", - ] + tracker3 = libtuner.CandidateTracker(3) + tracker3.compiled_model_path = libtuner.Path("/path/to/model_3.vmfb") - dump_list = libtuner.parse_grouped_benchmark_results( - path_config, grouped_benchmark_results, candidate_trackers - ) + candidate_trackers = [tracker0, tracker1, tracker2, tracker3] - assert dump_list == expect_dump_list, "basic parsing is incorrect" - assert ( - candidate_trackers != candidate_trackers_before - ), "candidate_trackers should be modified" - assert ( - candidate_trackers == expect_candidate_trackers - ), "candidate_trackers did not change as expected" - - b1 = "Benchmarking: baseline.vmfb on device 0" - s1 = "Benchmarking: unet_candidate_1.vmfb on device 0 BM_main/process_time/real_time_median 62.4 ms 15.4 ms 5 items_per_second=16.0223/s" - grouped_benchmark_results = [[generate_res(b1, 0), generate_res(s1, 0)]] - dump_list = libtuner.parse_grouped_benchmark_results( - path_config, grouped_benchmark_results, candidate_trackers - ) - expect_dump_list = [ - "Benchmarking: unet_candidate_1.vmfb on device 0 " - "BM_main/process_time/real_time_median 62.4 ms 15.4 ms 5 items_per_second=16.0223/s", - "Benchmarking result of baseline.vmfb on deivce 0 is incomplete\n", - ] - assert dump_list == expect_dump_list, "fail to parse incomplete baselines" - - b1 = "Benchmarking: some_dir/baseline.vmfb on device 0 BM_main/process_time/real_time_median 60.7 ms 13.5 ms 5 items_per_second=16.4733/s" - s1 = "Benchmarking: unet_candidate_1.vmfb on device 0" - grouped_benchmark_results = [[generate_res(b1, 0), generate_res(s1, 0)]] - candidate_trackers[1].model_path = "unet_candidate_1.vmfb" - dump_list = libtuner.parse_grouped_benchmark_results( - path_config, grouped_benchmark_results, candidate_trackers - ) - expect_dump_list = [ - "Benchmarking: some_dir/baseline.vmfb on device 0 " - "BM_main/process_time/real_time_median 60.7 ms 13.5 ms 5 items_per_second=16.4733/s", - "Benchmarking result of unet_candidate_1.vmfb on deivce 0 is incomplete\n", - ] - assert dump_list == expect_dump_list, "fail to parse incomplete candidates" - - b1 = "Benchmarking: baseline.vmfb on device 0" - s1 = "Benchmarking: unet_candidate_1.vmfb on device 0" - grouped_benchmark_results = [[generate_res(b1, 0), generate_res(s1, 0)]] - candidate_trackers[1].model_path = "unet_candidate_1.vmfb" - dump_list = libtuner.parse_grouped_benchmark_results( - path_config, grouped_benchmark_results, candidate_trackers - ) - expect_dump_list = [ - "Benchmarking result of baseline.vmfb on deivce 0 is incomplete\n", - "Benchmarking result of unet_candidate_1.vmfb on deivce 0 is incomplete\n", - ] - assert ( - dump_list == expect_dump_list - ), "fail to parse incomplete baseline and candidates" + # Setup mock data for task results + result1 = MagicMock(spec=libtuner.TaskResult) + result1.result = MagicMock(stdout="1.23") + result1.candidate_id = 1 + result1.device_id = "device1" + + result2 = MagicMock(spec=libtuner.TaskResult) + result2.result = MagicMock(stdout="4.56") + result2.candidate_id = 2 + result2.device_id = "device2" + + result3 = MagicMock(spec=libtuner.TaskResult) + result3.result = MagicMock(stdout="0.98") + result3.candidate_id = 0 + result3.device_id = "device1" + + result4 = MagicMock(spec=libtuner.TaskResult) + result4.result = MagicMock(stdout="4.13") + result4.candidate_id = 0 + result4.device_id = "device2" + + # Incomplete baseline on device3 + result5 = MagicMock(spec=libtuner.TaskResult) + result5.result = MagicMock(stdout=None) + result5.candidate_id = 0 + result5.device_id = "device3" + + result6 = MagicMock(spec=libtuner.TaskResult) + result6.result = MagicMock(stdout="3.38") + result6.candidate_id = 3 + result6.device_id = "device3" + + candidate_results = [result1, result2, result6] + baseline_results = [result3, result4, result5] + + # Skip real benchmark extraction, directly use given values from above + def mock_get_mean_time(self): + return float(self.result_str) if self.result_str else None + + # Mock IREEBenchmarkResult to return wanted benchmark times + with patch("libtuner.IREEBenchmarkResult.get_mean_time", new=mock_get_mean_time): + # Mock handle_error to avoid actual logging during tests + with patch("libtuner.handle_error") as mock_handle_error: + dump_list = libtuner.parse_model_benchmark_results( + candidate_trackers, candidate_results, baseline_results + ) + + # Verify interactions with candidate_trackers + assert tracker1.model_benchmark_time == 1.23 + assert tracker1.model_benchmark_device_id == "device1" + assert tracker1.baseline_benchmark_time == 0.98 + assert tracker1.calibrated_benchmark_diff == pytest.approx( + (1.23 - 0.98) / 0.98, rel=1e-6 + ) + + assert tracker2.model_benchmark_time == 4.56 + assert tracker2.model_benchmark_device_id == "device2" + assert tracker2.baseline_benchmark_time == 4.13 + assert tracker2.calibrated_benchmark_diff == pytest.approx( + (4.56 - 4.13) / 4.13, rel=1e-6 + ) + + assert tracker3.model_benchmark_time == 3.38 + assert tracker3.model_benchmark_device_id == "device3" + + assert dump_list == [ + "Benchmarking: /path/to/baseline.vmfb on device device1: 0.98\n" "\n", + "Benchmarking: /path/to/model_1.vmfb on device device1: 1.23 (+25.510%)\n" + "\n", + "Benchmarking: /path/to/baseline.vmfb on device device2: 4.13\n" "\n", + "Benchmarking: /path/to/model_2.vmfb on device device2: 4.56 (+10.412%)\n" + "\n", + "Benchmarking: /path/to/model_3.vmfb on device device3: 3.38\n" "\n", + "Benchmarking result of /path/to/baseline.vmfb on device device3 is incomplete\n", + ] + + # Verify handle_error was called correctly + mock_handle_error.assert_called_once_with( + condition=True, + msg="Benchmarking result of /path/to/baseline.vmfb on device device3 is incomplete", + level=libtuner.logging.WARNING, + ) def test_extract_driver_names():