Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hjh0119 committed Aug 5, 2024
1 parent b496f2d commit 297ca72
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 100 deletions.
23 changes: 14 additions & 9 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def _get_train_val_dataset(args: SftArguments) -> Tuple[HfDataset, Optional[HfDa
args.dataset_seed,
check_dataset_strategy=args.check_dataset_strategy,
model_name=args.model_name,
model_author=args.model_author)
model_author=args.model_author,
streaming=args.streaming)
if len(args.val_dataset) > 0:
# Loading val dataset
_, val_dataset = get_dataset(
Expand All @@ -45,7 +46,8 @@ def _get_train_val_dataset(args: SftArguments) -> Tuple[HfDataset, Optional[HfDa
args.dataset_seed,
check_dataset_strategy=args.check_dataset_strategy,
model_name=args.model_name,
model_author=args.model_author)
model_author=args.model_author,
streaming=args.streaming)

train_dataset, val_dataset = args._handle_dataset_compat(train_dataset, val_dataset)
logger.info(f'train_dataset: {train_dataset}')
Expand Down Expand Up @@ -111,6 +113,7 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]:
def llm_sft(args: SftArguments) -> Dict[str, Any]:
logger.info(f'args: {args}')
is_generation = TEMPLATE_MAPPING[args.template_type].get('is_generation', False)
streaming = args.streaming
if is_generation and type(args) is SftArguments:
logger.warning(f"Please check if args.template_type: '{args.template_type}' is correct. "
'Currently, SFT is in progress, but the template is used for PT.')
Expand Down Expand Up @@ -267,7 +270,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
fsdp_flatten_parameters=False)

train_dataset, val_dataset = _get_train_val_dataset(args)
training_args.train_dataset_sample = train_dataset.shape[0] if train_dataset is not None else 0 # torchacc
training_args.train_dataset_sample = train_dataset.shape[
0] if train_dataset is not None and not streaming else 0 # torchacc
template_kwargs = {}
template_kwargs['use_loss_scale'] = args.use_loss_scale
if args.loss_scale_config_path is not None:
Expand Down Expand Up @@ -307,10 +311,11 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
dataset_info['val_dataset'] = stat_dataset(val_dataset)
elif not args.lazy_tokenize:
dataset_info = {}
logger.info(f'Using num_proc: {args.preprocess_num_proc}')
train_dataset = dataset_map(train_dataset, template.encode, args.preprocess_num_proc)
if not streaming:
logger.info(f'Using num_proc: {args.preprocess_num_proc}')
train_dataset = dataset_map(train_dataset, template.encode, args.preprocess_num_proc, streaming=streaming)
if val_dataset is not None:
val_dataset = dataset_map(val_dataset, template.encode, args.preprocess_num_proc)
val_dataset = dataset_map(val_dataset, template.encode, args.preprocess_num_proc, streaming=streaming)
if args.test_oom_error:
train_dataset = sort_by_max_length(train_dataset, 20000)
# Data analysis
Expand All @@ -321,11 +326,11 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
raise AttributeError('Failed to access dataset attributes,train_dataset is None. This might be because:\n'
'(1) The dataset contains None for input or labels;\n'
"(2) The 'max_length' setting is too short causing data truncation.")
td0, tkwargs0 = train_dataset.data[0]
td0, tkwargs0 = train_dataset.data[0] if not streaming else next(iter(train_dataset.data)) # TODO
print_example(td0, tokenizer, tkwargs0)
dataset_info['train_dataset'] = stat_dataset(train_dataset)
dataset_info['train_dataset'] = stat_dataset(train_dataset) if not streaming else None
if val_dataset is not None:
dataset_info['val_dataset'] = stat_dataset(val_dataset)
dataset_info['val_dataset'] = stat_dataset(val_dataset) if not streaming else None
else:
dataset_info = None
td0, tkwargs0 = template.encode(train_dataset[0])
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
RLHFArguments, RomeArguments, SftArguments, WebuiArguments, is_adapter, swift_to_peft_format)
from .client_utils import (compat_openai, convert_to_base64, decode_base64, get_model_list_client, inference_client,
inference_client_async)
from .dataset import (DATASET_MAPPING, DatasetName, HfDataset, get_dataset, get_dataset_from_repo,
from .dataset import (DATASET_MAPPING, DATASET_TYPE, DatasetName, HfDataset, get_dataset, get_dataset_from_repo,
load_dataset_from_local, load_ms_dataset, register_dataset, register_dataset_info,
register_local_dataset, sample_dataset)
from .media import MediaCache, MediaTag
Expand Down
90 changes: 67 additions & 23 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.distributed as dist
import transformers
from datasets import Dataset as HfDataset
from datasets import concatenate_datasets
from datasets import concatenate_datasets, interleave_datasets
from packaging import version
from torch import dtype as Dtype
from transformers.utils import is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available, strtobool
Expand All @@ -25,7 +25,7 @@
from swift.utils import (add_version_to_work_dir, get_dist_setting, get_logger, get_pai_tensorboard_dir, is_dist,
is_local_master, is_mp, is_pai_training_job, use_torchacc)
from .client_utils import get_model_list_client
from .dataset import (DATASET_MAPPING, _dataset_name_exists, get_dataset, parse_dataset_name,
from .dataset import (DATASET_MAPPING, DATASET_TYPE, _dataset_name_exists, get_dataset, parse_dataset_name,
register_dataset_info_file, sample_dataset)
from .media import MediaTag
from .model import (MODEL_MAPPING, dtype_mapping, get_additional_saved_files, get_default_lora_target_modules,
Expand Down Expand Up @@ -374,42 +374,59 @@ def _register_self_cognition(self: Union['SftArguments', 'InferArguments']) -> N
'Representing the model name and model author in Chinese and English.')
setattr(self, k, v)

def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_dataset: Optional[HfDataset],
val_dataset: Optional[HfDataset]) -> Tuple[Optional[HfDataset], Optional[HfDataset]]:
def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_dataset: DATASET_TYPE,
val_dataset: DATASET_TYPE) -> Tuple[DATASET_TYPE, DATASET_TYPE]:
# compatibility. (Deprecated)
streaming = self.streaming
random_state = np.random.RandomState(self.dataset_seed)
val_dataset_sample = self.val_dataset_sample

