Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trouble tracing why convergence is slower in Lightning #2643

Closed
rbracco opened this issue Jul 18, 2020 · 6 comments
Closed

Trouble tracing why convergence is slower in Lightning #2643

rbracco opened this issue Jul 18, 2020 · 6 comments
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@rbracco
Copy link
Contributor

rbracco commented Jul 18, 2020

I recently refactored some code from [this tutorial])(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch) (trains speech-to-text using librispeech 100 hr) into Lightning and found it to be converging slower and never reaching the same level of loss. I made a lot of changes when I refactored into Pytorch lightning, and I slowly undid them until I was left with the original code and the lightning version. I even set comet to be the logger so that they had the same logging, and the results are the same. I can't figure out what's going on. As far as I can tell, the code is identical (except for the training loop) and I don't know how to proceed. (Edit: I've now added the training loop code from both notebooks at the bottom of this post for ease of access)

Here are the notebooks: The final few cells (train/test/main functions) are the relevant/distinct portion.
Non-lightning Notebook - converges faster
Lightning Notebook - converges slower

  • I am using torch 1.4.0 and torchaudio 0.4.0.
  • Train loss vs step in both versions, other metrics (char error rate and word error rate) are also worse.
  • The most recent comparison (visualized below) didn't use a seed, but I have run many variations, including ones where torch/np use the same seed, and the results were similar.

image
)

Non-Lightning Version:

def train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, experiment):
    model.train()
    data_len = len(train_loader.dataset)
    with experiment.train():
        for batch_idx, _data in enumerate(train_loader):
            spectrograms, labels, input_lengths, label_lengths = _data 
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            optimizer.zero_grad()

            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1) # (time, batch, n_class)

            loss = criterion(output, labels, input_lengths, label_lengths)
            loss.backward()

            experiment.log_metric('loss', loss.item(), step=iter_meter.get())
            experiment.log_metric('learning_rate', scheduler.get_lr(), step=iter_meter.get())

            optimizer.step()
            scheduler.step()
            iter_meter.step()
            if batch_idx % 100 == 0 or batch_idx == data_len:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(spectrograms), data_len,
                    100. * batch_idx / len(train_loader), loss.item()))
                # if(batch_idx >= 100):
                #   print("Exiting Early")
                #   break


def test(model, device, test_loader, criterion, epoch, iter_meter, experiment):
    print('\nevaluating...')
    model.eval()
    test_loss = 0
    test_cer, test_wer = [], []
    with experiment.test():
        with torch.no_grad():
            for i, _data in enumerate(test_loader):
                spectrograms, labels, input_lengths, label_lengths = _data 
                spectrograms, labels = spectrograms.to(device), labels.to(device)

                output = model(spectrograms)  # (batch, time, n_class)
                output = F.log_softmax(output, dim=2)
                output = output.transpose(0, 1) # (time, batch, n_class)

                loss = criterion(output, labels, input_lengths, label_lengths)
                test_loss += loss.item() / len(test_loader)

                decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)
                for j in range(len(decoded_preds)):
                    print("\nTarget:", decoded_targets[j])
                    print("\nPredicted:", decoded_preds[j])
                    test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
                    test_wer.append(wer(decoded_targets[j], decoded_preds[j]))


    avg_cer = sum(test_cer)/len(test_cer)
    avg_wer = sum(test_wer)/len(test_wer)
    experiment.log_metric('test_loss', test_loss, step=iter_meter.get())
    experiment.log_metric('cer', avg_cer, step=iter_meter.get())
    experiment.log_metric('wer', avg_wer, step=iter_meter.get())

    print('Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n'.format(test_loss, avg_cer, avg_wer))


def main(learning_rate=5e-4, batch_size=20, epochs=10,
        train_url="train-clean-100", test_url="test-clean",
        experiment=Experiment(api_key='dummy_key', disabled=True)):

    hparams = {
        "n_cnn_layers": 3,
        "n_rnn_layers": 5,
        "rnn_dim": 512,
        "n_class": 29,
        "n_feats": 128,
        "stride":2,
        "dropout": 0.1,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "epochs": epochs
    }

    experiment.log_parameters(hparams)

    use_cuda = torch.cuda.is_available()
    torch.manual_seed(7)
    device = torch.device("cuda" if use_cuda else "cpu")

    if not os.path.isdir("./data"):
        os.makedirs("./data")

    train_dataset = torchaudio.datasets.LIBRISPEECH("./data", url=train_url, download=True)
    test_dataset = torchaudio.datasets.LIBRISPEECH("./data", url=test_url, download=True)

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = data.DataLoader(dataset=train_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=True,
                                collate_fn=lambda x: data_processing(x, 'train'),
                                **kwargs)
    test_loader = data.DataLoader(dataset=test_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=False,
                                collate_fn=lambda x: data_processing(x, 'valid'),
                                **kwargs)

    model = SpeechRecognitionModel(
        hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
        hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
        ).to(device)

    print(model)
    print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

    optimizer = optim.AdamW(model.parameters(), hparams['learning_rate'])
    criterion = nn.CTCLoss(blank=28).to(device)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hparams['learning_rate'], 
                                            steps_per_epoch=int(len(train_loader)),
                                            epochs=hparams['epochs'],
                                            anneal_strategy='linear')
    
    iter_meter = IterMeter()
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, experiment)
        test(model, device, test_loader, criterion, epoch, iter_meter, experiment)

