From b4d9a7aae3357a29d6cf3b5ed06757cda321a17d Mon Sep 17 00:00:00 2001 From: Taekyung Heo <7621438+TaekyungHeo@users.noreply.github.com> Date: Wed, 15 May 2024 19:31:37 -0400 Subject: [PATCH] Apply Brian's changes to comply with Meta's requirements --- et_replay/lib/comm/commsTraceParser.py | 8 +- et_replay/lib/comm/comms_utils.py | 6 +- et_replay/lib/comm/pytorch_backend_utils.py | 2 +- et_replay/lib/comm/pytorch_dist_backend.py | 4 +- et_replay/lib/comm/pytorch_tpu_backend.py | 7 +- et_replay/lib/et_replay_utils.py | 12 +- et_replay/lib/utils.py | 2 +- et_replay/pyproject.toml | 2 +- et_replay/tools/comm_replay.py | 4 +- et_replay/tools/et_replay.py | 153 ++++++++++---------- 10 files changed, 103 insertions(+), 97 deletions(-) diff --git a/et_replay/lib/comm/commsTraceParser.py b/et_replay/lib/comm/commsTraceParser.py index 8207e585..c9a07cd3 100644 --- a/et_replay/lib/comm/commsTraceParser.py +++ b/et_replay/lib/comm/commsTraceParser.py @@ -5,11 +5,11 @@ from typing import List, Tuple -from et_replay.comm import comms_utils -from et_replay.comm.comms_utils import commsArgs -from et_replay.comm.pytorch_backend_utils import supportedP2pOps +from et_replay.lib.comm import comms_utils +from et_replay.lib.comm.comms_utils import commsArgs +from et_replay.lib.comm.pytorch_backend_utils import supportedP2pOps -from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace +from et_replay.lib.execution_trace import ExecutionTrace tensorDtypeMap = { "Tensor(int)": "int", diff --git a/et_replay/lib/comm/comms_utils.py b/et_replay/lib/comm/comms_utils.py index 240dc619..94ad2acc 100644 --- a/et_replay/lib/comm/comms_utils.py +++ b/et_replay/lib/comm/comms_utils.py @@ -38,15 +38,15 @@ import numpy as np import torch -from et_replay.comm.param_profile import paramTimer -from et_replay.comm.pytorch_backend_utils import ( +from et_replay.lib.comm.param_profile import paramTimer +from et_replay.lib.comm.pytorch_backend_utils import ( backendFunctions, collectiveArgsHolder, customized_backend, supportedC10dBackends, supportedDevices, ) -from torch._C._distributed_c10d import ProcessGroup +from torch._C._distributed_c10d import ProcessGroup # @manual random.seed() diff --git a/et_replay/lib/comm/pytorch_backend_utils.py b/et_replay/lib/comm/pytorch_backend_utils.py index 748843dc..35d72847 100644 --- a/et_replay/lib/comm/pytorch_backend_utils.py +++ b/et_replay/lib/comm/pytorch_backend_utils.py @@ -9,7 +9,7 @@ import torch -from et_replay.comm.param_profile import paramTimer +from et_replay.lib.comm.param_profile import paramTimer from torch.distributed import ProcessGroup diff --git a/et_replay/lib/comm/pytorch_dist_backend.py b/et_replay/lib/comm/pytorch_dist_backend.py index 3eda2649..6f1a0960 100644 --- a/et_replay/lib/comm/pytorch_dist_backend.py +++ b/et_replay/lib/comm/pytorch_dist_backend.py @@ -13,8 +13,8 @@ import torch import torch.distributed as dist import torch.nn as nn -from et_replay.comm.param_profile import paramProfile -from et_replay.comm.pytorch_backend_utils import ( +from et_replay.lib.comm.param_profile import paramProfile +from et_replay.lib.comm.pytorch_backend_utils import ( backendFunctions, collectiveArgsHolder, ) diff --git a/et_replay/lib/comm/pytorch_tpu_backend.py b/et_replay/lib/comm/pytorch_tpu_backend.py index 6cf4a7ab..7c5675eb 100644 --- a/et_replay/lib/comm/pytorch_tpu_backend.py +++ b/et_replay/lib/comm/pytorch_tpu_backend.py @@ -4,9 +4,10 @@ import numpy as np import torch import torch.nn as nn -import torch_xla.core.xla_model as xm -import torch_xla.distributed.xla_multiprocessing as xmp -from comms_utils import backendFunctions +import torch_xla.core.xla_model as xm # @manual +import torch_xla.distributed.xla_multiprocessing as xmp # @manual + +from .comms_utils import backendFunctions class PyTorchTPUBackend(backendFunctions): diff --git a/et_replay/lib/et_replay_utils.py b/et_replay/lib/et_replay_utils.py index 0be4ca90..0405b835 100644 --- a/et_replay/lib/et_replay_utils.py +++ b/et_replay/lib/et_replay_utils.py @@ -2,8 +2,8 @@ import re import torch +from et_replay.lib.execution_trace import NodeType from fbgemm_gpu.split_table_batched_embeddings_ops import PoolingMode, WeightDecayMode -from param_bench.et_replay.lib.execution_trace import NodeType from param_bench.train.compute.python.lib.pytorch.config_util import create_op_args @@ -469,11 +469,11 @@ def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows) import os import time from datetime import datetime -from et_replay.comm import comms_utils ++from et_replay.lib.comm import comms_utils import torch -from et_replay.comm import commsTraceReplay -from param_bench.et_replay.lib.et_replay_utils import ( +from et_replay.lib.comm import commsTraceReplay +from et_replay.lib.et_replay_utils import ( build_fbgemm_func, build_torchscript_func, generate_fbgemm_tensors, @@ -482,8 +482,8 @@ def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows) is_qualified, ) -from param_bench.et_replay.lib.execution_trace import ExecutionTrace -from param_bench.et_replay.lib.utils import trace_handler +from et_replay.lib.execution_trace import ExecutionTrace +from et_replay.lib.utils import trace_handler print("PyTorch version: ", torch.__version__) diff --git a/et_replay/lib/utils.py b/et_replay/lib/utils.py index 7bfaff16..5f188666 100644 --- a/et_replay/lib/utils.py +++ b/et_replay/lib/utils.py @@ -6,7 +6,7 @@ import uuid from typing import Any, Dict -from param_bench.et_replay.lib.execution_trace import ExecutionTrace +from et_replay.lib.execution_trace import ExecutionTrace def get_tmp_trace_filename() -> str: diff --git a/et_replay/pyproject.toml b/et_replay/pyproject.toml index 2af8391b..206b03dd 100644 --- a/et_replay/pyproject.toml +++ b/et_replay/pyproject.toml @@ -7,7 +7,7 @@ name = "et_replay" version = "0.5.0" [tool.setuptools.package-dir] -"et_replay.comm" = "lib/comm" +"et_replay.lib.comm" = "lib/comm" "et_replay.tools" = "tools" [project.scripts] diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 4d4c8bcd..6d5e8bd4 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -1383,9 +1383,7 @@ def initBackend( """ # init backend and corresponding function pointers if commsParams.nw_stack == "pytorch-dist": - from et_replay.comm.pytorch_dist_backend import ( - PyTorchDistBackend, - ) + from et_replay.comm.pytorch_dist_backend import PyTorchDistBackend self.backendFuncs = PyTorchDistBackend(bootstrap_info, commsParams) elif commsParams.nw_stack == "pytorch-xla-tpu": diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index dbe06ed7..ea86c850 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -12,6 +12,24 @@ import numpy as np import torch + +from et_replay.lib.comm import comms_utils + +# from et_replay.lib.comm import commsTraceReplay XXX FIXME + +from et_replay.lib.execution_trace import ExecutionTrace, NodeType + +from et_replay.lib.utils import trace_handler + +from param_bench.train.compute.python.lib import pytorch as lib_pytorch +from param_bench.train.compute.python.lib.init_helper import load_modules +from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch +from torch._inductor.codecache import AsyncCompile, TritonFuture + +# grid and split_scan_grid are dynamically loaded +from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid +from torch.profiler import ExecutionTraceObserver + from ..lib.et_replay_utils import ( build_fbgemm_func, build_torchscript_func, @@ -36,20 +54,6 @@ TORCH_DTYPES_RNG_str, ) -from et_replay.lib.execution_trace import ExecutionTrace, NodeType - -from param_bench.et_replay.lib.utils import trace_handler -from et_replay.comm import comms_utils, commsTraceReplay - -from param_bench.train.compute.python.lib import pytorch as lib_pytorch -from param_bench.train.compute.python.lib.init_helper import load_modules -from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch -from torch._inductor.codecache import AsyncCompile, TritonFuture - -# grid and split_scan_grid are dynamically loaded -from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid -from torch.profiler import ExecutionTraceObserver - class ExgrReplayManager: def __init__(self): @@ -1092,47 +1096,49 @@ def get_inputs(self, node): def run_op(self, node, iter): if node.name == "record_param_comms" and not self.compute_only: - opTensor = self.commsBench.replaySingle( - self.commsParams, node.id, self.regenerate_tensors - ) - # Wait, barrier has no output tensor. - if "wait" in node.inputs or "barrier" in node.inputs: - if self.wait_delay != 0: - time.sleep(self.wait_delay / 1000.0) - return - if self.args.separate: - return - - # # Total dimension of the output tensor should be the same as - # # the original in et, reshape if different. - # if type(opTensor) is list: - # for t in opTensor: - # print(t) - - original_shape = reduce(lambda x, y: x * y, node.output_shapes[0]) - op_tensor_shape = reduce(lambda x, y: x * y, list(opTensor.size())) - if original_shape != op_tensor_shape: - print( - "Comms ops output tensor shape mismatch: ", - node.id, - original_shape, - op_tensor_shape, - ) - exit(1) - op_tensor = torch.reshape(opTensor, tuple(node.output_shapes[0])) - t_id = tuple(node.outputs[0]) - if self.tensor_with_device: - t_id = tuple(list(t_id)[:5]) - if ( - t_id in self.dependency_permanent - and self.tensors_mapping[(node.id, t_id, False)] - not in self.unchangeable_intermediate_tensors - ): - if self.tensors_mapping[(node.id, t_id, False)] not in self.instantiate: - self.tensor_registry[ - self.tensors_mapping[(node.id, t_id, False)] - ] = op_tensor return + # XXX FIXME add commsTraceReplay + # opTensor = self.commsBench.replaySingle( + # self.commsParams, node.id, self.regenerate_tensors + # ) + # Wait, barrier has no output tensor. + # if "wait" in node.inputs or "barrier" in node.inputs: + # if self.wait_delay != 0: + # time.sleep(self.wait_delay / 1000.0) + # return + # if self.args.separate: + # return + + # # # Total dimension of the output tensor should be the same as + # # # the original in et, reshape if different. + # # if type(opTensor) is list: + # # for t in opTensor: + # # print(t) + + # original_shape = reduce(lambda x, y: x * y, node.output_shapes[0]) + # op_tensor_shape = reduce(lambda x, y: x * y, list(opTensor.size())) + # if original_shape != op_tensor_shape: + # print( + # "Comms ops output tensor shape mismatch: ", + # node.id, + # original_shape, + # op_tensor_shape, + # ) + # exit(1) + # op_tensor = torch.reshape(opTensor, tuple(node.output_shapes[0])) + # t_id = tuple(node.outputs[0]) + # if self.tensor_with_device: + # t_id = tuple(list(t_id)[:5]) + # if ( + # t_id in self.dependency_permanent + # and self.tensors_mapping[(node.id, t_id, False)] + # not in self.unchangeable_intermediate_tensors + # ): + # if self.tensors_mapping[(node.id, t_id, False)] not in self.instantiate: + # self.tensor_registry[ + # self.tensors_mapping[(node.id, t_id, False)] + # ] = op_tensor + # return if self.debug and iter >= self.numWarmupIters: start_ns = time.time_ns() @@ -1254,30 +1260,31 @@ def init_comms(self): comms_env_params = comms_utils.read_comms_env_vars() print(comms_env_params, self.cuda) - self.commsBench = commsTraceReplay.commsTraceReplayBench() - self.commsBench.trace_file = self.trace_file - if "://" in self.trace_file: - self.commsBench.use_remote_trace = True + # # XXX FIXME + # self.commsBench = commsTraceReplay.commsTraceReplayBench() + # self.commsBench.trace_file = self.trace_file + # if "://" in self.trace_file: + # self.commsBench.use_remote_trace = True - parser = argparse.ArgumentParser(description="Execution Trace Comms Replay") - comms_args = self.commsBench.readArgs(parser) + # parser = argparse.ArgumentParser(description="Execution Trace Comms Replay") + # comms_args = self.commsBench.readArgs(parser) - self.commsBench.checkArgs(comms_args) + # self.commsBench.checkArgs(comms_args) - time.sleep(1) - self.bootstrap_info = comms_utils.bootstrap_info_holder( - comms_args.master_ip, - comms_args.master_port, - comms_args.num_tpu_cores, - comms_env_params, - ) - self.commsParams = comms_utils.commsParamsHolderBase(comms_args) + # time.sleep(1) + # self.bootstrap_info = comms_utils.bootstrap_info_holder( + # comms_args.master_ip, + # comms_args.master_port, + # comms_args.num_tpu_cores, + # comms_env_params, + # ) + # self.commsParams = comms_utils.commsParamsHolderBase(comms_args) - self.commsBench.trace_type = "et" + # self.commsBench.trace_type = "et" - self.commsBench.initBackend(self.bootstrap_info, self.commsParams) - self.commsBench.initBench(self.commsParams, comms_args) - self.commsBench.replayInit(self.commsParams) + # self.commsBench.initBackend(self.bootstrap_info, self.commsParams) + # self.commsBench.initBench(self.commsParams, comms_args) + # self.commsBench.replayInit(self.commsParams) def analyze_ops(self): fused_cnt = 0