From 87bf12fccc9d9e1458b11dae26266adcb9a6d55e Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Fri, 24 Feb 2023 15:41:31 +0800 Subject: [PATCH] fix: make load stream wait default stream after init_parameters --- bmtrain/param_init.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bmtrain/param_init.py b/bmtrain/param_init.py index 125283e3..8b74c580 100644 --- a/bmtrain/param_init.py +++ b/bmtrain/param_init.py @@ -43,6 +43,9 @@ def init_parameters(model : torch.nn.Module): module.init_parameters() else: init_distributed_parameter( iterate_parameters(module) ) + + current_stream = torch.cuda.current_stream() + config['load_stream'].wait_stream(current_stream) def grouped_parameters(model : torch.nn.Module) -> Generator[Tuple[str, List[torch.nn.Parameter]], None, None]: """