Skip to content

Commit

Permalink
Merge pull request oobabooga#4475 from oobabooga/dev
Browse files Browse the repository at this point in the history
Merge dev branch
  • Loading branch information
oobabooga authored Nov 4, 2023
2 parents 262f8ae + 40f7f37 commit b5c5304
Show file tree
Hide file tree
Showing 26 changed files with 177 additions and 59 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
.DS_Store
.eslintrc.js
.idea
.env
.venv
venv
.vscode
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ Optionally, you can use the following command-line flags:
| `--sdp-attention` | Use PyTorch 2.0's SDP attention. Same as above. |
| `--trust-remote-code` | Set `trust_remote_code=True` while loading the model. Necessary for some models. |
| `--use_fast` | Set `use_fast=True` while loading the tokenizer. |
| `--use_flash_attention_2` | Set use_flash_attention_2=True while loading the model. |

#### Accelerate 4-bit

Expand Down Expand Up @@ -336,6 +337,8 @@ Optionally, you can use the following command-line flags:
|`--gpu-split` | Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7. |
|`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. |
|`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. |
|`--no_flash_attn` | Force flash-attention to not be used. |
|`--cache_8bit` | Use 8-bit cache to save VRAM. |

#### AutoGPTQ

Expand Down
2 changes: 2 additions & 0 deletions extensions/api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def build_parameters(body, chat=False):
'max_tokens_second': int(body.get('max_tokens_second', 0)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'temperature_last': bool(body.get('temperature_last', False)),
'top_p': float(body.get('top_p', 1)),
'min_p': float(body.get('min_p', 0)),
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)),
'eta_cutoff': float(body.get('eta_cutoff', 0)),
Expand Down
6 changes: 3 additions & 3 deletions models/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@
.*starchat-beta:
instruction_template: 'Starchat-Beta'
custom_stopping_strings: '"<|end|>"'
.*(openorca-platypus2):
instruction_template: 'OpenOrca-Platypus2'
custom_stopping_strings: '"### Instruction:", "### Response:"'
(?!.*v0)(?!.*1.1)(?!.*1_1)(?!.*stable)(?!.*chinese).*vicuna:
instruction_template: 'Vicuna-v0'
.*vicuna.*v0:
Expand Down Expand Up @@ -152,6 +149,9 @@
instruction_template: 'Orca Mini'
.*(platypus|gplatty|superplatty):
instruction_template: 'Alpaca'
.*(openorca-platypus2):
instruction_template: 'OpenOrca-Platypus2'
custom_stopping_strings: '"### Instruction:", "### Response:"'
.*longchat:
instruction_template: 'Vicuna-v1.1'
.*vicuna-33b:
Expand Down
2 changes: 1 addition & 1 deletion modules/GPTQ_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def noop(*args, **kwargs):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint), strict=False)
else:
model.load_state_dict(torch.load(checkpoint), strict=False)
model.load_state_dict(torch.load(checkpoint, weights_only=True), strict=False)

model.seqlen = 2048
return model
Expand Down
8 changes: 7 additions & 1 deletion modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Config,
ExLlamaV2Tokenizer
)
Expand Down Expand Up @@ -46,6 +47,7 @@ def from_pretrained(self, path_to_model):
config.max_seq_len = shared.args.max_seq_len
config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value
config.no_flash_attn = shared.args.no_flash_attn

model = ExLlamaV2(config)

Expand All @@ -56,7 +58,11 @@ def from_pretrained(self, path_to_model):
model.load(split)

tokenizer = ExLlamaV2Tokenizer(config)
cache = ExLlamaV2Cache(model)
if shared.args.cache_8bit:
cache = ExLlamaV2Cache_8bit(model)
else:
cache = ExLlamaV2Cache(model)

generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

result = self()
Expand Down
21 changes: 17 additions & 4 deletions modules/exllamav2_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from typing import Any, Dict, Optional, Union

