Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/improvement #100

Merged
merged 18 commits into from
Sep 30, 2024
30 changes: 21 additions & 9 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,7 +821,7 @@ def compute_mlm_loss(self, logits, mask_target_labels):
ignore_index=self.pad_token_id,
)

def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(self, model, inputs, return_outputs: bool = False):
""" Compute loss for AnglE.

:param model: Huggingface model.
Expand Down Expand Up @@ -859,6 +859,11 @@ def compute_loss(self, model, inputs, return_outputs=False):

return (loss, outputs) if return_outputs else loss

@torch.no_grad()
def prediction_step(self, model, inputs, *args, **kwargs):
eval_loss = self.compute_loss(model, inputs, return_outputs=False)
return eval_loss, None, None


class AngleESETrainer(AngleTrainer):
"""
Expand Down Expand Up @@ -1412,13 +1417,15 @@ def detect_dataset_format(self, ds: Dataset):
def fit(self,
train_ds: Dataset,
valid_ds: Optional[Dataset] = None,
valid_ds_for_callback: Optional[Dataset] = None,
batch_size: int = 32,
output_dir: Optional[str] = None,
epochs: int = 1,
learning_rate: float = 1e-5,
warmup_steps: int = 1000,
logging_steps: int = 10,
eval_steps: Optional[int] = None,
eval_steps: int = 1000,
evaluation_strategy: str = 'steps',
save_steps: int = 100,
save_strategy: str = 'steps',
save_total_limit: int = 10,
Expand All @@ -1439,13 +1446,17 @@ def fit(self,

:param train_ds: Dataset. tokenized train dataset. Required.
:param valid_ds: Optional[Dataset]. tokenized valid dataset. Default None.
:param valid_ds_for_callback: Optional[Dataset]. tokenized valid dataset for callback use.
The dataset format should be `DatasetFormats.A`. The spearmans' correlation will be computed
after each epoch training and the best model will be saved. Default None.
:param batch_size: int. Default 32.
:param output_dir: Optional[str]. save dir. Default None.
:param epochs: int. Default 1.
:param learning_rate: float. Default 1e-5.
:param warmup_steps: int. Default 1000.
:param logging_steps: int. Default 10.
:param eval_steps: Optional[int]. Default None.
:param eval_steps: int. Default 1000.
:param evaluation_strategy: str. Default 'steps'.
:param save_steps: int. Default 100.
:param save_strategy: str. Default steps.
:param save_total_limit: int. Default 10.
Expand Down Expand Up @@ -1491,16 +1502,16 @@ def fit(self,
trainer_kwargs = {}

callbacks = None
if valid_ds is not None:
if valid_ds_for_callback is not None:
# check format
for obj in valid_ds:
for obj in valid_ds_for_callback:
if obj['extra']['dataset_format'] != DatasetFormats.A:
raise ValueError('Currently only support evaluation for DatasetFormats.A.')
break
best_ckpt_dir = None
if output_dir is not None:
best_ckpt_dir = os.path.join(output_dir, 'best-checkpoint')
evaluate_callback = EvaluateCallback(self, valid_ds,
evaluate_callback = EvaluateCallback(self, valid_ds_for_callback,
partial(self.evaluate, batch_size=batch_size),
save_dir=best_ckpt_dir,
push_to_hub=push_to_hub,
Expand All @@ -1519,7 +1530,7 @@ def fit(self,
model=self.backbone,
dataset_format=self.detect_dataset_format(train_ds),
train_dataset=train_ds,
eval_dataset=None,
eval_dataset=valid_ds,
loss_kwargs=loss_kwargs,
tokenizer=self.tokenizer,
args=TrainingArguments(
Expand All @@ -1530,14 +1541,15 @@ def fit(self,
learning_rate=learning_rate,
fp16=fp16,
logging_steps=logging_steps,
save_steps=save_steps,
save_strategy=save_strategy,
evaluation_strategy=evaluation_strategy if valid_ds is not None else 'no',
eval_steps=eval_steps,
save_steps=save_steps,
output_dir=output_dir,
save_total_limit=save_total_limit,
load_best_model_at_end=False,
ddp_find_unused_parameters=False if self.gpu_count > 1 else None,
label_names=AnglE.special_columns,
remove_unused_columns=False,
**argument_kwargs,
),
callbacks=callbacks,
Expand Down
38 changes: 38 additions & 0 deletions angle_emb/angle_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,20 @@
help='Specify huggingface datasets subset name for valid set, default None')
parser.add_argument('--valid_split_name', type=str, default='train',
help='Specify huggingface datasets split name for valid set, default `train`')
parser.add_argument('--valid_name_or_path_for_callback', type=str, default=None,
help='Specify huggingface datasets name or local file path for callback valid set. '
'The dataset format should be `DatasetFormats.A`. Default None.')
parser.add_argument('--valid_subset_name_for_callback', type=str, default=None,
help='Specify huggingface datasets subset name for valid set for callback use, default None')
parser.add_argument('--valid_split_name_for_callback', type=str, default='train',
help='Specify huggingface datasets split name for valid set for callback use, default `train`')
parser.add_argument('--prompt_template', type=str, default=None,
help='Specify prompt_template like "xxx: {text}", default None.'
'This prompt will be applied for all text columns.'
'If you want to specify different prompts for different text columns,'
'please handle it in the preprocessing step.')
parser.add_argument('--filter_duplicate', type=int, default=1, choices=[0, 1],
help='Specify filter_duplicate, choices [0, 1], defaut 1')
parser.add_argument('--save_dir', type=str, default=None,
help='Specify save dir, default None')
parser.add_argument('--seed', type=int, default=-1,
Expand Down Expand Up @@ -84,6 +93,11 @@
parser.add_argument('--max_steps', type=int, default=-1,
help='Specify max steps, default -1 (Automatically calculated from epochs)')
parser.add_argument('--save_steps', type=int, default=100, help='Specify save_steps, default 1000')
parser.add_argument('--save_strategy', type=str, default='steps', choices=['steps', 'epoch'],
help='Specify save_strategy, default steps')
parser.add_argument('--eval_steps', type=int, default=1000, help='Specify eval_steps, default 1000')
parser.add_argument('--evaluation_strategy', type=str, default='steps', choices=['steps', 'epoch'],
help='Specify evaluation_strategy, default steps')
parser.add_argument('--batch_size', type=int, default=32, help='Specify batch size, default 32')
parser.add_argument('--maxlen', type=int, default=512, help='Specify max length, default 512')
parser.add_argument('--streaming', action='store_true', default=False,
Expand Down Expand Up @@ -227,6 +241,25 @@ def main():
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
num_proc=args.workers)

valid_ds_for_callback = None
if valid_ds_for_callback is None and args.valid_name_or_path_for_callback is not None:
logger.info('Validation for callback detected, processing validation...')
if os.path.exists(args.valid_name_or_path_for_callback):
valid_ds_for_callback = load_dataset(
'json', data_files=[args.valid_name_or_path_for_callback], num_proc=args.workers)
else:
if args.valid_subset_name_for_callback is not None:
valid_ds_for_callback = load_dataset(
args.valid_name_or_path_for_callback,
args.valid_subset_name_for_callback,
num_proc=args.workers)
else:
valid_ds_for_callback = load_dataset(
args.valid_name_or_path_for_callback, num_proc=args.workers)
valid_ds_for_callback = valid_ds_for_callback[args.valid_split_name_for_callback or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
num_proc=args.workers)

argument_kwargs = {}
if args.push_to_hub:
assert args.hub_model_id is not None, 'Please specify hub_mode_id via --hub_model_id xxx'
Expand Down Expand Up @@ -254,11 +287,15 @@ def main():
model.fit(
train_ds=train_ds,
valid_ds=valid_ds,
valid_ds_for_callback=valid_ds_for_callback,
output_dir=args.save_dir,
batch_size=args.batch_size,
epochs=args.epochs,
learning_rate=args.learning_rate,
save_steps=args.save_steps,
save_strategy=args.save_strategy,
eval_steps=args.eval_steps,
evaluation_strategy=args.evaluation_strategy,
warmup_steps=args.warmup_steps,
logging_steps=args.logging_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
Expand All @@ -271,6 +308,7 @@ def main():
'angle_tau': args.angle_tau,
},
fp16=args.fp16,
filter_duplicate=args.filter_duplicate,
argument_kwargs=argument_kwargs,
apply_ese=args.apply_ese,
trainer_kwargs=trainer_kwargs,
Expand Down
Loading