diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index ea7e2941..b2a62803 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -402,6 +402,8 @@ class Params: self.mmp.mlp_bias = True self.mmp.norm = "layernorm" + self.standard_calib_noise = (5, 30) + # Gemma if arch_string == "GemmaForCausalLM": diff --git a/exllamav2/embedding.py b/exllamav2/embedding.py index 9fd67997..f5571f0a 100644 --- a/exllamav2/embedding.py +++ b/exllamav2/embedding.py @@ -186,16 +186,40 @@ def forward( if self.archparams.normalize_embeddings: hidden_states *= cfg.hidden_size ** 0.5 - # Negative tokens during quantization are noise tokens + # Rows with negative tokens during quantization are noise tokens if kwargs.get("negative_ids_noise"): - mask = (input_ids < 0).unsqueeze(-1) - unmasked_values = hidden_states[~mask.expand_as(hidden_states)].float() - mean, std = unmasked_values.mean(), unmasked_values.std() - noise = torch.randn_like(hidden_states, dtype = torch.float) - noise = noise * std + mean - noise = noise.half() - hidden_states = torch.where(mask, noise, hidden_states) + + n = 0 + mean = torch.tensor([0.0], dtype = torch.float, device = hidden_states.device) + M2 = torch.tensor([0.0], dtype = torch.float, device = hidden_states.device) + + for i in range(input_ids.shape[0]): + if input_ids[i][0] < 0: + continue + + er = hidden_states[i].float() + n += er.numel() + delta = er - mean + mean += delta.sum() / n + delta2 = er - mean + M2 += (delta * delta2).sum() + del er + del delta + del delta2 + + if n > 1: + std = torch.sqrt(M2 / (n - 1)) + + for i in range(input_ids.shape[0]): + if input_ids[i][0] >= 0: + continue + + er = hidden_states[i] + noise = torch.randn(er.size(), dtype = torch.float, device = hidden_states.device) * std + mean + er.copy_(noise.half()) + del er + del noise # Move to pinned temp buffer for TP diff --git a/exllamav2/exllamav2_ext/ext_rope.cpp b/exllamav2/exllamav2_ext/ext_rope.cpp index d79a22b8..c7d58025 100644 --- a/exllamav2/exllamav2_ext/ext_rope.cpp +++ b/exllamav2/exllamav2_ext/ext_rope.cpp @@ -58,50 +58,50 @@ void rope_ ); } -long gen_mrope_pos_ids +int64_t gen_mrope_pos_ids ( torch::Tensor mrope_pos_ids, torch::Tensor ids, int merge_size, - const std::vector> &spans, - const std::vector> &grids + const std::vector> &spans, + const std::vector> &grids ) { int max_length = mrope_pos_ids.size(1); int in_length = ids.size(0); - long* in_ids = (long*) ids.data_ptr(); - long* pos_ids = (long*) mrope_pos_ids.data_ptr(); + int64_t* in_ids = (int64_t*) ids.data_ptr(); + int64_t* pos_ids = (int64_t*) mrope_pos_ids.data_ptr(); - long* out_t = pos_ids; - long* out_h = pos_ids + max_length; - long* out_w = pos_ids + 2 * max_length; + int64_t* out_t = pos_ids; + int64_t* out_h = pos_ids + max_length; + int64_t* out_w = pos_ids + 2 * max_length; - long base_t = 0; - long next_base_t = 0; + int64_t base_t = 0; + int64_t next_base_t = 0; for (int i = 0; i < max_length; ++i) { bool is_emb = false; if (i < in_length) { - long id = in_ids[i]; + int64_t id = in_ids[i]; for (int j = 0; j < spans.size(); ++j) { - long span_start = std::get<0>(spans[j]); - long span_end = std::get<1>(spans[j]); - long span = span_end - span_start; + int64_t span_start = std::get<0>(spans[j]); + int64_t span_end = std::get<1>(spans[j]); + int64_t span = span_end - span_start; if (id >= span_start && id < span_end) { is_emb = true; - long k = id - span_start; - long grid_t = std::get<0>(grids[j]); - long grid_h = std::get<1>(grids[j]) / (long)merge_size; - long grid_w = std::get<2>(grids[j]) / (long)merge_size; - long k_t = base_t + (k / grid_w / grid_h) % grid_t; - long k_h = base_t + (k / grid_w) % grid_h; - long k_w = base_t + k % grid_w; + int64_t k = id - span_start; + int64_t grid_t = std::get<0>(grids[j]); + int64_t grid_h = std::get<1>(grids[j]) / (int64_t)merge_size; + int64_t grid_w = std::get<2>(grids[j]) / (int64_t)merge_size; + int64_t k_t = base_t + (k / grid_w / grid_h) % grid_t; + int64_t k_h = base_t + (k / grid_w) % grid_h; + int64_t k_w = base_t + k % grid_w; *out_t++ = k_t; *out_h++ = k_h; *out_w++ = k_w; diff --git a/exllamav2/exllamav2_ext/ext_rope.h b/exllamav2/exllamav2_ext/ext_rope.h index 17adebd4..2a41b22c 100644 --- a/exllamav2/exllamav2_ext/ext_rope.h +++ b/exllamav2/exllamav2_ext/ext_rope.h @@ -11,11 +11,11 @@ void rope_ bool neox_style ); -long gen_mrope_pos_ids +int64_t gen_mrope_pos_ids ( torch::Tensor mrope_pos_ids, torch::Tensor ids, int merge_size, - const std::vector> &spans, - const std::vector> &grids + const std::vector> &spans, + const std::vector> &grids ); \ No newline at end of file diff --git a/exllamav2/generator/dynamic.py b/exllamav2/generator/dynamic.py index 602420da..86dcf35f 100644 --- a/exllamav2/generator/dynamic.py +++ b/exllamav2/generator/dynamic.py @@ -2589,8 +2589,9 @@ def deallocate_pages(self): self.generator.all_pages[0].backup() for seq in self.sequences: - for page in seq.allocated_pages: - page.sub_ref() - seq.allocated_pages = [] + if seq.allocated_pages is not None: + for page in seq.allocated_pages: + page.sub_ref() + seq.allocated_pages = [] self.generator.validate_cache() diff --git a/exllamav2/mrope.py b/exllamav2/mrope.py index 16ef31de..3b925314 100644 --- a/exllamav2/mrope.py +++ b/exllamav2/mrope.py @@ -36,7 +36,7 @@ def gen_mrope_embed( # Create 3D position IDs - ids = input_ids.squeeze(0) + ids = input_ids.squeeze(0).contiguous() mrope_pos_ids = torch.zeros((3, max_length), dtype = torch.long).contiguous() merge_size = 1 if not embeddings else embeddings[0].model.config.vision_spatial_merge_size spans = [] diff --git a/exllamav2/version.py b/exllamav2/version.py index 845be453..d1eb7428 100644 --- a/exllamav2/version.py +++ b/exllamav2/version.py @@ -1 +1 @@ -__version__ = "0.2.5" \ No newline at end of file +__version__ = "0.2.6" \ No newline at end of file