import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Config
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Config
)
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
Expand Down Expand Up @@ -40,11 +45,18 @@ def __init__(self, config: ExLlamaV2Config):
self.generation_config = GenerationConfig()
self.loras = None

self.ex_cache = ExLlamaV2Cache(self.ex_model)
self.past_seq = None
if shared.args.cache_8bit:
self.ex_cache = ExLlamaV2Cache_8bit(self.ex_model)
else:
self.ex_cache = ExLlamaV2Cache(self.ex_model)

self.past_seq = None
if shared.args.cfg_cache:
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)
if shared.args.cache_8bit:
self.ex_cache_negative = ExLlamaV2Cache_8bit(self.ex_model)
else:
self.ex_cache_negative = ExLlamaV2Cache(self.ex_model)

self.past_seq_negative = None

def _validate_model_class(self):
Expand Down Expand Up @@ -152,5 +164,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
config.max_seq_len = shared.args.max_seq_len
config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value
config.no_flash_attn = shared.args.no_flash_attn

return Exllamav2HF(config)
20 changes: 19 additions & 1 deletion modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
'Transformers': [
'cpu_memory',
'gpu_memory',
'trust_remote_code',
'load_in_8bit',
'bf16',
'cpu',
Expand All @@ -21,6 +20,7 @@
'compute_dtype',
'trust_remote_code',
'use_fast',
'use_flash_attention_2',
'alpha_value',
'rope_freq_base',
'compress_pos_emb',
Expand All @@ -41,6 +41,8 @@
'gpu_split',
'max_seq_len',
'cfg_cache',
'no_flash_attn',
'cache_8bit',
'alpha_value',
'compress_pos_emb',
'use_fast',
Expand All @@ -56,6 +58,8 @@
'ExLlamav2': [
'gpu_split',
'max_seq_len',
'no_flash_attn',
'cache_8bit',
'alpha_value',
'compress_pos_emb',
],
Expand Down Expand Up @@ -144,7 +148,9 @@
loaders_samplers = {
'Transformers': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
'typical_p',
'epsilon_cutoff',
Expand Down Expand Up @@ -179,7 +185,9 @@
},
'ExLlama_HF': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
'typical_p',
'epsilon_cutoff',
Expand Down Expand Up @@ -239,7 +247,9 @@
},
'ExLlamav2_HF': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
'typical_p',
'epsilon_cutoff',
Expand Down Expand Up @@ -270,7 +280,9 @@
},
'AutoGPTQ': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
'typical_p',
'epsilon_cutoff',
Expand Down Expand Up @@ -305,7 +317,9 @@
},
'GPTQ-for-LLaMa': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
'typical_p',
'epsilon_cutoff',
Expand Down Expand Up @@ -356,7 +370,9 @@
},
'llamacpp_HF': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
'typical_p',
'epsilon_cutoff',
Expand Down Expand Up @@ -394,7 +410,9 @@
},
'AutoAWQ': {
'temperature',
'temperature_last',
'top_p',
'min_p',
'top_k',
'typical_p',
'epsilon_cutoff',
Expand Down
7 changes: 6 additions & 1 deletion modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,13 @@ def huggingface_loader(model_name):
params = {
'low_cpu_mem_usage': True,
'trust_remote_code': shared.args.trust_remote_code,
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
'use_safetensors': True if shared.args.force_safetensors else None
}

if shared.args.use_flash_attention_2:
params['use_flash_attention_2'] = True

config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=params['trust_remote_code'])

