Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] KeyError: None #618

Open
TinaChen95 opened this issue Oct 12, 2019 · 13 comments
Open

Comments

@TinaChen95
Copy link

🐛 Bug

Describe the bug
I came across this error when using data.Field. It only happen when I define my own unk_token and set min_freq >1 at the same time.

To Reproduce
the code I use:
`SRC = data.Field(lower=True, unk_token="my_unk_token")
TGT = data.Field(lower=True)

train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(SRC, TGT))

SRC.build_vocab(train, min_freq=10)

train_iter = data.BucketIterator(dataset=train, batch_size=64,
sort_key=lambda x: data.interleave_keys(len(x.src), len(x.trg)))

batch = next(iter(train_iter))`

@mttk
Copy link
Contributor

mttk commented Oct 24, 2019

This happens due to redefining the unk token which is hardcoded at one point. It's a pretty big bug, thanks for spotting it. I will issue a fix soon.

@icmpnorequest
Copy link

When running on my own custom dataset using torchtext, it runs successfully on train_dataset, but for evaluating, it shows the error:

Traceback (most recent call last):
  File "/Users/baseline_model_torchtext.py", line 201, in <module>
    for i, batch in enumerate(valid_iter):
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/iterator.py", line 156, in __iter__
    yield Batch(minibatch, self.dataset, self.device)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/batch.py", line 34, in __init__
    setattr(self, name, field.process(batch, device=device))
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 237, in process
    tensor = self.numericalize(padded, device=device)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 336, in numericalize
    arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 336, in <listcomp>
    arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 336, in <listcomp>
    arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
KeyError: "'meal',"

Could somebody help me fix this issue?
Thanks in advance!

@mttk
Copy link
Contributor

mttk commented Nov 18, 2019

Your vocab is likely initialized without an UNK token. Can you paste the initialization of the Fields for the dataset?

@icmpnorequest
Copy link

@mttk
Here is my initialization of the Fields code:

# 1. data.Field()
TEXT = data.Field(include_lengths=True, pad_token='<pad>', unk_token='<unk>')
TAG_LABEL = data.LabelField()
AGE_LABEL = data.LabelField()
GENDER_LABEL = data.LabelField()

# 2. data.TabularDataset
train_data, test_data = data.TabularDataset.splits(path=TrustPilot_processed_dataset_path,
                                                   train="train_data.csv",
                                                   test="test_data.csv",
                                                   fields=[('text', TEXT), ('tag_label', TAG_LABEL),
                                                           ('age_label', AGE_LABEL), ('gender_label', GENDER_LABEL)],
                                                   format="csv")

# 3. Split train_data to train_data, valid_data
train_data, valid_data = train_data.split(random_state=random.seed(SEED))
print("Number of train_data = {}".format(len(train_data)))
print("Number of valid_data = {}".format(len(valid_data)))
print("Number of test_data = {}\n".format(len(test_data)))

# 4. data.BucketIterator
train_iter, valid_iter, test_iter = data.BucketIterator.splits((train_data, valid_data, test_data),
                                                               batch_size=BATCH_SIZE,
                                                               device=device,
                                                               sort_key=lambda x: len(x.text))

# 5. Build vocab
TEXT.build_vocab(train_data)
TAG_LABEL.build_vocab(train_data)
AGE_LABEL.build_vocab(train_data)
GENDER_LABEL.build_vocab(train_data)

It runs successfully with training dataset, however, with validating dataset / testing dataset, it shows error:

Traceback (most recent call last):
  File "baseline_model_torchtext.py", line 202, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/iterator.py", line 156, in __iter__
    yield Batch(minibatch, self.dataset, self.device)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/batch.py", line 34, in __init__
    setattr(self, name, field.process(batch, device=device))
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 237, in process
    tensor = self.numericalize(padded, device=device)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 338, in numericalize
    arr = [self.vocab.stoi[x] for x in arr]
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 338, in <listcomp>
    arr = [self.vocab.stoi[x] for x in arr]
