diff --git a/tools/customization_dataset_preparation/customization_dataset_preparation.py b/tools/customization_dataset_preparation/customization_dataset_preparation.py index 54f74bfa2c48..9a83f61f60e6 100644 --- a/tools/customization_dataset_preparation/customization_dataset_preparation.py +++ b/tools/customization_dataset_preparation/customization_dataset_preparation.py @@ -119,13 +119,31 @@ def recommend_hyperparameters(df, model=None): # every token is around 4 chars + 100 for extra capacity max_seq_length = max_char_length // 4 + 100 + + if len(df) <= 100: + encoder_hidden_size = 1024 + elif len(df) <= 1000: + encoder_hidden_size = 2048 + else: + encoder_hidden_size = 4096 + + if len(df) <= 100: + lr = 5e-3 + elif len(df) <= 1000: + lr = 1e-3 + elif len(df) <= 10000: + lr = 5e-4 + else: + lr = 1e-4 + return { 'batch_size': bs, 'max_batch_size': max_bs, 'num_virtual_tokens': 10, - 'lr': 0.0001, - 'epochs': 25, + 'lr': lr, + 'epochs': 10, 'max_seq_length': max_seq_length, + 'encoder_hidden_size': encoder_hidden_size, }