-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Distributed] all_reduce op and distributed info in graphs #284
Conversation
@yaoyaoding this pr is ready for review :) |
Merely assigning environment variables is insufficient for setting up dev environment now. We need to run pip to install hidet package in develop mode. Users still need to build source files written in C++ manually. Consider integrating that into `setup.py` in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @soodoshll !
I left some suggestions on the data organization and implementation.
python/hidet/cuda/nccl/comm.py
Outdated
def init_unique_id(unqie_id: NcclUniqueId) -> None: | ||
if not nccl_available(): | ||
raise RuntimeError("NCCL is not available") | ||
nccl_runtime_api.get_unique_id(unqie_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define init_unique_id(...)
as
def create_unique_id() -> NcclUniqueId:
...
I feel the current API is not very intuitive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point here is now we need the NcclUniqueId to be shared by all processes. And the current solution is
- Create a shared NcclUniqueId object;
- Launch multiple processes with the shared uniqueid object as one argument;
- Init the shared uniqueid object in process 0, which need the reference to the shared object
If we create the NcclUniqueId in process 0 after processes have been launched, it's not so easy to do the broadcast (if there's an elegant way of broadcasting, please let me know).
A workaround is to 1) create the shared object; 2) launch processes; 3) create a unique id object; 4) copy its value back to the shared object.
python/hidet/graph/flow_graph.py
Outdated
# For distributed graphs | ||
self.nrank = nrank | ||
self.rank = rank | ||
self.groups = groups | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's define a new class called FlowGraphAttrs
and define these attributes in that class. Then add a field in FlowGraph with FlowGraphAttrs
type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something like
class FlowGraph:
def __init__(..., attrs=None):
...
self.attrs: FlowGraphAttrs = attrs if attrs else FlowGraphAttrs()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
python/hidet/graph/flow_graph.py
Outdated
def is_distributed(self): | ||
return self.nrank is not None or self.rank is not None | ||
|
||
def set_dist_attrs(self, nrank: int, rank: int, groups: Optional[List[List[int]]] = None): | ||
self.nrank = nrank | ||
self.rank = rank | ||
self.groups = groups | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's define thses functions at the module that will use these functionality, instead of defining them as FlowGraph methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have replaced them with set_attrs
self.comm_id = comm_id | ||
self.op = op | ||
|
||
super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better also add comm_id
and op
to attributes
, so that the user can see the comm_id
and op
when compiling the task.
return f"all_reduce_{self.op}_{self.comm_id}" | ||
|
||
def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]: | ||
# we may need current rank here to avoid duplicated working_dirs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify the problem here? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add the comm_id
to attributes
, then the op hash would be different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we run the compilation concurrently in multiple processes, for the same op, there might be race conditions in the local filesystem.
comms_array = comms_to_array(self.nccl_comms) | ||
runtime_api.set_nccl_comms(comms_array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's create this when initialize the dist-related info, to avoid repeating creating the comm Array.
@@ -105,6 +114,10 @@ def __init__( | |||
self.cuda_workspace: Optional[Storage] = None | |||
self.cpu_workspace: Optional[Storage] = None | |||
|
|||
# distributed properties | |||
self.dist_info: Optional[GraphDistributedInfo] = dist_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to put this in GraphMetaData.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a better idea is to put the FlowGraphAttr in the GraphMetaData as a whole, instead of reiterating all attributes. But then where should we put FlowGraphAttr? Putting it in flow_graph.py will cause circular import.
@@ -105,6 +114,10 @@ def __init__( | |||
self.cuda_workspace: Optional[Storage] = None | |||
self.cpu_workspace: Optional[Storage] = None | |||
|
|||
# distributed properties | |||
self.dist_info: Optional[GraphDistributedInfo] = dist_info | |||
self.nccl_comms: List[NcclCommunicator] = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
store it as Array of NcclCommunicator directly, to avoid repeating creating the Array in run_async
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Array of NcclCommunicator cannot be directly passed into C++. C++ needs an array of ncclComm_t, which is basically the handle of NcclCommunicator. And to avoid NcclCommunicators being released by GC, we need to maintain the list of NcclCommunicator. If we also maintain the ncclComm_t array, we will have two redundant arrays which almost save the same value
def _recursive_find(root: Stmt): | ||
if isinstance(root, BlackBoxStmt): | ||
if root.template_string.startswith('nccl'): | ||
return True | ||
for child in dir(root): | ||
if isinstance(child, Stmt): | ||
if _recursive_find(child): | ||
return True | ||
return False | ||
|
||
ret = _recursive_find(func.body) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use hidet.ir.tools.collect
to collect all BlackStmt.
python/hidet/transforms/__init__.py
Outdated
@@ -80,5 +81,6 @@ def lower(ir_module: IRModule) -> IRModule: | |||
rule_based_simplify_pass(), | |||
inline_let_stmt_pass(), | |||
simplify_stmt_pass(), | |||
include_nccl_pass(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Later, we will use a pass to add the header information. Let's make this pass a general one and give a name like "annotate_headers" or "annotate_include_headers". Or "annotate_header_and_libs".
) Previously, if a primitive function calls a primitive function, the `instantiate_symbols` pass will update the corresponding `hidet.ir.primitives.func.PrimitiveFunctionRegistry.function` in-place (I am not sure exactly how it's done, but this is what I observed), adding symbol variables to its parameters. The primitive function pool is a global variable, therefore this effect is cumulative across tuning candidates. So while candidate 0 will have no problem, candidate 1 will have two extra copies of symbol params, and so on, leading to compile errors. Since primitive functions do not need symbol vars, a quick fix is just to not instantiate any symbols for them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @soodoshll !
I left some comments.
python/hidet/distributed/group.py
Outdated
NCCL_COMMS = [] | ||
_NCCL_ARRAY = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NCCL_COMMS = [] | |
_NCCL_ARRAY = None | |
NCCL_COMMS: List[NcclCommunicator] = [] | |
_NCCL_ARRAY: 'Array' = None |
python/hidet/distributed/store.py
Outdated
self._filename = filename | ||
self._lock_filename = filename + '.lock' | ||
self._world_size = world_size | ||
|
||
self._lock = filelock.FileLock(self._lock_filename) | ||
self._cache = {} | ||
self._timeout = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to add some type annotations to reduce the time of the code reader.
python/hidet/distributed/store.py
Outdated
key = self.REGULAR_PREFIX + key | ||
with self._lock: | ||
with open(self._filename, "ab+") as f: | ||
f.seek(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f.seek(0) |
python/hidet/distributed/store.py
Outdated
f.seek(0) | ||
self._update(f) | ||
has_key = key in self._cache | ||
print(has_key, self._cache[key]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print(has_key, self._cache[key]) |
python/hidet/distributed/store.py
Outdated
if k is None: | ||
return | ||
v = self._read(f) | ||
k = str(k, encoding='raw_unicode_escape') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I know why we choose this encoding, instead of encoding like 'utf-8'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also better to add the reason to the comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No special reasons besides pickle
uses that. Switching to utf-8
since it is the default value of encoding/decoding.
tests/unit_tests/test_store.py
Outdated
@@ -0,0 +1,137 @@ | |||
# Licensed under the Apache License, Version 2.0 (the "License"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to place this test to hidet/tests/distributed/test_file_store.py
.
python/hidet/distributed/store.py
Outdated
manually if required. | ||
|
||
We use a 4-byte integer to record the length of each (encoded) key and value. So do not insert | ||
more than 32768 bytes for each entry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 byte integer could represent up to 2^31-1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops. let me fix it
|
||
Deletion of an entry is done by adding a new entry with a suffix '-' (DELETE_PREFIX). It will | ||
overwrite the insertion of the given entry when we scanning the file. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments, now the design is very clear!
Thanks @soodoshll ! Looks good to me now. Good job! There seems is a typo in the comment. Feel free to merge this PR by yourself after fixing that. |
all_reduce
opall_reduce(relu(x * w))
in./examples/distributed/test.py