Skip to content

Commit

Permalink
fix: batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Feb 3, 2025
1 parent 261a9b4 commit e1d53ac
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def fit(
)

if batch_size is None:
batch_size = max(min(32, len(train_texts) // 10), 512)
# Set to a multiple of 32
base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16))
batch_size = int(base_number * 32)
logger.info("Batch size automatically set to %d.", batch_size)

logger.info("Preparing train dataset.")
Expand Down

0 comments on commit e1d53ac

Please sign in to comment.