if 'chatglm' in model_name.lower():
Expand Down
2 changes: 2 additions & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def default_preset():
return {
'do_sample': True,
'temperature': 1,
'temperature_last': False,
'top_p': 1,
'min_p': 0,
'top_k': 0,
'typical_p': 1,
'epsilon_cutoff': 0,
Expand Down
61 changes: 56 additions & 5 deletions modules/sampler_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,36 @@
global_scores = None


class MinPLogitsWarper(LogitsWarper):
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if min_p < 0 or min_p > 1.0:
raise ValueError(f"`min_p` has to be a float >= 0 and <= 1, but is {min_p}")
self.min_p = min_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Convert logits to probabilities
probs = torch.softmax(scores, dim=-1)
# Get the probability of the top token for each sequence in the batch
top_probs, _ = probs.max(dim=-1, keepdim=True)
# Calculate the actual min_p threshold by scaling min_p with the top token's probability
scaled_min_p = self.min_p * top_probs
# Create a mask for tokens that have a probability less than the scaled min_p
tokens_to_remove = probs < scaled_min_p

sorted_indices = torch.argsort(scores, descending=True, dim=-1)
sorted_indices_to_remove = torch.gather(tokens_to_remove, dim=-1, index=sorted_indices)

if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = False

indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores


class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
Expand Down Expand Up @@ -186,17 +216,36 @@ def get_logits_warper_patch(self, generation_config):
if not isinstance(warper, TemperatureLogitsWarper):
warpers.remove(warper)
else:
if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0:
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0:
warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))

if warpers and isinstance(warpers[-1], LogitNormalization):
warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization):
normalize = warpers.pop(-1)
else:
warpers += warpers_to_add
normalize = None

warpers += warpers_to_add
if generation_config.temperature_last:
temperature_idx = None
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
temperature_idx = i
break

if temperature_idx is not None:
warpers = warpers[:temperature_idx] + warpers[temperature_idx + 1:] + [warpers[temperature_idx]]
warpers = LogitsProcessorList(warpers)

if normalize is not None:
warpers.append(normalize)

warpers.append(SpyLogitsWarper())
# for i in range(len(warpers)):
# print(warpers[i].__class__.__name__)
return warpers


Expand All @@ -222,6 +271,7 @@ def get_logits_processor_patch(self, **kwargs):

def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
self.min_p = kwargs.pop("min_p", 0.0)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
Expand All @@ -230,6 +280,7 @@ def generation_config_init_patch(self, **kwargs):
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
self.presence_penalty = kwargs.pop("presence_penalty", 0)
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
self.temperature_last = kwargs.pop("temperature_last", False)


def hijack_samplers():
Expand Down
4 changes: 4 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@
parser.add_argument('--xformers', action='store_true', help='Use xformer\'s memory efficient attention. This is really old and probably doesn\'t do anything.')
parser.add_argument('--sdp-attention', action='store_true', help='Use PyTorch 2.0\'s SDP attention. Same as above.')
parser.add_argument('--trust-remote-code', action='store_true', help='Set trust_remote_code=True while loading the model. Necessary for some models.')
parser.add_argument('--force-safetensors', action='store_true', help='Set use_safetensors=True while loading the model. This prevents arbitrary code execution.')
parser.add_argument('--use_fast', action='store_true', help='Set use_fast=True while loading the tokenizer.')
parser.add_argument('--use_flash_attention_2', action='store_true', help='Set use_flash_attention_2=True while loading the model.')

# Accelerate 4-bit
parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision (using bitsandbytes).')
Expand All @@ -117,6 +119,8 @@
parser.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
parser.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.')
parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')
parser.add_argument('--cache_8bit', action='store_true', help='Use 8-bit cache to save VRAM.')

# AutoGPTQ
parser.add_argument('--triton', action='store_true', help='Use triton.')
Expand Down
2 changes: 1 addition & 1 deletion modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def apply_stopping_strings(reply, all_stop_strings):

def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
generate_params[k] = state[k]

if state['negative_prompt'] != '':
Expand Down
2 changes: 1 addition & 1 deletion modules/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def generate_and_tokenize_prompt(data_point):
lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
logger.info("Loading existing LoRA data...")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
set_peft_model_state_dict(lora_model, state_dict_peft)
except:
yield traceback.format_exc().replace('\n', '\n\n')
Expand Down
Loading

0 comments on commit b5c5304

Please sign in to comment.