Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hjh0119 committed Aug 18, 2024
1 parent 3081069 commit 7c873e5
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 41 deletions.
33 changes: 21 additions & 12 deletions swift/trainers/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,20 +216,29 @@ def concatenated_forward(
} if self.is_encoder_decoder else {})

if self.is_vision_model:
# Here, we restore the _data, processing image information within the forward hook of the model.
batch_size = concatenated_batch['concatenated_input_ids'].shape[0]
if self._data_keys is not None:
_data = [dict() for _ in range(batch_size)]
for k in self._data_keys:
if k == 'input_ids':
_data = [{**d, k: concatenated_batch['concatenated_input_ids'][i]} for i, d in enumerate(_data)]
elif k == 'pixel_values':
# convert the dtype of the pixel values that may be converted to float32 in tokenize_row
model_dtype = self.accelerator.unwrap_model(model).dtype
# for vision related data, paired response share the same one
_data = [{**d, k: concatenated_batch[k][i // 2].to(model_dtype)} for i, d in enumerate(_data)]
else:
_data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)]
model_kwargs['_data'] = _data
if self._data_keys:
_data = [dict() for _ in range(batch_size)]
for k in self._data_keys:
if k == 'input_ids':
_data = [{
**d, k: concatenated_batch['concatenated_input_ids'][i]
} for i, d in enumerate(_data)]
elif k == 'pixel_values':
# convert the dtype of the pixel values that may be converted to float32 in tokenize_row
model_dtype = self.accelerator.unwrap_model(model).dtype
# for vision related data, paired response share the same one
_data = [{
**d, k: concatenated_batch[k][i // 2].to(model_dtype)
} for i, d in enumerate(_data)]
else:
_data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)]
model_kwargs['_data'] = _data

if 'images' in concatenated_batch:
model_kwargs['images'] = concatenated_batch['images']

if self.aux_loss_enabled:
model_kwargs['output_router_logits'] = True
Expand Down
38 changes: 24 additions & 14 deletions swift/trainers/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,26 @@ def concatenated_forward(
# Here, we restore the _data, processing image information within the forward hook of the model.
batch_size = concatenated_batch['concatenated_input_ids'].shape[0]
if self._data_keys is not None:
_data = [dict() for _ in range(batch_size)]
for k in self._data_keys:
if k == 'input_ids':
_data = [{**d, k: concatenated_batch['concatenated_input_ids'][i]} for i, d in enumerate(_data)]
elif k == 'pixel_values':
# convert the dtype of the pixel values that may be converted to float32 in tokenize_row
model_dtype = self.accelerator.unwrap_model(model).dtype
# for vision related data, paired response share the same one
_data = [{**d, k: concatenated_batch[k][i // 2].to(model_dtype)} for i, d in enumerate(_data)]
else:
_data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)]
model_kwargs['_data'] = _data
if self._data_keys:
_data = [dict() for _ in range(batch_size)]
for k in self._data_keys:
if k == 'input_ids':
_data = [{
**d, k: concatenated_batch['concatenated_input_ids'][i]
} for i, d in enumerate(_data)]
elif k == 'pixel_values':
# convert the dtype of the pixel values that may be converted to float32 in tokenize_row
model_dtype = self.accelerator.unwrap_model(model).dtype
# for vision related data, paired response share the same one
_data = [{
**d, k: concatenated_batch[k][i // 2].to(model_dtype)
} for i, d in enumerate(_data)]
else:
_data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)]
model_kwargs['_data'] = _data

if 'images' in concatenated_batch:
model_kwargs['images'] = concatenated_batch['images']

if self.aux_loss_enabled:
model_kwargs['output_router_logits'] = True
Expand Down Expand Up @@ -427,9 +435,8 @@ def concatenated_inputs(
batch['prompt_attention_mask'].repeat(2, 1).to(device=device))

# patch here
# leave data collector in hook

if is_vision_model:
# for keys appear in _data, we leave data collector in hook
if 'prompt_pixel_values' in batch:
pixel_values = [values for values in batch['prompt_pixel_values']]
concatenated_batch['pixel_values'] = pixel_values
Expand All @@ -445,6 +452,9 @@ def concatenated_inputs(
if 'prompt_image_sizes' in batch:
concatenated_batch['image_sizes'] = batch['prompt_image_sizes']

if 'prompt_images' in batch:
# images not in _data, we manually execute data collector here
concatenated_batch['images'] = batch['prompt_images'].squeeze(1).repeat(2, 1, 1, 1).to(device=device)
return concatenated_batch

@staticmethod
Expand Down
33 changes: 21 additions & 12 deletions swift/trainers/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,20 +215,29 @@ def concatenated_forward(
} if self.is_encoder_decoder else {})

if self.is_vision_model:
# Here, we restore the _data, processing image information within the forward hook of the model.
batch_size = concatenated_batch['concatenated_input_ids'].shape[0]
if self._data_keys is not None:
_data = [dict() for _ in range(batch_size)]
for k in self._data_keys:
if k == 'input_ids':
_data = [{**d, k: concatenated_batch['concatenated_input_ids'][i]} for i, d in enumerate(_data)]
elif k == 'pixel_values':
# convert the dtype of the pixel values that may be converted to float32 in tokenize_row
model_dtype = self.accelerator.unwrap_model(model).dtype
# for vision related data, paired response share the same one
_data = [{**d, k: concatenated_batch[k][i // 2].to(model_dtype)} for i, d in enumerate(_data)]
else:
_data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)]
model_kwargs['_data'] = _data
if self._data_keys:
_data = [dict() for _ in range(batch_size)]
for k in self._data_keys:
if k == 'input_ids':
_data = [{
**d, k: concatenated_batch['concatenated_input_ids'][i]
} for i, d in enumerate(_data)]
elif k == 'pixel_values':
# convert the dtype of the pixel values that may be converted to float32 in tokenize_row
model_dtype = self.accelerator.unwrap_model(model).dtype
# for vision related data, paired response share the same one
_data = [{
**d, k: concatenated_batch[k][i // 2].to(model_dtype)
} for i, d in enumerate(_data)]
else:
_data = [{**d, k: concatenated_batch[k][i // 2]} for i, d in enumerate(_data)]
model_kwargs['_data'] = _data

if 'images' in concatenated_batch:
model_kwargs['images'] = concatenated_batch['images']

if self.aux_loss_enabled:
model_kwargs['output_router_logits'] = True
Expand Down
6 changes: 3 additions & 3 deletions swift/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def patch_datacollator():
def new_call(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
padded_batch = {}
for k in features[0].keys():
if k.endswith(('_input_ids', '_attention_mask', '_labels', '_pixel_values')):
if k.endswith(('_input_ids', '_attention_mask', '_labels', '_pixel_values', '_images')):
if self.is_encoder_decoder:
to_pad = [torch.LongTensor(ex[k]) for ex in features]

Expand Down Expand Up @@ -187,7 +187,7 @@ def new_call(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
padding_value = self.label_pad_token_id
elif k.endswith('_attention_mask'):
padding_value = 0
elif k.endswith('_pixel_values'):
elif k.endswith(('_pixel_values', '_images')):
padding_value = 0
else:
raise ValueError(f"Unexpected key in batch '{k}'")
Expand All @@ -199,7 +199,7 @@ def new_call(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
padding_side = 'right'

# Set the dtype
if k.endswith('_pixel_values'):
if k.endswith(('_pixel_values', '_images')):
dtype = torch.float32 # will be downcasted if necessary by the Trainer
else:
dtype = torch.int64
Expand Down

0 comments on commit 7c873e5

Please sign in to comment.