KeyError: "('DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'NOUN', 'ADP', 'DET', 'NOUN', 'NOUN', 'VERB', 'ADJ', 'NOUN')"

The testing code is below:

########## Evaluation ##########
model.eval()
total_correct = 0
avg_loss = 0.0
for i, batch in enumerate(valid_iter):
    text, text_lengths = batch.text

    y = batch.tag_label

    # Forward pass
    y_pred = model(text, text_lengths)
    loss = criterion(y_pred, y)
    avg_loss += loss.item()

    # _, pred = torch.max(output.data, 1)
    pred = torch.argmax(y_pred.data, dim=1)
    total_correct += (pred == y).sum().item()

avg_loss = avg_loss / len(valid_data)
print("Test Avg. Loss: {}, Accuracy: {}%"
      .format(avg_loss, 100 * total_correct / len(valid_data)))

@mttk
Copy link
Contributor

mttk commented Nov 18, 2019

From what I see, you have token-wise labels (the POS tags). Since the LabelField assumes there is no tokenization (a single label), it treats this "('DET', 'NOUN', 'ADP', 'DET', 'NOUN', 'NOUN', 'ADP', 'DET', 'NOUN', 'NOUN', 'VERB', 'ADJ', 'NOUN')" as a single string.
Since that exact sequence of POS tags wasn't seen in the training data, and LabelFields don't use unk tokens, this error occurs.

To fix this, for every output Field that contains sequential data (in this case, I assume that is the TAG_LABEL), instead define it by

def my_tokenize_function(string):
    # complete function to tokenize a line of POS_LABEL data
    # If I see correctly, this is 1. strip brackets; 2. comma split 3. strip `
    pass

POS_LABEL = data.Field(unk_token=None, tokenize=my_tokenize_function, is_target=True)

@icmpnorequest
Copy link

@mttk
Thank you so much for the detailed guidance on LabelField.
I tried building my_tokenize_function, but still with bugs. Finally solved by formatting the POS-Tagging dataset with the proper format in csv:

word, tag, attribute1, attribute 2
...

And using LabelField, it works!

@mttk
Copy link
Contributor

mttk commented Nov 20, 2019

Great, I'm glad that I inadvertendly helped!

@icmpnorequest
Copy link

Hi @mttk ,
Meet the same issue again.

Error:

 Traceback (most recent call last):
  File "multiview_sep_v1.py", line 1177, in <module>
    = train_valid(genderModel=genderModel, ageModel=None, task=args.task)
  File "multiview_sep_v1.py", line 738, in train_valid
    for i, batch in enumerate(valid_iter):
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/iterator.py", line 156, in __iter__
    yield Batch(minibatch, self.dataset, self.device)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/batch.py", line 34, in __init__
    setattr(self, name, field.process(batch, device=device))
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 237, in process
    tensor = self.numericalize(padded, device=device)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 336, in numericalize
    arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 336, in <listcomp>
    arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torchtext/data/field.py", line 336, in <listcomp>
    arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
KeyError: 4521

I used the BertTokenizer from hugging face's transformers. Here is my Field code:


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
max_input_length = 300
init_token_idx = tokenizer.cls_token_id
eos_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id


def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence)
    tokens = tokens[:max_input_length - 2]
    return tokens


# 1. data.Field()
# 1. data.Field()
TEXT = data.Field(batch_first=True,
                  tokenize=tokenize_and_cut,
                  preprocessing=tokenizer.convert_tokens_to_ids,
                  init_token=init_token_idx,
                  eos_token=eos_token_idx,
                  pad_token=pad_token_idx,
                  unk_token=unk_token_idx)

SENTI_TOKENS = data.Field(batch_first=True,
                          tokenize=tokenize_and_cut,
                          # preprocessing=tokenizer.convert_tokens_to_ids,
                          init_token=init_token_idx,
                          eos_token=eos_token_idx,
                          pad_token=pad_token_idx,
                          unk_token=unk_token_idx,
                          fix_length=FIX_LENGTH
                          )
