Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Petals support #3784

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
'disable_exllama',
'transformers_info'
],
'petals': [
'trust_remote_code',
'gpu_split',
],
'ExLlama_HF': [
'gpu_split',
'max_seq_len',
Expand Down Expand Up @@ -176,6 +180,39 @@
'skip_special_tokens',
'auto_max_new_tokens',
},
'petals': {
'temperature',
'top_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'seed',
'do_sample',
'penalty_alpha',
'num_beams',
'length_penalty',
'early_stopping',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'grammar_file_row',
'grammar_string',
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
},
'ExLlama_HF': {
'temperature',
'top_p',
Expand Down
33 changes: 26 additions & 7 deletions modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import re
import time
import json
import traceback
from pathlib import Path

Expand Down Expand Up @@ -64,6 +65,7 @@ def load_model(model_name, loader=None):
'ExLlamav2_HF': ExLlamav2_HF_loader,
'ctransformers': ctransformers_loader,
'AutoAWQ': AutoAWQ_loader,
'petals': huggingface_loader,
}

if loader is None:
Expand Down Expand Up @@ -99,12 +101,14 @@ def load_tokenizer(model_name, model):
path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
if any(s in model_name.lower() for s in ['gpt-4chan', 'gpt4chan']) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif path_to_model.exists():
else:
model_id = path_to_model if path_to_model.exists() else model_name

if shared.args.use_fast:
logger.info('Loading the tokenizer with use_fast=True.')

tokenizer = AutoTokenizer.from_pretrained(
path_to_model,
model_id,
trust_remote_code=shared.args.trust_remote_code,
use_fast=shared.args.use_fast
)
Expand All @@ -113,24 +117,43 @@ def load_tokenizer(model_name, model):


def huggingface_loader(model_name):
if shared.args.loader == "petals":
path_to_model = model_name
import logging
httpx_logger = logging.getLogger('httpx')
httpx_logger.setLevel(logging.WARNING)
else:
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')

path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
params = {
'low_cpu_mem_usage': True,
'trust_remote_code': shared.args.trust_remote_code,
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16
}

if shared.args.loader == "petals" and shared.args.gpu_split:
model_config = json.loads(shared.args.gpu_split)
for key in model_config.keys():
params[key] = model_config[key]

config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=params['trust_remote_code'])

if 'chatglm' in model_name.lower():
LoaderClass = AutoModel
elif shared.args.loader == "petals":
from petals import AutoDistributedModelForCausalLM
LoaderClass = AutoDistributedModelForCausalLM
else:
if config.to_dict().get('is_encoder_decoder', False):
LoaderClass = AutoModelForSeq2SeqLM
shared.is_seq2seq = True
else:
LoaderClass = AutoModelForCausalLM

if not any((shared.args.cpu, shared.args.deepspeed, torch.cuda.is_available(), torch.backends.mps.is_available())):
logger.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
shared.args.cpu = True

# Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama]):
model = LoaderClass.from_pretrained(path_to_model, **params)
Expand All @@ -149,10 +172,6 @@ def huggingface_loader(model_name):

# Load with quantization and/or offloading
else:
if not any((shared.args.cpu, torch.cuda.is_available(), torch.backends.mps.is_available())):
logger.warning('torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.')
shared.args.cpu = True

if shared.args.cpu:
params['torch_dtype'] = torch.float32
else:
Expand Down
2 changes: 2 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def fix_loader_name(name):
return 'ctransformers'
elif name in ['autoawq', 'awq', 'auto-awq']:
return 'AutoAWQ'
else:
return name


def add_extension(name):
Expand Down
3 changes: 2 additions & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def create_interface():
auth=auth or None,
ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True,
ssl_keyfile=shared.args.ssl_keyfile,
ssl_certfile=shared.args.ssl_certfile
ssl_certfile=shared.args.ssl_certfile,
debug=True if gr.utils.colab_check() else False
)


Expand Down