Skip to content

Commit

Permalink
Merge pull request #1 from RahulVadisetty91/RahulVadisetty91-patch-1
Browse files Browse the repository at this point in the history
Enhancements to BERT Training Script: AI Features Integration and Bug Fixes
  • Loading branch information
RahulVadisetty91 authored Aug 23, 2024
2 parents d10dc4f + df2937e commit fe4a043
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions BERT_Training_Enhanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import argparse
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast
import torch

from .model import BERT
from .trainer import BERTTrainer
from .dataset import BERTDataset, WordVocab

# Import EarlyStopping if it's from an external module or library
from your_module_name import EarlyStopping # Replace 'your_module_name' with the actual module name

def train():
parser = argparse.ArgumentParser()

parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for training BERT")
parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluating the training set")
parser.add_argument("-v", "--vocab_path", required=True, type=str, help="path to the vocabulary model")
parser.add_argument("-o", "--output_path", required=True, type=str, help="output path for the BERT model")

parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model")
parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers")
parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads")
parser.add_argument("-s", "--seq_len", type=int, default=20, help="maximum sequence length")

parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size")
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
parser.add_argument("-w", "--num_workers", type=int, default=5, help="number of dataloader workers")

parser.add_argument("--with_cuda", type=bool, default=True, help="train with CUDA: true or false")
parser.add_argument("--log_freq", type=int, default=10, help="print loss every n iterations")
parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in the corpus")
parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device IDs")
parser.add_argument("--on_memory", type=bool, default=True, help="load data on memory: true or false")

parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of Adam")
parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight decay for Adam")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="Adam's first beta value")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="Adam's second beta value")

# New features
parser.add_argument("--dynamic_lr", type=bool, default=True, help="use dynamic learning rate adjustment")
parser.add_argument("--early_stopping", type=bool, default=True, help="enable early stopping")
parser.add_argument("--patience", type=int, default=3, help="patience for early stopping")
parser.add_argument("--mixed_precision", type=bool, default=True, help="use mixed precision training")
parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="steps for gradient accumulation")
parser.add_argument("--data_augmentation", type=bool, default=False, help="apply data augmentation techniques")

args = parser.parse_args()

print("Loading Vocab", args.vocab_path)
vocab = WordVocab.load_vocab(args.vocab_path)
print("Vocab Size: ", len(vocab))

print("Loading Train Dataset", args.train_dataset)
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len,
corpus_lines=args.corpus_lines, on_memory=args.on_memory,
data_augmentation=args.data_augmentation)

print("Loading Test Dataset", args.test_dataset)
test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \
if args.test_dataset is not None else None

print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
if test_dataset is not None else None

print("Building BERT model")
bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads)

print("Creating BERT Trainer")
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq,
mixed_precision=args.mixed_precision, grad_accumulation_steps=args.grad_accumulation_steps)

# Dynamic Learning Rate Adjustment
if args.dynamic_lr:
scheduler = ReduceLROnPlateau(trainer.optimizer, mode='min', factor=0.5, patience=args.patience, verbose=True)

# Early Stopping
early_stopping = None
if args.early_stopping:
early_stopping = EarlyStopping(patience=args.patience, verbose=True)

print("Training Start")
for epoch in range(args.epochs):
trainer.train(epoch)

if test_data_loader is not None:
test_loss = trainer.test(epoch)

if args.dynamic_lr:
scheduler.step(test_loss)

if early_stopping is not None:
early_stopping(test_loss, trainer.model)
if early_stopping.early_stop:
print("Early stopping")
break

trainer.save(epoch, args.output_path)

0 comments on commit fe4a043

Please sign in to comment.