TOPIC_TOKENS = data.Field(batch_first=True,
                          tokenize=tokenize_and_cut,
                          # preprocessing=tokenizer.convert_tokens_to_ids,
                          init_token=init_token_idx,
                          eos_token=eos_token_idx,
                          pad_token=pad_token_idx,
                          unk_token=unk_token_idx,
                          fix_length=FIX_LENGTH
                          )

Any ideas to fix this issue?

@mttk
Copy link
Contributor

mttk commented Feb 21, 2020

Can you try setting use_vocab = False in the Fields where you use the HF tokenizer?

Right now, you use the HF tokenizer to convert tokens to IDs, and torchtext Fields by default construct a vocabulary (and expect strings as keys). You don't need a vocab because you're already using the pretrained one from HF so you can just disable it in torchtext.

@icmpnorequest
Copy link

@mttk
Thank you so much, fix it successfully by adding use_vocab=False when using HuggingFace's transformers models.

@vection
Copy link

vection commented Jul 13, 2020

Hi @mttk
I see this case is still open and I have the same issue still.
It happens when I don't have word from validation in training vocab.
Seems like its not replacing unkown words by unk token.

Any ideas?

KeyError:

262 with torch.no_grad():
--> 263 for batch_idx, batch_data in enumerate(data_loader):
264 text, text_lengths = batch_data.content
265 logits = self.model(text, text_lengths)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torchtext/data/iterator.py in iter(self)
154 else:
155 minibatch.sort(key=self.sort_key, reverse=True)
--> 156 yield Batch(minibatch, self.dataset, self.device)
157 if not self.repeat:
158 return

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torchtext/data/batch.py in init(self, data, dataset, device)
32 if field is not None:
33 batch = [getattr(x, name) for x in data]
---> 34 setattr(self, name, field.process(batch, device=device))
35
36 @classmethod

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torchtext/data/field.py in process(self, batch, device)
235 """
236 padded = self.pad(batch)
--> 237 tensor = self.numericalize(padded, device=device)
238 return tensor
239

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torchtext/data/field.py in numericalize(self, arr, device)
336 arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
337 else:
--> 338 arr = [self.vocab.stoi[x] for x in arr]
339
340 if self.postprocessing is not None:

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torchtext/data/field.py in (.0)
336 arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
337 else:
--> 338 arr = [self.vocab.stoi[x] for x in arr]
339
340 if self.postprocessing is not None:

KeyError: 'poly fill'

Preprocessing:

    TEXT = data.Field(sequential=True,
                      tokenize='moses',
                      include_lengths=True)  # necessary for packed_padded_sequence

    LABEL = data.LabelField(dtype=torch.float)

    fields = [('content', TEXT), ('classlabel', LABEL)]

    # Dataset
    train_dataset = data.TabularDataset(
        path=self.train_csv, format='csv',
        skip_header=True, fields=fields)

    valid_dataset = data.TabularDataset(
        path=self.valid_csv, format='csv',
        skip_header=True, fields=fields)

    # Building vocab
    TEXT.build_vocab(train_dataset,
                     vectors='glove.840B.300d',
                     max_size=self.vocab_size,
                     unk_init=torch.Tensor.normal_)

    LABEL.build_vocab(train_dataset)

    # Data loaders
    train_loader, valid_loader = data.BucketIterator.splits(
        (train_dataset, valid_dataset),
        batch_size=self.bs,
        sort_within_batch=True,
        sort_key=lambda x: len(x.content),
        device=self.device)

@binhna
Copy link

binhna commented Sep 18, 2020

I think it's weird in the numericalize function in Filed object. Instead of calling [[self.vocab[x] for x in ex] for ex in arr], you guys replace self.vocab[x] by self.vocab.stoi[x]. If you use the self.vocab[x], it will return UNK index if the token doesn't appear in the vocab, but with self.vocab.stoi[x], it will return error.

@zhangguanheng66
Copy link
Contributor

Will be retired soon. #985

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants