Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Scaled Dot Product Attention matmul error #5345

Closed
10 tasks done
gabeorlanski opened this issue Aug 8, 2021 · 2 comments · Fixed by #5360
Closed
10 tasks done

Scaled Dot Product Attention matmul error #5345

gabeorlanski opened this issue Aug 8, 2021 · 2 comments · Fixed by #5360
Assignees
Labels

Comments

@gabeorlanski
Copy link
Contributor

Checklist

  • I have verified that the issue exists against the main branch of AllenNLP.
  • I have read the relevant section in the contribution guide on reporting bugs.
  • I have checked the issues list for similar or identical bug reports.
  • I have checked the pull requests list for existing proposed fixes.
  • I have checked the CHANGELOG and the commit log to find out if the bug was already fixed in the main branch.
  • I have included in the "Description" section below a traceback from any exceptions related to this bug.
  • I have included in the "Related issues or possible duplicates" section beloew all related issues and possible duplicate issues (If there are none, check this box anyway).
  • I have included in the "Environment" section below the name of the operating system and Python version that I was using when I discovered this bug.
  • I have included in the "Environment" section below the output of pip freeze.
  • I have included in the "Steps to reproduce" section below a minimally reproducible example.

Description

When trying to use the ScaledDotProductAttention with the AutoRegressiveSeqDecoder from allennlp-models, a mat mul error is raised stating that the dimensions do not align.

Python traceback:

