diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index ab027fc..9820046 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -210,11 +210,15 @@ def init_param_storage(self): param.data = torch.tensor([], dtype=param.dtype, device=param.device).set_(self._storage_params[kw_name].storage(), to_offset_st, (to_offset_end - to_offset_st,)) self._param_info[-1]["begin"] = to_offset_st self._param_info[-1]["end"] = (to_offset_end - to_offset_st,) + setattr(param, "_start_partition", offset_st) + setattr(param, "_end_partition", offset_end) param.data[:] = \ torch.tensor([], dtype=d_dtype, device=d_device).set_(contiguous_param.storage(), offset_st, (offset_end - offset_st,))[:] del contiguous_param else: param.data = torch.tensor([], dtype=param.dtype, device=param.device) + setattr(param, "_start_partition", None) + setattr(param, "_end_partition", 0) # clear parameter data, but keep the dtype and device setattr(param, "_in_block", True) diff --git a/bmtrain/optim/_distributed.py b/bmtrain/optim/_distributed.py new file mode 100644 index 0000000..11daa2b --- /dev/null +++ b/bmtrain/optim/_distributed.py @@ -0,0 +1,29 @@ +import torch +from ..distributed import all_reduce, all_gather + +def state_dict_gather(state_dict): + param_key = [p for param_group in state_dict['param_groups'] for p in param_group['params'] ] + for k, v in state_dict['state'].items(): + if "step" in v: + step = v['step'] + + for k in param_key: + if k not in state_dict['state']: + state_dict['state'][k] = { + 'exp_avg' : torch.tensor([], device="cuda", dtype=torch.float32), + 'exp_avg_sq' : torch.tensor([], device="cuda", dtype=torch.float32), + '_param_fp32' : torch.tensor([], device="cuda", dtype=torch.float32), + 'step' : step + } + v = state_dict['state'][k] + for name, dtype in [("exp_avg", torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: + if name in v: + with torch.no_grad(): + numel = torch.tensor(v[name].numel(), device="cuda", dtype=torch.long) + max_numel = all_reduce(numel, op="max") + v_p = torch.nn.functional.pad(v[name], (0, max_numel - numel), value=-1e15) + if max_numel > 0: + whole_state = all_gather(v_p.cuda()).flatten() + whole_state = whole_state[whole_state != -1e15] + v[name] = whole_state.contiguous().cpu() + return state_dict \ No newline at end of file diff --git a/bmtrain/optim/_function.py b/bmtrain/optim/_function.py index f04f9ca..d458445 100644 --- a/bmtrain/optim/_function.py +++ b/bmtrain/optim/_function.py @@ -1,7 +1,18 @@ from .. import C import torch CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda -def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp32: torch.Tensor, + +def bf16_from_fp32(param_fp32): + param_bf16 = torch.empty_like(param_fp32, dtype=torch.bfloat16) + C.to_bf16_from_fp32(param_fp32.numel(), param_fp32.data_ptr(), param_bf16.data_ptr()) + return param_bf16 + +def fp16_from_fp32(param_fp32): + param_fp16 = torch.empty_like(param_fp32, dtype=torch.float16) + C.to_fp16_from_fp32(param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr()) + return param_fp16 + +def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, delta_info: torch.Tensor, g_fp16: torch.Tensor, m_fp32: torch.Tensor, v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, weight_decay: float, step: int) -> None: assert param_fp32.is_contiguous(), "param_fp32 must be contiguous" @@ -23,6 +34,11 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T assert param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements" assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_fp32 must have the same number of elements" assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" + if delta_info is not None: + assert delta_info.is_contiguous(), "delta_info must be contiguous" + assert delta_info.dtype == torch.float32, "delta_info must be float32 tensor" + assert delta_info.device == torch.device("cpu"), "delta_info must be a cpu tensor" + assert delta_info.numel() == 4, "delta_info have a length of 4" bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step if g_fp16.dtype == torch.float16: @@ -35,6 +51,7 @@ def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.T param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr(), + delta_info.data_ptr() if delta_info is not None else 0, g_fp16.data_ptr(), m_fp32.data_ptr(), v_fp32.data_ptr(), diff --git a/bmtrain/optim/adam.py b/bmtrain/optim/adam.py index a313898..d412b80 100644 --- a/bmtrain/optim/adam.py +++ b/bmtrain/optim/adam.py @@ -131,6 +131,14 @@ def step(self, closure=None, scale=1): return loss + def get_avg_delta(): + + raise NotImplementedError("get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer") + + def get_var_delta(): + + raise NotImplementedError("get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer") + def load_state_dict(self, state_dict: dict) -> None: r"""Loads the optimizer state. diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index d7910ae..c088a5e 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -7,6 +7,7 @@ from copy import deepcopy from itertools import chain from collections import defaultdict +from ._distributed import state_dict_gather class AdamOffloadOptimizer(torch.optim.Optimizer): """ @@ -14,7 +15,7 @@ class AdamOffloadOptimizer(torch.optim.Optimizer): """ _bmtrain_optimizer = True - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, hold_steps=0): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, hold_steps=0, record_delta=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -25,12 +26,17 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - + self.avg_delta = 0 + self.var_delta = 0 defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults) - self._hold_steps = hold_steps self._events = {} + self.record_delta = record_delta + if self.record_delta: + for group in self.param_groups: + for p in group['params']: + setattr(p, "_delta_info", ( torch.tensor([0 for i in range(4)], dtype=torch.float32, device="cpu") )) @torch.no_grad() def step(self, closure=None, scale=1): @@ -92,7 +98,9 @@ def step(self, closure=None, scale=1): else: state["_grad_fp16"].copy_(param.grad, non_blocking=True) torch.cuda.current_stream().record_event(event) - + sum_delta = 0 + sum_sq_delta = 0 + total_numel = 0 for param, state, event, beta1, beta2, eps, lr, weight_decay in update_params: # wait for transfer to host event.synchronize() @@ -135,6 +143,7 @@ def step(self, closure=None, scale=1): F.adam_cpu( state["_param_fp32"].view(-1), state["_param_fp16"].view(-1), + param._delta_info if self.record_delta else None, grad.view(-1), state["exp_avg"].view(-1), state["exp_avg_sq"].view(-1), @@ -144,12 +153,25 @@ def step(self, closure=None, scale=1): weight_decay, state["step"] ) + total_numel += state["_param_fp16"].numel() + if self.record_delta: + sum_delta += param._delta_info[2].item(); + sum_sq_delta += param._delta_info[3].item(); # transfer parameters back to device asynchronously param.copy_(state["_param_fp16"], non_blocking=True) + if self.record_delta: + self.avg_delta = sum_delta / total_numel + self.var_delta = sum_sq_delta / total_numel - self.avg_delta ** 2 return loss + def get_avg_delta(self) -> None: + return self.avg_delta if self.record_delta else 0 + + def get_var_delta(self) -> None: + return self.var_delta if self.record_delta else 0 + def load_state_dict(self, state_dict: dict) -> None: r"""Loads the optimizer state. @@ -158,6 +180,9 @@ def load_state_dict(self, state_dict: dict) -> None: from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API + + + state_dict = deepcopy(state_dict) # Validate the state_dict groups = self.param_groups @@ -177,17 +202,27 @@ def load_state_dict(self, state_dict: dict) -> None: zip(chain.from_iterable((g['params'] for g in saved_groups)), chain.from_iterable((g['params'] for g in groups)))} + # _param_start_end = chain.from_iterable((g["params_start_end"] for g in saved_groups)) # Copy state assigned to params (and cast tensors to appropriate types). # State that is not assigned to params is copied as is (needed for # backward compatibility). state = defaultdict(dict) + is_whole = False if "is_whole" not in state_dict else state_dict['is_whole'] + pop_key = [] for k, v in state_dict['state'].items(): if k in id_map: param = id_map[k] + if is_whole and param._start_partition is not None: + for key in ['_param_fp32', 'exp_avg_sq', 'exp_avg']: + if key in v: + v[key] = v[key][param._start_partition:param._end_partition] + elif is_whole and param._start_partition is None: + pop_key.append(param) if "_param_fp32" not in v: - v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu") - v["_param_fp32"].copy_(param) + with torch.no_grad(): + v["_param_fp32"] = torch.empty(param.size(), dtype=torch.float32, device="cpu") + v["_param_fp32"].copy_(param) for name, dtype in [("exp_avg", torch.float32), ("exp_avg_sq", torch.float32), ("_param_fp32", torch.float32)]: if name in v: @@ -204,7 +239,8 @@ def load_state_dict(self, state_dict: dict) -> None: state[param]["_grad_fp16"] = torch.empty(param.size(), dtype=param.dtype, pin_memory=True) # on host else: state[k] = v - + for k in pop_key: + state.pop(k) # Update parameter groups, setting their 'params' value def update_group(group, new_group): new_group['params'] = group['params'] @@ -212,8 +248,10 @@ def update_group(group, new_group): param_groups = [ update_group(g, ng) for g, ng in zip(groups, saved_groups)] self.__setstate__({'state': state, 'param_groups': param_groups}) + + - def state_dict(self) -> dict: + def state_dict(self, gather=False) -> dict: r"""Returns the state of the optimizer as a :class:`dict`. It contains two entries: @@ -223,6 +261,7 @@ def state_dict(self) -> dict: * param_groups - a list containing all parameter groups where each parameter group is a dict """ + # Save order indices instead of Tensors param_mappings = {} start_index = 0 @@ -247,11 +286,18 @@ def cut_states(state): # Remap state to use order indices as keys packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): cut_states(v) for k, v in self.state.items()} - return { + states = { 'state': packed_state, 'param_groups': param_groups, } + if gather: + states = state_dict_gather(states) + states['is_whole'] = True + else: + states['is_whole'] = False + + return states #TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu def zero_grad(self, set_to_none: bool = False): - super().zero_grad(set_to_none=set_to_none) \ No newline at end of file + super().zero_grad(set_to_none=set_to_none) diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index 088f0e7..7aa1bb8 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -203,9 +203,9 @@ def _justify_scale(self, scale): self.loss_scale = scale self.steps_since_last_scale = 0 - def state_dict(self) -> dict: + def state_dict(self, gather_opt=False) -> dict: return { - "optimizers": [opt.state_dict() for opt in self.optimizers], + "optimizers": [opt.state_dict(gather_opt) for opt in self.optimizers], "lr_schedulers": [lrs.state_dict() if lrs else None for lrs in self.lr_schedulers], "loss_scale": self.loss_scale, "loss_scale_enabled": self.loss_scale_enabled, diff --git a/csrc/bind.cpp b/csrc/bind.cpp index 047da89..b8f6fa8 100644 --- a/csrc/bind.cpp +++ b/csrc/bind.cpp @@ -1,9 +1,11 @@ #include "include/bind.hpp" PYBIND11_MODULE(C, m) { - m.def("is_bf16_supported",&is_bf16_supported,"whether bf16 supported"); - m.def("has_nan_inf_fp16_launcher",&has_nan_inf_fp16_launcher,"has nan inf"); - m.def("has_nan_inf_bf16_launcher",&has_nan_inf_bf16_launcher,"has nan inf bf16"); + m.def("to_fp16_from_fp32", &fp16_from_fp32_value_launcher, "convert"); + m.def("to_bf16_from_fp32", &bf16_from_fp32_value_launcher, "convert"); + m.def("is_bf16_supported", &is_bf16_supported, "whether bf16 supported"); + m.def("has_nan_inf_fp16_launcher", &has_nan_inf_fp16_launcher, "has nan inf"); + m.def("has_nan_inf_bf16_launcher", &has_nan_inf_bf16_launcher, "has nan inf bf16"); m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu"); m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu"); m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu"); @@ -26,8 +28,8 @@ PYBIND11_MODULE(C, m) { m.def("ncclReduceScatter", &pyNCCLReduceScatter, "nccl reduce scatter"); m.def("ncclGroupStart", &pyNCCLGroupStart, "nccl group start"); m.def("ncclGroupEnd", &pyNCCLGroupEnd, "nccl group end"); - m.def("ncclSend",&pyNCCLSend,"nccl send"); - m.def("ncclRecv",&pyNCCLRecv,"nccl recv"); - m.def("ncclCommCount",&pyNCCLCommCount,"nccl comm count"); - m.def("ncclCommUserRank",&pyNCCLCommUserRank,"nccl comm user rank"); + m.def("ncclSend", &pyNCCLSend, "nccl send"); + m.def("ncclRecv", &pyNCCLRecv, "nccl recv"); + m.def("ncclCommCount", &pyNCCLCommCount, "nccl comm count"); + m.def("ncclCommUserRank", &pyNCCLCommUserRank, "nccl comm user rank"); } diff --git a/csrc/include/adam_cpu.hpp b/csrc/include/adam_cpu.hpp index 81a8ec9..52575d6 100644 --- a/csrc/include/adam_cpu.hpp +++ b/csrc/include/adam_cpu.hpp @@ -5,12 +5,23 @@ #include #include #include +#include #include #include #include #include "cpu_info.h" #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +static inline float _mm256_reduce_add_ps(__m256 x) { + /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + return _mm_cvtss_f32(x32); +} inline float fp32_from_bits(uint32_t w) { union { @@ -121,11 +132,47 @@ inline float fp16_ieee_to_fp32_value(uint16_t h) { return fp32_from_bits(result); } -// fp32 -> bf16 inline uint16_t bf16_from_fp32_value(float f){ return *reinterpret_cast(&f) >> 16; } +// fp32 -> bf16 +void bf16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16 +){ + int span = 1; + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto param_bf16_ptr = reinterpret_cast(param_bf16); + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float p = param_fp32_ptr[i]; + param_bf16_ptr[i] = bf16_from_fp32_value(p); + } + break; // must break here + } + }); +} +void fp16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16 +){ + int span = 1; + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto param_fp16_ptr = reinterpret_cast(param_fp16); + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float p = param_fp32_ptr[i]; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + } + break; // must break here + } + }); +} // bf16 -> fp32 inline float bf16_to_fp32_value(uint16_t h){ uint32_t src = h; @@ -137,6 +184,7 @@ void adam_cpu_0( int64_t n, float* param_fp32_ptr, uint16_t* param_fp16_ptr, + float* delta_info_ptr, uint16_t* g_fp16_ptr, float* m_fp32_ptr, float* v_fp32_ptr, @@ -148,7 +196,12 @@ void adam_cpu_0( float bias_correction2 ){ int64_t span = 1; + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; for (int64_t j = start; j < end; j += span) { for (int64_t i = j; i < end; i++) { float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; @@ -157,6 +210,11 @@ void adam_cpu_0( float p = param_fp32_ptr[i]; m = beta1 * m + (1 - beta1) * g; v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; param_fp32_ptr[i] = p; param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); @@ -165,13 +223,26 @@ void adam_cpu_0( } break; // must break here } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } } void adam_cpu_bf16_0( int64_t n, float* param_fp32_ptr, uint16_t* param_bf16_ptr, + float* delta_info_ptr, uint16_t* g_bf16_ptr, float* m_fp32_ptr, float* v_fp32_ptr, @@ -183,7 +254,12 @@ void adam_cpu_bf16_0( float bias_correction2 ){ int64_t span = 1; + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; for (int64_t j = start; j < end; j += span) { for (int64_t i = j; i < end; i++) { float g = bf16_to_fp32_value(g_bf16_ptr[i]) / scale; @@ -192,6 +268,11 @@ void adam_cpu_bf16_0( float p = param_fp32_ptr[i]; m = beta1 * m + (1 - beta1) * g; v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; param_fp32_ptr[i] = p; param_bf16_ptr[i] = bf16_from_fp32_value(p); @@ -200,13 +281,26 @@ void adam_cpu_bf16_0( } break; // must break here } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } } static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( int64_t n, float* param_fp32_ptr, uint16_t* param_fp16_ptr, + float* delta_info_ptr, uint16_t* g_fp16_ptr, float* m_fp32_ptr, float* v_fp32_ptr, @@ -217,6 +311,9 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( float bias_correction1, float bias_correction2 ){ + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; auto avx_beta1 = _mm256_set1_ps(beta1); auto avx_beta2 = _mm256_set1_ps(beta2); auto avx_beta1_1 = _mm256_set1_ps(1 - beta1); @@ -229,6 +326,8 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( auto avx_bias_correction2 = _mm256_set1_ps(bias_correction2); int64_t span = 8; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; for (int64_t j = start; j < end; j += span) { if (j + span > end) { for (int64_t i = j; i < end; i++) { @@ -238,6 +337,11 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( float p = param_fp32_ptr[i]; m = beta1 * m + (1 - beta1) * g; v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; param_fp32_ptr[i] = p; param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); @@ -252,6 +356,17 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( auto p = _mm256_loadu_ps(¶m_fp32_ptr[j]); m = _mm256_fmadd_ps(avx_beta1, m, _mm256_mul_ps(avx_beta1_1, g)); v = _mm256_fmadd_ps(avx_beta2, v, _mm256_mul_ps(avx_beta2_1, _mm256_mul_ps(g, g))); + if (delta_info_ptr != NULL){ + auto delta_256 = _mm256_add_ps( + _mm256_div_ps( + _mm256_div_ps(m, avx_bias_correction1), // m / bias_correction1 + _mm256_add_ps(_mm256_sqrt_ps(_mm256_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps + ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + _mm256_mul_ps(avx_weight_decay, p) // weight_decay * p + ); // delta = m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + weight_decay * p + sum_delta_i += _mm256_reduce_add_ps(delta_256); + sum_sq_delta_i += _mm256_reduce_add_ps(_mm256_mul_ps(delta_256, delta_256)); + } p = _mm256_fmadd_ps(avx_neg_lr, _mm256_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p p = _mm256_fmadd_ps( avx_neg_lr, @@ -267,13 +382,26 @@ static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( _mm256_storeu_ps(&v_fp32_ptr[j], v); } } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } } static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( int64_t n, float* param_fp32_ptr, uint16_t* param_fp16_ptr, + float* delta_info_ptr, uint16_t* g_fp16_ptr, float* m_fp32_ptr, float* v_fp32_ptr, @@ -284,6 +412,9 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( float bias_correction1, float bias_correction2 ){ + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; auto avx_beta1 = _mm512_set1_ps(beta1); auto avx_beta2 = _mm512_set1_ps(beta2); auto avx_beta1_1 = _mm512_set1_ps(1 - beta1); @@ -296,6 +427,8 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( auto avx_bias_correction2 = _mm512_set1_ps(bias_correction2); int64_t span = 16; parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; for (int64_t j = start; j < end; j += span) { if (j + span > end) { for (int64_t i = j; i < end; i++) { @@ -305,6 +438,11 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( float p = param_fp32_ptr[i]; m = beta1 * m + (1 - beta1) * g; v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; param_fp32_ptr[i] = p; param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); @@ -319,6 +457,17 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( auto p = _mm512_loadu_ps(¶m_fp32_ptr[j]); m = _mm512_fmadd_ps(avx_beta1, m, _mm512_mul_ps(avx_beta1_1, g)); v = _mm512_fmadd_ps(avx_beta2, v, _mm512_mul_ps(avx_beta2_1, _mm512_mul_ps(g, g))); + if (delta_info_ptr != NULL){ + auto delta_512 = _mm512_add_ps( + _mm512_div_ps( + _mm512_div_ps(m, avx_bias_correction1), // m / bias_correction1 + _mm512_add_ps(_mm512_sqrt_ps(_mm512_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps + ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + _mm512_mul_ps(avx_weight_decay, p) // weight_decay * p + ); // delta = m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + weight_decay * p + sum_delta_i += _mm512_reduce_add_ps(delta_512); + sum_sq_delta_i += _mm512_reduce_add_ps(_mm512_mul_ps(delta_512, delta_512)); + } p = _mm512_fmadd_ps(avx_neg_lr, _mm512_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p p = _mm512_fmadd_ps( avx_neg_lr, @@ -337,13 +486,26 @@ static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( _mm512_storeu_ps(&v_fp32_ptr[j], v); } } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } } void adam_cpu_fp16_launcher( int64_t n, std::uintptr_t param_fp32, std::uintptr_t param_fp16, + std::uintptr_t delta_info, std::uintptr_t g_fp16, std::uintptr_t m_fp32, std::uintptr_t v_fp32, @@ -354,7 +516,7 @@ void adam_cpu_fp16_launcher( float bias_correction1, float bias_correction2 ) { - + auto delta_info_ptr = reinterpret_cast(delta_info); auto param_fp32_ptr = reinterpret_cast(param_fp32); auto m_fp32_ptr = reinterpret_cast(m_fp32); auto v_fp32_ptr = reinterpret_cast(v_fp32); @@ -362,11 +524,11 @@ void adam_cpu_fp16_launcher( auto g_fp16_ptr = reinterpret_cast(g_fp16); int cpu_level = get_cpu_level(); if (cpu_level == 0 ){ - adam_cpu_0(n, param_fp32_ptr, param_fp16_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + adam_cpu_0(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); }else if(cpu_level == 1){ - adam_cpu_1(n, param_fp32_ptr, param_fp16_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + adam_cpu_1(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); }else{ - adam_cpu_2(n, param_fp32_ptr, param_fp16_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + adam_cpu_2(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); } } @@ -374,6 +536,7 @@ void adam_cpu_bf16_launcher( int64_t n, std::uintptr_t param_fp32, std::uintptr_t param_bf16, + std::uintptr_t delta_info, std::uintptr_t g_bf16, std::uintptr_t m_fp32, std::uintptr_t v_fp32, @@ -384,10 +547,11 @@ void adam_cpu_bf16_launcher( float bias_correction1, float bias_correction2 ) { - auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto delta_info_ptr = reinterpret_cast(delta_info); auto m_fp32_ptr = reinterpret_cast(m_fp32); auto v_fp32_ptr = reinterpret_cast(v_fp32); + auto param_fp32_ptr = reinterpret_cast(param_fp32); auto param_bf16_ptr = reinterpret_cast(param_bf16); auto g_bf16_ptr = reinterpret_cast(g_bf16); - adam_cpu_bf16_0(n, param_fp32_ptr, param_bf16_ptr, g_bf16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + adam_cpu_bf16_0(n, param_fp32_ptr, param_bf16_ptr, delta_info_ptr, g_bf16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); } diff --git a/csrc/include/bind.hpp b/csrc/include/bind.hpp index a9067a2..3ff967f 100644 --- a/csrc/include/bind.hpp +++ b/csrc/include/bind.hpp @@ -4,9 +4,19 @@ int is_bf16_supported(); -void has_nan_inf_fp16_launcher(int32_t n,std::uintptr_t g_fp16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); -void has_nan_inf_bf16_launcher(int32_t n,std::uintptr_t g_bf16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); +void has_nan_inf_fp16_launcher(int32_t n, std::uintptr_t g_fp16, std::uintptr_t mid, std::uintptr_t out, std::uintptr_t stream); +void has_nan_inf_bf16_launcher(int32_t n, std::uintptr_t g_bf16, std::uintptr_t mid, std::uintptr_t out, std::uintptr_t stream); +void fp16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16 +); +void bf16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16 +); void cross_entropy_forward_fp16_launcher( int32_t m, int32_t n, std::uintptr_t input,