You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using XLM-R the representations change depending on the batch size.
Code sample
from fairseq.models.roberta import XLMRModel
from torchnlp.encoders.text import stack_and_pad_tensors
import torch
torch.set_printoptions(precision=10)
def batch_encoder(samples, tokenizer):
batch = []
for sequence in samples:
batch.append(tokenizer.encode(sequence))
return stack_and_pad_tensors(batch, tokenizer.task.source_dictionary.__dict__["indices"]["<pad>"])
xlmr = XLMRModel.from_pretrained(
"pretrained/xlmr.base", checkpoint_file="model.pt"
)
xlmr.eval()
samples = [
'the part of the regular expression within the forward slashes defines the pattern.',
'discards the current state and temporarily replaces it with the previous state.',
'to convert a smooth point to a corner point without direction lines, click the smooth point.'
]
with torch.no_grad():
big_batch_tokens, bb_lengths = batch_encoder(samples, xlmr)
small_batch_tokens, sb_lengths = batch_encoder(samples[:2], xlmr)
first_sample_tokens = xlmr.encode(samples[0])
first_sample_last_layer = xlmr.extract_features(first_sample_tokens)
print (first_sample_last_layer[:, 0, :][0][:5])
small_batch_last_layer = xlmr.extract_features(tokens=small_batch_tokens)
print (small_batch_last_layer[:, 0, :][0][:5])
big_batch_last_layer = xlmr.extract_features(tokens=big_batch_tokens)
print (big_batch_last_layer[:, 0, :][0][:5])
This seems to be floating point math issue.
I get similar range of difference when trying on CPU, but on GPU it seems to be exactly the same till 10th digit.
Some discussion on pytorch thread: pytorch/pytorch#4914 (although that one has floating point issues on CUDA rather than CPU).
🐛 Bug
When using XLM-R the representations change depending on the batch size.
Code sample
Expected behavior
Additional context
If I decide to average pool overall embeddings or if I max pool these differences are even bigger.
Am I doing something wrong? Is this behaviour expected?
The text was updated successfully, but these errors were encountered: