Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove model_type, use model_arch #2195

Merged
merged 38 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
173184c
update
Jintao-Huang Oct 3, 2024
4fc24d3
update
Jintao-Huang Oct 4, 2024
424b068
update
Jintao-Huang Oct 4, 2024
c5c5b51
Merge remote-tracking branch 'refs/remotes/origin/main3_1' into main3_1
Jintao-Huang Oct 4, 2024
31ecef5
Merge branch 'main3' into main3_1
Jintao-Huang Oct 4, 2024
cff243c
update
Jintao-Huang Oct 4, 2024
b4c84ac
update
Jintao-Huang Oct 4, 2024
1b873a3
update
Jintao-Huang Oct 4, 2024
3be5149
update
Jintao-Huang Oct 4, 2024
cd530ed
update
Jintao-Huang Oct 4, 2024
175f820
update
Jintao-Huang Oct 4, 2024
8a6cba6
update
Jintao-Huang Oct 4, 2024
8a906a0
update
Jintao-Huang Oct 4, 2024
309257f
update
Jintao-Huang Oct 4, 2024
6fba2b7
update
Jintao-Huang Oct 4, 2024
71ae04f
update
Jintao-Huang Oct 4, 2024
7ee8304
Merge branch 'main3' into main3_1
Jintao-Huang Oct 4, 2024
9119d99
Merge branch 'main3' into main3_1
Jintao-Huang Oct 4, 2024
5aeb945
update
Jintao-Huang Oct 5, 2024
14c8686
update
Jintao-Huang Oct 5, 2024
4763f01
update
Jintao-Huang Oct 5, 2024
ea36efb
update
Jintao-Huang Oct 5, 2024
4d3a44b
update
Jintao-Huang Oct 5, 2024
709aa7c
update
Jintao-Huang Oct 5, 2024
909d809
update
Jintao-Huang Oct 5, 2024
9d2dd5e
Merge remote-tracking branch 'refs/remotes/origin/main3_1' into main3_1
Jintao-Huang Oct 5, 2024
345b689
update
Jintao-Huang Oct 5, 2024
0867fac
Merge remote-tracking branch 'refs/remotes/origin/main3_1' into main3_1
Jintao-Huang Oct 5, 2024
8dc1a52
update
Jintao-Huang Oct 5, 2024
25bd739
update
Jintao-Huang Oct 5, 2024
2a25fd6
update
Jintao-Huang Oct 5, 2024
eec71a4
update
Jintao-Huang Oct 5, 2024
eac4cf2
update
Jintao-Huang Oct 5, 2024
c4600b1
update
Jintao-Huang Oct 5, 2024
1b6cad3
update
Jintao-Huang Oct 5, 2024
04f1748
update
Jintao-Huang Oct 5, 2024
b13dfc9
update
Jintao-Huang Oct 5, 2024
16cf77d
update
Jintao-Huang Oct 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from .train import sft_main, pt_main, rlhf_main
from .argument import (EvalArguments, InferArguments, SftArguments, ExportArguments, DeployArguments, RLHFArguments,
WebuiArguments, AppUIArguments)
from .template import TEMPLATE_MAPPING, StopWords, get_template, TemplateType
from .model import MODEL_MAPPING, ModelType, get_model_tokenizer, get_default_template_type, ConfigReader
from .template import TEMPLATE_MAPPING, Template, StopWords, get_template, TemplateType
from .model import MODEL_MAPPING, ModelType, get_model_tokenizer, get_default_template_type, HfConfigFactory
from .dataset import (AlpacaPreprocessor, ClsPreprocessor, ComposePreprocessor, ConversationsPreprocessor,
ListPreprocessor, PreprocessFunc, RenameColumnsPreprocessor, SmartPreprocessor,
TextGenerationPreprocessor, DatasetName, DatasetLoader, HubDatasetLoader, LocalDatasetLoader,
Expand All @@ -41,7 +41,7 @@
'RLHFArguments', 'AppUIArguments'
],
'template': ['TEMPLATE_MAPPING', 'Template', 'StopWords', 'get_template', 'TemplateType'],
'model': ['MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'get_default_template_type', 'ConfigReader'],
'model': ['MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'get_default_template_type', 'HfConfigFactory'],
'dataset': [
'AlpacaPreprocessor', 'ClsPreprocessor', 'ComposePreprocessor', 'ConversationsPreprocessor',
'ListPreprocessor', 'PreprocessFunc', 'RenameColumnsPreprocessor', 'SmartPreprocessor',
Expand Down
24 changes: 4 additions & 20 deletions swift/llm/argument/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from transformers.utils import is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available
from transformers.utils.versions import require_version

from swift.llm import MODEL_KEYS_MAPPING, MODEL_MAPPING, ConfigReader
from swift.llm import MODEL_KEYS_MAPPING, MODEL_MAPPING, RLHFArguments
from swift.llm.model import fix_do_sample_warning
from swift.utils import get_dist_setting, get_logger, use_hf_hub

logger = get_logger()
Expand Down Expand Up @@ -35,12 +36,8 @@ def handle_do_sample(self) -> None:
if self.temperature == 0:
self.do_sample = False
from swift.llm import InferArguments, SftArguments
if self.do_sample is False and (isinstance(self, SftArguments) or
(isinstance(self, InferArguments) and self.infer_backend == 'pt')):
# fix warning
self.temperature = 1.
self.top_p = 1.
self.top_k = 50
if (isinstance(self, SftArguments) or (isinstance(self, InferArguments) and self.infer_backend == 'pt')):
fix_do_sample_warning(self)
logger.info('Due to do_sample=False, the following settings are applied: args.temperature: '
f'{self.temperature}, args.top_p: {self.top_p}, args.top_k: {self.top_k}.')

Expand Down Expand Up @@ -86,19 +83,6 @@ def select_bnb(self) -> None:
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
self.load_in_4bit, self.load_in_8bit = load_in_4bit, load_in_8bit

def is_quant_model(self: Union['SftArguments', 'InferArguments']) -> bool:
"""Judge if the current model has already been a quantized model"""
# Check if the model is gptq, awq, aqlm model. Do not check for other quantization situations such as bnb.
if self.model_type is not None:
for k in ['int4', 'int8', 'awq', 'aqlm']:
if k in self.model_type:
return True

model_path = self.model_id_or_path or self.resume_from_checkpoint or self.ckpt_dir
bits = ConfigReader.read_config('quantization_config.bits', self.model_type, model_path, self.model_revision)
if bits:
return True

def __post_init__(self: Union['SftArguments', 'InferArguments']):
self.select_bnb()

Expand Down
6 changes: 4 additions & 2 deletions swift/llm/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@ class DatasetName:
def get_dataset_name_list(cls) -> List[str]:
res = []
for k in cls.__dict__.keys():
if k.startswith('__') or k == 'get_dataset_name_list':
if k.startswith('__'):
continue
res.append(cls.__dict__[k])
value = cls.__dict__[k]
if isinstance(value, str):
res.append(value)
return res


Expand Down
4 changes: 2 additions & 2 deletions swift/llm/infer/lmdeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from swift.llm import InferArguments, InferTemplate
from swift.llm.infer.base import InferFramework
from swift.llm.model import ConfigReader
from swift.llm.model import HfConfigFactory
from swift.llm.model.model import get_model_tokenizer
from swift.llm.template.template import Template, get_template
from swift.utils import get_logger, get_seed
Expand Down Expand Up @@ -411,7 +411,7 @@ def get_lmdeploy_engine(
lmdeploy_engine.is_multimodal = is_multimodal
lmdeploy_engine.hf_tokenizer = tokenizer
lmdeploy_engine.model_config = model_config
lmdeploy_engine.max_model_len = ConfigReader.get_max_model_len(model_config)
lmdeploy_engine.max_model_len = HfConfigFactory.get_max_model_len(model_config)

generation_config_path = os.path.join(model_dir, 'generation_config.json')
if os.path.isfile(generation_config_path):
Expand Down
7 changes: 3 additions & 4 deletions swift/llm/infer/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
from transformers.utils import is_torch_npu_available

from swift import get_logger
from swift.llm import DeployArguments, InferArguments, StopWords, Template, get_model_tokenizer, get_template
from swift.llm import (DeployArguments, HfConfigFactory, InferArguments, StopWords, Template, get_model_tokenizer,
get_template, to_device)
from swift.llm.dataset.utils import safe_tokenizer_decode
from swift.llm.model import ConfigReader
from swift.llm.model.utils import to_device
from swift.llm.template.base import StopWordsCriteria
from swift.llm.utils import Messages, set_generation_config
from swift.plugin.tuner import Tuner, extra_tuners
Expand Down Expand Up @@ -493,7 +492,7 @@ def _prepare_inputs(model: PreTrainedModel,
generation_config.bos_token_id = tokenizer.bos_token_id
if generation_config.max_new_tokens is not None:
generation_config.max_length = 20 # fix max_length, max_new_tokens warning
max_length = ConfigReader.get_max_model_len(model.config)
max_length = HfConfigFactory.get_max_model_len(model.config)
if max_length and token_len + generation_config.max_new_tokens > max_length:
generation_config.max_new_tokens = max_length - token_len
if generation_config.max_new_tokens <= 0:
Expand Down
17 changes: 14 additions & 3 deletions swift/llm/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from .config import ConfigReader
from .loader import safe_snapshot_download
from .model import MODEL_MAPPING, ModelType, get_default_template_type, get_model_tokenizer
from .constant import LLMModelType, MLLMModelType, ModelType
from .register import (MODEL_MAPPING, Model, ModelGroup, TemplateGroup, fix_do_sample_warning,
get_default_template_type, get_model_tokenizer)
from .utils import HfConfigFactory, safe_snapshot_download


def _register_files():
from . import qwen
from . import llama
# TODO
# from . import model


_register_files()
67 changes: 0 additions & 67 deletions swift/llm/model/config.py

This file was deleted.

62 changes: 62 additions & 0 deletions swift/llm/model/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import List


class LLMModelType:
qwen = 'qwen'
modelscope_agent = 'modelscope_agent'
qwen2 = 'qwen2'
qwen2_5 = 'qwen2_5'
qwen2_moe = 'qwen2_moe'

chatglm2 = 'chatglm2'
chatglm3 = 'chatglm3'
glm4 = 'glm4'

llama2 = 'llama2'
llama3 = 'llama3'
llama3_1 = 'llama3_1'
llama3_2 = 'llama3_2'
reflection_llama3_1 = 'reflection_llama3_1'
chinese_llama2 = 'chinese_llama2'
chinese_alpaca2 = 'chinese_alpaca2'
llama3_chinese = 'llama3_chinese'

longwriter_glm4 = 'longwriter_glm4'
longwriter_llama3_1 = 'longwriter_llama3_1'

atom = 'atom'

codefuse_qwen = 'codefuse_qwen'


class MLLMModelType:
qwen_vl = 'qwen_vl'
qwen_audio = 'qwen_audio'
qwen2_vl = 'qwen2_vl'
qwen2_audio = 'qwen2_audio'

glm4v = 'glm4v'
llama3_2_vision = 'llama3_2_vision'
llama3_1_omni = 'llama3_1_omni'
idefics3_llama3 = 'idefics3_llama3'

llava1_5 = 'llava1_5'
llava1_6_mistral = 'llava1_6_mistral'
llava1_6_vicuna = 'llava1_6_vicuna'
llava1_6_yi = 'llava1_6_yi'
llava1_6_llama3_1 = 'llava1_6_llama3_1'
llava_next = 'llava_next'


class ModelType(LLMModelType, MLLMModelType):

@classmethod
def get_model_name_list(cls) -> List[str]:
res = []
for k in cls.__dict__.keys():
if k.startswith('__'):
continue
value = cls.__dict__[k]
if isinstance(value, str):
res.append(value)
return res
File renamed without changes.
86 changes: 0 additions & 86 deletions swift/llm/model/loader.py

This file was deleted.

Loading
Loading