Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

fixing minor issues rebased #4593

Merged
merged 4 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -208,26 +208,26 @@ commands:
- setupcuda
- fixgit
- restore_cache:
key: deps-20220519-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
key: deps-20220615-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
- setup
- installdeps
- << parameters.more_installs >>
- save_cache:
key: deps-20220519-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
key: deps-20220615-<< parameters.cachename >>-{{ checksum "requirements.txt" }}
paths:
- "~/venv/bin"
- "~/venv/lib"
- findtests:
marker: << parameters.marker >>
- restore_cache:
key: data-20220519-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
key: data-20220615-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
- run:
name: Run tests
no_output_timeout: 60m
command: |
coverage run -m pytest -m << parameters.marker >> << parameters.pytest_flags >> --junitxml=test-results/junit.xml
- save_cache:
key: data-20220519-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
key: data-20220615-<< parameters.cachename >>-{{ checksum "teststorun.txt" }}
paths:
- "~/ParlAI/data"
- codecov
Expand All @@ -244,12 +244,12 @@ commands:
- checkout
- fixgit
- restore_cache:
key: deps-20220519-bw-{{ checksum "requirements.txt" }}
key: deps-20220615-bw-{{ checksum "requirements.txt" }}
- setup
- installdeps
- installtorchgpu
- save_cache:
key: deps-20220519-bw-{{ checksum "requirements.txt" }}
key: deps-20220615-bw-{{ checksum "requirements.txt" }}
paths:
- "~/venv/bin"
- "~/venv/lib"
Expand Down
3 changes: 3 additions & 0 deletions parlai/agents/bert_classifier/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ Train a classifier on the SNLI tas.
```bash
parlai train_model -m bert_classifier -t snli --classes 'entailment' 'contradiction' 'neutral' -mf /tmp/BERT_snli -bs 20
```

In the example above, tokenized input sentence will look as following:
`[CLS] premise : motor ##cy ##cl ##ists racing on a track . hypothesis : people are racing . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]`
4 changes: 2 additions & 2 deletions parlai/agents/bert_ranker/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,17 @@ def add_common_args(parser):

class BertWrapper(torch.nn.Module):
"""
Adds a optional transformer layer and a linear layer on top of BERT.
Adds a optional transformer layer and classification layers on top of BERT.
Args:
bert_model: pretrained BERT model
output_dim: dimension of the output layer for defult 1 linear layer classifier. Either output_dim or classifier_layer must be specified
classifier_layer: classification layers, can be a signle layer, or list of layers (for ex, ModuleList)
add_transformer_layer: if additional transformer layer should be added on top of the pretrained model
layer_pulled: which layer should be pulled from pretrained model
aggregation: embeddings aggregation (pooling) strategy. Available options are:
(default)"first" - [CLS] representation,
"mean" - average of all embeddings except CLS,
"max" - max of all embeddings except CLS
classifier_layer: classification layers, can be a signle layer, or list of layers (for ex, torch.nn.Sequential)
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,7 +2251,7 @@ def batch_act(self, observations):
for k, values in self._local_metrics.items():
if len(values) != len(batch.valid_indices):
raise IndexError(
f"Batchsize mismatch on metric {k} (got {len(values)}, "
f"Batchsize mismatch on metric {k} got {len(values)}, "
f"expected {len(batch.valid_indices)}"
)
for i, value in zip(batch.valid_indices, values):
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ sphinx_rtd_theme==0.4.3
sphinx-autodoc-typehints~=1.10.3
Sphinx~=2.2.0
subword-nmt==0.3.7
tensorboard>2.9.0
tensorboardX==2.1
tokenizers>=0.8.0
tomli<2.0.0
Expand All @@ -56,3 +57,4 @@ jsonlines==1.2.0
numpy<=1.21 # Used to be `==1.17.5` before but tests -- pulling in latest at 1.22 not happy
markdown<=3.3.2 # Pin to something that works so tests are happy
jinja2==3.0.3
protobuf<3.20,>=3.9.2 # required by {'tensorboard'}