if train_dataset is not None and self.train_dataset_sample >= 0:
train_dataset_sample = min(self.train_dataset_sample, train_dataset.shape[0])
if train_dataset.shape[0] > train_dataset_sample:
logger.info(f'train_dataset_sample: {train_dataset_sample}')
train_idxs = random_state.permutation(train_dataset_sample)
train_dataset = train_dataset.select(train_idxs)
if not streaming:
train_dataset_sample = min(self.train_dataset_sample, train_dataset.shape[0])
if train_dataset.shape[0] > train_dataset_sample:
logger.info(f'train_dataset_sample: {train_dataset_sample}')
train_idxs = random_state.permutation(train_dataset_sample)
train_dataset = train_dataset.select(train_idxs)
else:
train_dataset = train_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size)
train_dataset = train_dataset.take(train_dataset_sample)

if val_dataset_sample is None:
val_dataset_sample = max(int(train_dataset_sample * self.dataset_test_ratio), 1)
if val_dataset is not None and val_dataset_sample is not None and val_dataset_sample >= 0:
if val_dataset.shape[0] > val_dataset_sample:
if not streaming and val_dataset.shape[0] > val_dataset_sample:
logger.info(f'val_dataset_sample: {val_dataset_sample}')
val_idxs = random_state.permutation(val_dataset_sample)
val_dataset = val_dataset.select(val_idxs)
elif streaming:
val_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size)
val_dataset = val_dataset.take(val_dataset_sample)

if (train_dataset is None or not hasattr(self, 'train_dataset_mix_ratio') or self.train_dataset_mix_ratio <= 0
or len(self.train_dataset_mix_ds) == 0):
return train_dataset, val_dataset

