diff --git a/bmtrain/optim/_function.py b/bmtrain/optim/_function.py index b72ea8e..d458445 100644 --- a/bmtrain/optim/_function.py +++ b/bmtrain/optim/_function.py @@ -30,11 +30,15 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, delta_info: tor assert g_fp16.device == torch.device("cpu"), "g_fp16 must be a cpu tensor" assert m_fp32.device == torch.device("cpu"), "m_fp32 must be a cpu tensor" assert v_fp32.device == torch.device("cpu"), "v_fp32 must be a cpu tensor" - #TODO check avg_delta and var_delta assert param_fp32.numel() == param_fp16.numel(), "param_fp32 and param_fp16 must have the same number of elements" assert param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements" assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_fp32 must have the same number of elements" assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" + if delta_info is not None: + assert delta_info.is_contiguous(), "delta_info must be contiguous" + assert delta_info.dtype == torch.float32, "delta_info must be float32 tensor" + assert delta_info.device == torch.device("cpu"), "delta_info must be a cpu tensor" + assert delta_info.numel() == 4, "delta_info have a length of 4" bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step if g_fp16.dtype == torch.float16: diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index 7bbc7e4..c088a5e 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -223,10 +223,6 @@ def load_state_dict(self, state_dict: dict) -> None: with torch.no_grad(): v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu") v["_param_fp32"].copy_(param) - - if "_param_fp32" not in v: - v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu") - v["_param_fp32"].copy_(param) for name, dtype in [("exp_avg", torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: if name in v: diff --git a/csrc/bind.cpp b/csrc/bind.cpp index db5aedc..b8f6fa8 100644 --- a/csrc/bind.cpp +++ b/csrc/bind.cpp @@ -32,13 +32,4 @@ PYBIND11_MODULE(C, m) { m.def("ncclRecv", &pyNCCLRecv, "nccl recv"); m.def("ncclCommCount", &pyNCCLCommCount, "nccl comm count"); m.def("ncclCommUserRank", &pyNCCLCommUserRank, "nccl comm user rank"); - - py::class_(m, "CUDAEventScope") - .def(py::init(&CUDAEventScope::create)) - .def("recordStart", &CUDAEventScope::recordStart) - .def("recordEnd", &CUDAEventScope::recordEnd); - - py::class_(m, "WatchDog") - .def(py::init(&PyWatchDog::create)) - .def("watch", &PyWatchDog::watch); } diff --git a/csrc/include/adam_cpu.hpp b/csrc/include/adam_cpu.hpp index 4e3c923..da6bfe8 100644 --- a/csrc/include/adam_cpu.hpp +++ b/csrc/include/adam_cpu.hpp @@ -229,10 +229,12 @@ void adam_cpu_0( delta_mutex.unlock(); } }); - delta_info_ptr[0] = sum_delta / n; - delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 - delta_info_ptr[2] = sum_delta; - delta_info_ptr[3] = sum_sq_delta; + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } } void adam_cpu_bf16_0( diff --git a/csrc/include/bind.hpp b/csrc/include/bind.hpp index bcfcb14..3ff967f 100644 --- a/csrc/include/bind.hpp +++ b/csrc/include/bind.hpp @@ -10,11 +10,13 @@ void has_nan_inf_bf16_launcher(int32_t n, std::uintptr_t g_bf16, std::uintptr_t void fp16_from_fp32_value_launcher( int64_t n, std::uintptr_t param_fp32, - std::uintptr_t param_fp16); + std::uintptr_t param_fp16 +); void bf16_from_fp32_value_launcher( int64_t n, std::uintptr_t param_fp32, - std::uintptr_t param_bf16); + std::uintptr_t param_bf16 +); void cross_entropy_forward_fp16_launcher( int32_t m, int32_t n, std::uintptr_t input, @@ -22,14 +24,16 @@ void cross_entropy_forward_fp16_launcher( std::uintptr_t softmax, std::uintptr_t output, int32_t ignore_index, - std::uintptr_t stream); + std::uintptr_t stream +); void cross_entropy_backward_inplace_fp16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, std::uintptr_t target, std::uintptr_t x, int32_t ignore_index, - std::uintptr_t stream); + std::uintptr_t stream +); void cross_entropy_forward_bf16_launcher( int32_t m, int32_t n, std::uintptr_t input, @@ -37,7 +41,8 @@ void cross_entropy_forward_bf16_launcher( std::uintptr_t softmax, std::uintptr_t output, int32_t ignore_index, - std::uintptr_t stream); + std::uintptr_t stream +); void cross_entropy_backward_inplace_bf16_launcher( int32_t m, int32_t n, std::uintptr_t grad_output, @@ -87,7 +92,8 @@ void adam_fp16_launcher( float weight_decay, float bias_correction1, float bias_correction2, - uintptr_t stream); + uintptr_t stream +); void adam_bf16_launcher( int n, std::uintptr_t param_fp32,