From 343b71dd5d09aa21a1a3b7f76ba172f4746eac83 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Mon, 19 Jun 2023 12:08:51 -0400 Subject: [PATCH 01/40] init --- python/hidet/graph/ops/distributed.py | 44 +++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 python/hidet/graph/ops/distributed.py diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py new file mode 100644 index 000000000..9e7182d2e --- /dev/null +++ b/python/hidet/graph/ops/distributed.py @@ -0,0 +1,44 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union, Tuple + +from hidet.ir.type import DataType +from hidet.ir.expr import Expr +from hidet.ir.module import IRModule +from hidet.ir.task import Target +from hidet.utils import prod +from hidet.runtime.device import Device, instantiate_device +from .utils import Task, TensorNode, Operator, Tensor, compute, input_like + +from hidet.cuda.nccl import NcclRedOp + +def all_reduce(comm_id: int, x: Tensor, op: NcclRedOp) -> Tensor: + raise NotImplementedError() + +def broadcast(comm_id: int, x: Tensor, root:int) -> Tensor: + raise NotImplementedError() + +def reduce(comm_id: int, x: Tensor, root:int) -> Tensor: + raise NotImplementedError() + +def all_gather(comm_id: int, x: Tensor) -> Tensor: + raise NotImplementedError() + +def reduce_scatter(comm_id: int, x: Tensor) -> Tensor: + raise NotImplementedError() + +def send(comm_id: int, x: Tensor, peer: int) -> None: + raise NotImplementedError() + +# Recv is a little bit tricky since we need to pass the metadata of the recv buffer +def recv(comm_id: int, peer: int) -> Tensor: + raise NotImplementedError() \ No newline at end of file From 29ef7f531a7b2552438fca521ec7c881354b2c17 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Mon, 19 Jun 2023 16:22:32 -0400 Subject: [PATCH 02/40] op --- examples/distributed/test.py | 29 ++---------------- python/hidet/cuda/nccl/__init__.py | 2 +- python/hidet/drivers/build_module.py | 17 +++++++++++ python/hidet/graph/ops/__init__.py | 1 + python/hidet/graph/ops/distributed.py | 42 +++++++++++++++++++++++++-- python/hidet/ir/func.py | 18 +++++++++++- python/hidet/ir/module.py | 3 ++ 7 files changed, 80 insertions(+), 32 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index bad43ef23..139176cf9 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -8,6 +8,7 @@ import numpy import argparse +import hidet import hidet.cuda.nccl from hidet.cuda import nccl from hidet.cuda.nccl import NcclUniqueId, NcclDataType, NcclRedOp, nccl_library_filename @@ -47,34 +48,8 @@ def run(world_size, rank, shared_id, barrier): # Initialize send and receive buffer device = f"cuda:{rank}" send = hidet.randn([2, 2], device=device) - recv = hidet.empty([2, 2], device=device) - - print(rank, send) - - dtype = data_type('float32') - shape = [2, 2] - nbytes = dtype.nbytes * prod(shape) - - # Define IRModule - with hidet.script_module() as script_module: - @hidet.script - def launch(send: dtype[shape], recv: dtype[shape]): - attrs.func_kind = 'public' - all_reduce(0, send, recv, nbytes, dtype, getattr(NcclRedOp, args.reduce_op)) - - # Build - ir_module = script_module.ir_module() - ir_module.target = 'cuda' - ir_module.include_dirs.extend(get_nccl_include_dirs()) - ir_module.linking_dirs.extend(get_nccl_library_search_dirs()) - ir_module.include_headers.append(["nccl.h"]) - ir_module.linking_libs.append(":" + nccl_library_filename()) - out_dir = f'./.cache/all_reduce_{rank}' - - build_ir_module(ir_module, out_dir, target='cuda') - compiled_module = load_compiled_module(out_dir) + recv = hidet.ops.all_reduce(0, send, NcclRedOp.sum) - compiled_module(send, recv) s = hidet.cuda.current_stream() s.synchronize() print(rank, recv) diff --git a/python/hidet/cuda/nccl/__init__.py b/python/hidet/cuda/nccl/__init__.py index 6ff476d71..2b3ebf96e 100644 --- a/python/hidet/cuda/nccl/__init__.py +++ b/python/hidet/cuda/nccl/__init__.py @@ -9,5 +9,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .ffi import nccl_available, nccl_version, nccl_library_filename from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl -from .ffi import nccl_version, nccl_library_filename diff --git a/python/hidet/drivers/build_module.py b/python/hidet/drivers/build_module.py index 05671bec8..40ba891da 100644 --- a/python/hidet/drivers/build_module.py +++ b/python/hidet/drivers/build_module.py @@ -55,6 +55,23 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output with PassContext(instruments=instruments): ir_module = lower(ir_module) + # nccl-related + print(ir_module.use_distributed()) + if ir_module.use_distributed(): + if target != 'cuda': + raise RuntimeError("IRModules using NCCL must be targeted for cuda") + from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs + from hidet.cuda.nccl import nccl_available, nccl_library_filename + + if not nccl_available(): + raise RuntimeError("NCCL is not available") + + ir_module.include_dirs.extend(get_nccl_include_dirs()) + ir_module.linking_dirs.extend(get_nccl_library_search_dirs()) + ir_module.include_headers.append(["nccl.h"]) + ir_module.linking_libs.append(":" + nccl_library_filename()) + + # code generation codegen(ir_module, src_out_path=src_path, target=target) diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index 901ede4d2..af3063245 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -44,5 +44,6 @@ from .fusion import fused_operator from .transfer import transfer from .special import barrier +from .distributed import all_reduce from . import utils diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py index 9e7182d2e..dbf355e2b 100644 --- a/python/hidet/graph/ops/distributed.py +++ b/python/hidet/graph/ops/distributed.py @@ -21,19 +21,55 @@ from hidet.cuda.nccl import NcclRedOp +class AllReduceTask(Task): + def __init__(self, comm_id: int, x: TensorNode, op: NcclRedOp): + y = compute('out', x.shape, lambda *indices: x[indices]) + self.comm_id = comm_id + self.op = op + + super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={}) + + def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]: + # we may need current rank here to avoid duplicated working_dirs + import hidet + from hidet.ir.primitives.cuda.nccl import all_reduce + from hidet.lang import attrs + + dtype: DataType = self.inputs[0].type.dtype + shape: Tuple[Expr, ...] = self.inputs[0].shape + nbytes = dtype.nbytes * prod(shape) + + with hidet.script_module() as script_module: + @hidet.script + def launch(x: dtype[shape], y: dtype[shape]): + attrs.func_kind = 'public' + all_reduce(self.comm_id, x, y, nbytes, dtype, self.op) + + return [script_module.ir_module()] + +class AllReduceOp(Operator): + def __init__(self, comm_id: int, x: Tensor, op: NcclRedOp): + super().__init__( + inputs=[x], + attributes={'comm_id': comm_id}, + task=AllReduceTask(comm_id, input_like(x, 'x'), op) + ) + def all_reduce(comm_id: int, x: Tensor, op: NcclRedOp) -> Tensor: - raise NotImplementedError() + if x.device.kind != 'cuda': + raise RuntimeError("NCCL only supports CUDA tensors") + return AllReduceOp(comm_id, x, op).outputs[0] def broadcast(comm_id: int, x: Tensor, root:int) -> Tensor: raise NotImplementedError() -def reduce(comm_id: int, x: Tensor, root:int) -> Tensor: +def reduce(comm_id: int, x: Tensor, root:int, op: NcclRedOp) -> Tensor: raise NotImplementedError() def all_gather(comm_id: int, x: Tensor) -> Tensor: raise NotImplementedError() -def reduce_scatter(comm_id: int, x: Tensor) -> Tensor: +def reduce_scatter(comm_id: int, x: Tensor, op: NcclRedOp) -> Tensor: raise NotImplementedError() def send(comm_id: int, x: Tensor, peer: int) -> None: diff --git a/python/hidet/ir/func.py b/python/hidet/ir/func.py index 674879ec8..fb5cfd40a 100644 --- a/python/hidet/ir/func.py +++ b/python/hidet/ir/func.py @@ -14,7 +14,7 @@ from hidet.ir.node import Node from hidet.ir.type import BaseType from hidet.ir.expr import Var, Call -from hidet.ir.stmt import Stmt +from hidet.ir.stmt import Stmt, BlackBoxStmt def check_func_name(name: str): @@ -94,3 +94,19 @@ def get_attr(self, attr_name, default=None, allow_missing=False): return default else: raise KeyError('Attribute {} is not found in function {}'.format(attr_name, self.name)) + + def use_distributed(self) -> bool: + """ + Return true if this function involves any distributed primitives + """ + def _recursive_find(root: Stmt): + if isinstance(root, BlackBoxStmt): + if root.template_string.startswith('nccl'): + return True + for child in dir(root): + if isinstance(child, Stmt): + if _recursive_find(child): + return True + return False + ret = _recursive_find(self.body) + return ret \ No newline at end of file diff --git a/python/hidet/ir/module.py b/python/hidet/ir/module.py index 751c05a24..c40f32e43 100644 --- a/python/hidet/ir/module.py +++ b/python/hidet/ir/module.py @@ -115,3 +115,6 @@ def build(self): build_ir_module(self, output_dir, target=target) return load_compiled_module(output_dir) + + def use_distributed(self): + return any([func.use_distributed() for func in self.functions.values()]) \ No newline at end of file From 8e065609b40e25586dee37ea13e2b610cf4cdfd6 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Mon, 19 Jun 2023 16:58:20 -0400 Subject: [PATCH 03/40] update --- examples/distributed/test.py | 8 ++++++-- python/hidet/drivers/build_module.py | 2 -- python/hidet/graph/ops/distributed.py | 26 ++++++++++++------------- python/hidet/ir/primitives/cuda/nccl.py | 2 +- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 139176cf9..6ea97aa2f 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -47,8 +47,12 @@ def run(world_size, rank, shared_id, barrier): # Initialize send and receive buffer device = f"cuda:{rank}" - send = hidet.randn([2, 2], device=device) - recv = hidet.ops.all_reduce(0, send, NcclRedOp.sum) + send = hidet.randn([3, 2], device=device) + send_symb = hidet.symbol_like(send) + recv_symb = hidet.ops.all_reduce(send_symb, NcclRedOp.sum, 0) + graph = hidet.trace_from(recv_symb) + opt_graph = hidet.graph.optimize(graph) + recv = opt_graph(send) s = hidet.cuda.current_stream() s.synchronize() diff --git a/python/hidet/drivers/build_module.py b/python/hidet/drivers/build_module.py index 40ba891da..71a0d9110 100644 --- a/python/hidet/drivers/build_module.py +++ b/python/hidet/drivers/build_module.py @@ -56,7 +56,6 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output ir_module = lower(ir_module) # nccl-related - print(ir_module.use_distributed()) if ir_module.use_distributed(): if target != 'cuda': raise RuntimeError("IRModules using NCCL must be targeted for cuda") @@ -71,7 +70,6 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output ir_module.include_headers.append(["nccl.h"]) ir_module.linking_libs.append(":" + nccl_library_filename()) - # code generation codegen(ir_module, src_out_path=src_path, target=target) diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py index dbf355e2b..401a7c246 100644 --- a/python/hidet/graph/ops/distributed.py +++ b/python/hidet/graph/ops/distributed.py @@ -22,7 +22,7 @@ from hidet.cuda.nccl import NcclRedOp class AllReduceTask(Task): - def __init__(self, comm_id: int, x: TensorNode, op: NcclRedOp): + def __init__(self, x: TensorNode, op: NcclRedOp, comm_id: int): y = compute('out', x.shape, lambda *indices: x[indices]) self.comm_id = comm_id self.op = op @@ -43,38 +43,38 @@ def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModu @hidet.script def launch(x: dtype[shape], y: dtype[shape]): attrs.func_kind = 'public' - all_reduce(self.comm_id, x, y, nbytes, dtype, self.op) + all_reduce(x, y, nbytes, dtype, self.op, self.comm_id) return [script_module.ir_module()] class AllReduceOp(Operator): - def __init__(self, comm_id: int, x: Tensor, op: NcclRedOp): + def __init__(self, x: Tensor, op: NcclRedOp, comm_id: int): super().__init__( inputs=[x], - attributes={'comm_id': comm_id}, - task=AllReduceTask(comm_id, input_like(x, 'x'), op) + attributes={'op': op, 'comm_id': comm_id}, + task=AllReduceTask(input_like(x, 'x'), op, comm_id) ) -def all_reduce(comm_id: int, x: Tensor, op: NcclRedOp) -> Tensor: +def all_reduce(x: Tensor, op: NcclRedOp, comm_id: int) -> Tensor: if x.device.kind != 'cuda': raise RuntimeError("NCCL only supports CUDA tensors") - return AllReduceOp(comm_id, x, op).outputs[0] + return AllReduceOp(x, op, comm_id).outputs[0] -def broadcast(comm_id: int, x: Tensor, root:int) -> Tensor: +def broadcast(x: Tensor, root:int, comm_id: int) -> Tensor: raise NotImplementedError() -def reduce(comm_id: int, x: Tensor, root:int, op: NcclRedOp) -> Tensor: +def reduce(x: Tensor, root:int, op: NcclRedOp, comm_id: int) -> Tensor: raise NotImplementedError() -def all_gather(comm_id: int, x: Tensor) -> Tensor: +def all_gather(x: Tensor, comm_id: int) -> Tensor: raise NotImplementedError() -def reduce_scatter(comm_id: int, x: Tensor, op: NcclRedOp) -> Tensor: +def reduce_scatter(x: Tensor, op: NcclRedOp, comm_id: int) -> Tensor: raise NotImplementedError() -def send(comm_id: int, x: Tensor, peer: int) -> None: +def send(x: Tensor, peer: int, comm_id: int) -> None: raise NotImplementedError() # Recv is a little bit tricky since we need to pass the metadata of the recv buffer -def recv(comm_id: int, peer: int) -> Tensor: +def recv(peer: int, comm_id: int) -> Tensor: raise NotImplementedError() \ No newline at end of file diff --git a/python/hidet/ir/primitives/cuda/nccl.py b/python/hidet/ir/primitives/cuda/nccl.py index 96b2b2a2d..669d72322 100644 --- a/python/hidet/ir/primitives/cuda/nccl.py +++ b/python/hidet/ir/primitives/cuda/nccl.py @@ -18,7 +18,7 @@ from hidet.cuda.nccl import NcclRedOp, dtype_to_nccl -def all_reduce(comm_id: int, sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: DataType, op: NcclRedOp): +def all_reduce(sendbuff: Expr, recvbuff: Expr, count: Expr, dtype: DataType, op: NcclRedOp, comm_id: int): from hidet.ir.primitives.runtime import get_cuda_stream, get_nccl_comm comm = get_nccl_comm(comm_id) From c8559d1037cc606a1306533dc34f2af2e6f0a226 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Mon, 19 Jun 2023 18:10:49 -0400 Subject: [PATCH 04/40] graph --- examples/distributed/test.py | 30 +++++++++++------------------- python/hidet/cuda/nccl/__init__.py | 2 +- python/hidet/graph/__init__.py | 1 + 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 6ea97aa2f..0196845c5 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -11,15 +11,7 @@ import hidet import hidet.cuda.nccl from hidet.cuda import nccl -from hidet.cuda.nccl import NcclUniqueId, NcclDataType, NcclRedOp, nccl_library_filename -from hidet.ffi import runtime_api -from hidet.lang import attrs -from hidet.ir.primitives.cuda.nccl import all_reduce -from hidet.ir.type import data_type -from hidet.utils import prod -from hidet.drivers import build_ir_module -from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs -from hidet.runtime import load_compiled_module +from hidet.cuda.nccl import NcclUniqueId,NcclRedOp print("NCCL version:", nccl.nccl_version()) @@ -38,21 +30,21 @@ def run(world_size, rank, shared_id, barrier): barrier.wait() hidet.cuda.set_device(rank) - print('initialize', rank) - # Create NcclCommunicator and set the cuda context - # this part should be moved into CompiledGraph in the future - comm = nccl.create_comm(world_size, shared_id, rank) - comms_array = nccl.comms_to_array([comm]) - runtime_api.set_nccl_comms(comms_array) - - # Initialize send and receive buffer device = f"cuda:{rank}" - send = hidet.randn([3, 2], device=device) + send = hidet.randn([3, 3], device=device) + + # Create Computation Graph send_symb = hidet.symbol_like(send) recv_symb = hidet.ops.all_reduce(send_symb, NcclRedOp.sum, 0) graph = hidet.trace_from(recv_symb) opt_graph = hidet.graph.optimize(graph) - recv = opt_graph(send) + + # Create Distributed Graph + dist_graph = hidet.graph.DistributedFlowGraph(graph, world_size, rank) + dist_graph.initialize(shared_id) + + recv = dist_graph(send) + print(opt_graph) s = hidet.cuda.current_stream() s.synchronize() diff --git a/python/hidet/cuda/nccl/__init__.py b/python/hidet/cuda/nccl/__init__.py index 2b3ebf96e..34e8bf03f 100644 --- a/python/hidet/cuda/nccl/__init__.py +++ b/python/hidet/cuda/nccl/__init__.py @@ -10,4 +10,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .ffi import nccl_available, nccl_version, nccl_library_filename -from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl +from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl, NcclCommunicator diff --git a/python/hidet/graph/__init__.py b/python/hidet/graph/__init__.py index dc105b96e..d1eaa7b9d 100644 --- a/python/hidet/graph/__init__.py +++ b/python/hidet/graph/__init__.py @@ -29,3 +29,4 @@ from .tensor import from_numpy, from_dlpack, from_torch from .flow_graph import trace_from, load_graph, save_graph, forward_context from .transforms import optimize +from .distributed import DistributedFlowGraph \ No newline at end of file From d97a7f8c078d20e9b4a7b3b225a8b3e3ee9cfe6c Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Mon, 19 Jun 2023 18:22:38 -0400 Subject: [PATCH 05/40] update --- examples/distributed/test.py | 18 ++++++++++-------- python/hidet/graph/ops/distributed.py | 3 +++ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 0196845c5..260ac6e15 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -31,24 +31,26 @@ def run(world_size, rank, shared_id, barrier): hidet.cuda.set_device(rank) device = f"cuda:{rank}" - send = hidet.randn([3, 3], device=device) + x = hidet.randn([2, 3], device=device) + w = hidet.randn([3, 2], device=device) # Create Computation Graph - send_symb = hidet.symbol_like(send) - recv_symb = hidet.ops.all_reduce(send_symb, NcclRedOp.sum, 0) - graph = hidet.trace_from(recv_symb) + x_symb = hidet.symbol_like(x) + w_symb = hidet.symbol_like(w) + y_local = hidet.ops.relu(x_symb @ w_symb) + y_sync = hidet.ops.all_reduce(y_local, getattr(NcclRedOp, args.reduce_op), 0) + graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb]) opt_graph = hidet.graph.optimize(graph) # Create Distributed Graph - dist_graph = hidet.graph.DistributedFlowGraph(graph, world_size, rank) + dist_graph = hidet.graph.DistributedFlowGraph(opt_graph, world_size, rank) dist_graph.initialize(shared_id) - recv = dist_graph(send) - print(opt_graph) + y_local, y_sync = dist_graph(x, w) s = hidet.cuda.current_stream() s.synchronize() - print(rank, recv) + print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}") world_size = args.n_gpus diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py index 401a7c246..2c7fac799 100644 --- a/python/hidet/graph/ops/distributed.py +++ b/python/hidet/graph/ops/distributed.py @@ -28,6 +28,9 @@ def __init__(self, x: TensorNode, op: NcclRedOp, comm_id: int): self.op = op super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={}) + + def __str__(self): + return f"all_reduce_{int(self.op)}_{self.comm_id}" def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]: # we may need current rank here to avoid duplicated working_dirs From 70f3a917c23f702dfd9f0e8bc36e448187f0d09c Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Mon, 19 Jun 2023 18:28:17 -0400 Subject: [PATCH 06/40] format --- python/hidet/cuda/nccl/__init__.py | 11 ++++++++- python/hidet/drivers/build_module.py | 2 +- python/hidet/graph/__init__.py | 2 +- python/hidet/graph/ops/distributed.py | 32 ++++++++++++++++----------- python/hidet/ir/func.py | 4 +++- python/hidet/ir/module.py | 2 +- 6 files changed, 35 insertions(+), 18 deletions(-) diff --git a/python/hidet/cuda/nccl/__init__.py b/python/hidet/cuda/nccl/__init__.py index 34e8bf03f..aec4fe718 100644 --- a/python/hidet/cuda/nccl/__init__.py +++ b/python/hidet/cuda/nccl/__init__.py @@ -10,4 +10,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from .ffi import nccl_available, nccl_version, nccl_library_filename -from .comm import create_comm, NcclUniqueId, NcclDataType, NcclRedOp, comms_to_array, init_unique_id, dtype_to_nccl, NcclCommunicator +from .comm import ( + create_comm, + NcclUniqueId, + NcclDataType, + NcclRedOp, + comms_to_array, + init_unique_id, + dtype_to_nccl, + NcclCommunicator, +) diff --git a/python/hidet/drivers/build_module.py b/python/hidet/drivers/build_module.py index 71a0d9110..cfee97a3e 100644 --- a/python/hidet/drivers/build_module.py +++ b/python/hidet/drivers/build_module.py @@ -64,7 +64,7 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output if not nccl_available(): raise RuntimeError("NCCL is not available") - + ir_module.include_dirs.extend(get_nccl_include_dirs()) ir_module.linking_dirs.extend(get_nccl_library_search_dirs()) ir_module.include_headers.append(["nccl.h"]) diff --git a/python/hidet/graph/__init__.py b/python/hidet/graph/__init__.py index d1eaa7b9d..a7d5aaf10 100644 --- a/python/hidet/graph/__init__.py +++ b/python/hidet/graph/__init__.py @@ -29,4 +29,4 @@ from .tensor import from_numpy, from_dlpack, from_torch from .flow_graph import trace_from, load_graph, save_graph, forward_context from .transforms import optimize -from .distributed import DistributedFlowGraph \ No newline at end of file +from .distributed import DistributedFlowGraph diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py index 2c7fac799..4fce12027 100644 --- a/python/hidet/graph/ops/distributed.py +++ b/python/hidet/graph/ops/distributed.py @@ -16,26 +16,25 @@ from hidet.ir.module import IRModule from hidet.ir.task import Target from hidet.utils import prod -from hidet.runtime.device import Device, instantiate_device +from hidet.cuda.nccl import NcclRedOp from .utils import Task, TensorNode, Operator, Tensor, compute, input_like -from hidet.cuda.nccl import NcclRedOp class AllReduceTask(Task): def __init__(self, x: TensorNode, op: NcclRedOp, comm_id: int): y = compute('out', x.shape, lambda *indices: x[indices]) self.comm_id = comm_id self.op = op - + super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={}) - + def __str__(self): return f"all_reduce_{int(self.op)}_{self.comm_id}" def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]: # we may need current rank here to avoid duplicated working_dirs import hidet - from hidet.ir.primitives.cuda.nccl import all_reduce + from hidet.ir.primitives.cuda.nccl import all_reduce as _all_reduce from hidet.lang import attrs dtype: DataType = self.inputs[0].type.dtype @@ -43,41 +42,48 @@ def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModu nbytes = dtype.nbytes * prod(shape) with hidet.script_module() as script_module: + @hidet.script def launch(x: dtype[shape], y: dtype[shape]): attrs.func_kind = 'public' - all_reduce(x, y, nbytes, dtype, self.op, self.comm_id) + _all_reduce(x, y, nbytes, dtype, self.op, self.comm_id) + + return [script_module.ir_module()] - return [script_module.ir_module()] class AllReduceOp(Operator): def __init__(self, x: Tensor, op: NcclRedOp, comm_id: int): super().__init__( - inputs=[x], - attributes={'op': op, 'comm_id': comm_id}, - task=AllReduceTask(input_like(x, 'x'), op, comm_id) + inputs=[x], attributes={'op': op, 'comm_id': comm_id}, task=AllReduceTask(input_like(x, 'x'), op, comm_id) ) + def all_reduce(x: Tensor, op: NcclRedOp, comm_id: int) -> Tensor: if x.device.kind != 'cuda': raise RuntimeError("NCCL only supports CUDA tensors") return AllReduceOp(x, op, comm_id).outputs[0] -def broadcast(x: Tensor, root:int, comm_id: int) -> Tensor: + +def broadcast(x: Tensor, root: int, comm_id: int) -> Tensor: raise NotImplementedError() -def reduce(x: Tensor, root:int, op: NcclRedOp, comm_id: int) -> Tensor: + +def reduce(x: Tensor, root: int, op: NcclRedOp, comm_id: int) -> Tensor: raise NotImplementedError() + def all_gather(x: Tensor, comm_id: int) -> Tensor: raise NotImplementedError() + def reduce_scatter(x: Tensor, op: NcclRedOp, comm_id: int) -> Tensor: raise NotImplementedError() + def send(x: Tensor, peer: int, comm_id: int) -> None: raise NotImplementedError() + # Recv is a little bit tricky since we need to pass the metadata of the recv buffer def recv(peer: int, comm_id: int) -> Tensor: - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() diff --git a/python/hidet/ir/func.py b/python/hidet/ir/func.py index fb5cfd40a..124ab9d8b 100644 --- a/python/hidet/ir/func.py +++ b/python/hidet/ir/func.py @@ -99,6 +99,7 @@ def use_distributed(self) -> bool: """ Return true if this function involves any distributed primitives """ + def _recursive_find(root: Stmt): if isinstance(root, BlackBoxStmt): if root.template_string.startswith('nccl'): @@ -108,5 +109,6 @@ def _recursive_find(root: Stmt): if _recursive_find(child): return True return False + ret = _recursive_find(self.body) - return ret \ No newline at end of file + return ret diff --git a/python/hidet/ir/module.py b/python/hidet/ir/module.py index c40f32e43..bcecb6725 100644 --- a/python/hidet/ir/module.py +++ b/python/hidet/ir/module.py @@ -117,4 +117,4 @@ def build(self): return load_compiled_module(output_dir) def use_distributed(self): - return any([func.use_distributed() for func in self.functions.values()]) \ No newline at end of file + return any((func.use_distributed() for func in self.functions.values())) From 6819ab8cc45065a1c66fea43346cf7191c949f0a Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Mon, 19 Jun 2023 18:33:40 -0400 Subject: [PATCH 07/40] add distributed graph --- python/hidet/graph/distributed.py | 47 +++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 python/hidet/graph/distributed.py diff --git a/python/hidet/graph/distributed.py b/python/hidet/graph/distributed.py new file mode 100644 index 000000000..89575abbe --- /dev/null +++ b/python/hidet/graph/distributed.py @@ -0,0 +1,47 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=protected-access + +from typing import List, Union + +from hidet.cuda.nccl import NcclCommunicator, NcclUniqueId, create_comm, comms_to_array +from hidet.ffi import runtime_api +from .flow_graph import FlowGraph +from .tensor import Tensor + + +class DistributedFlowGraph: + def __init__(self, g: FlowGraph, nranks: int, rank: int): + self._g = g + self._nranks = nranks + self._rank = rank + self._comms: List[NcclCommunicator] = [] + + def initialize(self, unique_id: NcclUniqueId): + """ + This is the default initialization function. + Should be replaced by a customized one if the compiler gives non-trivial schedule. + """ + comm = create_comm(self._nranks, unique_id, self._rank) + self._comms = [comm] + + def forward(self, inputs: List[Tensor]) -> List[Tensor]: + comms_array = comms_to_array(self._comms) + # We need an explicit variable to ensure the comms_array will not be garbage-collected. + runtime_api.set_nccl_comms(comms_array) + return self._g.forward(inputs) + + def __call__(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]: + comms_array = comms_to_array(self._comms) + # We need an explicit variable to ensure the comms_array will not be garbage-collected. + runtime_api.set_nccl_comms(comms_array) + return self._g(*inputs) From 95cec0c96b99bc75b2849db75699984b8795b32d Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 21 Jun 2023 17:45:07 -0400 Subject: [PATCH 08/40] update --- examples/distributed/test.py | 20 ++++++--- python/hidet/cuda/nccl/__init__.py | 2 + python/hidet/cuda/nccl/comm.py | 23 ++++++++-- python/hidet/cuda/nccl/ffi.py | 21 +++++----- python/hidet/drivers/build_graph.py | 13 +++++- python/hidet/graph/__init__.py | 1 - python/hidet/graph/distributed.py | 47 --------------------- python/hidet/graph/flow_graph.py | 18 +++++++- python/hidet/graph/graph_utils/functors.py | 2 +- python/hidet/graph/ops/distributed.py | 28 ++++++------- python/hidet/runtime/compiled_graph.py | 49 +++++++++++++++++++++- 11 files changed, 137 insertions(+), 87 deletions(-) delete mode 100644 python/hidet/graph/distributed.py diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 260ac6e15..549964f8c 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -1,6 +1,14 @@ """ Testing script for distributed components for hidet To debug, set the environment variable NCCL_DEBUG=INFO + +To install nccl, run + + pip install nvidia-nccl-cu11==2.18.3 + +Or + + pip install nvidia-nccl-cu12==2.18.3 """ import hidet import multiprocessing @@ -11,7 +19,8 @@ import hidet import hidet.cuda.nccl from hidet.cuda import nccl -from hidet.cuda.nccl import NcclUniqueId,NcclRedOp +from hidet.cuda.nccl import NcclUniqueId +from hidet.runtime.compiled_graph import GraphDistributedInfo print("NCCL version:", nccl.nccl_version()) @@ -38,15 +47,16 @@ def run(world_size, rank, shared_id, barrier): x_symb = hidet.symbol_like(x) w_symb = hidet.symbol_like(w) y_local = hidet.ops.relu(x_symb @ w_symb) - y_sync = hidet.ops.all_reduce(y_local, getattr(NcclRedOp, args.reduce_op), 0) + y_sync = hidet.ops.all_reduce(y_local, args.reduce_op) graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb]) opt_graph = hidet.graph.optimize(graph) + opt_graph.set_dist_attrs(nrank=world_size, rank=rank) + compiled = opt_graph.build() # Create Distributed Graph - dist_graph = hidet.graph.DistributedFlowGraph(opt_graph, world_size, rank) - dist_graph.initialize(shared_id) + compiled.init_dist(shared_id) - y_local, y_sync = dist_graph(x, w) + y_local, y_sync = compiled(x, w) s = hidet.cuda.current_stream() s.synchronize() diff --git a/python/hidet/cuda/nccl/__init__.py b/python/hidet/cuda/nccl/__init__.py index aec4fe718..38dd766e5 100644 --- a/python/hidet/cuda/nccl/__init__.py +++ b/python/hidet/cuda/nccl/__init__.py @@ -19,4 +19,6 @@ init_unique_id, dtype_to_nccl, NcclCommunicator, + str_to_nccl_op, + NCCL_SPLIT_NOCOLOR ) diff --git a/python/hidet/cuda/nccl/comm.py b/python/hidet/cuda/nccl/comm.py index 241ffc18f..a74c340a3 100644 --- a/python/hidet/cuda/nccl/comm.py +++ b/python/hidet/cuda/nccl/comm.py @@ -15,7 +15,12 @@ from hidet.ffi.utils import Array from hidet.ir.type import void_p, DataType -from .ffi import nccl_runtime_api, NcclUniqueId +from .ffi import nccl_available, NcclUniqueId + +NCCL_SPLIT_NOCOLOR = -1 + +if nccl_available: + from .ffi import nccl_runtime_api class NcclDataType(IntEnum): @@ -43,6 +48,10 @@ class NcclRedOp(IntEnum): min = 3 avg = 4 +def str_to_nccl_op(name: str) -> NcclRedOp: + if name not in ('sum', 'prod', 'max', 'min', 'avg'): + raise RuntimeError(f"'{name}' is not a supported reduce op") + return getattr(NcclRedOp, name) class NcclCommunicator: def __init__(self, handle: int): @@ -50,7 +59,8 @@ def __init__(self, handle: int): Users should not call this constructor directly. Because there are two ways of creating a new communicator: 1) using unique_id and rank ; 2) using split. """ - + if not nccl_available(): + raise RuntimeError("NCCL is not available") self._handle = handle def __del__(self): @@ -60,11 +70,14 @@ def __del__(self): def handle(self): return self._handle - def split(self): - raise NotImplementedError() + def split(self, key, color): + new_handle = nccl_runtime_api.comm_split(self._handle, color, key) + return NcclCommunicator(new_handle) def create_comm(nranks: int, unique_id: NcclUniqueId, rank: int) -> NcclCommunicator: + if not nccl_available(): + raise RuntimeError("NCCL is not available") handle = nccl_runtime_api.comm_init_rank(nranks, unique_id, rank) return NcclCommunicator(handle) @@ -77,6 +90,8 @@ def comms_to_array(comms: List[NcclCommunicator]) -> Array: def init_unique_id(unqie_id: NcclUniqueId) -> None: + if not nccl_available(): + raise RuntimeError("NCCL is not available") nccl_runtime_api.get_unique_id(unqie_id) diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index 66a7dbc74..6d2197d3b 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -49,21 +49,12 @@ def load_nccl_library(): _LIB_NCCL = ctypes.cdll.LoadLibrary(lib_nccl_paths[0]) nccl_library_path = lib_nccl_paths[0] break - if _LIB_NCCL is None: - raise OSError('Can not find nccl library in the following directory: \n' + '\n'.join(library_dirs)) - load_nccl_library() - def nccl_library_filename(): return os.path.basename(nccl_library_path) - -if not nccl_available(): - raise RuntimeError("NCCL Library not found.") - - class NCCLRuntimeAPI: """ Runtime APIs regarding NCCL @@ -78,6 +69,8 @@ class NCCLRuntimeAPI: _comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) + _comm_split = get_func('ncclCommSplit', [c_void_p, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL) + @staticmethod def get_version() -> int: version = c_int(0) @@ -104,5 +97,13 @@ def comm_destroy(comm_handle) -> None: ret = NCCLRuntimeAPI._comm_destroy(comm_handle) assert ret == 0 + @staticmethod + def comm_split(comm_handle: int, color: int, key: int) -> int: + comm = c_void_p() + ret = NCCLRuntimeAPI._comm_split(comm_handle, color, key, pointer(comm), None) + assert ret == 0 + return comm.value + -nccl_runtime_api = NCCLRuntimeAPI() +if nccl_available(): + nccl_runtime_api = NCCLRuntimeAPI() diff --git a/python/hidet/drivers/build_graph.py b/python/hidet/drivers/build_graph.py index 994d13d88..05b3b8987 100644 --- a/python/hidet/drivers/build_graph.py +++ b/python/hidet/drivers/build_graph.py @@ -22,7 +22,7 @@ from hidet.graph.tensor import Tensor from hidet.graph.flow_graph import FlowGraph from hidet.runtime.compiled_module import CompiledModule -from hidet.runtime.compiled_graph import CompiledGraph, GraphMetaData, GraphExecution, GraphExecutionInstruction +from hidet.runtime.compiled_graph import CompiledGraph, GraphMetaData, GraphExecution, GraphExecutionInstruction, GraphDistributedInfo from hidet.runtime.compiled_task import CompiledTask, TensorSignature from hidet.graph.operator import Operator from hidet.ir import primitives @@ -141,6 +141,13 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD inputs=inputs, outputs=outputs, hidet_version=hidet.__version__, num_kernels=num_kernels, graph_hash=graph_hash ) +def get_graph_dist_info(graph: FlowGraph) -> GraphDistributedInfo: + if not graph.is_distributed(): + return None + return GraphDistributedInfo( + nrank = graph._nrank, + rank = graph._rank, + groups = graph._groups) def build_graph_module(graph: FlowGraph, graph_weights: List[Tensor], node2kernel: List[int]) -> CompiledModule: from hidet.lang import void_p, attrs, int32, int64, meta, cast @@ -329,6 +336,9 @@ def build_flow_graph(graph, *, space=0) -> CompiledGraph: # get the graph meta data graph_meta_data = get_graph_meta_data(graph, len(graph_kernels), space) + # get distributed information + graph_dist_info = get_graph_dist_info(graph) + # build the compiled graph compiled_graph = CompiledGraph( meta=graph_meta_data, @@ -337,6 +347,7 @@ def build_flow_graph(graph, *, space=0) -> CompiledGraph: compiled_tasks=graph_kernels, graph_execution=graph_execution, graph_string=str(graph), + dist_info=graph_dist_info ) # save the compiled graph to cache diff --git a/python/hidet/graph/__init__.py b/python/hidet/graph/__init__.py index a7d5aaf10..dc105b96e 100644 --- a/python/hidet/graph/__init__.py +++ b/python/hidet/graph/__init__.py @@ -29,4 +29,3 @@ from .tensor import from_numpy, from_dlpack, from_torch from .flow_graph import trace_from, load_graph, save_graph, forward_context from .transforms import optimize -from .distributed import DistributedFlowGraph diff --git a/python/hidet/graph/distributed.py b/python/hidet/graph/distributed.py deleted file mode 100644 index 89575abbe..000000000 --- a/python/hidet/graph/distributed.py +++ /dev/null @@ -1,47 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# pylint: disable=protected-access - -from typing import List, Union - -from hidet.cuda.nccl import NcclCommunicator, NcclUniqueId, create_comm, comms_to_array -from hidet.ffi import runtime_api -from .flow_graph import FlowGraph -from .tensor import Tensor - - -class DistributedFlowGraph: - def __init__(self, g: FlowGraph, nranks: int, rank: int): - self._g = g - self._nranks = nranks - self._rank = rank - self._comms: List[NcclCommunicator] = [] - - def initialize(self, unique_id: NcclUniqueId): - """ - This is the default initialization function. - Should be replaced by a customized one if the compiler gives non-trivial schedule. - """ - comm = create_comm(self._nranks, unique_id, self._rank) - self._comms = [comm] - - def forward(self, inputs: List[Tensor]) -> List[Tensor]: - comms_array = comms_to_array(self._comms) - # We need an explicit variable to ensure the comms_array will not be garbage-collected. - runtime_api.set_nccl_comms(comms_array) - return self._g.forward(inputs) - - def __call__(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]: - comms_array = comms_to_array(self._comms) - # We need an explicit variable to ensure the comms_array will not be garbage-collected. - runtime_api.set_nccl_comms(comms_array) - return self._g(*inputs) diff --git a/python/hidet/graph/flow_graph.py b/python/hidet/graph/flow_graph.py index 2286d99d0..cce662973 100644 --- a/python/hidet/graph/flow_graph.py +++ b/python/hidet/graph/flow_graph.py @@ -107,13 +107,18 @@ def forward_context() -> GraphForwardContext: class FlowGraph: """The computation graph representation.""" - def __init__(self, outputs: Sequence[Tensor], inputs: Optional[Sequence[Tensor]] = None, nodes=None): + def __init__(self, outputs: Sequence[Tensor], inputs: Optional[Sequence[Tensor]] = None, nodes=None, nrank=None, rank=None, groups=None): self.outputs: List[Tensor] = list(outputs) self.inputs: Optional[List[Tensor]] = list(inputs) if inputs is not None else None self._nodes: Optional[List[Operator]] = nodes self._usage_count: Optional[Dict[Tensor, int]] = None self.update_nodes() + # For distributed graphs + self._nrank = nrank + self._rank = rank + self._groups = groups + def __call__(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]: """ Run the computation graph. @@ -179,6 +184,14 @@ def _build_nodes(self): hidet.option.parallel_build(False) hidet.drivers.build_task_batch(tunable_tasks) # build tunable tasks one by one + def is_distributed(self): + return self._nrank is not None or self._rank is not None + + def set_dist_attrs(self, nrank: int, rank: int, groups: Optional[List[List[int]]] = None): + self._nrank = nrank + self._rank = rank + self._groups = groups + def forward(self, inputs: List[Tensor]) -> List[Tensor]: """Run the computation graph. @@ -193,6 +206,9 @@ def forward(self, inputs: List[Tensor]) -> List[Tensor]: output: List[Tensor] The output tensors of the computation graph. """ + if self.is_distributed(): + raise RuntimeError("Running Distributed FlowGraph is not supported. Please compile it first.") + from hidet.ffi import runtime_api inputs: List[Tensor] = list(inputs) diff --git a/python/hidet/graph/graph_utils/functors.py b/python/hidet/graph/graph_utils/functors.py index e23be6a12..c81084b25 100644 --- a/python/hidet/graph/graph_utils/functors.py +++ b/python/hidet/graph/graph_utils/functors.py @@ -119,7 +119,7 @@ def visit_Sequence(self, seq: Union[list, tuple]): class GraphCloneRewriter(GraphRewriter): def visit_FlowGraph(self, graph: FlowGraph): outputs = [self.visit(output) for output in graph.outputs] - return FlowGraph(outputs, graph.inputs) + return FlowGraph(outputs, graph.inputs, nrank=graph._nrank, rank=graph._rank, groups=graph._groups) def visit_Operator(self, op: Operator): inputs = [self(x) for x in op.inputs] diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py index 4fce12027..0cddc9d98 100644 --- a/python/hidet/graph/ops/distributed.py +++ b/python/hidet/graph/ops/distributed.py @@ -16,12 +16,12 @@ from hidet.ir.module import IRModule from hidet.ir.task import Target from hidet.utils import prod -from hidet.cuda.nccl import NcclRedOp +from hidet.cuda.nccl import str_to_nccl_op from .utils import Task, TensorNode, Operator, Tensor, compute, input_like class AllReduceTask(Task): - def __init__(self, x: TensorNode, op: NcclRedOp, comm_id: int): + def __init__(self, x: TensorNode, op: str, comm_id: int=0): y = compute('out', x.shape, lambda *indices: x[indices]) self.comm_id = comm_id self.op = op @@ -29,7 +29,7 @@ def __init__(self, x: TensorNode, op: NcclRedOp, comm_id: int): super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={}) def __str__(self): - return f"all_reduce_{int(self.op)}_{self.comm_id}" + return f"all_reduce_{self.op}_{self.comm_id}" def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]: # we may need current rank here to avoid duplicated working_dirs @@ -46,44 +46,40 @@ def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModu @hidet.script def launch(x: dtype[shape], y: dtype[shape]): attrs.func_kind = 'public' - _all_reduce(x, y, nbytes, dtype, self.op, self.comm_id) + _all_reduce(x, y, nbytes, dtype, str_to_nccl_op(self.op), self.comm_id) return [script_module.ir_module()] class AllReduceOp(Operator): - def __init__(self, x: Tensor, op: NcclRedOp, comm_id: int): + def __init__(self, x: Tensor, op: str, comm_id: int): super().__init__( inputs=[x], attributes={'op': op, 'comm_id': comm_id}, task=AllReduceTask(input_like(x, 'x'), op, comm_id) ) - -def all_reduce(x: Tensor, op: NcclRedOp, comm_id: int) -> Tensor: +def all_reduce(x: Tensor, op: str, comm_id: int=0) -> Tensor: if x.device.kind != 'cuda': raise RuntimeError("NCCL only supports CUDA tensors") return AllReduceOp(x, op, comm_id).outputs[0] - -def broadcast(x: Tensor, root: int, comm_id: int) -> Tensor: +def broadcast(x: Tensor, root: int, comm_id: int=0) -> Tensor: raise NotImplementedError() - -def reduce(x: Tensor, root: int, op: NcclRedOp, comm_id: int) -> Tensor: +def reduce(x: Tensor, root: int, op: str, comm_id: int=0) -> Tensor: raise NotImplementedError() - -def all_gather(x: Tensor, comm_id: int) -> Tensor: +def all_gather(x: Tensor, comm_id: int=0) -> Tensor: raise NotImplementedError() -def reduce_scatter(x: Tensor, op: NcclRedOp, comm_id: int) -> Tensor: +def reduce_scatter(x: Tensor, op: str, comm_id: int=0) -> Tensor: raise NotImplementedError() -def send(x: Tensor, peer: int, comm_id: int) -> None: +def send(x: Tensor, peer: int, comm_id: int=0) -> None: raise NotImplementedError() # Recv is a little bit tricky since we need to pass the metadata of the recv buffer -def recv(peer: int, comm_id: int) -> Tensor: +def recv(peer: int, comm_id: int=0) -> Tensor: raise NotImplementedError() diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index 739c9ccd3..020127fa4 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -28,6 +28,7 @@ from hidet.runtime.storage import Storage from hidet.ffi import runtime_api from hidet.utils import prod +from hidet.cuda.nccl import NcclCommunicator, NcclUniqueId, create_comm, NCCL_SPLIT_NOCOLOR, comms_to_array ModelExecutionHook = Callable[[int, List['Tensor'], List['Tensor']], None] @@ -63,6 +64,12 @@ class GraphExecution: outputs_index: List[int] tensor_device: List[str] +@dataclass +class GraphDistributedInfo: + nrank: int + rank: int + groups: List[List[int]] + class CompiledGraph: def __init__( @@ -73,6 +80,7 @@ def __init__( compiled_tasks: List[CompiledTask], graph_execution: GraphExecution, graph_string: str, + dist_info: Optional[GraphDistributedInfo]=None ): from hidet.graph.tensor import Tensor @@ -104,6 +112,10 @@ def __init__( self.dispatch_table: Dict[Tuple[int, ...], Array] = {} self.cuda_workspace: Optional[Storage] = None self.cpu_workspace: Optional[Storage] = None + + # distributed properties + self.dist_info: Optional[GraphDistributedInfo] = dist_info + self.nccl_comms: List[NcclCommunicator] = [] self._init_compiled_graph() @@ -169,6 +181,26 @@ def _init_compiled_graph(self): kernel_array[task_idx] = ctypes_func_pointer(compiled_task.candidates[sch_idx].ctypes_func) self.dispatch_table[tuple(symbol_dims)] = kernel_array + def init_dist(self, unique_id: NcclUniqueId): + if self.dist_info is None: + raise RuntimeError("Distributed information is not set.") + self.nccl_comms = [] + + # Initialize the default group + nrank = self.dist_info.nrank + rank = self.dist_info.rank + default_comm = create_comm(nrank, unique_id, rank) + self.nccl_comms.append(default_comm) + + # Create communicators according to groups + if self.dist_info.groups is not None: + for group in self.dist_info.groups: + in_group = rank in group + color = 0 if in_group else NCCL_SPLIT_NOCOLOR + key = group.index(rank) if in_group else 0 + self.nccl_comms.append(default_comm.split(key, color)) + + def _update_symbol_table(self, symbol_dims: Tuple[int, ...], best_candidates: List[int]): kernel_array = Array(void_p, len(self.compiled_tasks)) for task_idx, best_candidate in enumerate(best_candidates): @@ -277,6 +309,10 @@ def run_async(self, inputs): ret: List[hidet.Tensor] The output tensors. """ + if self.dist_info is not None: + comms_array = comms_to_array(self.nccl_comms) + runtime_api.set_nccl_comms(comms_array) + if hidet.option.get_runtime_check(): _check_inputs(self.meta.inputs, inputs) @@ -363,6 +399,12 @@ def _save_under(dir_path: str, dir_in_zip: str, exclude: Optional[List[str]] = N # save graph string with zf.open('graph_string.txt', 'w') as f: f.write(model.graph_string.encode('utf-8')) + + # save distibuted information + if model.dist_info is not None: + with zf.open('dist_info.json', 'w') as f: + dist_info_bytes = json.dumps(asdict(model.dist_info), indent=4).encode('utf-8') + f.write(dist_info_bytes) def load_compiled_graph(path: str) -> CompiledGraph: @@ -380,6 +422,11 @@ def load_compiled_graph(path: str) -> CompiledGraph: with zf.open('graph_execution.json', 'r') as f: graph_execution: GraphExecution = from_dict(GraphExecution, json.load(f)) + # load dist info + if zipfile.Path(zf, 'dist_info.json').exists(): + with zf.open('dist_info.json', 'r') as f: + dist_info: GraphDistributedInfo = from_dict(GraphDistributedInfo, json.load(f)) + # load weights as numpy arrays with zf.open('weights.npz', 'r') as f: with zipfile.ZipFile(f, 'r') as npz: @@ -411,6 +458,6 @@ def load_compiled_graph(path: str) -> CompiledGraph: graph_string = f.read() # construct the compiled graph - ret = CompiledGraph(meta_data, graph_module, weights, compiled_tasks, graph_execution, graph_string) + ret = CompiledGraph(meta_data, graph_module, weights, compiled_tasks, graph_execution, graph_string, dist_info=dist_info) return ret From 961e99bc917060d998aca590c35240a26115aef2 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 21 Jun 2023 20:06:24 -0400 Subject: [PATCH 09/40] support split --- examples/distributed/test.py | 13 ++++++++++++- python/hidet/cuda/nccl/comm.py | 2 ++ python/hidet/drivers/build_module.py | 15 --------------- python/hidet/ir/func.py | 17 ----------------- python/hidet/ir/module.py | 3 --- python/hidet/runtime/compiled_graph.py | 6 +++++- python/hidet/transforms/__init__.py | 3 ++- 7 files changed, 21 insertions(+), 38 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 549964f8c..63e5ebba3 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -27,6 +27,7 @@ parser = argparse.ArgumentParser() parser.add_argument("n_gpus", type=int) parser.add_argument("reduce_op", choices=['sum', 'prod', 'max', 'min', 'avg']) +parser.add_argument("--group_size", type=int, default=0) args = parser.parse_args() def run(world_size, rank, shared_id, barrier): @@ -39,6 +40,16 @@ def run(world_size, rank, shared_id, barrier): barrier.wait() hidet.cuda.set_device(rank) + use_group = args.group_size > 1 + if use_group: + gs = args.group_size + gn = world_size // gs + assert world_size % gs == 0 + groups = [list(range(i * gs, (i + 1) * gs)) for i in range(gn)] + else: + groups = [] + + device = f"cuda:{rank}" x = hidet.randn([2, 3], device=device) w = hidet.randn([3, 2], device=device) @@ -50,7 +61,7 @@ def run(world_size, rank, shared_id, barrier): y_sync = hidet.ops.all_reduce(y_local, args.reduce_op) graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb]) opt_graph = hidet.graph.optimize(graph) - opt_graph.set_dist_attrs(nrank=world_size, rank=rank) + opt_graph.set_dist_attrs(nrank=world_size, rank=rank, groups=groups) compiled = opt_graph.build() # Create Distributed Graph diff --git a/python/hidet/cuda/nccl/comm.py b/python/hidet/cuda/nccl/comm.py index a74c340a3..60fd94d5d 100644 --- a/python/hidet/cuda/nccl/comm.py +++ b/python/hidet/cuda/nccl/comm.py @@ -72,6 +72,8 @@ def handle(self): def split(self, key, color): new_handle = nccl_runtime_api.comm_split(self._handle, color, key) + if color == NCCL_SPLIT_NOCOLOR: + return None return NcclCommunicator(new_handle) diff --git a/python/hidet/drivers/build_module.py b/python/hidet/drivers/build_module.py index cfee97a3e..05671bec8 100644 --- a/python/hidet/drivers/build_module.py +++ b/python/hidet/drivers/build_module.py @@ -55,21 +55,6 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output with PassContext(instruments=instruments): ir_module = lower(ir_module) - # nccl-related - if ir_module.use_distributed(): - if target != 'cuda': - raise RuntimeError("IRModules using NCCL must be targeted for cuda") - from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs - from hidet.cuda.nccl import nccl_available, nccl_library_filename - - if not nccl_available(): - raise RuntimeError("NCCL is not available") - - ir_module.include_dirs.extend(get_nccl_include_dirs()) - ir_module.linking_dirs.extend(get_nccl_library_search_dirs()) - ir_module.include_headers.append(["nccl.h"]) - ir_module.linking_libs.append(":" + nccl_library_filename()) - # code generation codegen(ir_module, src_out_path=src_path, target=target) diff --git a/python/hidet/ir/func.py b/python/hidet/ir/func.py index 124ab9d8b..fb7d9cc77 100644 --- a/python/hidet/ir/func.py +++ b/python/hidet/ir/func.py @@ -95,20 +95,3 @@ def get_attr(self, attr_name, default=None, allow_missing=False): else: raise KeyError('Attribute {} is not found in function {}'.format(attr_name, self.name)) - def use_distributed(self) -> bool: - """ - Return true if this function involves any distributed primitives - """ - - def _recursive_find(root: Stmt): - if isinstance(root, BlackBoxStmt): - if root.template_string.startswith('nccl'): - return True - for child in dir(root): - if isinstance(child, Stmt): - if _recursive_find(child): - return True - return False - - ret = _recursive_find(self.body) - return ret diff --git a/python/hidet/ir/module.py b/python/hidet/ir/module.py index bcecb6725..751c05a24 100644 --- a/python/hidet/ir/module.py +++ b/python/hidet/ir/module.py @@ -115,6 +115,3 @@ def build(self): build_ir_module(self, output_dir, target=target) return load_compiled_module(output_dir) - - def use_distributed(self): - return any((func.use_distributed() for func in self.functions.values())) diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index 020127fa4..ae33bad54 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -198,7 +198,11 @@ def init_dist(self, unique_id: NcclUniqueId): in_group = rank in group color = 0 if in_group else NCCL_SPLIT_NOCOLOR key = group.index(rank) if in_group else 0 - self.nccl_comms.append(default_comm.split(key, color)) + comm = default_comm.split(key, color) + if in_group: + self.nccl_comms.append(comm) + + print(self.nccl_comms) def _update_symbol_table(self, symbol_dims: Tuple[int, ...], best_candidates: List[int]): diff --git a/python/hidet/transforms/__init__.py b/python/hidet/transforms/__init__.py index f91be59ef..ce6a48bf9 100644 --- a/python/hidet/transforms/__init__.py +++ b/python/hidet/transforms/__init__.py @@ -36,7 +36,7 @@ from .propagate_launch_bound import propagate_launch_bound_pass from .check_launch_configuration import check_launch_configuration_pass from .lower_special_cast import lower_special_cast_pass - +from .include_nccl import include_nccl_pass def lower_with(ir_module: IRModule, transforms: Sequence[Pass]) -> IRModule: ctx = PassContext.current() @@ -80,5 +80,6 @@ def lower(ir_module: IRModule) -> IRModule: rule_based_simplify_pass(), inline_let_stmt_pass(), simplify_stmt_pass(), + include_nccl_pass() ] return lower_with(ir_module, transforms) From 334a1eb57ca1442d7c281555c06fd10dbeb5a123 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 21 Jun 2023 20:13:12 -0400 Subject: [PATCH 10/40] update --- examples/distributed/test.py | 6 +-- python/hidet/cuda/nccl/__init__.py | 2 +- python/hidet/cuda/nccl/comm.py | 4 +- python/hidet/cuda/nccl/ffi.py | 3 ++ python/hidet/drivers/build_graph.py | 17 ++++--- python/hidet/graph/flow_graph.py | 24 ++++++--- python/hidet/graph/graph_utils/functors.py | 2 +- python/hidet/graph/ops/distributed.py | 20 +++++--- python/hidet/ir/func.py | 3 +- python/hidet/runtime/compiled_graph.py | 14 ++--- python/hidet/transforms/__init__.py | 3 +- python/hidet/transforms/include_nccl.py | 59 ++++++++++++++++++++++ 12 files changed, 119 insertions(+), 38 deletions(-) create mode 100644 python/hidet/transforms/include_nccl.py diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 63e5ebba3..37c45686c 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -51,14 +51,14 @@ def run(world_size, rank, shared_id, barrier): device = f"cuda:{rank}" - x = hidet.randn([2, 3], device=device) + x = hidet.randn([1, 3], device=device) w = hidet.randn([3, 2], device=device) # Create Computation Graph x_symb = hidet.symbol_like(x) w_symb = hidet.symbol_like(w) y_local = hidet.ops.relu(x_symb @ w_symb) - y_sync = hidet.ops.all_reduce(y_local, args.reduce_op) + y_sync = hidet.ops.all_reduce(y_local, args.reduce_op, comm_id=int(use_group)) graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb]) opt_graph = hidet.graph.optimize(graph) opt_graph.set_dist_attrs(nrank=world_size, rank=rank, groups=groups) @@ -71,7 +71,7 @@ def run(world_size, rank, shared_id, barrier): s = hidet.cuda.current_stream() s.synchronize() - print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}") + print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}\n", end='') world_size = args.n_gpus diff --git a/python/hidet/cuda/nccl/__init__.py b/python/hidet/cuda/nccl/__init__.py index 38dd766e5..4c7a0090a 100644 --- a/python/hidet/cuda/nccl/__init__.py +++ b/python/hidet/cuda/nccl/__init__.py @@ -20,5 +20,5 @@ dtype_to_nccl, NcclCommunicator, str_to_nccl_op, - NCCL_SPLIT_NOCOLOR + NCCL_SPLIT_NOCOLOR, ) diff --git a/python/hidet/cuda/nccl/comm.py b/python/hidet/cuda/nccl/comm.py index 60fd94d5d..48e232787 100644 --- a/python/hidet/cuda/nccl/comm.py +++ b/python/hidet/cuda/nccl/comm.py @@ -19,7 +19,7 @@ NCCL_SPLIT_NOCOLOR = -1 -if nccl_available: +if nccl_available(): from .ffi import nccl_runtime_api @@ -48,11 +48,13 @@ class NcclRedOp(IntEnum): min = 3 avg = 4 + def str_to_nccl_op(name: str) -> NcclRedOp: if name not in ('sum', 'prod', 'max', 'min', 'avg'): raise RuntimeError(f"'{name}' is not a supported reduce op") return getattr(NcclRedOp, name) + class NcclCommunicator: def __init__(self, handle: int): """ diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index 6d2197d3b..cb279b82a 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -50,11 +50,14 @@ def load_nccl_library(): nccl_library_path = lib_nccl_paths[0] break + load_nccl_library() + def nccl_library_filename(): return os.path.basename(nccl_library_path) + class NCCLRuntimeAPI: """ Runtime APIs regarding NCCL diff --git a/python/hidet/drivers/build_graph.py b/python/hidet/drivers/build_graph.py index 05b3b8987..5d83f5acf 100644 --- a/python/hidet/drivers/build_graph.py +++ b/python/hidet/drivers/build_graph.py @@ -22,7 +22,13 @@ from hidet.graph.tensor import Tensor from hidet.graph.flow_graph import FlowGraph from hidet.runtime.compiled_module import CompiledModule -from hidet.runtime.compiled_graph import CompiledGraph, GraphMetaData, GraphExecution, GraphExecutionInstruction, GraphDistributedInfo +from hidet.runtime.compiled_graph import ( + CompiledGraph, + GraphMetaData, + GraphExecution, + GraphExecutionInstruction, + GraphDistributedInfo, +) from hidet.runtime.compiled_task import CompiledTask, TensorSignature from hidet.graph.operator import Operator from hidet.ir import primitives @@ -141,13 +147,12 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD inputs=inputs, outputs=outputs, hidet_version=hidet.__version__, num_kernels=num_kernels, graph_hash=graph_hash ) + def get_graph_dist_info(graph: FlowGraph) -> GraphDistributedInfo: if not graph.is_distributed(): return None - return GraphDistributedInfo( - nrank = graph._nrank, - rank = graph._rank, - groups = graph._groups) + return GraphDistributedInfo(nrank=graph.nrank, rank=graph.rank, groups=graph.groups) + def build_graph_module(graph: FlowGraph, graph_weights: List[Tensor], node2kernel: List[int]) -> CompiledModule: from hidet.lang import void_p, attrs, int32, int64, meta, cast @@ -347,7 +352,7 @@ def build_flow_graph(graph, *, space=0) -> CompiledGraph: compiled_tasks=graph_kernels, graph_execution=graph_execution, graph_string=str(graph), - dist_info=graph_dist_info + dist_info=graph_dist_info, ) # save the compiled graph to cache diff --git a/python/hidet/graph/flow_graph.py b/python/hidet/graph/flow_graph.py index cce662973..884a369b4 100644 --- a/python/hidet/graph/flow_graph.py +++ b/python/hidet/graph/flow_graph.py @@ -107,7 +107,15 @@ def forward_context() -> GraphForwardContext: class FlowGraph: """The computation graph representation.""" - def __init__(self, outputs: Sequence[Tensor], inputs: Optional[Sequence[Tensor]] = None, nodes=None, nrank=None, rank=None, groups=None): + def __init__( + self, + outputs: Sequence[Tensor], + inputs: Optional[Sequence[Tensor]] = None, + nodes=None, + nrank=None, + rank=None, + groups=None, + ): self.outputs: List[Tensor] = list(outputs) self.inputs: Optional[List[Tensor]] = list(inputs) if inputs is not None else None self._nodes: Optional[List[Operator]] = nodes @@ -115,9 +123,9 @@ def __init__(self, outputs: Sequence[Tensor], inputs: Optional[Sequence[Tensor]] self.update_nodes() # For distributed graphs - self._nrank = nrank - self._rank = rank - self._groups = groups + self.nrank = nrank + self.rank = rank + self.groups = groups def __call__(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]: """ @@ -185,12 +193,12 @@ def _build_nodes(self): hidet.drivers.build_task_batch(tunable_tasks) # build tunable tasks one by one def is_distributed(self): - return self._nrank is not None or self._rank is not None + return self.nrank is not None or self.rank is not None def set_dist_attrs(self, nrank: int, rank: int, groups: Optional[List[List[int]]] = None): - self._nrank = nrank - self._rank = rank - self._groups = groups + self.nrank = nrank + self.rank = rank + self.groups = groups def forward(self, inputs: List[Tensor]) -> List[Tensor]: """Run the computation graph. diff --git a/python/hidet/graph/graph_utils/functors.py b/python/hidet/graph/graph_utils/functors.py index c81084b25..ce14a7d6a 100644 --- a/python/hidet/graph/graph_utils/functors.py +++ b/python/hidet/graph/graph_utils/functors.py @@ -119,7 +119,7 @@ def visit_Sequence(self, seq: Union[list, tuple]): class GraphCloneRewriter(GraphRewriter): def visit_FlowGraph(self, graph: FlowGraph): outputs = [self.visit(output) for output in graph.outputs] - return FlowGraph(outputs, graph.inputs, nrank=graph._nrank, rank=graph._rank, groups=graph._groups) + return FlowGraph(outputs, graph.inputs, nrank=graph.nrank, rank=graph.rank, groups=graph.groups) def visit_Operator(self, op: Operator): inputs = [self(x) for x in op.inputs] diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py index 0cddc9d98..41a93a8b9 100644 --- a/python/hidet/graph/ops/distributed.py +++ b/python/hidet/graph/ops/distributed.py @@ -21,7 +21,7 @@ class AllReduceTask(Task): - def __init__(self, x: TensorNode, op: str, comm_id: int=0): + def __init__(self, x: TensorNode, op: str, comm_id: int = 0): y = compute('out', x.shape, lambda *indices: x[indices]) self.comm_id = comm_id self.op = op @@ -57,29 +57,33 @@ def __init__(self, x: Tensor, op: str, comm_id: int): inputs=[x], attributes={'op': op, 'comm_id': comm_id}, task=AllReduceTask(input_like(x, 'x'), op, comm_id) ) -def all_reduce(x: Tensor, op: str, comm_id: int=0) -> Tensor: + +def all_reduce(x: Tensor, op: str, comm_id: int = 0) -> Tensor: if x.device.kind != 'cuda': raise RuntimeError("NCCL only supports CUDA tensors") return AllReduceOp(x, op, comm_id).outputs[0] -def broadcast(x: Tensor, root: int, comm_id: int=0) -> Tensor: + +def broadcast(x: Tensor, root: int, comm_id: int = 0) -> Tensor: raise NotImplementedError() -def reduce(x: Tensor, root: int, op: str, comm_id: int=0) -> Tensor: + +def reduce(x: Tensor, root: int, op: str, comm_id: int = 0) -> Tensor: raise NotImplementedError() -def all_gather(x: Tensor, comm_id: int=0) -> Tensor: + +def all_gather(x: Tensor, comm_id: int = 0) -> Tensor: raise NotImplementedError() -def reduce_scatter(x: Tensor, op: str, comm_id: int=0) -> Tensor: +def reduce_scatter(x: Tensor, op: str, comm_id: int = 0) -> Tensor: raise NotImplementedError() -def send(x: Tensor, peer: int, comm_id: int=0) -> None: +def send(x: Tensor, peer: int, comm_id: int = 0) -> None: raise NotImplementedError() # Recv is a little bit tricky since we need to pass the metadata of the recv buffer -def recv(peer: int, comm_id: int=0) -> Tensor: +def recv(peer: int, comm_id: int = 0) -> Tensor: raise NotImplementedError() diff --git a/python/hidet/ir/func.py b/python/hidet/ir/func.py index fb7d9cc77..674879ec8 100644 --- a/python/hidet/ir/func.py +++ b/python/hidet/ir/func.py @@ -14,7 +14,7 @@ from hidet.ir.node import Node from hidet.ir.type import BaseType from hidet.ir.expr import Var, Call -from hidet.ir.stmt import Stmt, BlackBoxStmt +from hidet.ir.stmt import Stmt def check_func_name(name: str): @@ -94,4 +94,3 @@ def get_attr(self, attr_name, default=None, allow_missing=False): return default else: raise KeyError('Attribute {} is not found in function {}'.format(attr_name, self.name)) - diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index ae33bad54..2b26b61e4 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -64,6 +64,7 @@ class GraphExecution: outputs_index: List[int] tensor_device: List[str] + @dataclass class GraphDistributedInfo: nrank: int @@ -80,7 +81,7 @@ def __init__( compiled_tasks: List[CompiledTask], graph_execution: GraphExecution, graph_string: str, - dist_info: Optional[GraphDistributedInfo]=None + dist_info: Optional[GraphDistributedInfo] = None, ): from hidet.graph.tensor import Tensor @@ -112,7 +113,7 @@ def __init__( self.dispatch_table: Dict[Tuple[int, ...], Array] = {} self.cuda_workspace: Optional[Storage] = None self.cpu_workspace: Optional[Storage] = None - + # distributed properties self.dist_info: Optional[GraphDistributedInfo] = dist_info self.nccl_comms: List[NcclCommunicator] = [] @@ -201,9 +202,6 @@ def init_dist(self, unique_id: NcclUniqueId): comm = default_comm.split(key, color) if in_group: self.nccl_comms.append(comm) - - print(self.nccl_comms) - def _update_symbol_table(self, symbol_dims: Tuple[int, ...], best_candidates: List[int]): kernel_array = Array(void_p, len(self.compiled_tasks)) @@ -403,7 +401,7 @@ def _save_under(dir_path: str, dir_in_zip: str, exclude: Optional[List[str]] = N # save graph string with zf.open('graph_string.txt', 'w') as f: f.write(model.graph_string.encode('utf-8')) - + # save distibuted information if model.dist_info is not None: with zf.open('dist_info.json', 'w') as f: @@ -462,6 +460,8 @@ def load_compiled_graph(path: str) -> CompiledGraph: graph_string = f.read() # construct the compiled graph - ret = CompiledGraph(meta_data, graph_module, weights, compiled_tasks, graph_execution, graph_string, dist_info=dist_info) + ret = CompiledGraph( + meta_data, graph_module, weights, compiled_tasks, graph_execution, graph_string, dist_info=dist_info + ) return ret diff --git a/python/hidet/transforms/__init__.py b/python/hidet/transforms/__init__.py index ce6a48bf9..5d40bab05 100644 --- a/python/hidet/transforms/__init__.py +++ b/python/hidet/transforms/__init__.py @@ -38,6 +38,7 @@ from .lower_special_cast import lower_special_cast_pass from .include_nccl import include_nccl_pass + def lower_with(ir_module: IRModule, transforms: Sequence[Pass]) -> IRModule: ctx = PassContext.current() for instrument in ctx.instruments: @@ -80,6 +81,6 @@ def lower(ir_module: IRModule) -> IRModule: rule_based_simplify_pass(), inline_let_stmt_pass(), simplify_stmt_pass(), - include_nccl_pass() + include_nccl_pass(), ] return lower_with(ir_module, transforms) diff --git a/python/hidet/transforms/include_nccl.py b/python/hidet/transforms/include_nccl.py new file mode 100644 index 000000000..b0dcb3640 --- /dev/null +++ b/python/hidet/transforms/include_nccl.py @@ -0,0 +1,59 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.ir.module import IRModule +from hidet.ir import Stmt +from hidet.ir.stmt import BlackBoxStmt +from hidet.transforms import Pass + + +def _use_distributed(func) -> bool: + def _recursive_find(root: Stmt): + if isinstance(root, BlackBoxStmt): + if root.template_string.startswith('nccl'): + return True + for child in dir(root): + if isinstance(child, Stmt): + if _recursive_find(child): + return True + return False + + ret = _recursive_find(func.body) + return ret + + +class IncludeNCCLPass(Pass): + def process_module(self, ir_module: IRModule) -> IRModule: + use_dist = False + for func in ir_module.functions.values(): + if _use_distributed(func): + use_dist = True + break + + if not use_dist: + return ir_module + + from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs + from hidet.cuda.nccl import nccl_available, nccl_library_filename + + if not nccl_available(): + raise RuntimeError("NCCL is not available") + + new_module = ir_module.copy() + new_module.include_dirs.extend(get_nccl_include_dirs()) + new_module.linking_dirs.extend(get_nccl_library_search_dirs()) + new_module.include_headers.append(["nccl.h"]) + new_module.linking_libs.append(":" + nccl_library_filename()) + return new_module + + +def include_nccl_pass(): + return IncludeNCCLPass() From 7dd55c450303f3dcd2d926049116e8252a98ee79 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 21 Jun 2023 20:22:07 -0400 Subject: [PATCH 11/40] update --- examples/distributed/test.py | 5 +++++ python/hidet/graph/graph_utils/functors.py | 2 +- python/hidet/runtime/compiled_graph.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 37c45686c..36e8fa695 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -64,6 +64,11 @@ def run(world_size, rank, shared_id, barrier): opt_graph.set_dist_attrs(nrank=world_size, rank=rank, groups=groups) compiled = opt_graph.build() + # test save and load + compiled_dir = f"./outs/graph_{rank}.zip" + compiled.save(compiled_dir) + compiled = hidet.runtime.load_compiled_graph(compiled_dir) + # Create Distributed Graph compiled.init_dist(shared_id) diff --git a/python/hidet/graph/graph_utils/functors.py b/python/hidet/graph/graph_utils/functors.py index ce14a7d6a..e23be6a12 100644 --- a/python/hidet/graph/graph_utils/functors.py +++ b/python/hidet/graph/graph_utils/functors.py @@ -119,7 +119,7 @@ def visit_Sequence(self, seq: Union[list, tuple]): class GraphCloneRewriter(GraphRewriter): def visit_FlowGraph(self, graph: FlowGraph): outputs = [self.visit(output) for output in graph.outputs] - return FlowGraph(outputs, graph.inputs, nrank=graph.nrank, rank=graph.rank, groups=graph.groups) + return FlowGraph(outputs, graph.inputs) def visit_Operator(self, op: Operator): inputs = [self(x) for x in op.inputs] diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index 2b26b61e4..0a38cf854 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -425,7 +425,7 @@ def load_compiled_graph(path: str) -> CompiledGraph: graph_execution: GraphExecution = from_dict(GraphExecution, json.load(f)) # load dist info - if zipfile.Path(zf, 'dist_info.json').exists(): + if 'dist_info.json' in zf.namelist(): with zf.open('dist_info.json', 'r') as f: dist_info: GraphDistributedInfo = from_dict(GraphDistributedInfo, json.load(f)) From 0c57cff10f9f30006a55d1adb36c8dd7a9dd5eec Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 21 Jun 2023 20:27:09 -0400 Subject: [PATCH 12/40] relaunch test --- examples/distributed/test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 36e8fa695..b3d12d3ba 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -20,7 +20,6 @@ import hidet.cuda.nccl from hidet.cuda import nccl from hidet.cuda.nccl import NcclUniqueId -from hidet.runtime.compiled_graph import GraphDistributedInfo print("NCCL version:", nccl.nccl_version()) From 047ea87f79b88d08b4fbb4c5e60165f1830ca68a Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 21 Jun 2023 20:55:21 -0400 Subject: [PATCH 13/40] update --- python/hidet/cuda/nccl/ffi.py | 93 +++++++++++++++++------------------ 1 file changed, 46 insertions(+), 47 deletions(-) diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index cb279b82a..45fc585ae 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -58,55 +58,54 @@ def nccl_library_filename(): return os.path.basename(nccl_library_path) -class NCCLRuntimeAPI: - """ - Runtime APIs regarding NCCL - TODO: Exception handling - """ - - _get_version = get_func('ncclGetVersion', [c_void_p], c_int, lib=_LIB_NCCL) - _get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int, lib=_LIB_NCCL) - _comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int, lib=_LIB_NCCL) - _comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int, lib=_LIB_NCCL) - - _comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) - _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) - - _comm_split = get_func('ncclCommSplit', [c_void_p, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL) - - @staticmethod - def get_version() -> int: - version = c_int(0) - NCCLRuntimeAPI._get_version(pointer(version)) - return version.value - - @staticmethod - def get_unique_id(comm_id: NcclUniqueId) -> None: +if nccl_available(): + class NCCLRuntimeAPI: """ - In-place initialization of the NcclUniqueId object + Runtime APIs regarding NCCL + TODO: Exception handling """ - ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id)) - assert ret == 0, ret - - @staticmethod - def comm_init_rank(ndev: int, comm_id: NcclUniqueId, rank: int) -> int: - comm = c_void_p() - ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank) - assert ret == 0, ret - return comm.value - @staticmethod - def comm_destroy(comm_handle) -> None: - ret = NCCLRuntimeAPI._comm_destroy(comm_handle) - assert ret == 0 + _get_version = get_func('ncclGetVersion', [c_void_p], c_int, lib=_LIB_NCCL) + _get_unique_id = get_func('ncclGetUniqueId', [c_void_p], c_int, lib=_LIB_NCCL) + _comm_init_rank = get_func('ncclCommInitRank', [c_void_p, c_int, NcclUniqueId, c_int], c_int, lib=_LIB_NCCL) + _comm_destroy = get_func('ncclCommDestroy', [c_void_p], c_int, lib=_LIB_NCCL) + + _comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) + _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) + + _comm_split = get_func('ncclCommSplit', [c_void_p, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL) + + @staticmethod + def get_version() -> int: + version = c_int(0) + NCCLRuntimeAPI._get_version(pointer(version)) + return version.value + + @staticmethod + def get_unique_id(comm_id: NcclUniqueId) -> None: + """ + In-place initialization of the NcclUniqueId object + """ + ret = NCCLRuntimeAPI._get_unique_id(pointer(comm_id)) + assert ret == 0, ret + + @staticmethod + def comm_init_rank(ndev: int, comm_id: NcclUniqueId, rank: int) -> int: + comm = c_void_p() + ret = NCCLRuntimeAPI._comm_init_rank(pointer(comm), ndev, comm_id, rank) + assert ret == 0, ret + return comm.value + + @staticmethod + def comm_destroy(comm_handle) -> None: + ret = NCCLRuntimeAPI._comm_destroy(comm_handle) + assert ret == 0 + + @staticmethod + def comm_split(comm_handle: int, color: int, key: int) -> int: + comm = c_void_p() + ret = NCCLRuntimeAPI._comm_split(comm_handle, color, key, pointer(comm), None) + assert ret == 0 + return comm.value - @staticmethod - def comm_split(comm_handle: int, color: int, key: int) -> int: - comm = c_void_p() - ret = NCCLRuntimeAPI._comm_split(comm_handle, color, key, pointer(comm), None) - assert ret == 0 - return comm.value - - -if nccl_available(): nccl_runtime_api = NCCLRuntimeAPI() From dba85a51eb688d272f89d5fbb9732a178a7bd457 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 21 Jun 2023 21:09:45 -0400 Subject: [PATCH 14/40] fix --- python/hidet/cuda/nccl/ffi.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index 45fc585ae..62a2d2a98 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -73,7 +73,11 @@ class NCCLRuntimeAPI: _comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) - _comm_split = get_func('ncclCommSplit', [c_void_p, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL) + # Early versions of NCCL do not have split + try: + _comm_split = get_func('ncclCommSplit', [c_void_p, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL) + except ValueError: + _comm_split = None @staticmethod def get_version() -> int: @@ -103,6 +107,8 @@ def comm_destroy(comm_handle) -> None: @staticmethod def comm_split(comm_handle: int, color: int, key: int) -> int: + if NCCLRuntimeAPI._comm_split is None: + raise RuntimeError("split is not supported on this version of NCCL. Please install a newer version.") comm = c_void_p() ret = NCCLRuntimeAPI._comm_split(comm_handle, color, key, pointer(comm), None) assert ret == 0 From 2c6e5b13b860b8e15fd41847e23615d26c544042 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 21 Jun 2023 21:11:05 -0400 Subject: [PATCH 15/40] format --- python/hidet/cuda/nccl/ffi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index 62a2d2a98..13c941702 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -59,6 +59,7 @@ def nccl_library_filename(): if nccl_available(): + class NCCLRuntimeAPI: """ Runtime APIs regarding NCCL From 5d51ed4533c05ce33dcef452aa78415fea19f449 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 21 Jun 2023 22:30:08 -0400 Subject: [PATCH 16/40] fix --- python/hidet/runtime/compiled_graph.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index 0a38cf854..c11dc6500 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -428,6 +428,8 @@ def load_compiled_graph(path: str) -> CompiledGraph: if 'dist_info.json' in zf.namelist(): with zf.open('dist_info.json', 'r') as f: dist_info: GraphDistributedInfo = from_dict(GraphDistributedInfo, json.load(f)) + else: + dist_info = None # load weights as numpy arrays with zf.open('weights.npz', 'r') as f: From f4bf8653b9a15cf188e6b23a634f12e3eea6883b Mon Sep 17 00:00:00 2001 From: Qidong Su <soodoshll@gmail.com> Date: Thu, 22 Jun 2023 15:45:09 -0400 Subject: [PATCH 17/40] [Document] fix installation guide (#288) Merely assigning environment variables is insufficient for setting up dev environment now. We need to run pip to install hidet package in develop mode. Users still need to build source files written in C++ manually. Consider integrating that into `setup.py` in the future? --- .../source/getting-started/build-from-source.rst | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/docs/source/getting-started/build-from-source.rst b/docs/source/getting-started/build-from-source.rst index 147729248..8f5a6529c 100644 --- a/docs/source/getting-started/build-from-source.rst +++ b/docs/source/getting-started/build-from-source.rst @@ -2,7 +2,7 @@ Build from source ------------------- .. _Build-from-source: -If you want to contribute to Hidet, or you encountered any problem installing hidet via pip, it is better to install +If you want to contribute to Hidet, or you encountered any problem directly installing hidet via pip, it is better to install hidet from source. Clone the code @@ -32,21 +32,15 @@ shared library: After building, you could find two libraries ``libhidet.so`` and ``libhidet_runtime.so`` under ``build/lib`` directory. -Update environment variables +Install the Hidet Python package ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To allow Python interpreter to find hidet package under ``python`` directory of the repository, we should append the -directory to ``PYTHONPATH`` variable. To allow the system find the shared libraries we built in the previous step, -we should append ``build/lib`` directory to ``LD_LIBRARY_PATH`` variable. +Next we will install the Python package of Hidet in the develop mode via pip: .. code-block:: console - $ export HIDET_HOME=<The Path to Hidet Repo> - $ export PYTHONPATH=$PYTHONPATH:$HIDET_HOME/python - $ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HIDET_HOME/build/lib - -To avoid repeating above commands, it is recommended to put above commands to your shell's initialization script -(e.g., ``~/.bashrc`` for Bash and ``~/.zshrc`` for Zsh). + $ cd .. # return to the root directory of Hidet + $ pip install -e . Validation ~~~~~~~~~~ From 64b9f0356835bc400225a03aea5bae89815ccab6 Mon Sep 17 00:00:00 2001 From: Hanjie <50634613+hjjq@users.noreply.github.com> Date: Thu, 22 Jun 2023 15:45:30 -0400 Subject: [PATCH 18/40] [Runtime] Check for input tensor device (#287) --- python/hidet/runtime/compiled_task.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/hidet/runtime/compiled_task.py b/python/hidet/runtime/compiled_task.py index fbb009996..eda490f19 100644 --- a/python/hidet/runtime/compiled_task.py +++ b/python/hidet/runtime/compiled_task.py @@ -190,6 +190,10 @@ def _check_inputs(traced_inputs: Iterable[TensorSignature], inputs): symbol_map = {} for i, (traced, new) in enumerate(zip(traced_inputs, inputs)): + if traced.device.partition(':')[0] != new.device.kind: + raise RuntimeError( + f"device mismatch at arg {i} between original: {traced.device} and new: {new.device.kind}" + ) if ir.data_type(traced.dtype) != new.dtype: raise RuntimeError(f"dtype mismatch at arg {i} between original: {traced.dtype} and new: {new.dtype}") traced_shape = traced.shape From 57ae2a949f4cec9ff9744ac86b1b1f48977be3f4 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Thu, 22 Jun 2023 23:00:05 -0400 Subject: [PATCH 19/40] fix --- examples/distributed/test.py | 5 ++- python/hidet/cuda/nccl/__init__.py | 2 +- python/hidet/cuda/nccl/comm.py | 7 +-- python/hidet/drivers/build_graph.py | 15 +------ python/hidet/graph/flow_graph.py | 26 +++++------ python/hidet/graph/ops/distributed.py | 4 +- python/hidet/runtime/compiled_graph.py | 30 ++----------- python/hidet/transforms/__init__.py | 4 +- python/hidet/transforms/include_nccl.py | 59 ------------------------- 9 files changed, 27 insertions(+), 125 deletions(-) delete mode 100644 python/hidet/transforms/include_nccl.py diff --git a/examples/distributed/test.py b/examples/distributed/test.py index b3d12d3ba..3104c4a6e 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -34,7 +34,8 @@ def run(world_size, rank, shared_id, barrier): # Initialize unique id if rank == 0: - nccl.init_unique_id(shared_id) + _id = nccl.create_unique_id() + shared_id.internal = _id.internal barrier.wait() hidet.cuda.set_device(rank) @@ -60,7 +61,7 @@ def run(world_size, rank, shared_id, barrier): y_sync = hidet.ops.all_reduce(y_local, args.reduce_op, comm_id=int(use_group)) graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb]) opt_graph = hidet.graph.optimize(graph) - opt_graph.set_dist_attrs(nrank=world_size, rank=rank, groups=groups) + opt_graph.set_attrs(nrank=world_size, rank=rank, groups=groups) compiled = opt_graph.build() # test save and load diff --git a/python/hidet/cuda/nccl/__init__.py b/python/hidet/cuda/nccl/__init__.py index 4c7a0090a..ec8d1dfed 100644 --- a/python/hidet/cuda/nccl/__init__.py +++ b/python/hidet/cuda/nccl/__init__.py @@ -16,7 +16,7 @@ NcclDataType, NcclRedOp, comms_to_array, - init_unique_id, + create_unique_id, dtype_to_nccl, NcclCommunicator, str_to_nccl_op, diff --git a/python/hidet/cuda/nccl/comm.py b/python/hidet/cuda/nccl/comm.py index 48e232787..0188c6373 100644 --- a/python/hidet/cuda/nccl/comm.py +++ b/python/hidet/cuda/nccl/comm.py @@ -93,11 +93,12 @@ def comms_to_array(comms: List[NcclCommunicator]) -> Array: return array -def init_unique_id(unqie_id: NcclUniqueId) -> None: +def create_unique_id() -> NcclUniqueId: if not nccl_available(): raise RuntimeError("NCCL is not available") - nccl_runtime_api.get_unique_id(unqie_id) - + unique_id = NcclUniqueId() + nccl_runtime_api.get_unique_id(unique_id) + return unique_id def dtype_to_nccl(dtype: DataType) -> NcclDataType: sname_dict = { diff --git a/python/hidet/drivers/build_graph.py b/python/hidet/drivers/build_graph.py index 5d83f5acf..673704368 100644 --- a/python/hidet/drivers/build_graph.py +++ b/python/hidet/drivers/build_graph.py @@ -27,7 +27,6 @@ GraphMetaData, GraphExecution, GraphExecutionInstruction, - GraphDistributedInfo, ) from hidet.runtime.compiled_task import CompiledTask, TensorSignature from hidet.graph.operator import Operator @@ -144,16 +143,10 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD graph_hash = sha256('\n'.join(lines).encode('utf-8')).hexdigest()[:16] return GraphMetaData( - inputs=inputs, outputs=outputs, hidet_version=hidet.__version__, num_kernels=num_kernels, graph_hash=graph_hash + inputs=inputs, outputs=outputs, hidet_version=hidet.__version__, num_kernels=num_kernels, graph_hash=graph_hash, + attrs=asdict(graph.attrs) ) - -def get_graph_dist_info(graph: FlowGraph) -> GraphDistributedInfo: - if not graph.is_distributed(): - return None - return GraphDistributedInfo(nrank=graph.nrank, rank=graph.rank, groups=graph.groups) - - def build_graph_module(graph: FlowGraph, graph_weights: List[Tensor], node2kernel: List[int]) -> CompiledModule: from hidet.lang import void_p, attrs, int32, int64, meta, cast from hidet.ir.primitives.runtime import memory_planner_init, memory_planner_allocate, memory_planner_free @@ -341,9 +334,6 @@ def build_flow_graph(graph, *, space=0) -> CompiledGraph: # get the graph meta data graph_meta_data = get_graph_meta_data(graph, len(graph_kernels), space) - # get distributed information - graph_dist_info = get_graph_dist_info(graph) - # build the compiled graph compiled_graph = CompiledGraph( meta=graph_meta_data, @@ -352,7 +342,6 @@ def build_flow_graph(graph, *, space=0) -> CompiledGraph: compiled_tasks=graph_kernels, graph_execution=graph_execution, graph_string=str(graph), - dist_info=graph_dist_info, ) # save the compiled graph to cache diff --git a/python/hidet/graph/flow_graph.py b/python/hidet/graph/flow_graph.py index 884a369b4..28b589403 100644 --- a/python/hidet/graph/flow_graph.py +++ b/python/hidet/graph/flow_graph.py @@ -16,6 +16,7 @@ import os import pickle from collections import defaultdict +from dataclasses import dataclass, field import hidet.graph.operator import hidet.cuda @@ -103,6 +104,11 @@ def benchmark(self, output_dir='./outs/benchmark', print_summary: bool = False, def forward_context() -> GraphForwardContext: return GraphForwardContext() +@dataclass +class FlowGraphAttrs: + nrank: int = 0 + rank: int = 0 + groups: List[List[int]] = field(default_factory=list) class FlowGraph: """The computation graph representation.""" @@ -112,9 +118,7 @@ def __init__( outputs: Sequence[Tensor], inputs: Optional[Sequence[Tensor]] = None, nodes=None, - nrank=None, - rank=None, - groups=None, + attrs: Optional[FlowGraphAttrs] = None ): self.outputs: List[Tensor] = list(outputs) self.inputs: Optional[List[Tensor]] = list(inputs) if inputs is not None else None @@ -122,10 +126,7 @@ def __init__( self._usage_count: Optional[Dict[Tensor, int]] = None self.update_nodes() - # For distributed graphs - self.nrank = nrank - self.rank = rank - self.groups = groups + self.attrs: FlowGraphAttrs = attrs if attrs else FlowGraphAttrs() def __call__(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]: """ @@ -192,13 +193,8 @@ def _build_nodes(self): hidet.option.parallel_build(False) hidet.drivers.build_task_batch(tunable_tasks) # build tunable tasks one by one - def is_distributed(self): - return self.nrank is not None or self.rank is not None - - def set_dist_attrs(self, nrank: int, rank: int, groups: Optional[List[List[int]]] = None): - self.nrank = nrank - self.rank = rank - self.groups = groups + def set_attrs(self, *args, **kwargs): + self.attrs = FlowGraphAttrs(*args, **kwargs) def forward(self, inputs: List[Tensor]) -> List[Tensor]: """Run the computation graph. @@ -214,8 +210,6 @@ def forward(self, inputs: List[Tensor]) -> List[Tensor]: output: List[Tensor] The output tensors of the computation graph. """ - if self.is_distributed(): - raise RuntimeError("Running Distributed FlowGraph is not supported. Please compile it first.") from hidet.ffi import runtime_api diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py index 41a93a8b9..a72e7a3a1 100644 --- a/python/hidet/graph/ops/distributed.py +++ b/python/hidet/graph/ops/distributed.py @@ -26,10 +26,10 @@ def __init__(self, x: TensorNode, op: str, comm_id: int = 0): self.comm_id = comm_id self.op = op - super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={}) + super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={'comm_id': comm_id, 'op': op}) def __str__(self): - return f"all_reduce_{self.op}_{self.comm_id}" + return f"all_reduce" def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]: # we may need current rank here to avoid duplicated working_dirs diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index c11dc6500..161502556 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Dict, Any, Callable +from typing import List, Optional, Tuple, Dict, Any, Callable, Union import zipfile import os import json @@ -30,7 +30,6 @@ from hidet.utils import prod from hidet.cuda.nccl import NcclCommunicator, NcclUniqueId, create_comm, NCCL_SPLIT_NOCOLOR, comms_to_array - ModelExecutionHook = Callable[[int, List['Tensor'], List['Tensor']], None] @@ -46,6 +45,7 @@ class GraphMetaData: hidet_version: str num_kernels: int graph_hash: str + attrs: Dict[str, Union[int, List[List[int]]]] @dataclass @@ -64,14 +64,6 @@ class GraphExecution: outputs_index: List[int] tensor_device: List[str] - -@dataclass -class GraphDistributedInfo: - nrank: int - rank: int - groups: List[List[int]] - - class CompiledGraph: def __init__( self, @@ -81,7 +73,6 @@ def __init__( compiled_tasks: List[CompiledTask], graph_execution: GraphExecution, graph_string: str, - dist_info: Optional[GraphDistributedInfo] = None, ): from hidet.graph.tensor import Tensor @@ -113,9 +104,6 @@ def __init__( self.dispatch_table: Dict[Tuple[int, ...], Array] = {} self.cuda_workspace: Optional[Storage] = None self.cpu_workspace: Optional[Storage] = None - - # distributed properties - self.dist_info: Optional[GraphDistributedInfo] = dist_info self.nccl_comms: List[NcclCommunicator] = [] self._init_compiled_graph() @@ -402,11 +390,6 @@ def _save_under(dir_path: str, dir_in_zip: str, exclude: Optional[List[str]] = N with zf.open('graph_string.txt', 'w') as f: f.write(model.graph_string.encode('utf-8')) - # save distibuted information - if model.dist_info is not None: - with zf.open('dist_info.json', 'w') as f: - dist_info_bytes = json.dumps(asdict(model.dist_info), indent=4).encode('utf-8') - f.write(dist_info_bytes) def load_compiled_graph(path: str) -> CompiledGraph: @@ -424,13 +407,6 @@ def load_compiled_graph(path: str) -> CompiledGraph: with zf.open('graph_execution.json', 'r') as f: graph_execution: GraphExecution = from_dict(GraphExecution, json.load(f)) - # load dist info - if 'dist_info.json' in zf.namelist(): - with zf.open('dist_info.json', 'r') as f: - dist_info: GraphDistributedInfo = from_dict(GraphDistributedInfo, json.load(f)) - else: - dist_info = None - # load weights as numpy arrays with zf.open('weights.npz', 'r') as f: with zipfile.ZipFile(f, 'r') as npz: @@ -463,7 +439,7 @@ def load_compiled_graph(path: str) -> CompiledGraph: # construct the compiled graph ret = CompiledGraph( - meta_data, graph_module, weights, compiled_tasks, graph_execution, graph_string, dist_info=dist_info + meta_data, graph_module, weights, compiled_tasks, graph_execution, graph_string ) return ret diff --git a/python/hidet/transforms/__init__.py b/python/hidet/transforms/__init__.py index 5d40bab05..0662ce207 100644 --- a/python/hidet/transforms/__init__.py +++ b/python/hidet/transforms/__init__.py @@ -36,7 +36,7 @@ from .propagate_launch_bound import propagate_launch_bound_pass from .check_launch_configuration import check_launch_configuration_pass from .lower_special_cast import lower_special_cast_pass -from .include_nccl import include_nccl_pass +from .annotate_include_headers import annotate_include_headers_pass def lower_with(ir_module: IRModule, transforms: Sequence[Pass]) -> IRModule: @@ -81,6 +81,6 @@ def lower(ir_module: IRModule) -> IRModule: rule_based_simplify_pass(), inline_let_stmt_pass(), simplify_stmt_pass(), - include_nccl_pass(), + annotate_include_headers_pass(), ] return lower_with(ir_module, transforms) diff --git a/python/hidet/transforms/include_nccl.py b/python/hidet/transforms/include_nccl.py deleted file mode 100644 index b0dcb3640..000000000 --- a/python/hidet/transforms/include_nccl.py +++ /dev/null @@ -1,59 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from hidet.ir.module import IRModule -from hidet.ir import Stmt -from hidet.ir.stmt import BlackBoxStmt -from hidet.transforms import Pass - - -def _use_distributed(func) -> bool: - def _recursive_find(root: Stmt): - if isinstance(root, BlackBoxStmt): - if root.template_string.startswith('nccl'): - return True - for child in dir(root): - if isinstance(child, Stmt): - if _recursive_find(child): - return True - return False - - ret = _recursive_find(func.body) - return ret - - -class IncludeNCCLPass(Pass): - def process_module(self, ir_module: IRModule) -> IRModule: - use_dist = False - for func in ir_module.functions.values(): - if _use_distributed(func): - use_dist = True - break - - if not use_dist: - return ir_module - - from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs - from hidet.cuda.nccl import nccl_available, nccl_library_filename - - if not nccl_available(): - raise RuntimeError("NCCL is not available") - - new_module = ir_module.copy() - new_module.include_dirs.extend(get_nccl_include_dirs()) - new_module.linking_dirs.extend(get_nccl_library_search_dirs()) - new_module.include_headers.append(["nccl.h"]) - new_module.linking_libs.append(":" + nccl_library_filename()) - return new_module - - -def include_nccl_pass(): - return IncludeNCCLPass() From a3d0a71d4873cbed227f924b897d13d89d6d8c41 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Thu, 22 Jun 2023 23:00:23 -0400 Subject: [PATCH 20/40] fix --- .../transforms/annotate_include_headers.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 python/hidet/transforms/annotate_include_headers.py diff --git a/python/hidet/transforms/annotate_include_headers.py b/python/hidet/transforms/annotate_include_headers.py new file mode 100644 index 000000000..20b68b9ee --- /dev/null +++ b/python/hidet/transforms/annotate_include_headers.py @@ -0,0 +1,43 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hidet.ir +from hidet.ir.module import IRModule +from hidet.ir import Stmt +from hidet.ir.stmt import BlackBoxStmt +from hidet.transforms import Pass + +def _use_distributed(func) -> bool: + black_stmts = hidet.ir.tools.collect(func.body, [BlackBoxStmt]) + return any(stmt.template_string.startswith('nccl') for stmt in black_stmts) + +class AnnotateIncludeHeadersPass(Pass): + def process_module(self, ir_module: IRModule) -> IRModule: + use_dist = any(_use_distributed(func) for func in ir_module.functions.values()) + if not use_dist: + return ir_module + + from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs + from hidet.cuda.nccl import nccl_available, nccl_library_filename + + if not nccl_available(): + raise RuntimeError("NCCL is not available") + + new_module = ir_module.copy() + new_module.include_dirs.extend(get_nccl_include_dirs()) + new_module.linking_dirs.extend(get_nccl_library_search_dirs()) + new_module.include_headers.append(["nccl.h"]) + new_module.linking_libs.append(":" + nccl_library_filename()) + return new_module + + +def annotate_include_headers_pass(): + return AnnotateIncludeHeadersPass() From 2ffcfe39e85fb7c3c11b821231229fab9ed50c8f Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Thu, 22 Jun 2023 23:14:08 -0400 Subject: [PATCH 21/40] fix --- python/hidet/runtime/compiled_graph.py | 4 +- python/hidet/transforms/__init__.py | 4 +- .../transforms/annotate_include_headers.py | 43 ------------------- 3 files changed, 4 insertions(+), 47 deletions(-) delete mode 100644 python/hidet/transforms/annotate_include_headers.py diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index 161502556..9ad761310 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -176,9 +176,9 @@ def init_dist(self, unique_id: NcclUniqueId): self.nccl_comms = [] # Initialize the default group - nrank = self.dist_info.nrank + nranks = self.dist_info.nrank rank = self.dist_info.rank - default_comm = create_comm(nrank, unique_id, rank) + default_comm = create_comm(nranks, unique_id, rank) self.nccl_comms.append(default_comm) # Create communicators according to groups diff --git a/python/hidet/transforms/__init__.py b/python/hidet/transforms/__init__.py index 0662ce207..6d21b0f6f 100644 --- a/python/hidet/transforms/__init__.py +++ b/python/hidet/transforms/__init__.py @@ -36,7 +36,7 @@ from .propagate_launch_bound import propagate_launch_bound_pass from .check_launch_configuration import check_launch_configuration_pass from .lower_special_cast import lower_special_cast_pass -from .annotate_include_headers import annotate_include_headers_pass +from .annotate_header_and_libs import annotate_header_and_libs_pass def lower_with(ir_module: IRModule, transforms: Sequence[Pass]) -> IRModule: @@ -81,6 +81,6 @@ def lower(ir_module: IRModule) -> IRModule: rule_based_simplify_pass(), inline_let_stmt_pass(), simplify_stmt_pass(), - annotate_include_headers_pass(), + annotate_header_and_libs_pass(), ] return lower_with(ir_module, transforms) diff --git a/python/hidet/transforms/annotate_include_headers.py b/python/hidet/transforms/annotate_include_headers.py deleted file mode 100644 index 20b68b9ee..000000000 --- a/python/hidet/transforms/annotate_include_headers.py +++ /dev/null @@ -1,43 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import hidet.ir -from hidet.ir.module import IRModule -from hidet.ir import Stmt -from hidet.ir.stmt import BlackBoxStmt -from hidet.transforms import Pass - -def _use_distributed(func) -> bool: - black_stmts = hidet.ir.tools.collect(func.body, [BlackBoxStmt]) - return any(stmt.template_string.startswith('nccl') for stmt in black_stmts) - -class AnnotateIncludeHeadersPass(Pass): - def process_module(self, ir_module: IRModule) -> IRModule: - use_dist = any(_use_distributed(func) for func in ir_module.functions.values()) - if not use_dist: - return ir_module - - from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs - from hidet.cuda.nccl import nccl_available, nccl_library_filename - - if not nccl_available(): - raise RuntimeError("NCCL is not available") - - new_module = ir_module.copy() - new_module.include_dirs.extend(get_nccl_include_dirs()) - new_module.linking_dirs.extend(get_nccl_library_search_dirs()) - new_module.include_headers.append(["nccl.h"]) - new_module.linking_libs.append(":" + nccl_library_filename()) - return new_module - - -def annotate_include_headers_pass(): - return AnnotateIncludeHeadersPass() From ee6024904c72315f91535f7bc96e37229958fff5 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Fri, 23 Jun 2023 01:49:44 -0400 Subject: [PATCH 22/40] update --- .../transforms/annotate_header_and_libs.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 python/hidet/transforms/annotate_header_and_libs.py diff --git a/python/hidet/transforms/annotate_header_and_libs.py b/python/hidet/transforms/annotate_header_and_libs.py new file mode 100644 index 000000000..20b68b9ee --- /dev/null +++ b/python/hidet/transforms/annotate_header_and_libs.py @@ -0,0 +1,43 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hidet.ir +from hidet.ir.module import IRModule +from hidet.ir import Stmt +from hidet.ir.stmt import BlackBoxStmt +from hidet.transforms import Pass + +def _use_distributed(func) -> bool: + black_stmts = hidet.ir.tools.collect(func.body, [BlackBoxStmt]) + return any(stmt.template_string.startswith('nccl') for stmt in black_stmts) + +class AnnotateIncludeHeadersPass(Pass): + def process_module(self, ir_module: IRModule) -> IRModule: + use_dist = any(_use_distributed(func) for func in ir_module.functions.values()) + if not use_dist: + return ir_module + + from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs + from hidet.cuda.nccl import nccl_available, nccl_library_filename + + if not nccl_available(): + raise RuntimeError("NCCL is not available") + + new_module = ir_module.copy() + new_module.include_dirs.extend(get_nccl_include_dirs()) + new_module.linking_dirs.extend(get_nccl_library_search_dirs()) + new_module.include_headers.append(["nccl.h"]) + new_module.linking_libs.append(":" + nccl_library_filename()) + return new_module + + +def annotate_include_headers_pass(): + return AnnotateIncludeHeadersPass() From f3aad899404ea58f5bfff85981fc08ab968df926 Mon Sep 17 00:00:00 2001 From: Hanjie <50634613+hjjq@users.noreply.github.com> Date: Sun, 25 Jun 2023 22:33:51 -0400 Subject: [PATCH 23/40] [FixBug] Don't instantiate symbol for primitive functions (#291) Previously, if a primitive function calls a primitive function, the `instantiate_symbols` pass will update the corresponding `hidet.ir.primitives.func.PrimitiveFunctionRegistry.function` in-place (I am not sure exactly how it's done, but this is what I observed), adding symbol variables to its parameters. The primitive function pool is a global variable, therefore this effect is cumulative across tuning candidates. So while candidate 0 will have no problem, candidate 1 will have two extra copies of symbol params, and so on, leading to compile errors. Since primitive functions do not need symbol vars, a quick fix is just to not instantiate any symbols for them. --- python/hidet/transforms/instantiate_symbols.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/hidet/transforms/instantiate_symbols.py b/python/hidet/transforms/instantiate_symbols.py index ea8e7e9a3..fed71728f 100644 --- a/python/hidet/transforms/instantiate_symbols.py +++ b/python/hidet/transforms/instantiate_symbols.py @@ -15,6 +15,7 @@ from hidet.ir.func import Function from hidet.ir.module import IRModule from hidet.ir.functors import IRRewriter +from hidet.ir.primitives import is_primitive_function from hidet.ir.primitives.runtime import get_symbol_value from hidet.ir.stmt import LetStmt, LaunchKernelStmt from hidet.ir.tools import collect @@ -59,6 +60,9 @@ def visit_Function(self, func: Function): else: assert False + if is_primitive_function(func.name): + return func + ordered_symbols: List[SymbolVar] = list(symbols) symbol_params: List[Var] = [Var(symbol.name, symbol.type) for symbol in ordered_symbols] self.func_symbols[func.name] = FuncSymbols( From 64a632aff4d71f3ff7a84c40f1300e7a1e78cf3d Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 01:27:17 -0400 Subject: [PATCH 24/40] file store --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index 563176e72..5af4f488c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,6 @@ packaging # for cuda runtime api and runtime compilation api cuda-python + +# for filestore +filelock From c028827c4092fd8a38f11d4c9e22ca48b91fe51c Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 11:53:35 -0400 Subject: [PATCH 25/40] file store --- python/hidet/distributed/__init__.py | 11 ++ python/hidet/distributed/distributed.py | 67 +++++++++ python/hidet/distributed/group.py | 46 ++++++ python/hidet/distributed/store.py | 188 ++++++++++++++++++++++++ 4 files changed, 312 insertions(+) create mode 100644 python/hidet/distributed/__init__.py create mode 100644 python/hidet/distributed/distributed.py create mode 100644 python/hidet/distributed/group.py create mode 100644 python/hidet/distributed/store.py diff --git a/python/hidet/distributed/__init__.py b/python/hidet/distributed/__init__.py new file mode 100644 index 000000000..2808da5c8 --- /dev/null +++ b/python/hidet/distributed/__init__.py @@ -0,0 +1,11 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/python/hidet/distributed/distributed.py b/python/hidet/distributed/distributed.py new file mode 100644 index 000000000..fdfa0543b --- /dev/null +++ b/python/hidet/distributed/distributed.py @@ -0,0 +1,67 @@ +from typing import Optional +from datetime import timedelta +from .store import Store, FileStore + +DEFAULT_TIMEOUT = timedelta(seconds=1800) + +def init_process_group( + backend: str = 'nccl', + init_method: Optional[str] = None, + store: Optional[Store] = None, + timeout: timedelta = DEFAULT_TIMEOUT, + world_size: int = -1, + rank: int = -1, +): + """ + We ues the same api as PyTorch. + Currently we only support FileStore. There are two ways to initialize via FileStore. + 1. Manually create a FileStore object and pass it as ``store``; + 2. Specify ``init_method`` with ``files://path-to-file``` + Now world_size and rank still need to be specified manually. + """ + + if world_size <= 0 or rank < 0: + raise RuntimeError("'world_size' and 'rank' must be specified.") + + if rank >= world_size: + raise RuntimeError("'rank' must be smaller than 'world_size'") + + if store is None: + if init_method is None: + raise RuntimeError("One of 'init_method' and 'store' must be specified.") + else: + if not init_method.startswith('file://'): + raise RuntimeError("Currently only FileStore is supported. Please speficy the path to the filestore with 'file://path-to-file'") + path_to_file = init_method[len('file://'):] + store = FileStore(path_to_file) + else: + if init_method is not None: + raise RuntimeError("'init_method' and 'store' are mutually exclusive.") + + store.set_timeout(timeout) + + + +def is_initialized(): + pass + +def is_nccl_available(): + pass + +def broadcast(): + pass + +def all_reduce(): + pass + +def reduce(): + pass + +def all_gather_into_tensor(): + pass + +def scatter(): + pass + +def reduce_scatter_tensor(): + pass \ No newline at end of file diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py new file mode 100644 index 000000000..57711dc35 --- /dev/null +++ b/python/hidet/distributed/group.py @@ -0,0 +1,46 @@ +import hidet +from hidet import Tensor +from typing import Optional, List + +from .store import Store + +class ProcessGroup: + def backend(self) -> str: + raise NotImplementedError() + + def rank(self) -> int: + raise NotImplementedError() + + def size(self) -> int: + raise NotImplementedError() + + def broadcast(self, tensor: Tensor, src: int): + raise NotImplementedError() + + def all_reduce(self, tensor: Tensor, op: str): + raise NotImplementedError() + + def reduce(self, tensor: Tensor, dst:int, op:str): + raise NotImplementedError() + + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor): + raise NotImplementedError() + + def all_gather_into_tensor(self, output_tensor: Tensor, input_tensor: Tensor): + raise NotImplementedError() + + def gather(self, tensor: Tensor, gather_list: Optional[List[Tensor]]=None, dst: int=0): + raise NotImplementedError() + + def scatter(self, tensor: Tensor, scattler_list: Optional[List[Tensor]]=None): + raise NotImplementedError() + + def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str): + raise NotImplementedError() + + def reduce_scatter_tensor(self, output: Tensor, input: Tensor, op: str): + raise NotImplementedError() + +class NCCLProcessGroup(ProcessGroup): + def __init__(self, store: Store, world_size: int, rank: int): + if rank == 0 \ No newline at end of file diff --git a/python/hidet/distributed/store.py b/python/hidet/distributed/store.py new file mode 100644 index 000000000..c744b04a3 --- /dev/null +++ b/python/hidet/distributed/store.py @@ -0,0 +1,188 @@ +from typing import List, Optional +from datetime import timedelta, datetime +import filelock +import time +import struct +from functools import partial +import os + +class Store: + def set(self, key: str, value: bytes) -> None: + raise NotImplementedError() + + def get(self, key: str) -> bytes: + raise NotImplementedError() + + def add(self, key: str, amount: int) -> int: + raise NotImplementedError() + + def compare_set(self, key: str, expected: bytes, desired: bytes) -> bytes: + raise NotImplementedError() + + def wait(self, keys: List[str], timeout: Optional[timedelta]=None) -> None: + raise NotImplementedError() + + def num_keys(self) -> int: + raise NotImplementedError() + + def delete_key(self, key: str) -> bool: + raise NotImplementedError() + + def set_timeout(self, timeout: timedelta): + raise NotImplementedError() + + def close(self): + raise NotImplementedError() + +class FileStore(Store): + REGULAR_PREFIX = '+' + DELETE_PREFIX = '-' + def __init__(self, filename: str, world_size: Optional[int] = -1): + self._filename = filename + self._lock_filename = filename + '.lock' + self._world_size = world_size + + self._lock = filelock.FileLock(self._lock_filename) + self._cache = {} + self._timeout = None + + num_peers = self._add('cnt', 1) + if num_peers > self._world_size: + raise RuntimeError("Warning: more peers than world size.") + + def _write(self, f, content): + f.write(struct.pack('i', len(content))) + f.write(content) + + def _read(self, f): + len_str = f.read(4) + if len_str == b'': + return + l = struct.unpack('i', len_str)[0] + return f.read(l) + + def _file_size(self, f): + origin_pos = f.tell() + f.seek(0, 2) # 2 means the file's end + size = f.tell() + f.seek(origin_pos, 0) + return size + + def _update(self, f): + self._cache = {} + while True: + k = self._read(f) + if k is None: + return + v = self._read(f) + k = str(k, encoding='raw_unicode_escape') + if k.startswith(self.DELETE_PREFIX): + del self._cache[k] + self._cache[k] = v + + def _add(self, key: str, amount: int) -> int: + key = key + with self._lock: + with open(self._filename, "ab+") as f: + f.seek(0) + self._update(f) + value = int(self._cache.get(key, '0')) + amount + with open(self._filename, "ab+") as f: + self._write(f, bytes(key, encoding='raw_unicode_escape')) + self._write(f, bytes(str(value), encoding='raw_unicode_escape')) + return value + + def _check(self, keys: List[str]): + with self._lock: + with open(self._filename, "ab+") as f: + f.seek(0) + self._update(f) + print(self._cache.keys()) + return all((key in self._cache.keys() for key in keys)) + + def set(self, key: str, value: bytes) -> None: + with self._lock: + with open(self._filename, "ab+") as f: + self._write(f, bytes(self.REGULAR_PREFIX + key, encoding='raw_unicode_escape')) + self._write(f, value) + + def get(self, key: str) -> bytes: + last_file_size = None + key = self.REGULAR_PREFIX + key + start_t = datetime.now() + while True: + self._lock.acquire() + with open(self._filename, "ab+") as f: + f.seek(0) + file_size = self._file_size(f) + if key not in self._cache.keys() and file_size == last_file_size: + # No new entries + last_file_size = file_size + self._lock.release() + if self._timeout is not None and datetime.now() - start_t > self._timeout: + raise TimeoutError() + time.sleep(0.01) + continue + last_file_size = file_size + self._update(f) + self._lock.release() + value = self._cache.get(key) + if value is not None: + return value + + def add(self, key: str, amount: int) -> int: + return self._add(self.REGULAR_PREFIX + key, amount) + + def compare_set(self, key: str, expected: bytes, desired: bytes) -> bytes: + key = self.REGULAR_PREFIX + key + with self._lock: + with open(self._filename, "ab+") as f: + f.seek(0) + self._update() + has_key = key in self._cache.keys() + if (not has_key and expected == b'') or (has_key and self._cache[key] == expected): + f.seek(0, 2) + self._write(f, bytes(key, encoding='raw_unicode_escape')) + self._write(f, desired) + return desired + elif not has_key: + return expected + return self._cache[key] + + def wait(self, keys: List[str], timeout: Optional[timedelta]=None) -> None: + timeout = self._timeout if self._timeout is not None else timeout + start_t = datetime.now() + keys = [self.REGULAR_PREFIX + key for key in keys] + while not self._check(keys): + if timeout is not None and datetime.now() - start_t > timeout: + raise TimeoutError() + time.sleep(0.01) + + def num_keys(self): + with self._lock(): + with open(self._filename, "rb") as f: + self._update(f) + return len(self._cache) + + def delete_key(self, key: str): + self.set(self.DELETE_PREFIX + key, b'') + + def set_timeout(self, timeout: timedelta): + self._timeout = timeout + + def close(self): + rest = self._add('cnt', -1) + if rest == 0: + os.remove(self._filename) + +if __name__ == '__main__': + store = FileStore('tmp') + store.set_timeout(timedelta(seconds=30)) + ret = store.add('baga', 2) + store.set('yarou', b'haha') + store.wait(['baga', 'yarou']) + print(ret) + ret = store.add('baga', 5) + print(ret) + print(store.get('baga')) + store.close() From 56a96cad2b2c0331941dbe755e40590239cf8fbb Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 13:40:49 -0400 Subject: [PATCH 26/40] update --- examples/distributed/test.py | 56 ++----------------- python/hidet/__init__.py | 1 + python/hidet/cuda/nccl/comm.py | 8 ++- python/hidet/cuda/nccl/ffi.py | 17 +++++- python/hidet/distributed/__init__.py | 5 +- python/hidet/distributed/distributed.py | 43 ++++++++++---- python/hidet/distributed/group.py | 50 +++++++++++++++-- python/hidet/distributed/store.py | 23 ++++---- python/hidet/drivers/build_graph.py | 3 +- python/hidet/graph/flow_graph.py | 20 +------ python/hidet/graph/ops/distributed.py | 1 - python/hidet/runtime/compiled_graph.py | 28 ---------- .../transforms/annotate_header_and_libs.py | 6 +- 13 files changed, 129 insertions(+), 132 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 3104c4a6e..a15d53c9e 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -18,75 +18,29 @@ import hidet import hidet.cuda.nccl -from hidet.cuda import nccl -from hidet.cuda.nccl import NcclUniqueId - -print("NCCL version:", nccl.nccl_version()) parser = argparse.ArgumentParser() parser.add_argument("n_gpus", type=int) parser.add_argument("reduce_op", choices=['sum', 'prod', 'max', 'min', 'avg']) -parser.add_argument("--group_size", type=int, default=0) args = parser.parse_args() -def run(world_size, rank, shared_id, barrier): +def run(world_size, rank): numpy.random.seed(rank) - # Initialize unique id - if rank == 0: - _id = nccl.create_unique_id() - shared_id.internal = _id.internal - - barrier.wait() hidet.cuda.set_device(rank) - - use_group = args.group_size > 1 - if use_group: - gs = args.group_size - gn = world_size // gs - assert world_size % gs == 0 - groups = [list(range(i * gs, (i + 1) * gs)) for i in range(gn)] - else: - groups = [] - + hidet.distributed.init_process_group(init_method='file://tmp', world_size=world_size, rank=rank) + hidet.distributed.set_nccl_comms() device = f"cuda:{rank}" x = hidet.randn([1, 3], device=device) w = hidet.randn([3, 2], device=device) - # Create Computation Graph - x_symb = hidet.symbol_like(x) - w_symb = hidet.symbol_like(w) - y_local = hidet.ops.relu(x_symb @ w_symb) - y_sync = hidet.ops.all_reduce(y_local, args.reduce_op, comm_id=int(use_group)) - graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb]) - opt_graph = hidet.graph.optimize(graph) - opt_graph.set_attrs(nrank=world_size, rank=rank, groups=groups) - compiled = opt_graph.build() - - # test save and load - compiled_dir = f"./outs/graph_{rank}.zip" - compiled.save(compiled_dir) - compiled = hidet.runtime.load_compiled_graph(compiled_dir) - - # Create Distributed Graph - compiled.init_dist(shared_id) - - y_local, y_sync = compiled(x, w) - s = hidet.cuda.current_stream() s.synchronize() - print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}\n", end='') + print(f"process {rank}") world_size = args.n_gpus - -# Barrier to ensure unique id is created -barrier = multiprocessing.Barrier(world_size) - -# Create a unique id object in shared memory -shared_id = multiprocessing.Value(NcclUniqueId, lock=False) - -processes = [Process(target=run, args=(world_size, i, shared_id, barrier)) for i in range(world_size)] +processes = [Process(target=run, args=(world_size, i)) for i in range(world_size)] for p in processes: p.start() diff --git a/python/hidet/__init__.py b/python/hidet/__init__.py index 9c3206e25..94efae1f2 100644 --- a/python/hidet/__init__.py +++ b/python/hidet/__init__.py @@ -21,6 +21,7 @@ from . import drivers from . import logging from . import cuda +from . import distributed from .version import __version__ diff --git a/python/hidet/cuda/nccl/comm.py b/python/hidet/cuda/nccl/comm.py index 0188c6373..16879304f 100644 --- a/python/hidet/cuda/nccl/comm.py +++ b/python/hidet/cuda/nccl/comm.py @@ -10,11 +10,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import IntEnum -from typing import List +from typing import List, Optional import struct from hidet.ffi.utils import Array from hidet.ir.type import void_p, DataType +from hidet.cuda import Stream, current_stream from .ffi import nccl_available, NcclUniqueId NCCL_SPLIT_NOCOLOR = -1 @@ -77,6 +78,11 @@ def split(self, key, color): if color == NCCL_SPLIT_NOCOLOR: return None return NcclCommunicator(new_handle) + + def all_reduce(self, sendbuff:int, recvbuff:int, count:int, datatype:DataType, op:str, s:Optional[Stream]=None): + if s is None: + s = current_stream() + nccl_runtime_api.all_reduce(sendbuff, recvbuff, count, dtype_to_nccl(datatype), str_to_nccl_op(op), self._handle, s) def create_comm(nranks: int, unique_id: NcclUniqueId, rank: int) -> NcclCommunicator: diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index 13c941702..0f49617a2 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -12,11 +12,12 @@ from typing import Optional import ctypes -from ctypes import c_void_p, c_int, pointer, Structure, c_byte, POINTER +from ctypes import c_void_p, c_int, pointer, Structure, c_byte, POINTER, c_uint64 import glob import os from hidet.ffi.ffi import get_func +from hidet.cuda import Stream from .libinfo import get_nccl_library_search_dirs _LIB_NCCL: Optional[ctypes.CDLL] = None @@ -74,6 +75,12 @@ class NCCLRuntimeAPI: _comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) + _all_reduce = get_func('ncclAllReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL) + _broadcast = get_func('ncclBroadcast', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL) + _reduce = get_func('ncclReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL) + _all_gather = get_func('ncclAllGather', [c_void_p, c_void_p, c_uint64, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL) + _reduce_scatter = get_func('ncclReduceScatter', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL) + # Early versions of NCCL do not have split try: _comm_split = get_func('ncclCommSplit', [c_void_p, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL) @@ -114,5 +121,13 @@ def comm_split(comm_handle: int, color: int, key: int) -> int: ret = NCCLRuntimeAPI._comm_split(comm_handle, color, key, pointer(comm), None) assert ret == 0 return comm.value + + # TODO: Currently only support all_reduce + @staticmethod + def all_reduce(sendbuff:int, recvbuff:int, count:int, datatype:int, op:int, comm_handle:int, s:Stream) -> None: + ret = NCCLRuntimeAPI._all_reduce( + sendbuff, recvbuff, count, datatype, op, comm_handle, s._handle + ) + assert ret == 0 nccl_runtime_api = NCCLRuntimeAPI() diff --git a/python/hidet/distributed/__init__.py b/python/hidet/distributed/__init__.py index 2808da5c8..1a748b6f3 100644 --- a/python/hidet/distributed/__init__.py +++ b/python/hidet/distributed/__init__.py @@ -8,4 +8,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. + +from .distributed import init_process_group +from .group import set_nccl_comms \ No newline at end of file diff --git a/python/hidet/distributed/distributed.py b/python/hidet/distributed/distributed.py index fdfa0543b..d06b5cf82 100644 --- a/python/hidet/distributed/distributed.py +++ b/python/hidet/distributed/distributed.py @@ -1,9 +1,17 @@ from typing import Optional from datetime import timedelta from .store import Store, FileStore +from .group import create_nccl_group, ProcessGroup + +import hidet +from hidet.graph import Tensor +from hidet.cuda.nccl import nccl_available + DEFAULT_TIMEOUT = timedelta(seconds=1800) +DEFAULT_GROUP = None + def init_process_group( backend: str = 'nccl', init_method: Optional[str] = None, @@ -19,6 +27,7 @@ def init_process_group( 2. Specify ``init_method`` with ``files://path-to-file``` Now world_size and rank still need to be specified manually. """ + global DEFAULT_GROUP if world_size <= 0 or rank < 0: raise RuntimeError("'world_size' and 'rank' must be specified.") @@ -39,29 +48,41 @@ def init_process_group( raise RuntimeError("'init_method' and 'store' are mutually exclusive.") store.set_timeout(timeout) - - + if backend == 'nccl': + if not is_nccl_available(): + raise RuntimeError("NCCL is not found.") + DEFAULT_GROUP = create_nccl_group(store, world_size, rank) def is_initialized(): - pass + return DEFAULT_GROUP is not None def is_nccl_available(): - pass + return nccl_available() def broadcast(): - pass + raise NotImplementedError() -def all_reduce(): - pass +def all_reduce(tensor: Tensor, op:str, group:Optional[ProcessGroup]=None): + if group is None: + group = DEFAULT_GROUP + group.all_reduce(tensor, op) def reduce(): - pass + raise NotImplementedError() def all_gather_into_tensor(): - pass + raise NotImplementedError() def scatter(): - pass + raise NotImplementedError() def reduce_scatter_tensor(): - pass \ No newline at end of file + raise NotImplementedError() + +if __name__ == '__main__': + init_process_group(init_method='file://tmp', world_size=1, rank=0) + print(is_initialized()) + test = hidet.randn((2, 2), device='cuda') + print(test) + all_reduce(test, 'sum') + print(test) \ No newline at end of file diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py index 57711dc35..9bb3194ce 100644 --- a/python/hidet/distributed/group.py +++ b/python/hidet/distributed/group.py @@ -1,8 +1,9 @@ import hidet -from hidet import Tensor +from hidet.graph import Tensor from typing import Optional, List -from .store import Store +from .store import Store, FileStore +from hidet.cuda.nccl import create_unique_id, NcclUniqueId, create_comm, NcclCommunicator, comms_to_array class ProcessGroup: def backend(self) -> str: @@ -41,6 +42,47 @@ def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str): def reduce_scatter_tensor(self, output: Tensor, input: Tensor, op: str): raise NotImplementedError() + def barrier(self): + raise NotImplementedError() + +NCCL_COMMS = [] + class NCCLProcessGroup(ProcessGroup): - def __init__(self, store: Store, world_size: int, rank: int): - if rank == 0 \ No newline at end of file + def __init__(self, comm: NcclCommunicator, world_size: int, rank: int): + global NCCL_COMMS + self._comm: NcclCommunicator = comm + self._world_size: int = world_size + self._rank: int = rank + NCCL_COMMS.append(comm) + + def rank(self) -> int: + return self._rank + + def size(self) -> int: + return self._world_size + + def all_reduce(self, tensor: Tensor, op:str): + assert not tensor.is_symbolic() + assert tensor.device.is_cuda() + addr = tensor.storage.addr + self._comm.all_reduce(addr, addr, tensor.nbytes, tensor.dtype, op) + +def create_nccl_group(store: Store, world_size: int, rank: int): + if rank == 0: + unique_id = create_unique_id() + store.set('unique_id', unique_id.internal) + else: + unique_id = store.get('unique_id') + unique_id = NcclUniqueId(unique_id) + comm = create_comm(world_size, unique_id, rank) + return NCCLProcessGroup(comm, world_size, rank) + +def set_nccl_comms(): + from hidet.ffi.runtime_api import runtime_api + comm_array = comms_to_array(NCCL_COMMS) + runtime_api.set_nccl_comms(comm_array) + + +if __name__ == '__main__': + store = FileStore('tmp') + group = create_nccl_group(store, 1, 0) \ No newline at end of file diff --git a/python/hidet/distributed/store.py b/python/hidet/distributed/store.py index c744b04a3..3db423ad5 100644 --- a/python/hidet/distributed/store.py +++ b/python/hidet/distributed/store.py @@ -5,6 +5,7 @@ import struct from functools import partial import os +import atexit class Store: def set(self, key: str, value: bytes) -> None: @@ -31,9 +32,6 @@ def delete_key(self, key: str) -> bool: def set_timeout(self, timeout: timedelta): raise NotImplementedError() - def close(self): - raise NotImplementedError() - class FileStore(Store): REGULAR_PREFIX = '+' DELETE_PREFIX = '-' @@ -47,8 +45,19 @@ def __init__(self, filename: str, world_size: Optional[int] = -1): self._timeout = None num_peers = self._add('cnt', 1) - if num_peers > self._world_size: + if world_size >= 0 and num_peers > world_size: raise RuntimeError("Warning: more peers than world size.") + + # We cannot operate files in __del__, and we don't want to call close explicitly + # So we register a atexit function doing cleanup when python interpreter exits + @atexit.register + def cleanup(): + with self._lock: + if os.path.exists(self._filename): + rest = self._add('cnt', -1) + if rest == 0: + os.remove(self._filename) + def _write(self, f, content): f.write(struct.pack('i', len(content))) @@ -170,11 +179,6 @@ def delete_key(self, key: str): def set_timeout(self, timeout: timedelta): self._timeout = timeout - def close(self): - rest = self._add('cnt', -1) - if rest == 0: - os.remove(self._filename) - if __name__ == '__main__': store = FileStore('tmp') store.set_timeout(timedelta(seconds=30)) @@ -185,4 +189,3 @@ def close(self): ret = store.add('baga', 5) print(ret) print(store.get('baga')) - store.close() diff --git a/python/hidet/drivers/build_graph.py b/python/hidet/drivers/build_graph.py index 673704368..34a376325 100644 --- a/python/hidet/drivers/build_graph.py +++ b/python/hidet/drivers/build_graph.py @@ -143,8 +143,7 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD graph_hash = sha256('\n'.join(lines).encode('utf-8')).hexdigest()[:16] return GraphMetaData( - inputs=inputs, outputs=outputs, hidet_version=hidet.__version__, num_kernels=num_kernels, graph_hash=graph_hash, - attrs=asdict(graph.attrs) + inputs=inputs, outputs=outputs, hidet_version=hidet.__version__, num_kernels=num_kernels, graph_hash=graph_hash ) def build_graph_module(graph: FlowGraph, graph_weights: List[Tensor], node2kernel: List[int]) -> CompiledModule: diff --git a/python/hidet/graph/flow_graph.py b/python/hidet/graph/flow_graph.py index 28b589403..b724c254f 100644 --- a/python/hidet/graph/flow_graph.py +++ b/python/hidet/graph/flow_graph.py @@ -16,7 +16,6 @@ import os import pickle from collections import defaultdict -from dataclasses import dataclass, field import hidet.graph.operator import hidet.cuda @@ -104,30 +103,16 @@ def benchmark(self, output_dir='./outs/benchmark', print_summary: bool = False, def forward_context() -> GraphForwardContext: return GraphForwardContext() -@dataclass -class FlowGraphAttrs: - nrank: int = 0 - rank: int = 0 - groups: List[List[int]] = field(default_factory=list) - class FlowGraph: """The computation graph representation.""" - def __init__( - self, - outputs: Sequence[Tensor], - inputs: Optional[Sequence[Tensor]] = None, - nodes=None, - attrs: Optional[FlowGraphAttrs] = None - ): + def __init__(self, outputs: Sequence[Tensor], inputs: Optional[Sequence[Tensor]] = None, nodes=None): self.outputs: List[Tensor] = list(outputs) self.inputs: Optional[List[Tensor]] = list(inputs) if inputs is not None else None self._nodes: Optional[List[Operator]] = nodes self._usage_count: Optional[Dict[Tensor, int]] = None self.update_nodes() - self.attrs: FlowGraphAttrs = attrs if attrs else FlowGraphAttrs() - def __call__(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]: """ Run the computation graph. @@ -193,9 +178,6 @@ def _build_nodes(self): hidet.option.parallel_build(False) hidet.drivers.build_task_batch(tunable_tasks) # build tunable tasks one by one - def set_attrs(self, *args, **kwargs): - self.attrs = FlowGraphAttrs(*args, **kwargs) - def forward(self, inputs: List[Tensor]) -> List[Tensor]: """Run the computation graph. diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py index a72e7a3a1..511fd112e 100644 --- a/python/hidet/graph/ops/distributed.py +++ b/python/hidet/graph/ops/distributed.py @@ -32,7 +32,6 @@ def __str__(self): return f"all_reduce" def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]: - # we may need current rank here to avoid duplicated working_dirs import hidet from hidet.ir.primitives.cuda.nccl import all_reduce as _all_reduce from hidet.lang import attrs diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index 9ad761310..05f866a40 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -28,7 +28,6 @@ from hidet.runtime.storage import Storage from hidet.ffi import runtime_api from hidet.utils import prod -from hidet.cuda.nccl import NcclCommunicator, NcclUniqueId, create_comm, NCCL_SPLIT_NOCOLOR, comms_to_array ModelExecutionHook = Callable[[int, List['Tensor'], List['Tensor']], None] @@ -45,7 +44,6 @@ class GraphMetaData: hidet_version: str num_kernels: int graph_hash: str - attrs: Dict[str, Union[int, List[List[int]]]] @dataclass @@ -104,7 +102,6 @@ def __init__( self.dispatch_table: Dict[Tuple[int, ...], Array] = {} self.cuda_workspace: Optional[Storage] = None self.cpu_workspace: Optional[Storage] = None - self.nccl_comms: List[NcclCommunicator] = [] self._init_compiled_graph() @@ -170,27 +167,6 @@ def _init_compiled_graph(self): kernel_array[task_idx] = ctypes_func_pointer(compiled_task.candidates[sch_idx].ctypes_func) self.dispatch_table[tuple(symbol_dims)] = kernel_array - def init_dist(self, unique_id: NcclUniqueId): - if self.dist_info is None: - raise RuntimeError("Distributed information is not set.") - self.nccl_comms = [] - - # Initialize the default group - nranks = self.dist_info.nrank - rank = self.dist_info.rank - default_comm = create_comm(nranks, unique_id, rank) - self.nccl_comms.append(default_comm) - - # Create communicators according to groups - if self.dist_info.groups is not None: - for group in self.dist_info.groups: - in_group = rank in group - color = 0 if in_group else NCCL_SPLIT_NOCOLOR - key = group.index(rank) if in_group else 0 - comm = default_comm.split(key, color) - if in_group: - self.nccl_comms.append(comm) - def _update_symbol_table(self, symbol_dims: Tuple[int, ...], best_candidates: List[int]): kernel_array = Array(void_p, len(self.compiled_tasks)) for task_idx, best_candidate in enumerate(best_candidates): @@ -299,10 +275,6 @@ def run_async(self, inputs): ret: List[hidet.Tensor] The output tensors. """ - if self.dist_info is not None: - comms_array = comms_to_array(self.nccl_comms) - runtime_api.set_nccl_comms(comms_array) - if hidet.option.get_runtime_check(): _check_inputs(self.meta.inputs, inputs) diff --git a/python/hidet/transforms/annotate_header_and_libs.py b/python/hidet/transforms/annotate_header_and_libs.py index 20b68b9ee..da1cf6977 100644 --- a/python/hidet/transforms/annotate_header_and_libs.py +++ b/python/hidet/transforms/annotate_header_and_libs.py @@ -19,7 +19,7 @@ def _use_distributed(func) -> bool: black_stmts = hidet.ir.tools.collect(func.body, [BlackBoxStmt]) return any(stmt.template_string.startswith('nccl') for stmt in black_stmts) -class AnnotateIncludeHeadersPass(Pass): +class AnnotateHeaderAndLibsPass(Pass): def process_module(self, ir_module: IRModule) -> IRModule: use_dist = any(_use_distributed(func) for func in ir_module.functions.values()) if not use_dist: @@ -39,5 +39,5 @@ def process_module(self, ir_module: IRModule) -> IRModule: return new_module -def annotate_include_headers_pass(): - return AnnotateIncludeHeadersPass() +def annotate_header_and_libs_pass(): + return AnnotateHeaderAndLibsPass() From a39c1991f6600f223e28f0f6da0847cf5ab6fce3 Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 13:42:47 -0400 Subject: [PATCH 27/40] update --- examples/distributed/test.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index a15d53c9e..6f81d6da7 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -34,10 +34,19 @@ def run(world_size, rank): device = f"cuda:{rank}" x = hidet.randn([1, 3], device=device) w = hidet.randn([3, 2], device=device) - - s = hidet.cuda.current_stream() - s.synchronize() - print(f"process {rank}") + + # Create Computation Graph + x_symb = hidet.symbol_like(x) + w_symb = hidet.symbol_like(w) + y_local = hidet.ops.relu(x_symb @ w_symb) + y_sync = hidet.ops.all_reduce(y_local, args.reduce_op, comm_id=0) + graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb]) + opt_graph = hidet.graph.optimize(graph) + compiled = opt_graph.build() + y_local, y_sync = compiled(x, w) + + s = hidet.cuda.current_stream().synchronize() + print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}\n", end='') world_size = args.n_gpus processes = [Process(target=run, args=(world_size, i)) for i in range(world_size)] From 0a04b828d203883aed55059faece98465abd0565 Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 13:55:08 -0400 Subject: [PATCH 28/40] update --- examples/distributed/test.py | 5 +++-- python/hidet/distributed/group.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 6f81d6da7..761e91423 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -39,14 +39,15 @@ def run(world_size, rank): x_symb = hidet.symbol_like(x) w_symb = hidet.symbol_like(w) y_local = hidet.ops.relu(x_symb @ w_symb) - y_sync = hidet.ops.all_reduce(y_local, args.reduce_op, comm_id=0) + y_sync = hidet.ops.all_reduce(y_local, args.reduce_op) graph = hidet.trace_from([y_local, y_sync], inputs=[x_symb, w_symb]) opt_graph = hidet.graph.optimize(graph) compiled = opt_graph.build() y_local, y_sync = compiled(x, w) - s = hidet.cuda.current_stream().synchronize() + hidet.cuda.current_stream().synchronize() print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}\n", end='') + print("sss") world_size = args.n_gpus processes = [Process(target=run, args=(world_size, i)) for i in range(world_size)] diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py index 9bb3194ce..cc1837d2e 100644 --- a/python/hidet/distributed/group.py +++ b/python/hidet/distributed/group.py @@ -46,6 +46,7 @@ def barrier(self): raise NotImplementedError() NCCL_COMMS = [] +_NCCL_ARRAY = None class NCCLProcessGroup(ProcessGroup): def __init__(self, comm: NcclCommunicator, world_size: int, rank: int): @@ -78,9 +79,10 @@ def create_nccl_group(store: Store, world_size: int, rank: int): return NCCLProcessGroup(comm, world_size, rank) def set_nccl_comms(): + global _NCCL_ARRAY from hidet.ffi.runtime_api import runtime_api - comm_array = comms_to_array(NCCL_COMMS) - runtime_api.set_nccl_comms(comm_array) + _NCCL_ARRAY = comms_to_array(NCCL_COMMS) + runtime_api.set_nccl_comms(_NCCL_ARRAY) if __name__ == '__main__': From eedaf84bc2fdf0ff6247f2e6a22ac836ce31e58b Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 14:39:46 -0400 Subject: [PATCH 29/40] add test --- tests/unit_tests/test_store.py | 120 +++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 tests/unit_tests/test_store.py diff --git a/tests/unit_tests/test_store.py b/tests/unit_tests/test_store.py new file mode 100644 index 000000000..e386a9dc7 --- /dev/null +++ b/tests/unit_tests/test_store.py @@ -0,0 +1,120 @@ +import pytest +import multiprocessing +from multiprocessing import Process, Queue +import os +import time +from datetime import timedelta +import random + +from hidet.distributed import FileStore + +TMP_PATH='./tmp' + +def test_filestore_get_hold(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(): + store = FileStore(TMP_PATH) + store.get('non-existing-key') + + p = Process(target=subproc) + p.start() + store = FileStore(TMP_PATH) + store.set('key', b'value') + time.sleep(1) + assert p.is_alive() + p.terminate() + +def test_filestore_set_get(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(q): + store = FileStore(TMP_PATH) + store.set_timeout(timedelta(seconds=10)) + b = store.get('key') + q.put(b) + + store = FileStore(TMP_PATH) + store.set('key', random.randbytes(8)) + new_value = random.randbytes(8) + store.set('key', new_value) + q = Queue() + p = Process(target=subproc, args=(q, )) + p.start() + ret = q.get() + assert ret == new_value + p.join() + +def test_filestore_add(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(): + store = FileStore(TMP_PATH) + store.add('cnt', 1) + store.add('cnt', 2) + + store = FileStore(TMP_PATH) + store.add('cnt', 1) + p = Process(target=subproc) + p.start() + p.join() + ret = store.add('cnt', 2) + assert ret == 6 + +def test_filestore_del(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(): + store = FileStore(TMP_PATH) + store.get('key') + + p = Process(target=subproc) + p.start() + store = FileStore(TMP_PATH) + store.set('key', b'value') + store.delete_key('key') + time.sleep(1) + assert p.is_alive() + p.terminate() + +def test_filestore_wait(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(): + store = FileStore(TMP_PATH) + store.wait(['key'], timeout=timedelta(seconds=10)) + + p = Process(target=subproc) + p.start() + store = FileStore(TMP_PATH) + time.sleep(1) + assert p.is_alive() + store.set('key', b'test') + p.join() + assert not p.is_alive() + +def test_filestore_compare_set(): + if os.path.exists(TMP_PATH): + os.remove(TMP_PATH) + + def subproc(): + store = FileStore(TMP_PATH) + store.compare_set("key", b"first", b"second") + + store = FileStore(TMP_PATH) + store.set("key", b"random") + p = Process(target=subproc) + p.start() + p.join() + assert store.get("key") == b"random" + store.set("key", b"first") + store.compare_set("key", b"first", b"second") + p = Process(target=subproc) + p.start() + p.join() + assert store.get("key") == b"second" From 37c86547d1bcd9e9732b89487714c0014a30831d Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 14:51:56 -0400 Subject: [PATCH 30/40] format & copyright --- examples/distributed/test.py | 3 +- python/hidet/cuda/nccl/comm.py | 11 ++- python/hidet/cuda/nccl/ffi.py | 30 ++++--- python/hidet/distributed/__init__.py | 3 +- python/hidet/distributed/distributed.py | 41 +++++++-- python/hidet/distributed/group.py | 67 +++++++++------ python/hidet/distributed/store.py | 85 ++++++++++++------- python/hidet/drivers/build_graph.py | 8 +- python/hidet/graph/flow_graph.py | 1 + python/hidet/graph/ops/distributed.py | 2 +- python/hidet/runtime/compiled_graph.py | 8 +- .../transforms/annotate_header_and_libs.py | 3 +- tests/unit_tests/test_store.py | 32 ++++--- 13 files changed, 185 insertions(+), 109 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 761e91423..72caafb74 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -15,6 +15,7 @@ from multiprocessing import Process import numpy import argparse +import atexit import hidet import hidet.cuda.nccl @@ -47,7 +48,7 @@ def run(world_size, rank): hidet.cuda.current_stream().synchronize() print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}\n", end='') - print("sss") + atexit._run_exitfuncs() world_size = args.n_gpus processes = [Process(target=run, args=(world_size, i)) for i in range(world_size)] diff --git a/python/hidet/cuda/nccl/comm.py b/python/hidet/cuda/nccl/comm.py index 16879304f..801549f8e 100644 --- a/python/hidet/cuda/nccl/comm.py +++ b/python/hidet/cuda/nccl/comm.py @@ -78,11 +78,15 @@ def split(self, key, color): if color == NCCL_SPLIT_NOCOLOR: return None return NcclCommunicator(new_handle) - - def all_reduce(self, sendbuff:int, recvbuff:int, count:int, datatype:DataType, op:str, s:Optional[Stream]=None): + + def all_reduce( + self, sendbuff: int, recvbuff: int, count: int, datatype: DataType, op: str, s: Optional[Stream] = None + ): if s is None: s = current_stream() - nccl_runtime_api.all_reduce(sendbuff, recvbuff, count, dtype_to_nccl(datatype), str_to_nccl_op(op), self._handle, s) + nccl_runtime_api.all_reduce( + sendbuff, recvbuff, count, dtype_to_nccl(datatype), str_to_nccl_op(op), self._handle, s + ) def create_comm(nranks: int, unique_id: NcclUniqueId, rank: int) -> NcclCommunicator: @@ -106,6 +110,7 @@ def create_unique_id() -> NcclUniqueId: nccl_runtime_api.get_unique_id(unique_id) return unique_id + def dtype_to_nccl(dtype: DataType) -> NcclDataType: sname_dict = { 'f64': NcclDataType.float64, diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index 0f49617a2..a479a147b 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -75,11 +75,21 @@ class NCCLRuntimeAPI: _comm_user_rank = get_func('ncclCommUserRank', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) - _all_reduce = get_func('ncclAllReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL) - _broadcast = get_func('ncclBroadcast', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL) - _reduce = get_func('ncclReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL) - _all_gather = get_func('ncclAllGather', [c_void_p, c_void_p, c_uint64, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL) - _reduce_scatter = get_func('ncclReduceScatter', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL) + _all_reduce = get_func( + 'ncclAllReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL + ) + _broadcast = get_func( + 'ncclBroadcast', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL + ) + _reduce = get_func( + 'ncclReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL + ) + _all_gather = get_func( + 'ncclAllGather', [c_void_p, c_void_p, c_uint64, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL + ) + _reduce_scatter = get_func( + 'ncclReduceScatter', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL + ) # Early versions of NCCL do not have split try: @@ -121,13 +131,13 @@ def comm_split(comm_handle: int, color: int, key: int) -> int: ret = NCCLRuntimeAPI._comm_split(comm_handle, color, key, pointer(comm), None) assert ret == 0 return comm.value - + # TODO: Currently only support all_reduce @staticmethod - def all_reduce(sendbuff:int, recvbuff:int, count:int, datatype:int, op:int, comm_handle:int, s:Stream) -> None: - ret = NCCLRuntimeAPI._all_reduce( - sendbuff, recvbuff, count, datatype, op, comm_handle, s._handle - ) + def all_reduce( + sendbuff: int, recvbuff: int, count: int, datatype: int, op: int, comm_handle: int, s: Stream + ) -> None: + ret = NCCLRuntimeAPI._all_reduce(sendbuff, recvbuff, count, datatype, op, comm_handle, s.handle()) assert ret == 0 nccl_runtime_api = NCCLRuntimeAPI() diff --git a/python/hidet/distributed/__init__.py b/python/hidet/distributed/__init__.py index 1a748b6f3..665c77c5c 100644 --- a/python/hidet/distributed/__init__.py +++ b/python/hidet/distributed/__init__.py @@ -11,4 +11,5 @@ # limitations under the License. from .distributed import init_process_group -from .group import set_nccl_comms \ No newline at end of file +from .group import set_nccl_comms +from .store import FileStore diff --git a/python/hidet/distributed/distributed.py b/python/hidet/distributed/distributed.py index d06b5cf82..619752b9e 100644 --- a/python/hidet/distributed/distributed.py +++ b/python/hidet/distributed/distributed.py @@ -1,17 +1,30 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional from datetime import timedelta -from .store import Store, FileStore -from .group import create_nccl_group, ProcessGroup import hidet from hidet.graph import Tensor from hidet.cuda.nccl import nccl_available +from .store import Store, FileStore +from .group import create_nccl_group, ProcessGroup DEFAULT_TIMEOUT = timedelta(seconds=1800) DEFAULT_GROUP = None + def init_process_group( backend: str = 'nccl', init_method: Optional[str] = None, @@ -31,7 +44,7 @@ def init_process_group( if world_size <= 0 or rank < 0: raise RuntimeError("'world_size' and 'rank' must be specified.") - + if rank >= world_size: raise RuntimeError("'rank' must be smaller than 'world_size'") @@ -40,49 +53,61 @@ def init_process_group( raise RuntimeError("One of 'init_method' and 'store' must be specified.") else: if not init_method.startswith('file://'): - raise RuntimeError("Currently only FileStore is supported. Please speficy the path to the filestore with 'file://path-to-file'") - path_to_file = init_method[len('file://'):] + raise RuntimeError( + "Currently only FileStore is supported. " + "Please speficy the path to the filestore with 'file://path-to-file'" + ) + path_to_file = init_method[len('file://') :] store = FileStore(path_to_file) else: if init_method is not None: raise RuntimeError("'init_method' and 'store' are mutually exclusive.") - + store.set_timeout(timeout) if backend == 'nccl': if not is_nccl_available(): raise RuntimeError("NCCL is not found.") DEFAULT_GROUP = create_nccl_group(store, world_size, rank) + def is_initialized(): return DEFAULT_GROUP is not None + def is_nccl_available(): return nccl_available() + def broadcast(): raise NotImplementedError() -def all_reduce(tensor: Tensor, op:str, group:Optional[ProcessGroup]=None): + +def all_reduce(tensor: Tensor, op: str, group: Optional[ProcessGroup] = None): if group is None: group = DEFAULT_GROUP group.all_reduce(tensor, op) + def reduce(): raise NotImplementedError() + def all_gather_into_tensor(): raise NotImplementedError() + def scatter(): raise NotImplementedError() + def reduce_scatter_tensor(): raise NotImplementedError() + if __name__ == '__main__': init_process_group(init_method='file://tmp', world_size=1, rank=0) print(is_initialized()) test = hidet.randn((2, 2), device='cuda') print(test) all_reduce(test, 'sum') - print(test) \ No newline at end of file + print(test) diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py index cc1837d2e..c57fdc7ab 100644 --- a/python/hidet/distributed/group.py +++ b/python/hidet/distributed/group.py @@ -1,73 +1,89 @@ -import hidet -from hidet.graph import Tensor +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#pylint: disable=W0223 + from typing import Optional, List -from .store import Store, FileStore +from hidet.graph import Tensor from hidet.cuda.nccl import create_unique_id, NcclUniqueId, create_comm, NcclCommunicator, comms_to_array +from .store import Store + class ProcessGroup: def backend(self) -> str: raise NotImplementedError() - + def rank(self) -> int: raise NotImplementedError() - + def size(self) -> int: raise NotImplementedError() - + def broadcast(self, tensor: Tensor, src: int): raise NotImplementedError() - + def all_reduce(self, tensor: Tensor, op: str): raise NotImplementedError() - - def reduce(self, tensor: Tensor, dst:int, op:str): + + def reduce(self, tensor: Tensor, dst: int, op: str): raise NotImplementedError() - + def all_gather(self, tensor_list: List[Tensor], tensor: Tensor): raise NotImplementedError() - + def all_gather_into_tensor(self, output_tensor: Tensor, input_tensor: Tensor): raise NotImplementedError() - - def gather(self, tensor: Tensor, gather_list: Optional[List[Tensor]]=None, dst: int=0): + + def gather(self, tensor: Tensor, gather_list: Optional[List[Tensor]] = None, dst: int = 0): raise NotImplementedError() - - def scatter(self, tensor: Tensor, scattler_list: Optional[List[Tensor]]=None): + + def scatter(self, tensor: Tensor, scattler_list: Optional[List[Tensor]] = None): raise NotImplementedError() - + def reduce_scatter(self, output: Tensor, input_list: List[Tensor], op: str): raise NotImplementedError() - + def reduce_scatter_tensor(self, output: Tensor, input: Tensor, op: str): raise NotImplementedError() def barrier(self): raise NotImplementedError() + NCCL_COMMS = [] _NCCL_ARRAY = None + class NCCLProcessGroup(ProcessGroup): def __init__(self, comm: NcclCommunicator, world_size: int, rank: int): - global NCCL_COMMS self._comm: NcclCommunicator = comm self._world_size: int = world_size self._rank: int = rank NCCL_COMMS.append(comm) - + def rank(self) -> int: return self._rank - + def size(self) -> int: return self._world_size - - def all_reduce(self, tensor: Tensor, op:str): + + def all_reduce(self, tensor: Tensor, op: str): assert not tensor.is_symbolic() assert tensor.device.is_cuda() addr = tensor.storage.addr self._comm.all_reduce(addr, addr, tensor.nbytes, tensor.dtype, op) + def create_nccl_group(store: Store, world_size: int, rank: int): if rank == 0: unique_id = create_unique_id() @@ -78,13 +94,10 @@ def create_nccl_group(store: Store, world_size: int, rank: int): comm = create_comm(world_size, unique_id, rank) return NCCLProcessGroup(comm, world_size, rank) + def set_nccl_comms(): global _NCCL_ARRAY from hidet.ffi.runtime_api import runtime_api + _NCCL_ARRAY = comms_to_array(NCCL_COMMS) runtime_api.set_nccl_comms(_NCCL_ARRAY) - - -if __name__ == '__main__': - store = FileStore('tmp') - group = create_nccl_group(store, 1, 0) \ No newline at end of file diff --git a/python/hidet/distributed/store.py b/python/hidet/distributed/store.py index 3db423ad5..98d9d1570 100644 --- a/python/hidet/distributed/store.py +++ b/python/hidet/distributed/store.py @@ -1,40 +1,54 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import List, Optional from datetime import timedelta, datetime -import filelock -import time +import time import struct -from functools import partial import os import atexit +import filelock + class Store: def set(self, key: str, value: bytes) -> None: raise NotImplementedError() - + def get(self, key: str) -> bytes: raise NotImplementedError() - + def add(self, key: str, amount: int) -> int: raise NotImplementedError() - + def compare_set(self, key: str, expected: bytes, desired: bytes) -> bytes: raise NotImplementedError() - - def wait(self, keys: List[str], timeout: Optional[timedelta]=None) -> None: + + def wait(self, keys: List[str], timeout: Optional[timedelta] = None) -> None: raise NotImplementedError() - + def num_keys(self) -> int: raise NotImplementedError() - + def delete_key(self, key: str) -> bool: raise NotImplementedError() - + def set_timeout(self, timeout: timedelta): raise NotImplementedError() - + + class FileStore(Store): REGULAR_PREFIX = '+' DELETE_PREFIX = '-' + def __init__(self, filename: str, world_size: Optional[int] = -1): self._filename = filename self._lock_filename = filename + '.lock' @@ -45,9 +59,9 @@ def __init__(self, filename: str, world_size: Optional[int] = -1): self._timeout = None num_peers = self._add('cnt', 1) - if world_size >= 0 and num_peers > world_size: + if 0 <= world_size < num_peers: raise RuntimeError("Warning: more peers than world size.") - + # We cannot operate files in __del__, and we don't want to call close explicitly # So we register a atexit function doing cleanup when python interpreter exits @atexit.register @@ -57,7 +71,6 @@ def cleanup(): rest = self._add('cnt', -1) if rest == 0: os.remove(self._filename) - def _write(self, f, content): f.write(struct.pack('i', len(content))) @@ -66,19 +79,20 @@ def _write(self, f, content): def _read(self, f): len_str = f.read(4) if len_str == b'': - return + return None l = struct.unpack('i', len_str)[0] return f.read(l) - + def _file_size(self, f): origin_pos = f.tell() - f.seek(0, 2) # 2 means the file's end + f.seek(0, 2) # 2 means the file's end size = f.tell() f.seek(origin_pos, 0) return size - + def _update(self, f): self._cache = {} + f.seek(0) while True: k = self._read(f) if k is None: @@ -86,11 +100,12 @@ def _update(self, f): v = self._read(f) k = str(k, encoding='raw_unicode_escape') if k.startswith(self.DELETE_PREFIX): + k = k[len(self.DELETE_PREFIX) :] del self._cache[k] - self._cache[k] = v + else: + self._cache[k] = v def _add(self, key: str, amount: int) -> int: - key = key with self._lock: with open(self._filename, "ab+") as f: f.seek(0) @@ -106,15 +121,17 @@ def _check(self, keys: List[str]): with open(self._filename, "ab+") as f: f.seek(0) self._update(f) - print(self._cache.keys()) - return all((key in self._cache.keys() for key in keys)) + return all((key in self._cache for key in keys)) - def set(self, key: str, value: bytes) -> None: + def _set(self, key: str, value: bytes): with self._lock: with open(self._filename, "ab+") as f: - self._write(f, bytes(self.REGULAR_PREFIX + key, encoding='raw_unicode_escape')) + self._write(f, bytes(key, encoding='raw_unicode_escape')) self._write(f, value) - + + def set(self, key: str, value: bytes) -> None: + self._set(self.REGULAR_PREFIX + key, value) + def get(self, key: str) -> bytes: last_file_size = None key = self.REGULAR_PREFIX + key @@ -124,7 +141,7 @@ def get(self, key: str) -> bytes: with open(self._filename, "ab+") as f: f.seek(0) file_size = self._file_size(f) - if key not in self._cache.keys() and file_size == last_file_size: + if key not in self._cache and file_size == last_file_size: # No new entries last_file_size = file_size self._lock.release() @@ -141,14 +158,15 @@ def get(self, key: str) -> bytes: def add(self, key: str, amount: int) -> int: return self._add(self.REGULAR_PREFIX + key, amount) - + def compare_set(self, key: str, expected: bytes, desired: bytes) -> bytes: key = self.REGULAR_PREFIX + key with self._lock: with open(self._filename, "ab+") as f: f.seek(0) - self._update() - has_key = key in self._cache.keys() + self._update(f) + has_key = key in self._cache + print(has_key, self._cache[key]) if (not has_key and expected == b'') or (has_key and self._cache[key] == expected): f.seek(0, 2) self._write(f, bytes(key, encoding='raw_unicode_escape')) @@ -158,7 +176,7 @@ def compare_set(self, key: str, expected: bytes, desired: bytes) -> bytes: return expected return self._cache[key] - def wait(self, keys: List[str], timeout: Optional[timedelta]=None) -> None: + def wait(self, keys: List[str], timeout: Optional[timedelta] = None) -> None: timeout = self._timeout if self._timeout is not None else timeout start_t = datetime.now() keys = [self.REGULAR_PREFIX + key for key in keys] @@ -172,13 +190,14 @@ def num_keys(self): with open(self._filename, "rb") as f: self._update(f) return len(self._cache) - + def delete_key(self, key: str): - self.set(self.DELETE_PREFIX + key, b'') + self._set(self.DELETE_PREFIX + self.REGULAR_PREFIX + key, b'') def set_timeout(self, timeout: timedelta): self._timeout = timeout + if __name__ == '__main__': store = FileStore('tmp') store.set_timeout(timedelta(seconds=30)) diff --git a/python/hidet/drivers/build_graph.py b/python/hidet/drivers/build_graph.py index 34a376325..994d13d88 100644 --- a/python/hidet/drivers/build_graph.py +++ b/python/hidet/drivers/build_graph.py @@ -22,12 +22,7 @@ from hidet.graph.tensor import Tensor from hidet.graph.flow_graph import FlowGraph from hidet.runtime.compiled_module import CompiledModule -from hidet.runtime.compiled_graph import ( - CompiledGraph, - GraphMetaData, - GraphExecution, - GraphExecutionInstruction, -) +from hidet.runtime.compiled_graph import CompiledGraph, GraphMetaData, GraphExecution, GraphExecutionInstruction from hidet.runtime.compiled_task import CompiledTask, TensorSignature from hidet.graph.operator import Operator from hidet.ir import primitives @@ -146,6 +141,7 @@ def get_graph_meta_data(graph: FlowGraph, num_kernels, space: int) -> GraphMetaD inputs=inputs, outputs=outputs, hidet_version=hidet.__version__, num_kernels=num_kernels, graph_hash=graph_hash ) + def build_graph_module(graph: FlowGraph, graph_weights: List[Tensor], node2kernel: List[int]) -> CompiledModule: from hidet.lang import void_p, attrs, int32, int64, meta, cast from hidet.ir.primitives.runtime import memory_planner_init, memory_planner_allocate, memory_planner_free diff --git a/python/hidet/graph/flow_graph.py b/python/hidet/graph/flow_graph.py index b724c254f..996d02bff 100644 --- a/python/hidet/graph/flow_graph.py +++ b/python/hidet/graph/flow_graph.py @@ -103,6 +103,7 @@ def benchmark(self, output_dir='./outs/benchmark', print_summary: bool = False, def forward_context() -> GraphForwardContext: return GraphForwardContext() + class FlowGraph: """The computation graph representation.""" diff --git a/python/hidet/graph/ops/distributed.py b/python/hidet/graph/ops/distributed.py index 511fd112e..3cc5a964b 100644 --- a/python/hidet/graph/ops/distributed.py +++ b/python/hidet/graph/ops/distributed.py @@ -29,7 +29,7 @@ def __init__(self, x: TensorNode, op: str, comm_id: int = 0): super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={'comm_id': comm_id, 'op': op}) def __str__(self): - return f"all_reduce" + return "all_reduce" def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]: import hidet diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index 05f866a40..54205d381 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Dict, Any, Callable, Union +from typing import List, Optional, Tuple, Dict, Any, Callable import zipfile import os import json @@ -62,6 +62,7 @@ class GraphExecution: outputs_index: List[int] tensor_device: List[str] + class CompiledGraph: def __init__( self, @@ -363,7 +364,6 @@ def _save_under(dir_path: str, dir_in_zip: str, exclude: Optional[List[str]] = N f.write(model.graph_string.encode('utf-8')) - def load_compiled_graph(path: str) -> CompiledGraph: from hidet.utils.dataclass import from_dict @@ -410,8 +410,6 @@ def load_compiled_graph(path: str) -> CompiledGraph: graph_string = f.read() # construct the compiled graph - ret = CompiledGraph( - meta_data, graph_module, weights, compiled_tasks, graph_execution, graph_string - ) + ret = CompiledGraph(meta_data, graph_module, weights, compiled_tasks, graph_execution, graph_string) return ret diff --git a/python/hidet/transforms/annotate_header_and_libs.py b/python/hidet/transforms/annotate_header_and_libs.py index da1cf6977..53a974a8e 100644 --- a/python/hidet/transforms/annotate_header_and_libs.py +++ b/python/hidet/transforms/annotate_header_and_libs.py @@ -11,14 +11,15 @@ # limitations under the License. import hidet.ir from hidet.ir.module import IRModule -from hidet.ir import Stmt from hidet.ir.stmt import BlackBoxStmt from hidet.transforms import Pass + def _use_distributed(func) -> bool: black_stmts = hidet.ir.tools.collect(func.body, [BlackBoxStmt]) return any(stmt.template_string.startswith('nccl') for stmt in black_stmts) + class AnnotateHeaderAndLibsPass(Pass): def process_module(self, ir_module: IRModule) -> IRModule: use_dist = any(_use_distributed(func) for func in ir_module.functions.values()) diff --git a/tests/unit_tests/test_store.py b/tests/unit_tests/test_store.py index e386a9dc7..5e37584bb 100644 --- a/tests/unit_tests/test_store.py +++ b/tests/unit_tests/test_store.py @@ -8,7 +8,8 @@ from hidet.distributed import FileStore -TMP_PATH='./tmp' +TMP_PATH = './tmp' + def test_filestore_get_hold(): if os.path.exists(TMP_PATH): @@ -17,7 +18,7 @@ def test_filestore_get_hold(): def subproc(): store = FileStore(TMP_PATH) store.get('non-existing-key') - + p = Process(target=subproc) p.start() store = FileStore(TMP_PATH) @@ -26,36 +27,38 @@ def subproc(): assert p.is_alive() p.terminate() + def test_filestore_set_get(): if os.path.exists(TMP_PATH): os.remove(TMP_PATH) - + def subproc(q): store = FileStore(TMP_PATH) store.set_timeout(timedelta(seconds=10)) b = store.get('key') q.put(b) - + store = FileStore(TMP_PATH) store.set('key', random.randbytes(8)) new_value = random.randbytes(8) store.set('key', new_value) q = Queue() - p = Process(target=subproc, args=(q, )) + p = Process(target=subproc, args=(q,)) p.start() ret = q.get() assert ret == new_value p.join() - + + def test_filestore_add(): if os.path.exists(TMP_PATH): os.remove(TMP_PATH) - + def subproc(): store = FileStore(TMP_PATH) store.add('cnt', 1) store.add('cnt', 2) - + store = FileStore(TMP_PATH) store.add('cnt', 1) p = Process(target=subproc) @@ -63,7 +66,8 @@ def subproc(): p.join() ret = store.add('cnt', 2) assert ret == 6 - + + def test_filestore_del(): if os.path.exists(TMP_PATH): os.remove(TMP_PATH) @@ -71,7 +75,7 @@ def test_filestore_del(): def subproc(): store = FileStore(TMP_PATH) store.get('key') - + p = Process(target=subproc) p.start() store = FileStore(TMP_PATH) @@ -79,7 +83,8 @@ def subproc(): store.delete_key('key') time.sleep(1) assert p.is_alive() - p.terminate() + p.terminate() + def test_filestore_wait(): if os.path.exists(TMP_PATH): @@ -88,7 +93,7 @@ def test_filestore_wait(): def subproc(): store = FileStore(TMP_PATH) store.wait(['key'], timeout=timedelta(seconds=10)) - + p = Process(target=subproc) p.start() store = FileStore(TMP_PATH) @@ -98,6 +103,7 @@ def subproc(): p.join() assert not p.is_alive() + def test_filestore_compare_set(): if os.path.exists(TMP_PATH): os.remove(TMP_PATH) @@ -106,7 +112,7 @@ def subproc(): store = FileStore(TMP_PATH) store.compare_set("key", b"first", b"second") - store = FileStore(TMP_PATH) + store = FileStore(TMP_PATH) store.set("key", b"random") p = Process(target=subproc) p.start() From 3fd749127bf88f3330903563cd0b84ee723d369e Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 15:18:39 -0400 Subject: [PATCH 31/40] update --- examples/distributed/test.py | 4 ++++ python/hidet/cuda/nccl/comm.py | 2 +- python/hidet/cuda/nccl/ffi.py | 2 ++ python/hidet/distributed/__init__.py | 2 +- python/hidet/distributed/group.py | 1 + 5 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 72caafb74..57a33c1cb 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -35,6 +35,10 @@ def run(world_size, rank): device = f"cuda:{rank}" x = hidet.randn([1, 3], device=device) w = hidet.randn([3, 2], device=device) + + # test runtime distributed op + hidet.distributed.all_reduce(w, 'avg') + # print(w) # Create Computation Graph x_symb = hidet.symbol_like(x) diff --git a/python/hidet/cuda/nccl/comm.py b/python/hidet/cuda/nccl/comm.py index 801549f8e..416d0f0cf 100644 --- a/python/hidet/cuda/nccl/comm.py +++ b/python/hidet/cuda/nccl/comm.py @@ -85,7 +85,7 @@ def all_reduce( if s is None: s = current_stream() nccl_runtime_api.all_reduce( - sendbuff, recvbuff, count, dtype_to_nccl(datatype), str_to_nccl_op(op), self._handle, s + sendbuff, recvbuff, count, int(dtype_to_nccl(datatype)), int(str_to_nccl_op(op)), self._handle, s ) diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index a479a147b..5605d38cd 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -137,7 +137,9 @@ def comm_split(comm_handle: int, color: int, key: int) -> int: def all_reduce( sendbuff: int, recvbuff: int, count: int, datatype: int, op: int, comm_handle: int, s: Stream ) -> None: + print("get all_reduce request", sendbuff, recvbuff, count, datatype, op, comm_handle, s.handle()) ret = NCCLRuntimeAPI._all_reduce(sendbuff, recvbuff, count, datatype, op, comm_handle, s.handle()) + print(ret) assert ret == 0 nccl_runtime_api = NCCLRuntimeAPI() diff --git a/python/hidet/distributed/__init__.py b/python/hidet/distributed/__init__.py index 665c77c5c..a5b06eb29 100644 --- a/python/hidet/distributed/__init__.py +++ b/python/hidet/distributed/__init__.py @@ -10,6 +10,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .distributed import init_process_group +from .distributed import init_process_group, all_reduce from .group import set_nccl_comms from .store import FileStore diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py index c57fdc7ab..8198af3a1 100644 --- a/python/hidet/distributed/group.py +++ b/python/hidet/distributed/group.py @@ -81,6 +81,7 @@ def all_reduce(self, tensor: Tensor, op: str): assert not tensor.is_symbolic() assert tensor.device.is_cuda() addr = tensor.storage.addr + print(addr, tensor.nbytes, tensor, tensor.dtype) self._comm.all_reduce(addr, addr, tensor.nbytes, tensor.dtype, op) From 8bc856fa00da6d04da3d776151de1458ac949960 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 15:38:50 -0400 Subject: [PATCH 32/40] update --- examples/distributed/test.py | 6 +++++- python/hidet/cuda/nccl/ffi.py | 2 -- python/hidet/distributed/distributed.py | 10 ---------- python/hidet/distributed/group.py | 6 +++--- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/examples/distributed/test.py b/examples/distributed/test.py index 57a33c1cb..7fc1acb81 100644 --- a/examples/distributed/test.py +++ b/examples/distributed/test.py @@ -16,6 +16,7 @@ import numpy import argparse import atexit +import os import hidet import hidet.cuda.nccl @@ -38,7 +39,7 @@ def run(world_size, rank): # test runtime distributed op hidet.distributed.all_reduce(w, 'avg') - # print(w) + print(w) # Create Computation Graph x_symb = hidet.symbol_like(x) @@ -54,6 +55,9 @@ def run(world_size, rank): print(f"process {rank}\nbefore allreduce:{y_local}\nafter allreduce:{y_sync}\n", end='') atexit._run_exitfuncs() +if os.path.exists('tmp'): + os.remove('tmp') + world_size = args.n_gpus processes = [Process(target=run, args=(world_size, i)) for i in range(world_size)] diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index 5605d38cd..a479a147b 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -137,9 +137,7 @@ def comm_split(comm_handle: int, color: int, key: int) -> int: def all_reduce( sendbuff: int, recvbuff: int, count: int, datatype: int, op: int, comm_handle: int, s: Stream ) -> None: - print("get all_reduce request", sendbuff, recvbuff, count, datatype, op, comm_handle, s.handle()) ret = NCCLRuntimeAPI._all_reduce(sendbuff, recvbuff, count, datatype, op, comm_handle, s.handle()) - print(ret) assert ret == 0 nccl_runtime_api = NCCLRuntimeAPI() diff --git a/python/hidet/distributed/distributed.py b/python/hidet/distributed/distributed.py index 619752b9e..6852629ba 100644 --- a/python/hidet/distributed/distributed.py +++ b/python/hidet/distributed/distributed.py @@ -13,7 +13,6 @@ from typing import Optional from datetime import timedelta -import hidet from hidet.graph import Tensor from hidet.cuda.nccl import nccl_available from .store import Store, FileStore @@ -102,12 +101,3 @@ def scatter(): def reduce_scatter_tensor(): raise NotImplementedError() - - -if __name__ == '__main__': - init_process_group(init_method='file://tmp', world_size=1, rank=0) - print(is_initialized()) - test = hidet.randn((2, 2), device='cuda') - print(test) - all_reduce(test, 'sum') - print(test) diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py index 8198af3a1..e8404f9ec 100644 --- a/python/hidet/distributed/group.py +++ b/python/hidet/distributed/group.py @@ -81,7 +81,6 @@ def all_reduce(self, tensor: Tensor, op: str): assert not tensor.is_symbolic() assert tensor.device.is_cuda() addr = tensor.storage.addr - print(addr, tensor.nbytes, tensor, tensor.dtype) self._comm.all_reduce(addr, addr, tensor.nbytes, tensor.dtype, op) @@ -90,8 +89,9 @@ def create_nccl_group(store: Store, world_size: int, rank: int): unique_id = create_unique_id() store.set('unique_id', unique_id.internal) else: - unique_id = store.get('unique_id') - unique_id = NcclUniqueId(unique_id) + _id = store.get('unique_id') + unique_id = NcclUniqueId() + unique_id.internal[:] = _id[:] comm = create_comm(world_size, unique_id, rank) return NCCLProcessGroup(comm, world_size, rank) From bb4d6d1ff0433455184ee7fe430ef8f39567e9a5 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 15:39:04 -0400 Subject: [PATCH 33/40] format --- python/hidet/distributed/group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py index e8404f9ec..c62542b35 100644 --- a/python/hidet/distributed/group.py +++ b/python/hidet/distributed/group.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -#pylint: disable=W0223 +# pylint: disable=W0223 from typing import Optional, List From 8518e9e43e8f5cd7c8a34703c41bb35a4cfe8586 Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 17:49:36 -0400 Subject: [PATCH 34/40] update --- python/hidet/cuda/nccl/ffi.py | 12 ++++++------ python/hidet/distributed/group.py | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/hidet/cuda/nccl/ffi.py b/python/hidet/cuda/nccl/ffi.py index a479a147b..f415cb006 100644 --- a/python/hidet/cuda/nccl/ffi.py +++ b/python/hidet/cuda/nccl/ffi.py @@ -76,19 +76,19 @@ class NCCLRuntimeAPI: _comm_count = get_func('ncclCommCount', [c_void_p, POINTER(c_int)], c_int, lib=_LIB_NCCL) _all_reduce = get_func( - 'ncclAllReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL + 'ncclAllReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL ) _broadcast = get_func( - 'ncclBroadcast', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL + 'ncclBroadcast', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL ) _reduce = get_func( - 'ncclReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL + 'ncclReduce', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL ) _all_gather = get_func( - 'ncclAllGather', [c_void_p, c_void_p, c_uint64, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL + 'ncclAllGather', [c_void_p, c_void_p, c_uint64, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL ) _reduce_scatter = get_func( - 'ncclReduceScatter', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_int], c_int, lib=_LIB_NCCL + 'ncclReduceScatter', [c_void_p, c_void_p, c_uint64, c_int, c_int, c_void_p, c_void_p], c_int, lib=_LIB_NCCL ) # Early versions of NCCL do not have split @@ -137,7 +137,7 @@ def comm_split(comm_handle: int, color: int, key: int) -> int: def all_reduce( sendbuff: int, recvbuff: int, count: int, datatype: int, op: int, comm_handle: int, s: Stream ) -> None: - ret = NCCLRuntimeAPI._all_reduce(sendbuff, recvbuff, count, datatype, op, comm_handle, s.handle()) + ret = NCCLRuntimeAPI._all_reduce(sendbuff, recvbuff, count, datatype, op, comm_handle, c_void_p(int(s))) assert ret == 0 nccl_runtime_api = NCCLRuntimeAPI() diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py index c62542b35..601d99930 100644 --- a/python/hidet/distributed/group.py +++ b/python/hidet/distributed/group.py @@ -78,6 +78,7 @@ def size(self) -> int: return self._world_size def all_reduce(self, tensor: Tensor, op: str): + print(tensor, op) assert not tensor.is_symbolic() assert tensor.device.is_cuda() addr = tensor.storage.addr From dcb87aaec9a937c58e3e769a28736e6b1be7efbd Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 17:54:28 -0400 Subject: [PATCH 35/40] fix --- python/hidet/distributed/group.py | 1 - tests/unit_tests/test_store.py | 15 +++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py index 601d99930..c62542b35 100644 --- a/python/hidet/distributed/group.py +++ b/python/hidet/distributed/group.py @@ -78,7 +78,6 @@ def size(self) -> int: return self._world_size def all_reduce(self, tensor: Tensor, op: str): - print(tensor, op) assert not tensor.is_symbolic() assert tensor.device.is_cuda() addr = tensor.storage.addr diff --git a/tests/unit_tests/test_store.py b/tests/unit_tests/test_store.py index 5e37584bb..941b9a0dc 100644 --- a/tests/unit_tests/test_store.py +++ b/tests/unit_tests/test_store.py @@ -1,3 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest import multiprocessing from multiprocessing import Process, Queue @@ -39,8 +50,8 @@ def subproc(q): q.put(b) store = FileStore(TMP_PATH) - store.set('key', random.randbytes(8)) - new_value = random.randbytes(8) + store.set('key', b'u98guj89ks') + new_value = b'32894728934798' store.set('key', new_value) q = Queue() p = Process(target=subproc, args=(q,)) From a2d8be612f4016e67f3027a2b5dbf7c7718bf23f Mon Sep 17 00:00:00 2001 From: su <soodoshll@gmail.com> Date: Tue, 27 Jun 2023 18:00:04 -0400 Subject: [PATCH 36/40] format --- tests/unit_tests/test_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/test_store.py b/tests/unit_tests/test_store.py index 941b9a0dc..8ae8ede1f 100644 --- a/tests/unit_tests/test_store.py +++ b/tests/unit_tests/test_store.py @@ -51,7 +51,7 @@ def subproc(q): store = FileStore(TMP_PATH) store.set('key', b'u98guj89ks') - new_value = b'32894728934798' + new_value = b'32894728934798' store.set('key', new_value) q = Queue() p = Process(target=subproc, args=(q,)) From 917d24f3b451d8885ba80edc099d406684554540 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 28 Jun 2023 19:32:44 -0400 Subject: [PATCH 37/40] fix --- python/hidet/distributed/store.py | 12 ------------ .../test_store.py => distributed/test_file_store.py} | 0 2 files changed, 12 deletions(-) rename tests/{unit_tests/test_store.py => distributed/test_file_store.py} (100%) diff --git a/python/hidet/distributed/store.py b/python/hidet/distributed/store.py index 98d9d1570..e4bb35f8f 100644 --- a/python/hidet/distributed/store.py +++ b/python/hidet/distributed/store.py @@ -196,15 +196,3 @@ def delete_key(self, key: str): def set_timeout(self, timeout: timedelta): self._timeout = timeout - - -if __name__ == '__main__': - store = FileStore('tmp') - store.set_timeout(timedelta(seconds=30)) - ret = store.add('baga', 2) - store.set('yarou', b'haha') - store.wait(['baga', 'yarou']) - print(ret) - ret = store.add('baga', 5) - print(ret) - print(store.get('baga')) diff --git a/tests/unit_tests/test_store.py b/tests/distributed/test_file_store.py similarity index 100% rename from tests/unit_tests/test_store.py rename to tests/distributed/test_file_store.py From fdf749f7f82a8f2c4794de233ddd58433b37dc04 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 28 Jun 2023 19:33:07 -0400 Subject: [PATCH 38/40] fix --- python/hidet/distributed/distributed.py | 2 + python/hidet/distributed/group.py | 4 +- python/hidet/distributed/store.py | 51 +++++++++++++++++-------- 3 files changed, 39 insertions(+), 18 deletions(-) diff --git a/python/hidet/distributed/distributed.py b/python/hidet/distributed/distributed.py index 6852629ba..b60e433dc 100644 --- a/python/hidet/distributed/distributed.py +++ b/python/hidet/distributed/distributed.py @@ -67,6 +67,8 @@ def init_process_group( if not is_nccl_available(): raise RuntimeError("NCCL is not found.") DEFAULT_GROUP = create_nccl_group(store, world_size, rank) + else: + raise ValueError(f"Backend {backend} is not supported.") def is_initialized(): diff --git a/python/hidet/distributed/group.py b/python/hidet/distributed/group.py index c62542b35..d60de97dc 100644 --- a/python/hidet/distributed/group.py +++ b/python/hidet/distributed/group.py @@ -60,8 +60,8 @@ def barrier(self): raise NotImplementedError() -NCCL_COMMS = [] -_NCCL_ARRAY = None +NCCL_COMMS: List[NcclCommunicator] = [] +_NCCL_ARRAY: 'Array' = None class NCCLProcessGroup(ProcessGroup): diff --git a/python/hidet/distributed/store.py b/python/hidet/distributed/store.py index e4bb35f8f..631b4c25c 100644 --- a/python/hidet/distributed/store.py +++ b/python/hidet/distributed/store.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List, Optional, Dict from datetime import timedelta, datetime import time import struct @@ -46,17 +46,40 @@ def set_timeout(self, timeout: timedelta): class FileStore(Store): + """ + A shared KV-store based on the local filesystem. + + It will create a binary file (specified by the filename argument) and a locking file. + Each time an new entry (key, value) is requested to be inserted, it will be inserted to + the end of the file. Only the newest is effective among all entries with the same key. + So when scanning the file from beginning, we can get the up-to-date status of the KV-store. + + All keys requested by public methods will be given a prefix '+' (REGULAR_PREFIX) to be + distinguished from some internal keys used by the store itself. For example, we have an + internal entry 'cnt' to maintain how many clients are using this store currently. + + Keys will be converted from Python strings to bytes automatically, while values won't since + values can be arbitary bytes arrays that might not be decodable. So please do the conversion + manually if required. + + We use a 4-byte integer to record the length of each (encoded) key and value. So do not insert + more than 32768 bytes for each entry. + + Deletion of an entry is done by adding a new entry with a suffix '-' (DELETE_PREFIX). It will + overwrite the insertion of the given entry when we scanning the file. + """ + REGULAR_PREFIX = '+' DELETE_PREFIX = '-' def __init__(self, filename: str, world_size: Optional[int] = -1): - self._filename = filename - self._lock_filename = filename + '.lock' - self._world_size = world_size + self._filename: str = filename + self._lock_filename: str = filename + '.lock' + self._world_size: int = world_size - self._lock = filelock.FileLock(self._lock_filename) - self._cache = {} - self._timeout = None + self._lock: filelock.FileLock = filelock.FileLock(self._lock_filename) + self._cache: Dict[str, bytes] = {} + self._timeout: timedelta = None num_peers = self._add('cnt', 1) if 0 <= world_size < num_peers: @@ -98,7 +121,7 @@ def _update(self, f): if k is None: return v = self._read(f) - k = str(k, encoding='raw_unicode_escape') + k = k.decode() if k.startswith(self.DELETE_PREFIX): k = k[len(self.DELETE_PREFIX) :] del self._cache[k] @@ -112,8 +135,8 @@ def _add(self, key: str, amount: int) -> int: self._update(f) value = int(self._cache.get(key, '0')) + amount with open(self._filename, "ab+") as f: - self._write(f, bytes(key, encoding='raw_unicode_escape')) - self._write(f, bytes(str(value), encoding='raw_unicode_escape')) + self._write(f, bytes(key, encoding='utf-8')) + self._write(f, bytes(str(value), encoding='utf-8')) return value def _check(self, keys: List[str]): @@ -126,7 +149,7 @@ def _check(self, keys: List[str]): def _set(self, key: str, value: bytes): with self._lock: with open(self._filename, "ab+") as f: - self._write(f, bytes(key, encoding='raw_unicode_escape')) + self._write(f, bytes(key, encoding='utf-8')) self._write(f, value) def set(self, key: str, value: bytes) -> None: @@ -139,11 +162,9 @@ def get(self, key: str) -> bytes: while True: self._lock.acquire() with open(self._filename, "ab+") as f: - f.seek(0) file_size = self._file_size(f) if key not in self._cache and file_size == last_file_size: # No new entries - last_file_size = file_size self._lock.release() if self._timeout is not None and datetime.now() - start_t > self._timeout: raise TimeoutError() @@ -163,13 +184,11 @@ def compare_set(self, key: str, expected: bytes, desired: bytes) -> bytes: key = self.REGULAR_PREFIX + key with self._lock: with open(self._filename, "ab+") as f: - f.seek(0) self._update(f) has_key = key in self._cache - print(has_key, self._cache[key]) if (not has_key and expected == b'') or (has_key and self._cache[key] == expected): f.seek(0, 2) - self._write(f, bytes(key, encoding='raw_unicode_escape')) + self._write(f, bytes(key, encoding='utf-8')) self._write(f, desired) return desired elif not has_key: From 816da19cd09e0d112a28a18f50d5be679e07cc45 Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 28 Jun 2023 19:39:07 -0400 Subject: [PATCH 39/40] remove redundant seek --- python/hidet/distributed/store.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/hidet/distributed/store.py b/python/hidet/distributed/store.py index 631b4c25c..5a7053f06 100644 --- a/python/hidet/distributed/store.py +++ b/python/hidet/distributed/store.py @@ -131,7 +131,6 @@ def _update(self, f): def _add(self, key: str, amount: int) -> int: with self._lock: with open(self._filename, "ab+") as f: - f.seek(0) self._update(f) value = int(self._cache.get(key, '0')) + amount with open(self._filename, "ab+") as f: @@ -142,7 +141,6 @@ def _add(self, key: str, amount: int) -> int: def _check(self, keys: List[str]): with self._lock: with open(self._filename, "ab+") as f: - f.seek(0) self._update(f) return all((key in self._cache for key in keys)) From c3eee0dd54dfaa87037fa59cb4989452dede8c3e Mon Sep 17 00:00:00 2001 From: Qidong <soodoshll@gmail.com> Date: Wed, 28 Jun 2023 22:02:53 -0400 Subject: [PATCH 40/40] fix --- python/hidet/distributed/store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/hidet/distributed/store.py b/python/hidet/distributed/store.py index 5a7053f06..5e53f2c2e 100644 --- a/python/hidet/distributed/store.py +++ b/python/hidet/distributed/store.py @@ -63,7 +63,7 @@ class FileStore(Store): manually if required. We use a 4-byte integer to record the length of each (encoded) key and value. So do not insert - more than 32768 bytes for each entry. + more than 2^31 - 1 bytes for each entry. Deletion of an entry is done by adding a new entry with a suffix '-' (DELETE_PREFIX). It will overwrite the insertion of the given entry when we scanning the file.