diff --git a/swift/llm/rlhf.py b/swift/llm/rlhf.py index f6ad596f0b..8d1503d318 100644 --- a/swift/llm/rlhf.py +++ b/swift/llm/rlhf.py @@ -209,6 +209,8 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]: trainer_kwargs['is_vision'] = args.is_vision model.config.model_type += '_' # add suffix to avoid checks in hfDPOTrainer + trainer_kwargs['streaming'] = streaming + trainer = trainer_cls( model=model, train_dataset=train_dataset, diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 74e38bcb8c..beb23d99b6 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -389,7 +389,7 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_ train_idxs = random_state.permutation(train_dataset_sample) train_dataset = train_dataset.select(train_idxs) else: - # train_dataset = train_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size) + train_dataset = train_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size) train_dataset = train_dataset.take(train_dataset_sample) if val_dataset_sample is None: @@ -400,7 +400,7 @@ def _handle_dataset_compat(self: Union['SftArguments', 'InferArguments'], train_ val_idxs = random_state.permutation(val_dataset_sample) val_dataset = val_dataset.select(val_idxs) elif streaming: - # val_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size) + val_dataset.shuffle(seed=self.dataset_seed, buffer_size=self.streaming_buffer_size) val_dataset = val_dataset.take(val_dataset_sample) if (train_dataset is None or not hasattr(self, 'train_dataset_mix_ratio') or self.train_dataset_mix_ratio <= 0 @@ -552,36 +552,6 @@ def load_from_ckpt_dir(self, is_sft: bool = False) -> None: if self.val_dataset is None: self.val_dataset = [] - def _handle_streaming_args(self: Union['SftArguments', 'InferArguments']) -> None: - if not self.streaming: - return - if hasattr(self, 'packing') and self.packing: - self.packing = False - logger.warning('Packing is not supported for streaming dataset, set to False') - - if hasattr(self, 'test_oom_error') and self.test_oom_error: - self.test_oom_error = False - logger.warning('test_oom_error is not supported for streaming dataset, set to False') - - if hasattr(self, 'lazy_tokenize') and self.lazy_tokenize: - self.lazy_tokenize = False - logger.info('lazy_tokenize set to False in streaming dataset') - - if hasattr(self, 'train_dataset_mix_ratio') and self.train_dataset_mix_ratio > 0: - logger.warning('train_dataset_mix_ratio is not supported for streaming dataset, set to 0') - self.train_dataset_mix_ratio = 0 - - if self.dataset_test_ratio > 0: - logger.warning('Since the length of streaming data cannot be estimated,' - 'set dataset_test_ratio to 0. You can manually set val_dataset_sample.') - self.dataset_test_ratio = 0 - - if self.train_dataset_sample > 0 or self.val_dataset_sample: - logger.warning('The final data size in streaming data may be smaller than train_dataset_sample') - - if self.max_steps == -1: - raise ValueError('Please specify `max_steps` in streaming mode.') - @dataclass class SftArguments(ArgumentsBase): @@ -1068,10 +1038,9 @@ def __post_init__(self) -> None: if self.lazy_tokenize is None: self.lazy_tokenize = template_info.get('lazy_tokenize', False) logger.info(f'Setting args.lazy_tokenize: {self.lazy_tokenize}') + self._handle_streaming_args() if self.dataloader_num_workers is None: - if self.streaming: - self.dataloader_num_workers = 0 - elif 'dataloader_num_workers' in template_info: + if 'dataloader_num_workers' in template_info: self.dataloader_num_workers = template_info['dataloader_num_workers'] elif platform.system() == 'Windows': self.dataloader_num_workers = 0 @@ -1224,6 +1193,40 @@ def _handle_pai_compat(self) -> None: self.add_output_dir_suffix = False logger.info(f'Setting args.add_output_dir_suffix: {self.add_output_dir_suffix}') + def _handle_streaming_args(self) -> None: + if not self.streaming: + return + if self.packing: + self.packing = False + logger.warning('Packing is not supported for streaming dataset, set to False') + + if self.test_oom_error: + self.test_oom_error = False + logger.warning('test_oom_error is not supported for streaming dataset, set to False') + + if self.lazy_tokenize: + self.lazy_tokenize = False + logger.info('lazy_tokenize set to False in streaming dataset') + + if self.train_dataset_mix_ratio > 0: + logger.warning('train_dataset_mix_ratio is not supported for streaming dataset, set to 0') + self.train_dataset_mix_ratio = 0 + + if self.dataset_test_ratio > 0: + logger.warning('Since the length of streaming data cannot be estimated,' + 'set dataset_test_ratio to 0. You can manually set val_dataset_sample.') + self.dataset_test_ratio = 0 + + if self.train_dataset_sample > 0 or self.val_dataset_sample: + logger.warning('The final data size in streaming data may be smaller than train_dataset_sample') + + if self.max_steps == -1: + raise ValueError('Please specify `max_steps` in streaming mode.') + + if self.dataloader_num_workers is None or self.dataloader_num_workers > 0: + logger.info('dataloader_num_workers is not supported in streaming mode, set to 0') + self.dataloader_num_workers = 0 + @dataclass class InferArguments(ArgumentsBase): diff --git a/swift/llm/utils/dataset.py b/swift/llm/utils/dataset.py index cceefba255..f26a6bd9f6 100644 --- a/swift/llm/utils/dataset.py +++ b/swift/llm/utils/dataset.py @@ -312,7 +312,7 @@ def register_dataset_info(dataset_name: str, d_info: Dict[str, Any], **kwargs) - def load_ms_dataset(dataset_id: str, subset_split_list: Optional[List[SubsetSplit]], use_hf: bool = False, - streaming: bool = False) -> Optional[HfDataset]: + streaming: bool = False) -> DATASET_TYPE: if not use_hf: from modelscope import MsDataset @@ -352,6 +352,10 @@ def load_ms_dataset(dataset_id: str, continue except Exception: raise + if streaming and hasattr(dataset, '_hf_ds'): + dataset = dataset._hf_ds + if not isinstance(dataset, HfIterableDataset): + dataset = dataset.to_iterable_dataset() if hasattr(dataset, 'to_hf_dataset'): dataset = dataset.to_hf_dataset() dataset_list.append(dataset) @@ -361,14 +365,8 @@ def load_ms_dataset(dataset_id: str, return interleave_datasets(dataset_list, stopping_strategy='all_exhausted') # TODO: set arg stopping_strategy? -def sample_dataset(dataset: HfDataset, - dataset_sample: int, - random_state: Optional[RandomState] = None, - streaming=False, - train=False) -> HfDataset: - if dataset_sample in {None, -1}: - return dataset - elif not streaming and dataset_sample == len(dataset): +def sample_dataset(dataset: HfDataset, dataset_sample: int, random_state: Optional[RandomState] = None) -> HfDataset: + if dataset_sample in {None, -1, len(dataset)}: return dataset if random_state is None: random_state = RandomState() @@ -388,7 +386,7 @@ def _post_preprocess( dataset_test_ratio: float = 0., remove_useless_columns: bool = True, streaming: bool = False, -) -> Tuple[HfDataset, Optional[HfDataset]]: +) -> Tuple[DATASET_TYPE, Optional[DATASET_TYPE]]: # process train/val dataset and remove useless columns assert train_dataset is not None if not streaming: @@ -422,8 +420,8 @@ def _post_preprocess( if dataset_test_ratio == 1: train_dataset, val_dataset = None, train_dataset if dataset_sample > 0: - # train_dataset = train_dataset.shuffle( - # seed=get_seed(random_state), buffer_size=16384) # TODO: set buffer_size + train_dataset = train_dataset.shuffle( + seed=get_seed(random_state), buffer_size=16384) # TODO: set buffer_size train_dataset = train_dataset.take(dataset_sample) res = [] @@ -446,7 +444,7 @@ def get_dataset_from_repo(dataset_id: str, dataset_test_ratio: float = 0., remove_useless_columns: bool = True, use_hf: bool = False, - streaming: bool = False) -> Tuple[HfDataset, Optional[HfDataset]]: + streaming: bool = False) -> Tuple[DATASET_TYPE, Optional[DATASET_TYPE]]: if subsets is None: subsets = [] assert len(split) > 0 @@ -481,7 +479,7 @@ def _concat_inst_inp_alpaca_zh(inst: str, inp: str) -> str: hf_dataset_id='llm-wizard/alpaca-gpt4-data-zh') -def preprocess_sharegpt_4o_images(dataset): +def preprocess_sharegpt_4o_images(dataset: DATASET_TYPE): url = 'https://www.modelscope.cn/api/v1/datasets/AI-ModelScope/ShareGPT-4o/repo?Revision=master&FilePath=images.zip' local_dir = MediaCache.download(url, 'sharegpt_4o_images') prefix_path = os.path.join(local_dir, 'mnt', 'petrelfs', 'wangwenhai', 'workspace_cef', '4o', 'image') @@ -495,8 +493,10 @@ def preprocess_row(row): return {'image': [], 'conversations': []} return {'image': [image]} - dataset = dataset.map( - preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['conversations']) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + dataset = dataset.map(preprocess_row, **kwargs).filter(lambda row: row['conversations']) return ConversationsPreprocessor( user_role='human', assistant_role='gpt', media_type='image', error_strategy='delete')( dataset) @@ -512,7 +512,7 @@ def preprocess_row(row): hf_dataset_id='OpenGVLab/ShareGPT-4o') -def preprocess_sa1b_paired_caption(dataset: HfDataset): +def preprocess_sa1b_paired_caption(dataset: DATASET_TYPE): prompt = ['图片中展示了什么', '讲述一下图片中内容', '告诉我里面有什么', '图片内容是啥'] @@ -524,8 +524,10 @@ def preprocess_row(row): 'response': response, } - return dataset.map( - preprocess_row, load_from_cache_file=dataset_enable_cache).rename_column('opensource_url', 'images') + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess_row, **kwargs).rename_column('opensource_url', 'images') register_dataset( @@ -539,7 +541,7 @@ def preprocess_row(row): tags=['zh', 'multi-modal', 'vqa']) -def preprocess_sa1b_dense_caption(dataset: HfDataset): +def preprocess_sa1b_dense_caption(dataset: DATASET_TYPE): prompt = ['图片中展示了什么', '讲述一下图片中内容', '告诉我里面有什么', '图片内容是啥'] @@ -552,10 +554,10 @@ def preprocess_row(row): 'response': response, } - return dataset.map( - preprocess_row, - load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response')).rename_column( - 'url', 'images') + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess_row, **kwargs).filter(lambda row: row.get('response')).rename_column('url', 'images') register_dataset( @@ -598,7 +600,10 @@ def preprocess_row(row): else: return {'images': []} - return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['images']) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess_row, **kwargs).filter(lambda row: row['images']) def get_mantis_dataset(dataset_id: str, @@ -651,7 +656,7 @@ def get_mantis_dataset(dataset_id: str, hf_dataset_id='TIGER-Lab/Mantis-Instruct') -def preprocess_llava_data(dataset: HfDataset) -> HfDataset: +def preprocess_llava_data(dataset: DATASET_TYPE) -> DATASET_TYPE: all_folders = {} for media_type in ['coco', 'gqa', 'ocr_vqa', 'textvqa', 'VG_100K', 'VG_100K_2']: @@ -683,7 +688,10 @@ def preprocess_image(example): example['images'] = [] return example - dataset = dataset.map(preprocess_image, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['images']) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + dataset = dataset.map(preprocess_image, **kwargs).filter(lambda row: row['images']) return ConversationsPreprocessor( user_role='user', assistant_role='assistant', @@ -723,20 +731,33 @@ def preprocess_image(example): is_main=False) -def _preprocess_vision_dataset2(dataset: HfDataset) -> HfDataset: +def _preprocess_vision_dataset2(dataset: DATASET_TYPE) -> DATASET_TYPE: query = 'please describe the image.' image_key = 'image' response_key = 'caption' dataset._info.features._column_requires_decoding['image'] = False - response = [] - images = [] - for d in tqdm(dataset): - images.append([d[image_key]['path']]) - if '&&' in d[response_key]: - d[response_key] = d[response_key].split('&&')[0] - response.append(d[response_key]) - return HfDataset.from_dict({'query': [query] * len(response), 'response': response, 'images': images}) + + if isinstance(dataset, HfIterableDataset): + + def _preprocess_vision_dataset2_generator(dataset): + for d in tqdm(dataset): + d[image_key] = d[image_key]['path'] # TODO + if '&&' in d[response_key]: + d[response_key] = d[response_key].split('&&')[0] + response = d[response_key] + yield {'query': query, 'response': response, 'images': images} + + return HfIterableDataset.from_generator(_preprocess_vision_dataset2_generator, gen_kwargs={'dataset': dataset}) + else: + response = [] + images = [] + for d in tqdm(dataset): + images.append([d[image_key]['path']]) + if '&&' in d[response_key]: + d[response_key] = d[response_key].split('&&')[0] + response.append(d[response_key]) + return HfDataset.from_dict({'query': [query] * len(response), 'response': response, 'images': images}) register_dataset( @@ -758,7 +779,7 @@ def _preprocess_vision_dataset2(dataset: HfDataset) -> HfDataset: is_main=False) -def _preprocess_pixelprose(dataset: HfDataset): +def _preprocess_pixelprose(dataset: DATASET_TYPE): caption_prompt = [ 'Give the description of this image.', 'Describe this picture', 'What is the proper title of this image?' @@ -774,7 +795,10 @@ def preprocess(row): 'query': np.random.choice(caption_prompt), } - return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess, **kwargs) register_dataset( @@ -890,7 +914,7 @@ def _repair_ms_bench(conversations: str) -> Optional[List[Dict[str, str]]]: return conversations -def long_alpaca_preprocessor(dataset: HfDataset): +def long_alpaca_preprocessor(dataset: DATASET_TYPE): def map_row(row): response = row['response'] @@ -900,7 +924,10 @@ def map_row(row): return response dataset = AlpacaPreprocessor()(dataset) - return dataset.map(map_row, load_from_cache_file=dataset_enable_cache) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(map_row, **kwargs) register_dataset( @@ -913,7 +940,7 @@ def map_row(row): hf_dataset_id='Yukang/LongAlpaca-12k') -def _preprocess_ruozhiba(dataset: HfDataset): +def _preprocess_ruozhiba(dataset: DATASET_TYPE): def map_row(row): title = row['title'] if row.get('title', None) is not None else row['content'] @@ -927,7 +954,10 @@ def map_row(row): title = match.group(1) return {'response': title} - return dataset.map(map_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['response']) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(map_row, **kwargs).filter(lambda row: row['response']) register_dataset( @@ -1067,7 +1097,7 @@ def _preprocess_dureader_robust(dataset: HfDataset) -> HfDataset: tags=['text-generation', '🔥']) -def process_hh_rlhf(dataset): +def process_hh_rlhf(dataset: DATASET_TYPE): def reorganize_row(row): import re @@ -1108,8 +1138,10 @@ def reorganize_row(row): 'history': history, } - return dataset.map( - reorganize_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query'] is not None) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(reorganize_row, **kwargs).filter(lambda row: row['query'] is not None) register_dataset( @@ -1121,7 +1153,7 @@ def reorganize_row(row): tags=['rlhf', 'dpo', 'pairwise']) -def process_hh_rlhf_cn(dataset): +def process_hh_rlhf_cn(dataset: DATASET_TYPE): def reorganize_row(row): history = [] @@ -1173,8 +1205,10 @@ def row_can_be_parsed(row): except: # noqa return False - return dataset.filter(row_can_be_parsed).map( - reorganize_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query']) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.filter(row_can_be_parsed).map(reorganize_row, **kwargs).filter(lambda row: row['query']) register_dataset( @@ -1202,7 +1236,7 @@ def _preprocess_m3it(dataset: HfDataset) -> HfDataset: return dataset -def _preprocess_sharegpt4v(dataset: HfDataset) -> HfDataset: +def _preprocess_sharegpt4v(dataset: DATASET_TYPE) -> DATASET_TYPE: split = ['ShareGPT4V', 'ShareGPT4V-PT'] if dataset.config_name is None else dataset.config_name IMAGE_DATASET_REQUIREMENTS = { 'ShareGPT4V': ['coco', 'sam', 'llava', 'wikiart', 'share_textvqa', 'web-celebrity', 'web-landmark'], @@ -1243,9 +1277,10 @@ def preprocess_image(example): example['images'] = None return example - dataset = dataset.map( - preprocess_image, - load_from_cache_file=dataset_enable_cache).filter(lambda example: example['images'] is not None) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + dataset = dataset.map(preprocess_image, **kwargs).filter(lambda example: example['images'] is not None) processer = ConversationsPreprocessor( user_role='human', assistant_role='gpt', media_type='image', media_key='images', error_strategy='delete') return processer(dataset) @@ -1279,7 +1314,7 @@ def preprocess_image(example): tags=['chat', 'multi-modal', 'vision']) -def preprocess_text_caps(dataset): +def preprocess_text_caps(dataset: DATASET_TYPE): def preprocess(row): try: @@ -1289,12 +1324,13 @@ def preprocess(row): except Exception: return {'response': '', 'image': None} - return dataset.map( - preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response')).rename_columns( - {'image': 'images'}) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess, **kwargs).filter(lambda row: row.get('response')).rename_columns({'image': 'images'}) -def preprocess_refcoco_unofficial_caption(dataset): +def preprocess_refcoco_unofficial_caption(dataset: DATASET_TYPE): cache_dir = MediaCache.download( 'https://www.modelscope.cn/api/v1/datasets/we_dont_produce_water/' @@ -1322,7 +1358,10 @@ def preprocess(row): res['response'] = '' return res - return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response')) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess, **kwargs).filter(lambda row: row.get('response')) register_dataset( @@ -1344,7 +1383,7 @@ def preprocess(row): tags=['multi-modal', 'en', 'caption']) -def preprocess_refcoco_unofficial_grounding(dataset): +def preprocess_refcoco_unofficial_grounding(dataset: DATASET_TYPE): cache_dir = MediaCache.download( 'https://www.modelscope.cn/api/v1/datasets/we_dont_produce_water/' @@ -1372,7 +1411,10 @@ def preprocess(row): res['response'] = '' return res - return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response')) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess, **kwargs).filter(lambda row: row.get('response')) register_dataset( @@ -1414,7 +1456,7 @@ def preprocess(row): tags=['multi-modal', 'en', 'ocr-vqa', 'quality']) -def _preprocess_llava_instruct_images(dataset: HfDataset) -> HfDataset: +def _preprocess_llava_instruct_images(dataset: DATASET_TYPE) -> DATASET_TYPE: all_folders = {} for media_type in ['coco', 'gqa', 'ocr_vqa', 'textvqa', 'VG_100K', 'VG_100K_2']: all_folders[media_type] = MediaCache.download(media_type) @@ -1440,9 +1482,10 @@ def preprocess_image(example): example['images'] = None return example - dataset = dataset.map( - preprocess_image, - load_from_cache_file=dataset_enable_cache).filter(lambda example: example['images'] is not None) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + dataset = dataset.map(preprocess_image, **kwargs).filter(lambda example: example['images'] is not None) processer = ConversationsPreprocessor( user_role='human', assistant_role='gpt', media_type='image', media_key='images', error_strategy='delete') return processer(dataset) @@ -1490,7 +1533,7 @@ def repair_conversations(s: Union[str, Any]) -> Any: tags=['chat', 'em']) -def _preprocess_llava_pretrain(dataset): +def _preprocess_llava_pretrain(dataset: DATASET_TYPE): media_dir = MediaCache.download( 'https://www.modelscope.cn/api/v1/datasets/AI-ModelScope/LLaVA-Pretrain/repo?Revision=master&FilePath=images.zip', # noqa 'llava_pretrain') @@ -1505,7 +1548,10 @@ def preprocess(row): else: return {'image': ''} - dataset = dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['image']) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + dataset = dataset.map(preprocess, **kwargs).filter(lambda row: row['image']) return ConversationsPreprocessor( user_role='human', assistant_role='gpt', media_type='image', error_strategy='delete')( dataset) @@ -1522,7 +1568,7 @@ def preprocess(row): tags=['vqa', 'multi-modal', 'quality']) -def process_shareai_dpo(dataset): +def process_shareai_dpo(dataset: DATASET_TYPE): def reorganize_row(row): return { @@ -1531,7 +1577,10 @@ def reorganize_row(row): 'rejected_response': row['answer_en'], } - return dataset.map(reorganize_row, load_from_cache_file=dataset_enable_cache) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(reorganize_row, **kwargs) def process_ultrafeedback_kto(dataset: HfDataset): @@ -1550,7 +1599,7 @@ def process_ultrafeedback_kto(dataset: HfDataset): tags=['rlhf', 'kto']) -def process_zhihu_kol(dataset: HfDataset): +def process_zhihu_kol(dataset: DATASET_TYPE): def reorganize_row(row): return { @@ -1558,7 +1607,10 @@ def reorganize_row(row): 'response': row['RESPONSE'], } - return dataset.map(reorganize_row, load_from_cache_file=dataset_enable_cache) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(reorganize_row, **kwargs) register_dataset( @@ -1579,7 +1631,7 @@ def reorganize_row(row): tags=['zhihu', 'qa']) -def preprocess_guanaco(dataset): +def preprocess_guanaco(dataset: DATASET_TYPE): from swift.utils.utils import split_str_parts_by def preprocess_row(row): @@ -1609,8 +1661,10 @@ def preprocess_row(row): 'response': output, } - return dataset.map( - preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query'] and row['response']) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess_row, **kwargs).filter(lambda row: row['query'] and row['response']) register_dataset( @@ -1622,7 +1676,7 @@ def preprocess_row(row): tags=['chat', 'zh']) -def preprocess_dolly_15k(dataset): +def preprocess_dolly_15k(dataset: DATASET_TYPE): def preprocess_row(row): instruction = row['instruction'] @@ -1639,7 +1693,10 @@ def preprocess_row(row): 'response': response, } - return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess_row, **kwargs) register_dataset( @@ -1671,7 +1728,7 @@ def preprocess_row(row): tags=['medical', 'en', 'vqa']) -def preprocess_okvqa(dataset): +def preprocess_okvqa(dataset: DATASET_TYPE): def preprocess(row): query = row['question'] @@ -1681,7 +1738,10 @@ def preprocess(row): 'query': query, } - return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).rename_column('image', 'images') + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess, **kwargs).rename_column('image', 'images') register_dataset( @@ -1694,7 +1754,7 @@ def preprocess(row): tags=['multi-modal', 'en', 'vqa', 'quality']) -def preprocess_a_okvqa(dataset): +def preprocess_a_okvqa(dataset: DATASET_TYPE): def preprocess(row): query = row['question'] @@ -1704,7 +1764,10 @@ def preprocess(row): 'query': query, } - return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).rename_column('image', 'images') + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess, **kwargs).rename_column('image', 'images') register_dataset( @@ -1717,7 +1780,7 @@ def preprocess(row): tags=['multi-modal', 'en', 'vqa', 'quality']) -def preprocess_ocr_vqa(dataset): +def preprocess_ocr_vqa(dataset: DATASET_TYPE): def preprocess(row): idx = np.random.choice(range(len(row['questions']))) @@ -1728,7 +1791,10 @@ def preprocess(row): 'query': query, } - return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache).rename_column('image', 'images') + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess, **kwargs).rename_column('image', 'images') register_dataset( @@ -1741,7 +1807,7 @@ def preprocess(row): tags=['multi-modal', 'en', 'ocr-vqa']) -def preprocess_science_qa(dataset): +def preprocess_science_qa(dataset: DATASET_TYPE): def preprocess_row(row): query = row['question'] @@ -1749,9 +1815,10 @@ def preprocess_row(row): solution = row['solution'] return {'query': query, 'response': f'{solution}\nSo the final answer is:{response}'} - return dataset.map( - preprocess_row, - load_from_cache_file=dataset_enable_cache).filter(lambda row: row['image']).rename_columns({'image': 'images'}) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess_row, **kwargs).filter(lambda row: row['image']).rename_columns({'image': 'images'}) register_dataset( @@ -1764,7 +1831,7 @@ def preprocess_row(row): tags=['multi-modal', 'science', 'vqa', 'quality']) -def preprocess_grit(dataset): +def preprocess_grit(dataset: DATASET_TYPE): def has_overlap(start_ends): for i in range(1, len(start_ends)): @@ -1805,7 +1872,10 @@ def preprocess_row(row): return {'images': images, 'response': response, 'objects': json.dumps(objects or [], ensure_ascii=False)} - return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['objects']) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess_row, **kwargs).filter(lambda row: row['objects']) register_dataset( @@ -1819,7 +1889,7 @@ def preprocess_row(row): tags=['multi-modal', 'en', 'caption-grounding', 'quality']) -def preprocess_gqa(dataset): +def preprocess_gqa(dataset: DATASET_TYPE): local_cache = MediaCache.download('gqa') def preprocess_row(row): @@ -1832,7 +1902,10 @@ def preprocess_row(row): else: return {'query': '', 'response': '', 'images': ''} - return dataset.map(preprocess_row, load_from_cache_file=dataset_enable_cache).filter(lambda row: row['query']) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess_row, **kwargs).filter(lambda row: row['query']) register_dataset( @@ -1845,7 +1918,7 @@ def preprocess_row(row): tags=['multi-modal', 'en', 'vqa', 'quality']) -def preprocess_llava_mix_sft(dataset): +def preprocess_llava_mix_sft(dataset: DATASET_TYPE): def preprocess_row(row): messages = row['messages'] @@ -1864,18 +1937,19 @@ def preprocess_row(row): return {'messages': rounds} - dataset = dataset.map( - preprocess_row, load_from_cache_file=dataset_enable_cache).map( - ConversationsPreprocessor( - user_role='user', - assistant_role='assistant', - conversations_key='messages', - from_key='role', - value_key='content', - media_key='images', - media_type='image', - ).preprocess, - load_from_cache_file=dataset_enable_cache) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + dataset = dataset.map(preprocess_row, **kwargs).map( + ConversationsPreprocessor( + user_role='user', + assistant_role='assistant', + conversations_key='messages', + from_key='role', + value_key='content', + media_key='images', + media_type='image', + ).preprocess, **kwargs) return dataset @@ -1889,7 +1963,7 @@ def preprocess_row(row): tags=['multi-modal', 'en', 'vqa', 'quality']) -def orpo_dpo_mix_40k_preprocessor(dataset: HfDataset): +def orpo_dpo_mix_40k_preprocessor(dataset: DATASET_TYPE): def preprocess(row): chosen_history = row['chosen'] @@ -1931,9 +2005,11 @@ def preprocess(row): 'rejected_response': rejected_response, } - return dataset.map( - preprocess, load_from_cache_file=dataset_enable_cache).filter( - lambda r: r['source'] != 'toxic-dpo-v0.2' and r['query'] is not None) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess, + **kwargs).filter(lambda r: r['source'] != 'toxic-dpo-v0.2' and r['query'] is not None) register_dataset( @@ -1945,7 +2021,7 @@ def preprocess(row): tags=['dpo', 'orpo', 'en', 'quality']) -def synthetic_text_to_sql_preprocesser(dataset: HfDataset): +def synthetic_text_to_sql_preprocesser(dataset: DATASET_TYPE): def preprocess(row): sql_prompt = row['sql_prompt'] @@ -1959,7 +2035,10 @@ def preprocess(row): 'response': response, } - return dataset.map(preprocess, load_from_cache_file=dataset_enable_cache) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(preprocess, **kwargs) register_dataset( @@ -2241,7 +2320,7 @@ def process_conversation(conv): tags=['chat', 'agent', 'multi-round', 'role-play', 'multi-agent']) -def _preprocess_toolbench(dataset: HfDataset) -> HfDataset: +def _preprocess_toolbench(dataset: DATASET_TYPE) -> DATASET_TYPE: def reorganize_row(row): convs = row['conversations'] @@ -2261,7 +2340,10 @@ def reorganize_row(row): 'response': convs[-1]['value'] } - return dataset.map(reorganize_row, load_from_cache_file=dataset_enable_cache) + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + return dataset.map(reorganize_row, **kwargs) register_dataset( diff --git a/swift/llm/utils/preprocess.py b/swift/llm/utils/preprocess.py index c763aa8938..fd82c8edc3 100644 --- a/swift/llm/utils/preprocess.py +++ b/swift/llm/utils/preprocess.py @@ -113,7 +113,7 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]: class SwiftPreprocessor: - def __call__(self, dataset: HfDataset) -> HfDataset: + def __call__(self, dataset: DATASET_TYPE) -> DATASET_TYPE: if 'history' in dataset.features: old_history = dataset['history'] has_history = False @@ -274,10 +274,11 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]: else: return self.empty_row - def __call__(self, dataset: HfDataset) -> HfDataset: - dataset = dataset.map( - self.preprocess, - load_from_cache_file=dataset_enable_cache).filter(lambda row: row.get('response') is not None) + def __call__(self, dataset: DATASET_TYPE) -> DATASET_TYPE: + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + dataset = dataset.map(self.preprocess, **kwargs).filter(lambda row: row.get('response') is not None) if self.media_type and isinstance(self.media_key, str) and self.media_key != self.media_name: dataset = dataset.rename_columns({self.media_key: self.media_name}) return dataset @@ -332,9 +333,11 @@ def preprocess(self, d: Dict[str, Any]) -> Dict[str, Any]: return self.empty_row return row - def __call__(self, dataset: HfDataset): - dataset = dataset.map( - self.preprocess, load_from_cache_file=dataset_enable_cache).filter(lambda d: d.get('response')) + def __call__(self, dataset: DATASET_TYPE): + kwargs = {} + if not isinstance(dataset, HfIterableDataset): + kwargs['load_from_cache_file'] = dataset_enable_cache + dataset = dataset.map(self.preprocess, **kwargs).filter(lambda d: d.get('response')) if self.media_type and isinstance(self.media_key, str) and self.media_key != self.media_name: dataset = dataset.rename_columns({self.media_key: self.media_name}) return dataset diff --git a/swift/trainers/dpo_trainer.py b/swift/trainers/dpo_trainer.py index 7a5bd3ad54..12e901ad17 100644 --- a/swift/trainers/dpo_trainer.py +++ b/swift/trainers/dpo_trainer.py @@ -21,19 +21,23 @@ class DPOTrainer(PushToMsHubMixin, SwiftMixin, HFDPOTrainer): def __init__(self, *args, template: Template, sft_beta=0., test_oom_error=False, **kwargs): self.template = template self.sft_beta = sft_beta + self.streaming = kwargs.pop('streaming') is_vision = kwargs.pop('is_vision') self.keys = [] + self.need_filter: bool = False super().__init__(*args, **kwargs) - self.train_dataset = self.train_dataset.filter(lambda x: x['prompt_input_ids'] is not None) - if self.eval_dataset is not None: - self.eval_dataset = self.eval_dataset.filter(lambda x: x['prompt_input_ids'] is not None) - train_ds_info = self.stat_dataset(self.train_dataset, self.is_encoder_decoder) - - if self.eval_dataset is not None: - val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder) - self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info} - else: - self.dataset_info = {'train_dataset': train_ds_info} + if self.need_filter: + self.train_dataset = self.train_dataset.filter(lambda x: x['prompt_input_ids'] is not None) + if self.eval_dataset is not None: + self.eval_dataset = self.eval_dataset.filter(lambda x: x['prompt_input_ids'] is not None) + if not self.streaming: + train_ds_info = self.stat_dataset(self.train_dataset, self.is_encoder_decoder) + + if self.eval_dataset is not None: + val_ds_info = self.stat_dataset(self.eval_dataset, self.is_encoder_decoder) + self.dataset_info = {'train_dataset': train_ds_info, 'val_dataset': val_ds_info} + else: + self.dataset_info = {'train_dataset': train_ds_info} if test_oom_error: self.train_dataset = sort_by_max_length(self.train_dataset, 20000) # performance @@ -43,6 +47,7 @@ def __init__(self, *args, template: Template, sft_beta=0., test_oom_error=False, 'memory': {}, 'model': self.model.get_trainable_parameters() if hasattr(self.model, 'get_trainable_parameters') else None, } + # modify after init self.is_vision_model = is_vision self.model.config.model_type = self.model.config.model_type[:-1] # remove suffix @@ -63,6 +68,7 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) # Skip examples that do not contain 'input_ids' if 'input_ids' not in prompt_tokens: + self.need_filter = True return {k: None for k in self.keys} # resolve conflict in data_collator when labels are None, pop it afterwards diff --git a/swift/trainers/utils.py b/swift/trainers/utils.py index 29a4e31028..2c810e5ceb 100644 --- a/swift/trainers/utils.py +++ b/swift/trainers/utils.py @@ -106,15 +106,20 @@ def build_tokenized_answer(answer, template: Template): def patch_trl(): from .callback import DefaultFlowCallbackNew, PrinterCallbackNew, ProgressCallbackNew from transformers import trainer - import torch - from typing import Any, Dict, List - from trl.trainer.utils import DPODataCollatorWithPadding, pad trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew] trainer.PrinterCallback = PrinterCallbackNew # fix encoder-decoder error + patch_datacollator() + patch_itds_map() + + +def patch_datacollator(): + import torch + from typing import Any, Dict, List + from trl.trainer.utils import DPODataCollatorWithPadding, pad if not hasattr(DPODataCollatorWithPadding, '_old_call'): # Avoid double patching from torch.nn.utils.rnn import pad_sequence from functools import wraps @@ -191,3 +196,25 @@ def new_call(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: DPODataCollatorWithPadding.__call__ = new_call DPODataCollatorWithPadding._old_call = old_call + + +def patch_itds_map(): + # resolve conflict with `num_proc` in iterable_dataset map func + from datasets import IterableDataset + from functools import wraps + + def _patch_ids_map(map_func): + pass + + if not hasattr(IterableDataset, '_old_map'): # Avoid double patching + old_map = IterableDataset.map + + @wraps(old_map) + def new_map(self, *args, **kwargs): + kwargs.pop('num_proc', None) + kwargs.pop('writer_batch_size', None) + return old_map(self, *args, **kwargs) + + IterableDataset.map = new_map + IterableDataset._old_map = old_map + # model.forward = MethodType(_patch_ids_map(map_func), IterableDataset)