Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Oct 4, 2024
1 parent b4c84ac commit 1b873a3
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 16 deletions.
9 changes: 6 additions & 3 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
from .export import export_main
from .eval import eval_main
from .train import sft_main, pt_main, rlhf_main
from .argument import EvalArguments, InferArguments, SftArguments, ExportArguments, WebuiArguments, DeployArguments, RLHFArguments, WebuiArguments, AppUIArguments
from .template import TEMPLATE_MAPPING, Template, StopWords, get_template, TemplateType, to_device
from .argument import (
EvalArguments, InferArguments, SftArguments, ExportArguments, DeployArguments,
RLHFArguments, WebuiArguments, AppUIArguments
)
from .template import TEMPLATE_MAPPING, StopWords, get_template, TemplateType, to_device
from .model import MODEL_MAPPING, ModelType, get_model_tokenizer, get_default_template_type
from .dataset import (AlpacaPreprocessor, ClsPreprocessor, ComposePreprocessor, ConversationsPreprocessor,
ListPreprocessor, PreprocessFunc, RenameColumnsPreprocessor, SmartPreprocessor,
Expand All @@ -35,7 +38,7 @@
'train': ['sft_main', 'pt_main', 'rlhf_main'],
'argument': [
'EvalArguments', 'InferArguments', 'SftArguments', 'ExportArguments', 'WebuiArguments', 'DeployArguments',
'RLHFArguments', 'WebuiArguments', 'AppUIArguments'
'RLHFArguments', 'AppUIArguments'
],
'template': ['TEMPLATE_MAPPING', 'Template', 'StopWords', 'get_template', 'TemplateType', 'to_device'],
'model': ['MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'get_default_template_type'],
Expand Down
8 changes: 1 addition & 7 deletions swift/llm/argument/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,10 @@
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Union

from datasets import Dataset as HfDataset
from datasets import IterableDataset as HfIterableDataset

from swift.llm.dataset import DATASET_MAPPING, register_dataset_info_file
from swift.llm.model import get_default_template_type
from swift.llm.template import TEMPLATE_MAPPING
from swift.llm import DATASET_MAPPING, TEMPLATE_MAPPING, get_default_template_type, register_dataset_info_file
from swift.utils import get_logger

logger = get_logger()
DATASET_TYPE = Union[HfDataset, HfIterableDataset]


@dataclass
Expand Down
11 changes: 7 additions & 4 deletions swift/llm/argument/eval_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ def __post_init__(self):
if self.model_type is None:
self.model_type = model.id

def select_dtype(self):
def select_dtype(self) -> None:
"""Override the super one because eval_url does not have a proper model_type"""
if self.eval_url is None:
return super().select_dtype()
return None, None, None
super().select_dtype()

def select_model_type(self) -> None:
"""Override the super one because eval_url does not have a proper model_type"""
Expand All @@ -67,6 +66,10 @@ def handle_infer_backend(self) -> None:
if self.eval_url is None:
super().handle_infer_backend()

@property
def is_multimodal(self) -> bool:
"""Override the super one because eval_url does not have a proper model_type"""
return False if self.eval_url is not None else super().is_multimodal()
if self.eval_url is None:
return super().is_multimodal
else:
return False
2 changes: 1 addition & 1 deletion swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from swift.llm.model.loader import MODEL_MAPPING
from swift.llm.template import TEMPLATE_MAPPING
from swift.tuners.utils import swift_to_peft_format
from swift.utils import get_logger, is_vllm_available, is_lmdeploy_available
from swift.utils import get_logger, is_lmdeploy_available, is_vllm_available
from .base_args import BaseArguments

logger = get_logger()
Expand Down
1 change: 1 addition & 0 deletions swift/llm/argument/tuner_args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field
from typing import List, Literal, Optional

Expand Down
1 change: 1 addition & 0 deletions swift/llm/argument/webui_args.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass
from typing import Optional

Expand Down
4 changes: 3 additions & 1 deletion swift/llm/template/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .base import TEMPLATE_MAPPING, StopWords, Template, get_template
from .base import TEMPLATE_MAPPING, StopWords, get_template
from .utils import to_device
from .template import TemplateType

Empty file.

0 comments on commit 1b873a3

Please sign in to comment.