Skip to content

Commit 1b873a3

Browse files
committed
update
1 parent b4c84ac commit 1b873a3

File tree

8 files changed

+20
-16
lines changed

8 files changed

+20
-16
lines changed

swift/llm/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
from .export import export_main
1212
from .eval import eval_main
1313
from .train import sft_main, pt_main, rlhf_main
14-
from .argument import EvalArguments, InferArguments, SftArguments, ExportArguments, WebuiArguments, DeployArguments, RLHFArguments, WebuiArguments, AppUIArguments
15-
from .template import TEMPLATE_MAPPING, Template, StopWords, get_template, TemplateType, to_device
14+
from .argument import (
15+
EvalArguments, InferArguments, SftArguments, ExportArguments, DeployArguments,
16+
RLHFArguments, WebuiArguments, AppUIArguments
17+
)
18+
from .template import TEMPLATE_MAPPING, StopWords, get_template, TemplateType, to_device
1619
from .model import MODEL_MAPPING, ModelType, get_model_tokenizer, get_default_template_type
1720
from .dataset import (AlpacaPreprocessor, ClsPreprocessor, ComposePreprocessor, ConversationsPreprocessor,
1821
ListPreprocessor, PreprocessFunc, RenameColumnsPreprocessor, SmartPreprocessor,
@@ -35,7 +38,7 @@
3538
'train': ['sft_main', 'pt_main', 'rlhf_main'],
3639
'argument': [
3740
'EvalArguments', 'InferArguments', 'SftArguments', 'ExportArguments', 'WebuiArguments', 'DeployArguments',
38-
'RLHFArguments', 'WebuiArguments', 'AppUIArguments'
41+
'RLHFArguments', 'AppUIArguments'
3942
],
4043
'template': ['TEMPLATE_MAPPING', 'Template', 'StopWords', 'get_template', 'TemplateType', 'to_device'],
4144
'model': ['MODEL_MAPPING', 'ModelType', 'get_model_tokenizer', 'get_default_template_type'],

swift/llm/argument/data_args.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,10 @@
44
from dataclasses import dataclass, field
55
from typing import List, Literal, Optional, Union
66

7-
from datasets import Dataset as HfDataset
8-
from datasets import IterableDataset as HfIterableDataset
9-
10-
from swift.llm.dataset import DATASET_MAPPING, register_dataset_info_file
11-
from swift.llm.model import get_default_template_type
12-
from swift.llm.template import TEMPLATE_MAPPING
7+
from swift.llm import DATASET_MAPPING, TEMPLATE_MAPPING, get_default_template_type, register_dataset_info_file
138
from swift.utils import get_logger
149

1510
logger = get_logger()
16-
DATASET_TYPE = Union[HfDataset, HfIterableDataset]
1711

1812

1913
@dataclass

swift/llm/argument/eval_args.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@ def __post_init__(self):
4141
if self.model_type is None:
4242
self.model_type = model.id
4343

44-
def select_dtype(self):
44+
def select_dtype(self) -> None:
4545
"""Override the super one because eval_url does not have a proper model_type"""
4646
if self.eval_url is None:
47-
return super().select_dtype()
48-
return None, None, None
47+
super().select_dtype()
4948

5049
def select_model_type(self) -> None:
5150
"""Override the super one because eval_url does not have a proper model_type"""
@@ -67,6 +66,10 @@ def handle_infer_backend(self) -> None:
6766
if self.eval_url is None:
6867
super().handle_infer_backend()
6968

69+
@property
7070
def is_multimodal(self) -> bool:
7171
"""Override the super one because eval_url does not have a proper model_type"""
72-
return False if self.eval_url is not None else super().is_multimodal()
72+
if self.eval_url is None:
73+
return super().is_multimodal
74+
else:
75+
return False

swift/llm/argument/infer_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from swift.llm.model.loader import MODEL_MAPPING
1313
from swift.llm.template import TEMPLATE_MAPPING
1414
from swift.tuners.utils import swift_to_peft_format
15-
from swift.utils import get_logger, is_vllm_available, is_lmdeploy_available
15+
from swift.utils import get_logger, is_lmdeploy_available, is_vllm_available
1616
from .base_args import BaseArguments
1717

1818
logger = get_logger()

swift/llm/argument/tuner_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
12
from dataclasses import dataclass, field
23
from typing import List, Literal, Optional
34

swift/llm/argument/webui_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
12
from dataclasses import dataclass
23
from typing import Optional
34

swift/llm/template/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1-
from .base import TEMPLATE_MAPPING, StopWords, Template, get_template
1+
from .base import TEMPLATE_MAPPING, StopWords, get_template
22
from .utils import to_device
3+
from .template import TemplateType
4+

swift/llm/template/ms_template.py

Whitespace-only changes.

0 commit comments

Comments
 (0)