def train_batch(model, loss_fn, optimizer, batch):
    """Train the model on a single batch of data,
    and return the total loss

    """
    optimizer.zero_grad()

    x = batch[0].cuda()

    pred = model(x)
    loss_batch = loss_fn(pred, x)
    loss_batch.backward()
    optimizer.step()

    return loss_batch.item()


def train(model, loss_fn, optimizer, loader, logger):
    """Train the model over all batches in a given dataset,
    and return the total loss

    """
    model.train()

    loss_acc = 0.0
    batches = 0
    for idx, batch in enumerate(loader):
        loss = train_batch(model, loss_fn, optimizer, batch)

        logger.batch(idx, loss)

        loss_acc += loss
        batches += 1

    return loss_acc / batches