Traceback (most recent call last):
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\click\core.py", line 782, in main
    rv = self.invoke(ctx)
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\click\core.py", line 1066, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\click\core.py", line 610, in invoke
    return callback(*args, **kwargs)
  File "C:/Coding/merlin_labs/nn-semparse/debug_scripts/debug_train.py", line 51, in debug_train
    disable_tracking=disable_tracking
  File "C:\Coding\merlin_labs\nn-semparse\src\commands\train_extended.py", line 359, in train_model
    file_friendly_logging=file_friendly_logging,
  File "C:\Coding\merlin_labs\nn-semparse\src\commands\train_extended.py", line 586, in _train_worker
    metrics = train_loop.run()
  File "C:\Coding\merlin_labs\nn-semparse\src\commands\train_extended.py", line 658, in run
    return self.trainer.train()
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\allennlp\training\gradient_descent_trainer.py", line 706, in train
    metrics, epoch = self._try_train()
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\allennlp\training\gradient_descent_trainer.py", line 727, in _try_train
    train_metrics = self._train_epoch(epoch)
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\allennlp\training\gradient_descent_trainer.py", line 458, in _train_epoch
    batch_outputs = self.batch_outputs(batch, for_training=True)
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\allennlp\training\gradient_descent_trainer.py", line 351, in batch_outputs
    output_dict = self._pytorch_model(**batch)
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Coding\merlin_labs\nn-semparse\src\models\simple_seq2seq.py", line 126, in forward
    outputs = self._decoder(state, target_tokens)
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Coding\merlin_labs\nn-semparse\src\models\modules\extendable_auto_regressive.py", line 451, in forward
    output_dict = self._forward_loss(state_forward_loss, target_tokens)
  File "C:\Coding\merlin_labs\nn-semparse\src\models\modules\extendable_auto_regressive.py", line 245, in _forward_loss
    effective_last_prediction, state)
  File "C:\Coding\merlin_labs\nn-semparse\src\models\modules\extendable_auto_regressive.py", line 321, in _prepare_output_projections
    previous_steps_predictions=previous_steps_predictions,
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\allennlp_models\generation\modules\decoder_nets\lstm_cell.py", line 118, in forward
    decoder_hidden, encoder_outputs, source_mask
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\allennlp_models\generation\modules\decoder_nets\lstm_cell.py", line 71, in _prepare_attended_input
    input_weights = self._attention(decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\allennlp\modules\attention\attention.py", line 45, in forward
    similarities = self._forward_internal(vector, matrix)
  File "C:\ProgramData\Miniconda3\envs\nn-semparse\lib\site-packages\allennlp\modules\attention\scaled_dot_product_attention.py", line 31, in _forward_internal
    scores = torch.matmul(vector, matrix)
RuntimeError: mat1 dim 1 must match mat2 dim 0

Upon looking into the code for ScaledDotProductAttention, I noticed that it is missing the transpose from the equation 1 in the Attention Is All You Need. This appears to be fixed with this PR but it was closed without merging. That issue is addressed by using matrix.bmm(vector.unsqueeze(-1)).squeeze(-1) instead of the torch.matmul(vector, matrix) that is currently present. It is worth noting that the normal DotProductAttention uses matrix.bmm(vector.unsqueeze(-1)).squeeze(-1) instead of torch.matmul(vector, matrix).

I can open a PR with the fix from the original pull request, but I am not sure it would be easier than just reopening the original PR and merging.

Related issues or possible duplicates

Environment

OS: Windows

Python version: 3.7.10

Output of pip freeze:

allennlp==2.5.0
allennlp-models==2.5.0
argon2-cffi==20.1.0
async-generator==1.10
atomicwrites==1.4.0
attrs==21.2.0
backcall==0.2.0
backports.csv==1.0.7
beautifulsoup4==4.9.3
bleach==3.3.0
blis==0.7.4
boto3==1.17.101
botocore==1.20.101
cached-property==1.5.2
cachetools==4.2.2
catalogue==2.0.4
certifi==2021.5.30
cffi==1.14.5
chardet==4.0.0
checklist==0.0.11
cheroot==8.5.2
CherryPy==18.6.0
click==7.1.2
colorama==0.4.4
configparser==5.0.2
conllu==4.4
coverage @ file:///C:/ci/coverage_1614614910274/work
cryptography==3.4.7
cymem==2.0.5
decorator==5.0.9
defusedxml==0.7.1
dill==0.3.4
docker-pycreds==0.4.0
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0-py3-none-any.whl
entrypoints==0.3
feedparser==6.0.8
filelock==3.0.12
ftfy==6.0.3
future==0.18.2
gitdb==4.0.7
GitPython==3.1.18
google-api-core==1.30.0
google-auth==1.32.0
google-cloud-core==1.7.1
google-cloud-storage==1.38.0
google-crc32c==1.1.2
google-resumable-media==1.3.1
googleapis-common-protos==1.53.0
h5py==3.3.0
huggingface-hub==0.0.8
idna==2.10
importlib-metadata==4.6.0
iniconfig==1.1.1
ipykernel==5.5.5
ipython==7.25.0
ipython-genutils==0.2.0
ipywidgets==7.6.3
iso-639==0.4.5
jaraco.classes==3.2.1
jaraco.collections==3.3.0
jaraco.functools==3.3.0
jaraco.text==3.5.0
jedi==0.18.0
Jinja2==3.0.1
jmespath==0.10.0
joblib==1.0.1
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==6.1.12
jupyter-console==6.4.0
jupyter-core==4.7.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
lmdb==1.2.1
lxml==4.6.3
MarkupSafe==2.0.1
matplotlib-inline==0.1.2
mistune==0.8.4
more-itertools==8.8.0
munch==2.5.0
murmurhash==1.0.5
nbclient==0.5.3
nbconvert==6.1.0
nbformat==5.1.3
nest-asyncio==1.5.1
nltk==3.6.2
notebook==6.4.0
numpy==1.21.0
overrides==3.1.0
packaging==20.9
pandocfilters==1.4.3
parso==0.8.2
pathtools==0.1.2
pathy==0.6.0
patternfork-nosql==3.6
pdfminer.six==20201018
pickleshare==0.7.5
Pillow==8.2.0
pluggy==0.13.1
portalocker==2.0.0
portend==2.7.1
preshed==3.0.5
prometheus-client==0.11.0
promise==2.3
prompt-toolkit==3.0.19
protobuf==3.17.3
psutil==5.8.0
py==1.10.0
py-rouge==1.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
pydantic==1.7.4
Pygments==2.9.0
pyparsing==2.4.7
pyrsistent==0.17.3
pytest==6.2.4
python-dateutil==2.8.1
python-docx==0.8.11
pytz==2021.1
pywin32==301
pywinpty==1.1.3
PyYAML==5.4.1
pyzmq==22.1.0
qtconsole==5.1.0
QtPy==1.9.0
regex==2021.4.4
requests==2.25.1
rsa==4.7.2
s3transfer==0.4.2
sacrebleu==1.5.1
sacremoses==0.0.45
scikit-learn==0.24.2
scipy==1.7.0
Send2Trash==1.7.1
sentencepiece==0.1.96
sentry-sdk==1.1.0
sgmllib3k==1.0.0
shortuuid==1.0.1
six==1.16.0
sklearn==0.0
smart-open==5.1.0
smmap==4.0.0
sortedcontainers==2.4.0
soupsieve==2.2.1
spacy==3.0.6
spacy-legacy==3.0.6
srsly==2.4.1
subprocess32==3.5.4
tempora==4.1.1
tensorboardX==2.3
termcolor==1.1.0
terminado==0.10.1
testpath==0.5.0
thinc==8.0.6
threadpoolctl==2.1.0
tokenizers==0.10.3
toml==0.10.2
torch==1.8.1+cu102
torchaudio==0.8.1
torchvision==0.9.1+cu102
tornado==6.1
tqdm==4.61.1
traitlets==5.0.5
transformers==4.6.1
typer==0.3.2
typing-extensions==3.10.0.0
urllib3==1.25.11
wandb==0.10.33
wasabi==0.8.2
wcwidth==0.2.5
webencodings==0.5.1
widgetsnbextension==3.5.1
wincertstore==0.2
word2number==1.1
yapf==0.31.0
zc.lockfile==2.0
zipp==3.4.1

Steps to reproduce

To reproduce, use the same setup as the semantic parsing section in Part 3 of the guide but make a small modification to the training config such that it uses the scaled dot product attention.

Example source:

{
  "dataset_reader": {
    "type": "seq2seq",
    "source_tokenizer": {
      "type": "whitespace"
    },
    "target_tokenizer": {
      "type": "whitespace"
    },
    "source_token_indexers": {
      "tokens": {
        "type": "single_id",
        "namespace": "source_tokens"
      }
    },
    "target_token_indexers": {
      "tokens": {
        "namespace": "target_tokens"
      }
    }
  },
  "train_data_path": "data/nla_with_meaning_rep_train.tsv",
  "validation_data_path": "data/nla_with_meaning_rep_dev.tsv",
  "model": {
    "type": "composed_seq2seq",
    "source_text_embedder": {
      "token_embedders": {
        "tokens": {
          "type": "embedding",
          "vocab_namespace": "source_tokens",
          "embedding_dim": 100,
          "trainable": true
        }
      }
    },
    "encoder": {
      "type": "lstm",
      "input_size": 100,
      "hidden_size": 50,
      "num_layers": 1
    },
    "decoder": {
      "decoder_net": {
         "type": "lstm_cell",
         "decoding_dim": 50,
         "target_embedding_dim": 50,
         "attention": {
                    "type": "scaled_dot_product",
                    "scaling_factor": 3
         },
      },
      "max_decoding_steps": 50,
      "target_namespace": "target_tokens",
      "target_embedder": {
        "vocab_namespace": "target_tokens",
        "embedding_dim": 50
      },
      "scheduled_sampling_ratio": 0.5,
      "beam_size": 10,
      "token_based_metric": "nla_metric"
    }
  },
  "data_loader": {
    "batch_sampler": {
        "type": "bucket",
        "batch_size": 10,
        "padding_noise": 0.0
    }
},
  "trainer": {
    "num_epochs": 20,
    "patience": 10,
    "validation_metric": "+sequence_accuracy",
    "cuda_device": -1,
    "optimizer": {
      "type": "adam",
      "lr": 0.01
    }
  }
}

@dirkgr
Copy link
Member

dirkgr commented Aug 16, 2021

@JohnGiorgi's PR is different in some other aspects as well, so we can't just merge it. But he got the dimensions right, so I'll make a PR that takes that part of the code.

@dirkgr
Copy link
Member

dirkgr commented Aug 17, 2021

There was a rat's tail of dependencies on this issue, since this class was used in the transformer toolkit. The transformer toolkit should have been using MatrixAttention, so I implemented scaled dot product matrix attention, converted the toolkit to use it, and then fixed this implementation to match the other Attention classes.

@dirkgr dirkgr self-assigned this Aug 17, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants