Skip to content

Commit

Permalink
rename confusing variable name in LM tutorial (#1953)
Browse files Browse the repository at this point in the history
The tutorial "Language Modeling with nn.Transformer and TorchText" contains
code snippets with variables named `batch_size`. The issue is that in some
places, `batch_size` means that number of sequences in a batch, and in other
places it means the number of tokens in each batch sequence. This inconsistency
was solved in this commit: `batch_size` was replaced with `seq_len` in the two
places where it has the latter meaning.

Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
  • Loading branch information
sliorde and Svetlana Karslioglu authored Aug 3, 2022
1 parent ae22720 commit d5f7a40
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions beginner_source/transformer_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def train(model: nn.Module) -> None:
num_batches = len(train_data) // bptt
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i)
batch_size = data.size(0)
if batch_size != bptt: # only on last batch
src_mask = src_mask[:batch_size, :batch_size]
seq_len = data.size(0)
if seq_len != bptt: # only on last batch
src_mask = src_mask[:seq_len, :seq_len]
output = model(data, src_mask)
loss = criterion(output.view(-1, ntokens), targets)

Expand Down Expand Up @@ -327,12 +327,12 @@ def evaluate(model: nn.Module, eval_data: Tensor) -> float:
with torch.no_grad():
for i in range(0, eval_data.size(0) - 1, bptt):
data, targets = get_batch(eval_data, i)
batch_size = data.size(0)
if batch_size != bptt:
src_mask = src_mask[:batch_size, :batch_size]
seq_len = data.size(0)
if seq_len != bptt:
src_mask = src_mask[:seq_len, :seq_len]
output = model(data, src_mask)
output_flat = output.view(-1, ntokens)
total_loss += batch_size * criterion(output_flat, targets).item()
total_loss += seq_len * criterion(output_flat, targets).item()
return total_loss / (len(eval_data) - 1)

######################################################################
Expand Down

0 comments on commit d5f7a40

Please sign in to comment.