mix_dataset_sample = int(len(train_dataset) * self.train_dataset_mix_ratio)
logger.info(f'train_dataset_mix_ds: {self.train_dataset_mix_ds}')
logger.info(f'len(train_dataset): {len(train_dataset)}, mix_dataset_sample: {mix_dataset_sample}')
mixed_dataset = get_dataset(
self.train_dataset_mix_ds, 0.0, random_state, check_dataset_strategy=self.check_dataset_strategy)[0]
if len(mixed_dataset) < mix_dataset_sample:
logger.warn(f'The length of dataset used for mixin: {self.train_dataset_mix_ds} are '
'lesser than the ratio required by the `train_dataset_mix_ratio` '
f'argument: {self.train_dataset_mix_ratio}. '
f'the actual ratio is: {len(mixed_dataset) / len(train_dataset):.6}.')
if streaming:
logger.warning('`train_dataset_mix_ds` is not supported in streaming mode.')
else:
mixed_dataset = sample_dataset(mixed_dataset, mix_dataset_sample, random_state)
train_dataset = concatenate_datasets([train_dataset, mixed_dataset])
mix_dataset_sample = int(len(train_dataset) * self.train_dataset_mix_ratio)
logger.info(f'train_dataset_mix_ds: {self.train_dataset_mix_ds}')
logger.info(f'len(train_dataset): {len(train_dataset)}, mix_dataset_sample: {mix_dataset_sample}')
mixed_dataset = get_dataset(
self.train_dataset_mix_ds,
0.0,
random_state,
check_dataset_strategy=self.check_dataset_strategy,
streaming=streaming)[0]
if len(mixed_dataset) < mix_dataset_sample:
logger.warn(f'The length of dataset used for mixin: {self.train_dataset_mix_ds} are '
'lesser than the ratio required by the `train_dataset_mix_ratio` '
f'argument: {self.train_dataset_mix_ratio}. '
f'the actual ratio is: {len(mixed_dataset) / len(train_dataset):.6}.')
else:
mixed_dataset = sample_dataset(mixed_dataset, mix_dataset_sample, random_state)
train_dataset = concatenate_datasets([train_dataset, mixed_dataset])
return train_dataset, val_dataset

def prepare_template(self: Union['SftArguments', 'InferArguments']):
Expand Down Expand Up @@ -538,6 +555,29 @@ def load_from_ckpt_dir(self, is_sft: bool = False) -> None:
if self.val_dataset is None:
self.val_dataset = []

def _handle_streaming_args(self: Union['SftArguments', 'InferArguments']) -> None:
if not self.streaming:
return
if hasattr(self, 'packing') and self.packing:
self.packing = False
logger.warning('Packing is not supported for streaming dataset, set to False')

if hasattr(self, 'test_oom_error') and self.test_oom_error:
self.test_oom_error = False
logger.warning('test_oom_error is not supported for streaming dataset, set to False')

if hasattr(self, 'lazy_tokenize') and self.lazy_tokenize:
self.lazy_tokenize = False
logger.info('lazy_tokenize set to False in streaming dataset')

if self.dataset_test_ratio > 0:
logger.warning('Since the length of streaming data cannot be estimated,'
'set dataset_test_ratio to 0. You can manually set val_dataset_sample.')
self.dataset_test_ratio = 0

if self.train_dataset_sample > 0 or self.val_dataset_sample:
logger.warning('The final data size in streaming data may be smaller than train_dataset_sample')


@dataclass
class SftArguments(ArgumentsBase):
Expand Down Expand Up @@ -587,7 +627,10 @@ class SftArguments(ArgumentsBase):
max_length: int = 2048 # -1: no limit
truncation_strategy: Literal['delete', 'truncation_left'] = 'delete'
check_dataset_strategy: Literal['none', 'discard', 'error', 'warning'] = 'none'

# streaming dataset
streaming: bool = False
streaming_buffer_size: int = 16384
stopping_strategy: Literal['first_exhausted', 'all_exhausted'] = 'all_exhausted' # do we need?
# Chinese name and English name
model_name: List[str] = field(default_factory=lambda: [None, None], metadata={'help': "e.g. ['小黄', 'Xiao Huang']"})
model_author: List[str] = field(
Expand Down Expand Up @@ -1069,6 +1112,7 @@ def __post_init__(self) -> None:
self.logging_dir = f'{self.output_dir}/runs'
if self.train_backend == 'transformers':
self.training_args.logging_dir = self.logging_dir
self._handle_streaming_args()

def _init_training_args(self) -> None:
additional_saved_files = []
Expand Down
Loading

0 comments on commit 297ca72

Please sign in to comment.