Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up prefetched parameters #6557

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Changes from 35 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5825104
add apis to offload states of model, optimizer, and engine
tohtana Aug 16, 2024
600c822
update api doc
tohtana Aug 16, 2024
153a482
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 16, 2024
126d9b7
reduce global reference to buffer
tohtana Aug 16, 2024
05df37c
loosen type hint
tohtana Aug 16, 2024
837c06c
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 19, 2024
3f8179d
add option for pin_memory and non blocking copy
tohtana Aug 20, 2024
37ffa02
fix offloading of lp grad
tohtana Aug 20, 2024
93c5a90
add verification in test
tohtana Aug 20, 2024
512e9c9
improve offloading of lp params
tohtana Aug 20, 2024
de2a894
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 21, 2024
c749b05
fix pinning
tohtana Aug 22, 2024
36d6e10
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Aug 22, 2024
af95a37
resolve conflict
tohtana Aug 28, 2024
1ca3a7f
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 3, 2024
2a4733e
fix method name and enum key
tohtana Sep 3, 2024
e9a499e
elimitate duplicated buffer for lp param
tohtana Sep 3, 2024
b8f47c6
simplified offloding of adam states
tohtana Sep 4, 2024
d338079
validate devcies of offload states
tohtana Sep 4, 2024
15ff7b3
add document
tohtana Sep 4, 2024
40427c1
Merge branch 'master' into tohtana/offload_zero_buffers
loadams Sep 4, 2024
e20d827
fix usage example
tohtana Sep 4, 2024
031464d
Merge branch 'tohtana/offload_zero_buffers' of github.com:microsoft/D…
tohtana Sep 4, 2024
60deaf1
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 6, 2024
3f001b6
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 9, 2024
8f81634
Merge branch 'master' into tohtana/offload_zero_buffers
tohtana Sep 12, 2024
7bc3c66
clean up inflight status params
tohtana Sep 21, 2024
3c5089f
Merge branch 'master' into tohtana/clean_up_prefetch_param
tohtana Sep 27, 2024
45fea6e
Merge branch 'master' into tohtana/clean_up_prefetch_param
tohtana Sep 27, 2024
1963137
Merge branch 'master' into tohtana/clean_up_prefetch_param
tjruwase Sep 30, 2024
ce8636a
Merge branch 'master' into tohtana/clean_up_prefetch_param
tjruwase Oct 7, 2024
c89d576
clean inflight registry when invalidating trace
tohtana Oct 8, 2024
70ea412
Merge branch 'master' into tohtana/clean_up_prefetch_param
tohtana Oct 8, 2024
e92f549
Merge branch 'master' into tohtana/clean_up_prefetch_param
loadams Oct 8, 2024
9b7e0b8
Merge branch 'master' into tohtana/clean_up_prefetch_param
tjruwase Oct 8, 2024
2a2b3ee
Merge branch 'master' into tohtana/clean_up_prefetch_param
loadams Oct 8, 2024
47af705
Merge branch 'master' into tohtana/clean_up_prefetch_param
tjruwase Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,18 @@ def is_invalid_trace(self) -> bool:
def is_record_trace(self) -> bool:
return self.__trace_mode == ZeRoTraceMode.RECORD

def _clean_inflight_param_registry(self) -> None:
for param, handle in self.__inflight_param_registry.items():
handle.wait()
self.__release_param(param)
self.__inflight_param_registry.clear()

def _invalidate_trace(self) -> None:
if self.is_invalid_trace():
raise RuntimeError("attempted to invalidate already invalid trace")
self.__trace_mode = ZeRoTraceMode.INVALID
self._clear_trace_structures()
self._clean_inflight_param_registry()

def trace_prologue(self, sub_module: Module) -> None:
if self.is_complete_trace():
Expand Down Expand Up @@ -204,9 +211,7 @@ def construct_parameter_trace_from_module_trace(self):

def reset_step(self) -> None:
"""indicate that we have completed one fwd+bwd for the model"""
if self.__inflight_param_registry:
raise RuntimeError(f"still have inflight params "
f"{[p.ds_summary() for p in self.__inflight_param_registry.keys()]}")
self._clean_inflight_param_registry()

if not self.is_complete_trace(): # not self.trace_complete:
# Make sure that recorded submodule orders are identical across ranks
Expand Down Expand Up @@ -409,7 +414,7 @@ def release_and_reset_all(self, module: Module) -> None:
"""release all module parameters"""
for param in iter_params(module, recurse=True):
if param in self.__inflight_param_registry:
raise RuntimeError(f"param {param.ds_summary()} still in flight")
self.__inflight_param_registry.pop(param).wait()

# TODO. make this throw if if there are still active submodules. currently
# there's a hook execution issue
Expand Down
Loading