Skip to content

Commit

Permalink
changes to include the distributed operations in the aten_ops lib
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Dec 23, 2024
1 parent fc74cff commit d1576da
Show file tree
Hide file tree
Showing 9 changed files with 312 additions and 204 deletions.
2 changes: 1 addition & 1 deletion examples/distributed_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ torchrun --nproc_per_node=2 tensor_parallel_llama2.py

pip install tensorrt-llm

For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export trtllm_env={lib_path}. For example, we have already set the variable in initialize_distributed_env(). Note that won't work while running example, since it needs to be preset for the converter library to get.
For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export trtllm_env={lib_path}. For example, we have already set the variable in initialize_distributed_env(). You can replace this with your trtllm_env and unset it there

#then pip install the tensorrt and torch version compatible with installed torchTRT

Expand Down
67 changes: 67 additions & 0 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
import torch.distributed as dist
from torch.distributed._tensor.device_mesh import init_device_mesh


def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
for i in range(max_depth):
files = os.listdir(dir_path)
if "MODULE.bazel" in files:
return dir_path
else:
dir_path = os.path.dirname(dir_path)

raise RuntimeError("Could not find repo root")


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))

# Set up environment variable to run with mpirun
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["trtllm_env"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)

# We use nccl backend
dist.init_process_group("nccl")

# set a manual seed for reproducibility
torch.manual_seed(1111)

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
9 changes: 6 additions & 3 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
import time

import torch
import torch_tensorrt
from llama3_model import ModelArgs, ParallelTransformer
from tensor_parallel_nccl_ops import register_nccl_ops
from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)

device_mesh, _world_size, _rank, logger = register_nccl_ops("./tensor_parallel_llama3")
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_llama3"
)
# Import should be after initialization of the TRT-LLM plugin .so path
import tensorrt_llm

logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
Expand Down
197 changes: 0 additions & 197 deletions examples/distributed_inference/tensor_parallel_nccl_ops.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import time

import tensorrt as trt
import tensorrt_llm
import torch
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_nccl_ops import register_nccl_ops
from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = register_nccl_ops(
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
)
import tensorrt_llm

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
Expand Down
44 changes: 44 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand All @@ -20,6 +21,11 @@
enforce_tensor_types,
get_positive_dim,
is_only_operator_on_placeholder,
load_tensorrt_llm,
)
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
)
from torch_tensorrt.dynamo.types import TRTTensor

Expand Down Expand Up @@ -3585,3 +3591,41 @@ def aten_ops_full(
fill_value=args[1],
dtype=kwargs.get("dtype", None),
)


if load_tensorrt_llm():

@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def insert_nccl_gather_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.distributed.gather_op(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
def insert_nccl_reduce_scatter_plugin(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.distributed.reduce_scatter_op(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

else:
_LOGGER.warning("Unable to load the TRT-LLM plugins")
Loading

0 comments on commit d1576da

Please sign in to comment.