Skip to content

Commit

Permalink
Complete tuning process in punet_autotune.py
Browse files Browse the repository at this point in the history
  • Loading branch information
RattataKing committed Aug 19, 2024
1 parent ae05cd3 commit d420c46
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 26 deletions.
33 changes: 7 additions & 26 deletions tuning/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,27 +178,6 @@ def get_model_benchmark_command(self, candidate_tracker) -> list[str]:
pass


@dataclass
class DefaultTuningClient(TuningClient):
def get_dispatch_compile_command(
self, candidate_tracker: CandidateTracker
) -> list[str]:
command = [""]
return command

def get_dispatch_benchmark_command(self, candidate_tracker) -> list[str]:
command = [""]
return command

def get_model_compile_command(self, candidate_tracker) -> list[str]:
command = [""]
return command

def get_model_benchmark_command(self, candidate_tracker) -> list[str]:
command = [""]
return command


@dataclass
class TaskTuple:
args: argparse.Namespace
Expand Down Expand Up @@ -724,6 +703,11 @@ def load_pickle(file_path: Path) -> list[Any]:
return loaded_array


def save_pickle(file_path: Path, input_list: list[Any]) -> None:
with open(file_path, "wb") as file:
pickle.dump(input_list, file)


def append_to_file(lines: list[str], filepath: Path, title: str = "") -> None:
"""Appends new content to the end of the output.log."""
title_str = "=" * 5 + f" {title} " + "=" * 5 + "\n" if title != "" else ""
Expand Down Expand Up @@ -1359,13 +1343,12 @@ def summerize_top_candidates(
file.writelines(dump_list)


def autotune(args: argparse.Namespace) -> None:
def autotune(args: argparse.Namespace, tuning_client: TuningClient) -> None:
path_config = PathConfig()
path_config.base_dir.mkdir(parents=True, exist_ok=True)
path_config.output_unilog.touch()

candidate_trackers: list[CandidateTracker] = []
tuning_client = DefaultTuningClient()
stop_after_phase: str = args.stop_after

print("Setup logging")
Expand Down Expand Up @@ -1397,7 +1380,6 @@ def autotune(args: argparse.Namespace) -> None:
args, path_config, compiled_candidates, candidate_trackers, tuning_client
)
print(f"Stored results in {path_config.output_unilog}\n")

if stop_after_phase == ExecutionPhases.benchmark_dispatches:
return

Expand All @@ -1420,8 +1402,7 @@ def autotune(args: argparse.Namespace) -> None:
summerize_top_candidates(path_config, candidate_trackers)
print(f"Stored top candidates info in {path_config.result_summary_log}\n")

with open(path_config.candidate_trackers_pkl, "wb") as file:
pickle.dump(candidate_trackers, file)
save_pickle(path_config.candidate_trackers_pkl, candidate_trackers)
print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n")

print("Check the detailed execution logs in:")
Expand Down
41 changes: 41 additions & 0 deletions tuning/punet_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,69 @@ def main():
path_config.output_unilog.touch()
candidate_trackers: list[autotune.CandidateTracker] = []
punet_client = PunetClient()
stop_after_phase: str = args.stop_after

print("Setup logging")
autotune.setup_logging(args, path_config)
print(path_config.run_log, end="\n\n")

print("Validating devices")
autotune.validate_devices(args.devices)
print("Validation successful!\n")

print("Generating candidates...")
candidates = autotune.generate_candidates(
args, path_config, candidate_trackers, punet_client
)
print(f"Generated [{len(candidates)}] candidates in {path_config.candidates_dir}\n")
if stop_after_phase == autotune.ExecutionPhases.generate_candidates:
return

print("Compiling candidates...")
compiled_candidates = autotune.compile_dispatches(
args, path_config, candidates, candidate_trackers, punet_client
)
print(f"Compiled files are stored in {path_config.compiled_dir}\n")
if stop_after_phase == autotune.ExecutionPhases.compile_dispatches:
return

print("Benchmarking compiled candidates...")
top_candidates = autotune.benchmark_dispatches(
args, path_config, compiled_candidates, candidate_trackers, punet_client
)
print(f"Stored results in {path_config.output_unilog}\n")
if stop_after_phase == ExecutionPhases.benchmark_dispatches:
return

print(f"Compiling top model candidates...")
punet_candidates = autotune.compile_models(
args, path_config, top_candidates, candidate_trackers, punet_client
)
print(f"Model candidates compiled in {path_config.base_dir}\n")
if stop_after_phase == autotune.ExecutionPhases.compile_models:
return

print("Benchmarking model candidates...")
autotune.benchmark_models(
args, path_config, punet_candidates, candidate_trackers, punet_client
)
print(f"Stored results in {path_config.output_unilog}")
if stop_after_phase == autotune.ExecutionPhases.benchmark_models:
return

autotune.summerize_top_candidates(path_config, candidate_trackers)
print(f"Stored top candidates info in {path_config.result_summary_log}\n")

autotune.save_pickle(path_config.candidate_trackers_pkl, candidate_trackers)
print(f"Candidate trackers are saved in {path_config.candidate_trackers_pkl}\n")

print("Check the detailed execution logs in:")
print(path_config.run_log)

for candidate in candidate_trackers:
autotune.logging.debug(candidate)
if args.verbose:
print(candidate)


if __name__ == "__main__":
Expand Down

0 comments on commit d420c46

Please sign in to comment.