From 6ffc284b3ee52fc778ed95464e3e5bfaa3179813 Mon Sep 17 00:00:00 2001 From: apbose Date: Tue, 17 Dec 2024 13:57:16 -0800 Subject: [PATCH] changes to include the distributed operations in the aten_ops lib --- .../tensor_parallel_initialize_dist.py | 69 ++++++ .../tensor_parallel_llama3.py | 6 +- .../tensor_parallel_nccl_ops.py | 197 ------------------ .../tensor_parallel_simple_example.py | 4 +- .../dynamo/conversion/aten_ops_converters.py | 80 +++++++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/nccl_ops.py | 120 +++++++++++ 7 files changed, 276 insertions(+), 201 deletions(-) create mode 100644 examples/distributed_inference/tensor_parallel_initialize_dist.py delete mode 100644 examples/distributed_inference/tensor_parallel_nccl_ops.py create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py diff --git a/examples/distributed_inference/tensor_parallel_initialize_dist.py b/examples/distributed_inference/tensor_parallel_initialize_dist.py new file mode 100644 index 0000000000..a85f570e09 --- /dev/null +++ b/examples/distributed_inference/tensor_parallel_initialize_dist.py @@ -0,0 +1,69 @@ +import logging +import os +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +import torch.distributed as dist +from torch.distributed._tensor.device_mesh import init_device_mesh + + +def find_repo_root(max_depth=10): + dir_path = os.path.dirname(os.path.realpath(__file__)) + for i in range(max_depth): + files = os.listdir(dir_path) + if "MODULE.bazel" in files: + return dir_path + else: + dir_path = os.path.dirname(dir_path) + + raise RuntimeError("Could not find repo root") + + +def initialize_logger(rank, logger_file_name): + logger = logging.getLogger() + logger.setLevel(logging.INFO) + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") + fh.setLevel(logging.INFO) + logger.addHandler(fh) + return logger + + +# This is required for env initialization since we use mpirun +def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): + local_rank = int( + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) + ) + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) + + # Set up environment variable to run with mpirun + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + # Note this will not work in the initialization here + # You would need to set it externally as a user + os.environ["trtllm_env"] = ( + find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so" + ) + + # Necessary to assign a device to each rank. + torch.cuda.set_device(local_rank) + + # We use nccl backend + dist.init_process_group("nccl") + + # set a manual seed for reproducibility + torch.manual_seed(1111) + + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) + rank = device_mesh.get_rank() + assert rank == local_rank + logger = initialize_logger(rank, logger_file_name) + device_id = ( + rank % torch.cuda.device_count() + ) # Ensure each rank gets a unique device + torch.cuda.set_device(device_id) + + return device_mesh, world_size, rank, logger diff --git a/examples/distributed_inference/tensor_parallel_llama3.py b/examples/distributed_inference/tensor_parallel_llama3.py index 0e4dd8b6bc..a0533b1488 100644 --- a/examples/distributed_inference/tensor_parallel_llama3.py +++ b/examples/distributed_inference/tensor_parallel_llama3.py @@ -7,7 +7,7 @@ import torch import torch_tensorrt from llama3_model import ModelArgs, ParallelTransformer -from tensor_parallel_nccl_ops import register_nccl_ops +from tensor_parallel_initialize_dist import initialize_distributed_env from torch.distributed._composable.fsdp import MixedPrecisionPolicy from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.distributed._tensor import Replicate, Shard @@ -15,7 +15,9 @@ checkpoint_wrapper, ) -device_mesh, _world_size, _rank, logger = register_nccl_ops("./tensor_parallel_llama3") +device_mesh, _world_size, _rank, logger = initialize_distributed_env( + "./tensor_parallel_llama3" +) logger.info(f"Starting PyTorch TP example on rank {_rank}.") assert ( diff --git a/examples/distributed_inference/tensor_parallel_nccl_ops.py b/examples/distributed_inference/tensor_parallel_nccl_ops.py deleted file mode 100644 index 26e03b70db..0000000000 --- a/examples/distributed_inference/tensor_parallel_nccl_ops.py +++ /dev/null @@ -1,197 +0,0 @@ -import ctypes -import logging -import os -import site -from enum import IntEnum, IntFlag, auto -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union - -import numpy as np -import tensorrt as trt -import tensorrt_llm -import torch -import torch.distributed as dist -import torch_tensorrt -from torch.distributed._tensor.device_mesh import init_device_mesh -from torch.fx import GraphModule, Node -from torch.fx.node import Argument, Target -from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( - dynamo_tensorrt_converter, -) -from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( - tensorrt_fused_nccl_all_gather_op, - tensorrt_fused_nccl_reduce_scatter_op, -) -from torch_tensorrt.dynamo.types import TRTTensor -from torch_tensorrt.fx.converters.converter_utils import set_layer_name - - -# class for AllReduce -class AllReduceStrategy(IntEnum): - """Warning: actual definition is in kernels/customAllReduceKernels.h. - - They must be kept in sync. - """ - - NCCL = 0 - ONESHOT = 1 - TWOSHOT = 2 - AUTO = 3 - - -class AllReduceConfig(IntFlag): - """Warning: actual definition is in kernels/customAllReduceKernels.h. - - They must be kept in sync - """ - - USE_MEMCPY = auto() - PUSH_MODE = auto() - - -def initialize_logger(rank, logger_file_name): - logger = logging.getLogger() - logger.setLevel(logging.INFO) - fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") - fh.setLevel(logging.INFO) - logger.addHandler(fh) - return logger - - -# This is required for env initialization since we use mpirun -def initialize_distributed_env(rank=0, world_size=1, port=29500): - local_rank = int( - os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) - ) - world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) - - # Set up environment variable to run with mpirun - os.environ["RANK"] = str(local_rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - - # Necessary to assign a device to each rank. - torch.cuda.set_device(local_rank) - - # We use nccl backend - dist.init_process_group("nccl") - - # set a manual seed for reproducibility - torch.manual_seed(1111) - - return local_rank, world_size - - -def register_nccl_ops(logger_file_name): - # Initialization - initialize_distributed_env() - # create a device mesh based on the given world_size. - _world_size = int(os.environ["WORLD_SIZE"]) - - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) - _rank = device_mesh.get_rank() - logger = initialize_logger(_rank, logger_file_name) - device_id = ( - _rank % torch.cuda.device_count() - ) # Ensure each rank gets a unique device - torch.cuda.set_device(device_id) - - # TensorRT NCCL plugins - # Iterate over all registered plugin creators - plugin_registry = trt.get_plugin_registry() - for plugin_creator in plugin_registry.plugin_creator_list: - logger.info( - f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}" - ) - - @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) - def insert_nccl_gather_op( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, - ) -> Union[TRTTensor, Sequence[TRTTensor]]: - plug_inputs = [args[0]] - allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator( - "AllGather", "1", "tensorrt_llm" - ) - assert allgather_plg_creator is not None - _world_size = os.environ.get("WORLD_SIZE") - if _world_size is not None: - _world_size = int(_world_size) - else: - raise RuntimeError( - f"The WORLD_SIZE env variable is not set in distributed environment" - ) - group = list(range(_world_size)) - group = trt.PluginField( - "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 - ) - p_dtype = trt.float32 - pf_type = trt.PluginField( - "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 - ) - pfc = trt.PluginFieldCollection([group, pf_type]) - allgather = allgather_plg_creator.create_plugin("allgather", pfc) - layer = ctx.net.add_plugin_v2(plug_inputs, allgather) - set_layer_name(layer, target, name) - return layer.get_output(0) - - @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) - def insert_nccl_reduce_scatter_plugin( - ctx: ConversionContext, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, - ) -> Union[TRTTensor, Sequence[TRTTensor]]: - plug_inputs = [args[0]] - allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator( - "ReduceScatter", "1", "tensorrt_llm" - ) - - assert allreduce_plg_creator is not None - - counter = 0 - strategy = AllReduceStrategy.NCCL - config = AllReduceConfig(0) - _world_size = os.environ.get("WORLD_SIZE") - if _world_size is not None: - _world_size = int(_world_size) - else: - raise RuntimeError( - f"The WORLD_SIZE env variable is not set in distributed environment" - ) - group = list(range(_world_size)) - group = trt.PluginField( - "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 - ) - - p_dtype = trt.float16 - pf_dtype = trt.PluginField( - "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 - ) - pfc = [group, pf_dtype] - p_strategy = trt.PluginField( - "strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8 - ) - pfc.append(p_strategy) - p_config = trt.PluginField( - "config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8 - ) - pfc.append(p_config) - p_counter = trt.PluginField( - "counter", np.array([counter], np.int32), trt.PluginFieldType.INT32 - ) - pfc.append(p_counter) - - pfc = trt.PluginFieldCollection(pfc) - ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc) - - layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug) - set_layer_name(layer, target, name) - return layer.get_output(0) - - return device_mesh, _world_size, _rank, logger diff --git a/examples/distributed_inference/tensor_parallel_simple_example.py b/examples/distributed_inference/tensor_parallel_simple_example.py index c8b93ac6ca..e397df8e5a 100755 --- a/examples/distributed_inference/tensor_parallel_simple_example.py +++ b/examples/distributed_inference/tensor_parallel_simple_example.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch_tensorrt -from tensor_parallel_nccl_ops import register_nccl_ops +from tensor_parallel_initialize_dist import initialize_distributed_env from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, @@ -13,7 +13,7 @@ parallelize_module, ) -device_mesh, _world_size, _rank, logger = register_nccl_ops( +device_mesh, _world_size, _rank, logger = initialize_distributed_env( "./tensor_parallel_simple_example" ) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index ee7fb219aa..fb22cb7411 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,10 +1,13 @@ # mypy: disallow-untyped-decorators=False +import ctypes import logging import operator +import os from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import numpy as np +import tensorrt as trt import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo._settings import CompilationSettings @@ -21,6 +24,10 @@ get_positive_dim, is_only_operator_on_placeholder, ) +from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( + tensorrt_fused_nccl_all_gather_op, + tensorrt_fused_nccl_reduce_scatter_op, +) from torch_tensorrt.dynamo.types import TRTTensor _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -3585,3 +3592,76 @@ def aten_ops_full( fill_value=args[1], dtype=kwargs.get("dtype", None), ) + + +try: + import tensorrt_llm as trt_llm +except (ImportError, AssertionError) as e_import_error: + _LOGGER.warning( + "TensorRT_LLM is not installed. Please install TensorRT_LLM or set trtllm_env", + e_import_error, + ) + # note this is for Linux only + plugin_lib_path = os.environ.get("trtllm_env") + if plugin_lib_path is None: + _LOGGER.warning( + "Please specify a valid path for trtllm_env libnvinfer_plugin_tensorrt_llm.so when using distributed examples in examples/distributed_inference" + ) + else: + _LOGGER.info(f"Plugin lib path found: {plugin_lib_path}") + try: + handle = ctypes.CDLL(plugin_lib_path) + _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") + except OSError as e_os_error: + _LOGGER.error( + f"Failed to load the shared library at {plugin_lib_path}. " + f"Ensure the path is correct and the library is compatible.", + e_os_error, + ) + try: + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + handle.initTrtLlmPlugins.restype = ctypes.c_bool + except AttributeError as e_plugin_unavailable: + _LOGGER.warning("TensorRT-LLM Plugin is unavailable") + try: + TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" + assert handle.initTrtLlmPlugins( + None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8") + ) + except Exception as e_initialization_error: + _LOGGER.warning( + "Exception happened in initializing TensorRT-LLM plugins", e + ) + else: + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op) + def insert_nccl_gather_op( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.nccl_ops.gather_op( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) + + @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op) + def insert_nccl_reduce_scatter_plugin( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.nccl_ops.reduce_scatter_op( + ctx, + target, + SourceIR.ATEN, + name, + [args[0]], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index c1187f0dd9..75f7492591 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -13,6 +13,7 @@ full, grid, matmul, + nccl_ops, normalization, pad, permutation, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py new file mode 100644 index 0000000000..64b8cbfda5 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py @@ -0,0 +1,120 @@ +import os +from enum import IntEnum, IntFlag, auto +from typing import Optional, Sequence, Union + +import numpy as np +import tensorrt as trt +import torch +from torch.fx.node import Target +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name + + +# class for AllReduce +class AllReduceStrategy(IntEnum): + """Warning: actual definition is in kernels/customAllReduceKernels.h. + + They must be kept in sync. + """ + + NCCL = 0 + ONESHOT = 1 + TWOSHOT = 2 + AUTO = 3 + + +class AllReduceConfig(IntFlag): + """Warning: actual definition is in kernels/customAllReduceKernels.h. + + They must be kept in sync + """ + + USE_MEMCPY = auto() + PUSH_MODE = auto() + + +def gather_op( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs, +): + allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator( + "AllGather", "1", "tensorrt_llm" + ) + assert allgather_plg_creator is not None + _world_size = os.environ.get("WORLD_SIZE") + if _world_size is not None: + _world_size = int(_world_size) + else: + raise RuntimeError( + f"The WORLD_SIZE env variable is not set in distributed environment" + ) + group = list(range(_world_size)) + group = trt.PluginField( + "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 + ) + p_dtype = trt.float32 + pf_type = trt.PluginField( + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 + ) + pfc = trt.PluginFieldCollection([group, pf_type]) + allgather = allgather_plg_creator.create_plugin("allgather", pfc) + layer = ctx.net.add_plugin_v2(plug_inputs, allgather) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) + + +def reduce_scatter_op( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + plug_inputs, +): + allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator( + "ReduceScatter", "1", "tensorrt_llm" + ) + + assert allreduce_plg_creator is not None + + counter = 0 + strategy = AllReduceStrategy.NCCL + config = AllReduceConfig(0) + _world_size = os.environ.get("WORLD_SIZE") + if _world_size is not None: + _world_size = int(_world_size) + else: + raise RuntimeError( + f"The WORLD_SIZE env variable is not set in distributed environment" + ) + group = list(range(_world_size)) + group = trt.PluginField( + "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 + ) + + p_dtype = trt.float16 + pf_dtype = trt.PluginField( + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 + ) + pfc = [group, pf_dtype] + p_strategy = trt.PluginField( + "strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8 + ) + pfc.append(p_strategy) + p_config = trt.PluginField( + "config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8 + ) + pfc.append(p_config) + p_counter = trt.PluginField( + "counter", np.array([counter], np.int32), trt.PluginFieldType.INT32 + ) + pfc.append(p_counter) + + pfc = trt.PluginFieldCollection(pfc) + ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc) + + layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0)