Skip to content

Commit

Permalink
Merge pull request #53 from SeanLee97/bugfix/trainer
Browse files Browse the repository at this point in the history
Bugfix/trainer
  • Loading branch information
SeanLee97 authored Feb 26, 2024
2 parents a2be3f0 + d4a91a8 commit 7beedd2
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
files = setup.py angle_emb/__init__.py
current_version = 0.3.4
current_version = 0.3.5
commit = True
tag = True
2 changes: 1 addition & 1 deletion angle_emb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .angle import *


__version__ = '0.3.4'
__version__ = '0.3.5'
1 change: 1 addition & 0 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ def __init__(self,
pooling_strategy='all',
padding_strategy=self.pooler.padding_strategy,
is_llm=False)
self.kl_loss_fct = nn.KLDivLoss(reduction='batchmean')
logger.info(f'Train with alignment, teacher={fixed_teacher_name_or_path}')

def compute_loss(self, model, inputs, return_outputs=False):
Expand Down
9 changes: 6 additions & 3 deletions angle_emb/train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,17 @@ def main():
argument_kwargs['report_to'] = 'wandb'

trainer_kwargs = None
if args.apply_tdmse:
if args.fixed_teacher_name_or_path is not None:
trainer_kwargs = {
'fixed_teacher_name_or_path': args.fixed_teacher_name_or_path
}
if args.apply_tdmse:
trainer_kwargs = dict(trainer_kwargs, **{
'apply_tdmse_kl': args.apply_tdmse_kl,
'tdmse_kl_temperature': args.tdmse_kl_temperature,
'tdmse_teacher_lambda': args.tdmse_teacher_lambda,
'tdmse_student_lambda': args.tdmse_student_lambda,
'fixed_teacher_name_or_path': args.fixed_teacher_name_or_path,
}
})

model.fit(
train_ds=train_ds,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name='angle_emb',
version='0.3.4',
version='0.3.5',
description='AnglE-optimize Text Embeddings',
long_description=readme,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 7beedd2

Please sign in to comment.