learning_rate = 5e-4
batch_size = 20
epochs = 10
libri_train_set = "train-clean-100"
libri_test_set = "test-clean"

main(learning_rate, batch_size, epochs, libri_train_set, libri_test_set, experiment)

Lightning Version:


class SpeechTrainModel(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.model = SpeechRecognitionModel(
        hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
        hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
        )
        self.criterion = nn.CTCLoss(blank=28)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_nb):
        spectrograms, labels, input_lengths, label_lengths = batch 
        output = self(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1) # (time, batch, n_class)

        loss = self.criterion(output, labels, input_lengths, label_lengths)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):

        spectrograms, labels, input_lengths, label_lengths = batch 

        output = self(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1) # (time, batch, n_class)

        loss = self.criterion(output, labels, input_lengths, label_lengths)

        # decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)
        # for j in range(len(decoded_preds)):
        #     test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
        #     test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
        return {'val_loss': loss}
    
    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    
    def prepare_data(self):
        if not os.path.isdir("./data"):
            os.makedirs("./data")
        self.train_dataset = torchaudio.datasets.LIBRISPEECH("./data", url="train-clean-100", download=True)
        self.test_dataset = torchaudio.datasets.LIBRISPEECH("./data", url="test-clean", download=True)

    def configure_optimizers(self):
        hparams = self.hparams
        optimizer = optim.AdamW(self.parameters(), hparams['learning_rate'])
        
        scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hparams['learning_rate'], 
                        steps_per_epoch=int(len(self.train_dataset)),
                        epochs=hparams['epochs'],
                        anneal_strategy='linear')

        scheduler_dict = {'scheduler': scheduler, 'interval': 'step'}

        return [optimizer], [scheduler_dict]


    def train_dataloader(self):
        self.train_loader = data.DataLoader(dataset=self.train_dataset,
                    batch_size=self.hparams['batch_size'],
                    shuffle=True,
                    collate_fn=lambda x: data_processing(x, 'train'),
                                )
        return self.train_loader
        
    def val_dataloader(self):
        return data.DataLoader(dataset=self.test_dataset,
                    batch_size=self.hparams['batch_size'],
                    shuffle=False,
                    collate_fn=lambda x: data_processing(x, 'valid'),
                                )

learning_rate=5e-4
batch_size=20
epochs=10

hparams = {
    "n_cnn_layers": 3,
    "n_rnn_layers": 5,
    "rnn_dim": 512,
    "n_class": 29,
    "n_feats": 128,
    "stride":2,
    "dropout": 0.1,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "epochs": epochs
}

mdl = SpeechTrainModel(hparams)
trainer = pl.Trainer(gpus=1, max_epochs=epochs, logger=comet_logger)    
trainer.fit(mdl) 
@rbracco rbracco added bug Something isn't working help wanted Open to be worked on labels Jul 18, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@rohitgr7
Copy link
Contributor

Convergence problem.
Try:

def configure_optimizers(self):
        hparams = self.hparams
        optimizer = optim.AdamW(self.parameters(), hparams['learning_rate'])
        
        scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hparams['learning_rate'], 
                        steps_per_epoch=int(len(self.train_dataset)),
                        epochs=hparams['epochs'],
                        anneal_strategy='linear')

        scheduler_dict = {'scheduler': scheduler, 'interval': 'step'}

        return [optimizer], [scheduler_dict]

@rbracco
Copy link
Contributor Author

rbracco commented Jul 19, 2020

Thank you for the quick reply and suggestion, I changed that and tracked the results. I do not believe it fixed the problem. It appears to have improved things slightly but maybe it is variance as I am currently using colab(differing hardware) and not seeding. If there is anything you'd suggest I run to ensure reproducibility please let me know. Here are the results:

Train loss vs step
image

Train_loss vs duration
image

As you can see train loss vs step seems to be about the same as before. Train loss vs duration improves and more closely matches the non-lightning version, but it could also be that colab allocated better hardware for that session (or that in the previous training, I was running both at once, I'm not sure what the impact of that is).

@rohitgr7
Copy link
Contributor

Ok, found 2 more mistakes.

  1. In the non-lightning version, your iter_meter does take a step in test loop, but in lightning, this step is updated in val_loop too. Maybe this is creating a mismatch in the above graph since step is on the x-axis.
  2. In scheduler, in non-lightning version your scheduler takes steps_per_epoch=int(len(train_loader)) but in lightning version it is steps_per_epoch=int(len(self.train_dataset)).

@rbracco
Copy link
Contributor Author

rbracco commented Jul 19, 2020

Thank you Rohit, you are a lifesaver. It was #2, updating the lightning version by changing steps_per_epoch in the scheduler to the length of the train loader fixed the problem. I really appreciate you taking the time to help.

image

@rbracco rbracco closed this as completed Jul 19, 2020
@rohitgr7
Copy link
Contributor

Np. Also there is pl.seed_everything in lightning you can use to reproduce your results.
https://github.com/PyTorchLightning/pytorch-lightning/blob/ed581eb64f105a39c1516527cd733a93f033edcb/pytorch_lightning/utilities/seed.py#L13-L18

BTW thanks for blog link above. Was looking for something similar to start with speech. It is quite good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

No branches or pull requests

2 participants