Skip to content

Commit

Permalink
Apply Brian's changes to comply with Meta's requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed May 15, 2024
1 parent b79b86a commit b4d9a7a
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 97 deletions.
8 changes: 4 additions & 4 deletions et_replay/lib/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from typing import List, Tuple

from et_replay.comm import comms_utils
from et_replay.comm.comms_utils import commsArgs
from et_replay.comm.pytorch_backend_utils import supportedP2pOps
from et_replay.lib.comm import comms_utils
from et_replay.lib.comm.comms_utils import commsArgs
from et_replay.lib.comm.pytorch_backend_utils import supportedP2pOps

from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace
from et_replay.lib.execution_trace import ExecutionTrace

tensorDtypeMap = {
"Tensor(int)": "int",
Expand Down
6 changes: 3 additions & 3 deletions et_replay/lib/comm/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@

import numpy as np
import torch
from et_replay.comm.param_profile import paramTimer
from et_replay.comm.pytorch_backend_utils import (
from et_replay.lib.comm.param_profile import paramTimer
from et_replay.lib.comm.pytorch_backend_utils import (
backendFunctions,
collectiveArgsHolder,
customized_backend,
supportedC10dBackends,
supportedDevices,
)
from torch._C._distributed_c10d import ProcessGroup
from torch._C._distributed_c10d import ProcessGroup # @manual

random.seed()

Expand Down
2 changes: 1 addition & 1 deletion et_replay/lib/comm/pytorch_backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from et_replay.comm.param_profile import paramTimer
from et_replay.lib.comm.param_profile import paramTimer

from torch.distributed import ProcessGroup

Expand Down
4 changes: 2 additions & 2 deletions et_replay/lib/comm/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from et_replay.comm.param_profile import paramProfile
from et_replay.comm.pytorch_backend_utils import (
from et_replay.lib.comm.param_profile import paramProfile
from et_replay.lib.comm.pytorch_backend_utils import (
backendFunctions,
collectiveArgsHolder,
)
Expand Down
7 changes: 4 additions & 3 deletions et_replay/lib/comm/pytorch_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import numpy as np
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from comms_utils import backendFunctions
import torch_xla.core.xla_model as xm # @manual
import torch_xla.distributed.xla_multiprocessing as xmp # @manual

from .comms_utils import backendFunctions


class PyTorchTPUBackend(backendFunctions):
Expand Down
12 changes: 6 additions & 6 deletions et_replay/lib/et_replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import re

import torch
from et_replay.lib.execution_trace import NodeType
from fbgemm_gpu.split_table_batched_embeddings_ops import PoolingMode, WeightDecayMode
from param_bench.et_replay.lib.execution_trace import NodeType

from param_bench.train.compute.python.lib.pytorch.config_util import create_op_args

Expand Down Expand Up @@ -469,11 +469,11 @@ def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows)
import os
import time
from datetime import datetime
from et_replay.comm import comms_utils
+from et_replay.lib.comm import comms_utils
import torch
from et_replay.comm import commsTraceReplay
from param_bench.et_replay.lib.et_replay_utils import (
from et_replay.lib.comm import commsTraceReplay
from et_replay.lib.et_replay_utils import (
build_fbgemm_func,
build_torchscript_func,
generate_fbgemm_tensors,
Expand All @@ -482,8 +482,8 @@ def generate_prefix(label, skip_nodes, et_input, cuda, compute_only, tf32, rows)
is_qualified,
)
from param_bench.et_replay.lib.execution_trace import ExecutionTrace
from param_bench.et_replay.lib.utils import trace_handler
from et_replay.lib.execution_trace import ExecutionTrace
from et_replay.lib.utils import trace_handler
print("PyTorch version: ", torch.__version__)
Expand Down
2 changes: 1 addition & 1 deletion et_replay/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from typing import Any, Dict

from param_bench.et_replay.lib.execution_trace import ExecutionTrace
from et_replay.lib.execution_trace import ExecutionTrace


def get_tmp_trace_filename() -> str:
Expand Down
2 changes: 1 addition & 1 deletion et_replay/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "et_replay"
version = "0.5.0"

[tool.setuptools.package-dir]
"et_replay.comm" = "lib/comm"
"et_replay.lib.comm" = "lib/comm"
"et_replay.tools" = "tools"

[project.scripts]
Expand Down
4 changes: 1 addition & 3 deletions et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,9 +1383,7 @@ def initBackend(
"""
# init backend and corresponding function pointers
if commsParams.nw_stack == "pytorch-dist":
from et_replay.comm.pytorch_dist_backend import (
PyTorchDistBackend,
)
from et_replay.comm.pytorch_dist_backend import PyTorchDistBackend

self.backendFuncs = PyTorchDistBackend(bootstrap_info, commsParams)
elif commsParams.nw_stack == "pytorch-xla-tpu":
Expand Down
153 changes: 80 additions & 73 deletions et_replay/tools/et_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@

import numpy as np
import torch

from et_replay.lib.comm import comms_utils

# from et_replay.lib.comm import commsTraceReplay XXX FIXME

from et_replay.lib.execution_trace import ExecutionTrace, NodeType

from et_replay.lib.utils import trace_handler

from param_bench.train.compute.python.lib import pytorch as lib_pytorch
from param_bench.train.compute.python.lib.init_helper import load_modules
from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch
from torch._inductor.codecache import AsyncCompile, TritonFuture

# grid and split_scan_grid are dynamically loaded
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
from torch.profiler import ExecutionTraceObserver

from ..lib.et_replay_utils import (
build_fbgemm_func,
build_torchscript_func,
Expand All @@ -36,20 +54,6 @@
TORCH_DTYPES_RNG_str,
)

from et_replay.lib.execution_trace import ExecutionTrace, NodeType

from param_bench.et_replay.lib.utils import trace_handler
from et_replay.comm import comms_utils, commsTraceReplay

from param_bench.train.compute.python.lib import pytorch as lib_pytorch
from param_bench.train.compute.python.lib.init_helper import load_modules
from param_bench.train.compute.python.workloads import pytorch as workloads_pytorch
from torch._inductor.codecache import AsyncCompile, TritonFuture

# grid and split_scan_grid are dynamically loaded
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
from torch.profiler import ExecutionTraceObserver


class ExgrReplayManager:
def __init__(self):
Expand Down Expand Up @@ -1092,47 +1096,49 @@ def get_inputs(self, node):

def run_op(self, node, iter):
if node.name == "record_param_comms" and not self.compute_only:
opTensor = self.commsBench.replaySingle(
self.commsParams, node.id, self.regenerate_tensors
)
# Wait, barrier has no output tensor.
if "wait" in node.inputs or "barrier" in node.inputs:
if self.wait_delay != 0:
time.sleep(self.wait_delay / 1000.0)
return
if self.args.separate:
return

# # Total dimension of the output tensor should be the same as
# # the original in et, reshape if different.
# if type(opTensor) is list:
# for t in opTensor:
# print(t)

original_shape = reduce(lambda x, y: x * y, node.output_shapes[0])
op_tensor_shape = reduce(lambda x, y: x * y, list(opTensor.size()))
if original_shape != op_tensor_shape:
print(
"Comms ops output tensor shape mismatch: ",
node.id,
original_shape,
op_tensor_shape,
)
exit(1)
op_tensor = torch.reshape(opTensor, tuple(node.output_shapes[0]))
t_id = tuple(node.outputs[0])
if self.tensor_with_device:
t_id = tuple(list(t_id)[:5])
if (
t_id in self.dependency_permanent
and self.tensors_mapping[(node.id, t_id, False)]
not in self.unchangeable_intermediate_tensors
):
if self.tensors_mapping[(node.id, t_id, False)] not in self.instantiate:
self.tensor_registry[
self.tensors_mapping[(node.id, t_id, False)]
] = op_tensor
return
# XXX FIXME add commsTraceReplay
# opTensor = self.commsBench.replaySingle(
# self.commsParams, node.id, self.regenerate_tensors
# )
# Wait, barrier has no output tensor.
# if "wait" in node.inputs or "barrier" in node.inputs:
# if self.wait_delay != 0:
# time.sleep(self.wait_delay / 1000.0)
# return
# if self.args.separate:
# return

# # # Total dimension of the output tensor should be the same as
# # # the original in et, reshape if different.
# # if type(opTensor) is list:
# # for t in opTensor:
# # print(t)

# original_shape = reduce(lambda x, y: x * y, node.output_shapes[0])
# op_tensor_shape = reduce(lambda x, y: x * y, list(opTensor.size()))
# if original_shape != op_tensor_shape:
# print(
# "Comms ops output tensor shape mismatch: ",
# node.id,
# original_shape,
# op_tensor_shape,
# )
# exit(1)
# op_tensor = torch.reshape(opTensor, tuple(node.output_shapes[0]))
# t_id = tuple(node.outputs[0])
# if self.tensor_with_device:
# t_id = tuple(list(t_id)[:5])
# if (
# t_id in self.dependency_permanent
# and self.tensors_mapping[(node.id, t_id, False)]
# not in self.unchangeable_intermediate_tensors
# ):
# if self.tensors_mapping[(node.id, t_id, False)] not in self.instantiate:
# self.tensor_registry[
# self.tensors_mapping[(node.id, t_id, False)]
# ] = op_tensor
# return

if self.debug and iter >= self.numWarmupIters:
start_ns = time.time_ns()
Expand Down Expand Up @@ -1254,30 +1260,31 @@ def init_comms(self):
comms_env_params = comms_utils.read_comms_env_vars()
print(comms_env_params, self.cuda)

self.commsBench = commsTraceReplay.commsTraceReplayBench()
self.commsBench.trace_file = self.trace_file
if "://" in self.trace_file:
self.commsBench.use_remote_trace = True
# # XXX FIXME
# self.commsBench = commsTraceReplay.commsTraceReplayBench()
# self.commsBench.trace_file = self.trace_file
# if "://" in self.trace_file:
# self.commsBench.use_remote_trace = True

parser = argparse.ArgumentParser(description="Execution Trace Comms Replay")
comms_args = self.commsBench.readArgs(parser)
# parser = argparse.ArgumentParser(description="Execution Trace Comms Replay")
# comms_args = self.commsBench.readArgs(parser)

self.commsBench.checkArgs(comms_args)
# self.commsBench.checkArgs(comms_args)

time.sleep(1)
self.bootstrap_info = comms_utils.bootstrap_info_holder(
comms_args.master_ip,
comms_args.master_port,
comms_args.num_tpu_cores,
comms_env_params,
)
self.commsParams = comms_utils.commsParamsHolderBase(comms_args)
# time.sleep(1)
# self.bootstrap_info = comms_utils.bootstrap_info_holder(
# comms_args.master_ip,
# comms_args.master_port,
# comms_args.num_tpu_cores,
# comms_env_params,
# )
# self.commsParams = comms_utils.commsParamsHolderBase(comms_args)

self.commsBench.trace_type = "et"
# self.commsBench.trace_type = "et"

self.commsBench.initBackend(self.bootstrap_info, self.commsParams)
self.commsBench.initBench(self.commsParams, comms_args)
self.commsBench.replayInit(self.commsParams)
# self.commsBench.initBackend(self.bootstrap_info, self.commsParams)
# self.commsBench.initBench(self.commsParams, comms_args)
# self.commsBench.replayInit(self.commsParams)

def analyze_ops(self):
fused_cnt = 0
Expand Down

0 comments on commit b4d9a7a

Please sign in to comment.