Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLNet support and overhaul/cleanup of BERT support #845

Merged
merged 64 commits into from
Aug 7, 2019
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
0cf7b23
Rename namespaces to suppress warnings.
sleepinyourhat Jul 12, 2019
38c5581
Revert "Rename namespaces to suppress warnings."
sleepinyourhat Jul 12, 2019
0c4546b
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Jul 15, 2019
e881c19
Initial working-ish attempt.
sleepinyourhat Jul 16, 2019
6d4ff7f
Intermediate check-in...
sleepinyourhat Jul 17, 2019
0d64ff2
More partial progress.
sleepinyourhat Jul 17, 2019
4d8c125
Another pass...
sleepinyourhat Jul 17, 2019
8f98adf
Fix sep/cls handling, cleanup.
sleepinyourhat Jul 17, 2019
3f4c434
Further cleanup.
sleepinyourhat Jul 17, 2019
fec6d36
Merge branch 'master' of https://github.com/nyu-mll/jiant into pytorc…
sleepinyourhat Jul 17, 2019
283d23a
Keyword name fix.
sleepinyourhat Jul 17, 2019
1563ef0
Another flag fix.
sleepinyourhat Jul 17, 2019
d98915b
Pull debug print.
sleepinyourhat Jul 17, 2019
7b04a03
Line length cleanup.
sleepinyourhat Jul 17, 2019
f6b09cd
WiC fix.
sleepinyourhat Jul 18, 2019
e933ecb
Two task setup bugs.
sleepinyourhat Jul 18, 2019
b6166a2
BoolQ typo
sleepinyourhat Jul 18, 2019
3ad8471
Improved segment handling.
sleepinyourhat Jul 18, 2019
4d31430
Delete unused is_pair_task, other cleanup/fixes.
sleepinyourhat Jul 18, 2019
cf994c7
Merge branch 'master' into pytorch-transformers
sleepinyourhat Jul 18, 2019
9cd7555
Merge branch 'pytorch-transformers' of https://github.com/nyu-mll/jia…
sleepinyourhat Jul 18, 2019
25b2d29
Fix deleted path from merge.
sleepinyourhat Jul 18, 2019
baba4b0
Fix cache path.
sleepinyourhat Jul 18, 2019
e2f06cb
Merge branch 'master' of https://github.com/nyu-mll/jiant into pytorc…
sleepinyourhat Jul 20, 2019
4e2734b
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Jul 21, 2019
033d24c
Merge branch 'master' into pytorch-transformers
sleepinyourhat Jul 22, 2019
9f81708
Merge branch 'pytorch-transformers' of https://github.com/nyu-mll/jia…
sleepinyourhat Jul 22, 2019
6a8459e
Address (spurious?) tokenization warning.
sleepinyourhat Jul 22, 2019
3aca0d5
Select pool_type automatically to match model.
sleepinyourhat Jul 22, 2019
df3a271
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Jul 22, 2019
f81108b
Merge branch 'nyu-master' into pytorch-transformers
sleepinyourhat Jul 22, 2019
01e4bd7
Config updates.
sleepinyourhat Jul 22, 2019
6f7e9e6
Path fix
sleepinyourhat Jul 22, 2019
437d4e4
Fix XLNet UNK handling.
sleepinyourhat Jul 23, 2019
455792a
Internal temporary MNLI alternate.
sleepinyourhat Jul 23, 2019
8e819a6
Revert "Internal temporary MNLI alternate."
sleepinyourhat Jul 23, 2019
8767db2
Add helper fn tests
sleepinyourhat Jul 23, 2019
ab84323
Merge branch 'master' of https://github.com/nyu-mll/jiant into pytorc…
sleepinyourhat Jul 23, 2019
5dfe1df
Finish merge
sleepinyourhat Jul 23, 2019
96a7c37
Remove unused argument.
sleepinyourhat Jul 23, 2019
fdf7677
Possible ReCoRD bug fix
sleepinyourhat Jul 24, 2019
062cd72
Merge branch 'master' into pytorch-transformers
sleepinyourhat Jul 24, 2019
d0e28ee
Cleanup
sleepinyourhat Jul 24, 2019
606da40
Merge branch 'pytorch-transformers' of https://github.com/nyu-mll/jia…
sleepinyourhat Jul 24, 2019
69f54bd
Fix merge issues.
sleepinyourhat Jul 24, 2019
721e825
Revert "Remove unused argument."
sleepinyourhat Jul 24, 2019
8ed3996
Merge branch 'master' of https://github.com/nyu-mll/jiant into pytorc…
sleepinyourhat Jul 24, 2019
a4a4231
Assorted responses to Alex's commenst.
sleepinyourhat Jul 24, 2019
cea18bd
Further ReCoRD fix.
sleepinyourhat Jul 24, 2019
a9fd95d
@iftenney's comments.
sleepinyourhat Jul 26, 2019
6649393
Merge branch 'master' into pytorch-transformers
sleepinyourhat Jul 26, 2019
b61e07f
Fix/simplify segment logic.
sleepinyourhat Jul 26, 2019
205486f
Merge branch 'pytorch-transformers' of https://github.com/nyu-mll/jia…
sleepinyourhat Jul 26, 2019
7d5b2d2
@W4ngatang's comments
sleepinyourhat Jul 26, 2019
657b2c9
Cleanup.
sleepinyourhat Jul 27, 2019
237214d
Cleanup
sleepinyourhat Aug 4, 2019
a9ee48e
Fix issues with alternative embeddings_mode settings, max_layer.
sleepinyourhat Aug 4, 2019
89e426c
More mix cleanup.
sleepinyourhat Aug 4, 2019
1da2753
Masking fix.
sleepinyourhat Aug 4, 2019
b616bbd
Address (most of) @iftenney's comments
sleepinyourhat Aug 6, 2019
281c45c
Tidying.
sleepinyourhat Aug 6, 2019
12021ea
Misc cleanup.
sleepinyourhat Aug 7, 2019
65a9963
Comment.
sleepinyourhat Aug 7, 2019
7db9704
Merge branch 'master' into pytorch-transformers
sleepinyourhat Aug 7, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ user_config.sh
.idea
.ipynb_checkpoints/
perluniprops/
.DS_Store
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ A few things you might want to know about `jiant`:
- `jiant` is configuration-driven. You can run an enormous variety of experiments by simply writing configuration files. Of course, if you need to add any major new features, you can also easily edit or extend the code.
- `jiant` contains implementations of strong baselines for the [GLUE](https://gluebenchmark.com) and [SuperGLUE](https://super.gluebenchmark.com/) benchmarks, and it's the recommended starting point for work on these benchmarks.
- `jiant` was developed at [the 2018 JSALT Workshop](https://www.clsp.jhu.edu/workshops/18-workshop/) by [the General-Purpose Sentence Representation Learning](https://jsalt18-sentence-repl.github.io/) team and is maintained by [the NYU Machine Learning for Language Lab](https://wp.nyu.edu/ml2/people/), with help from [many outside collaborators](https://github.com/nyu-mll/jiant/graphs/contributors) (especially Google AI Language's [Ian Tenney](https://ai.google/research/people/IanTenney)).
- `jiant` is built on [PyTorch](https://pytorch.org). It also uses many components from [AllenNLP](https://github.com/allenai/allennlp) and the HuggingFace PyTorch [implementations](https://github.com/huggingface/pytorch-pretrained-BERT) of BERT and GPT.
- `jiant` is built on [PyTorch](https://pytorch.org). It also uses many components from [AllenNLP](https://github.com/allenai/allennlp) and the HuggingFace PyTorch [implementations](https://github.com/huggingface/pytorch-transformers) of GPT, BERT, and XLNet.
- The name `jiant` doesn't mean much. The 'j' stands for JSALT. That's all the acronym we have.

## Getting Started
Expand Down Expand Up @@ -84,10 +84,10 @@ This package is released under the [MIT License](LICENSE.md). The material in th

## Acknowledgments

- Part of the development of `jiant` took at the 2018 Frederick Jelinek Memorial Summer Workshop on Speech and Language Technologies, and was supported by Johns Hopkins University with unrestricted gifts from Amazon, Facebook, Google, Microsoft and Mitsubishi Electric Research Laboratories.
- Part of the development of `jiant` took at the 2018 Frederick Jelinek Memorial Summer Workshop on Speech and Language Technologies, and was supported by Johns Hopkins University with unrestricted gifts from Amazon, Facebook, Google, Microsoft and Mitsubishi Electric Research Laboratories.
- This work was made possible in part by a donation to NYU from Eric and Wendy Schmidt made
by recommendation of the Schmidt Futures program.
- We gratefully acknowledge the support of NVIDIA Corporation with the donation of a Titan V GPU used at NYU in this work.
- We gratefully acknowledge the support of NVIDIA Corporation with the donation of a Titan V GPU used at NYU in this work.
- Developer Alex Wang is supported by the National Science Foundation Graduate Research Fellowship Program under Grant
No. DGE 1342536. Any opinions, findings, and conclusions or recommendations expressed in this
material are those of the author(s) and do not necessarily reflect the views of the National Science
Expand Down
14 changes: 10 additions & 4 deletions cola_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@

from jiant.models import build_model
from jiant.preprocess import build_indexers, build_tasks
from jiant.tasks.tasks import process_sentence, sentence_to_text_field
from jiant.tasks.tasks import tokenize_and_truncate, sentence_to_text_field
from jiant.utils import config
from jiant.utils.data_loaders import load_tsv
from jiant.utils.utils import check_arg_name, load_model_state
from jiant.utils.utils import check_arg_name, load_model_state, select_pool_type

log.basicConfig(format="%(asctime)s: %(message)s", datefmt="%m/%d %I:%M:%S %p", level=log.INFO)

Expand Down Expand Up @@ -121,6 +121,7 @@ def main(cl_arguments):
cl_args = handle_arguments(cl_arguments)
args = config.params_from_file(cl_args.config_file, cl_args.overrides)
check_arg_name(args)

assert args.target_tasks == "cola", "Currently only supporting CoLA. ({})".format(
args.target_tasks
)
Expand All @@ -138,6 +139,11 @@ def main(cl_arguments):
)
args.cuda = -1

if args.tokenizer == "auto":
sleepinyourhat marked this conversation as resolved.
Show resolved Hide resolved
args.tokenizer = tokenizers.select_tokenizer(args)
if args.pool_type == "auto":
args.pool_type = select_pool_type(args)

# Prepare data #
_, target_tasks, vocab, word_embs = build_tasks(args)
tasks = sorted(set(target_tasks), key=lambda x: x.name)
Expand Down Expand Up @@ -185,7 +191,7 @@ def run_repl(model, vocab, indexers, task, args):
if input_string == "QUIT":
break

tokens = process_sentence(
tokens = tokenize_and_truncate(
tokenizer_name=task.tokenizer_name, sent=input_string, max_seq_len=args.max_seq_len
)
print("TOKENS:", " ".join("[{}]".format(tok) for tok in tokens))
Expand Down Expand Up @@ -282,7 +288,7 @@ def load_cola_data(input_path, task, input_format, max_seq_len):
with open(input_path, "r") as f_in:
sentences = f_in.readlines()
tokens = [
process_sentence(
tokenize_and_truncate(
tokenizer_name=task.tokenizer_name, sent=sentence, max_seq_len=max_seq_len
)
for sentence in sentences
Expand Down
2 changes: 0 additions & 2 deletions config/ccg_bert.conf
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ include "defaults.conf"
pretrain_tasks = ccg
target_tasks = ccg
input_module = bert-base-uncased
tokenizer = ${input_module}
do_target_task_training = 0
transfer_paradigm = finetune

Expand All @@ -16,7 +15,6 @@ skip_embs = 1

// BERT-specific setup
classifier = log_reg // following BERT paper
pool_type = first

dropout = 0.1 // following BERT paper
optimizer = bert_adam
Expand Down
2 changes: 0 additions & 2 deletions config/copa_bert.conf
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ do_full_eval = 1

// Typical BERT base setup
input_module = bert-base-uncased
tokenizer = bert-base-uncased
transfer_paradigm = finetune
classifier = log_reg
pool_type = first
optimizer = bert_adam
lr = 0.00001
sent_enc = none
Expand Down
111 changes: 57 additions & 54 deletions config/defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ batch_size = 32 // Training batch size.
optimizer = adam // Optimizer. All valid AllenNLP options are available, including 'sgd'.
// Use 'bert_adam' for reproducing BERT experiments.
// 'adam' uses the newer AMSGrad variant.
// Warning: bert_adam is designed for cases where the number of epochs is known
W4ngatang marked this conversation as resolved.
Show resolved Hide resolved
// in advance, so it may not behave reasonably unless max_epochs is set to a
// reasonable positive value.
lr = 0.0001 // Initial learning rate.
min_lr = 0.000001 // Minimum learning rate. Training will stop when our explicit LR decay lowers
// the LR below this point or if any other stopping criterion applies.
Expand Down Expand Up @@ -221,42 +224,41 @@ max_targ_word_v_size = 20000 // Maximum target word vocab size for seq2seq task

// Input Handling //

input_module = "" // The word embedding or contextual word representation layer.
// Currently supported options:
// - scratch: Word embeddings trained from scratch.
// - glove: Leaded GloVe word embeddings. Typically used with
// tokenizer = MosesTokenizer. Note that this is not quite identical to the
// Stanford tokenizer used to train GloVe.
// - fastText: Leaded GloVe word embeddings. Use with
// tokenizer = MosesTokenizer.
// - elmo: AllenNLP's ELMo contextualized word vector model hidden states. Use
// with tokenizer = MosesTokenizer.
// - elmo-chars-only: The dynamic CNN-based word embedding layer of AllenNLP's
// ELMo, but not ELMo's LSTM layer hidden states. Use with
// tokenizer = MosesTokenizer.
// - bert-base-uncased, etc.: Any BERT model specifier that is valid for
// pytorch-pretrained-bert may be specified here. Use with
// tokenizer = ${input_module}
// We support the newer bert-large-uncased-whole-word-masking and
// bert-large-cased-whole-word-masking cased models, but they require
// the git development version of pytorch-pretrained-bert. To use these
// models, follow the instructions under 'From source' here:
// https://github.com/huggingface/pytorch-pretrained-BERT
// Most of these options use MosesTokenizer tokenization, but
// BERT and GPT need more specific tokenization (tokenizer config
// parameter should be equal to input_module for BERT, and should be
// equal to 'OpenAI.BPE' if input_module = gpt).
// For ELMo, BERT, and GPT, there are additional config parameters below.

tokenizer = "MosesTokenizer" // The name of the tokenizer, passed to the Task constructor for
// appropriate handling during data loading. Currently supported
// options:
// - "": Split the input data on whitespace.
// - MosesTokenizer: Our standard word tokenizer. (Support for
// other NLTK tokenizers is pending.)
// - bert-uncased-base, etc.: Use the tokenizer supplied with
// pytorch-pretrained-bert that corresponds to that BERT model.
// - OpenAI.BPE: The tokenizer supplied with OpenAI GPT.
input_module = "" // The word embedding or contextual word representation layer.
// Currently supported options:
// - scratch: Word embeddings trained from scratch.
// - glove: Loaded GloVe word embeddings. Typically used with
// tokenizer = MosesTokenizer. Note that this is not quite identical to
// the Stanford tokenizer used to train GloVe.
// - fastText: Loaded fastText word embeddings. Use with
// tokenizer = MosesTokenizer.
// - elmo: AllenNLP's ELMo contextualized word vector model hidden states. Use
// with tokenizer = MosesTokenizer.
// - elmo-chars-only: The dynamic CNN-based word embedding layer of AllenNLP's
// ELMo, but not ELMo's LSTM layer hidden states. Use with
// tokenizer = MosesTokenizer.
// - gpt: The OpenAI GPT language model encoder.
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
// Use with tokenizer = OpenAI.BPE.
// - bert-base-uncased, etc.: Any BERT model specifier that is valid for
// pytorch-pretrained-bert may be specified here. Use with
// tokenizer = ${input_module}
// We support the newer bert-large-uncased-whole-word-masking and
// bert-large-cased-whole-word-masking cased models, but they require
// the git development version of pytorch-pretrained-bert. To use these
// models, follow the instructions under 'From source' here:
// https://github.com/huggingface/pytorch-pretrained-BERT

tokenizer = auto // The name of the tokenizer, passed to the Task constructor for
// appropriate handling during data loading. Currently supported
// options:
// - auto: Select the tokenizer that matches the model specified in
// input_module above. Usually a safe default.
// - "": Split the input data on whitespace.
// - MosesTokenizer: Our standard word tokenizer. (Support for
// other NLTK tokenizers is pending.)
// - bert-uncased-base, etc.: Use the tokenizer supplied with
// pytorch-pretrained-bert that corresponds to that BERT model.
// - OpenAI.BPE: The tokenizer supplied with OpenAI GPT.

word_embs_file = ${WORD_EMBS_FILE} // Path to embeddings file, used with glove and fastText.
d_word = 300 // Dimension of word embeddings, used with scratch, glove, or fastText.
Expand All @@ -282,22 +284,21 @@ openai_embeddings_mode = "none" // How to handle the embedding layer of the Ope
// "mix" uses ELMo-style scalar mixing (with
// learned weights) across all layers.

bert_embeddings_mode = "none" // How to handle the embedding layer of the
// BERT model:
// "none" or "top" returns only top-layer activation,
// "cat" returns top-layer concatenated with
// lexical layer,
// "only" returns only lexical layer,
// "mix" uses ELMo-style scalar mixing (with
// learned weights) across all layers.
bert_max_layer = -1 // Maximum layer to return from BERT encoder. Layer 0 is
// wordpiece embeddings.
// bert_embeddings_mode will behave as if the BERT encoder
// is truncated at this layer, so 'top' will return this
// layer, and 'mix' will return a mix of all layers up to
// and including this layer.
// Set to -1 to use all layers.
// Used for probing experiments.
pytorch_transformers_embedding_mode = "none" // How to handle the embedding layer of the
// BERT/XLNet model:
// "none" or "top" returns only top-layer activation,
// "cat" returns top-layer concatenated with
// lexical layer,
// "only" returns only lexical layer,
// "mix" uses ELMo-style scalar mixing (with
// learned weights) across all layers.
pytorch_transformers_max_layer = -1 // Maximum layer to return from BERT etc. encoder. Layer 0 is
// wordpiece embeddings. pytorch_transformers_embeddings_mode
// will behave as if the is truncated at this layer, so 'top'
// will return this layer, and 'mix' will return a mix of all
// layers up to and including this layer.
// Set to -1 to use all layers.
// Used for probing experiments.

force_include_wsj_vocabulary = 0 // Set if using PTB parsing (grammar induction) task. Makes sure
// to include WSJ vocabulary.
Expand Down Expand Up @@ -364,8 +365,10 @@ pair_attn = 1 // If true, use attn in sentence-pair classification/regression t
d_hid_attn = 512 // Post-attention LSTM state size.
shared_pair_attn = 0 // If true, share pair_attn parameters across all tasks that use it.
d_proj = 512 // Size of task-specific linear projection applied before before pooling.
pool_type = "max" // Type of pooling to reduce sequences of vectors into a single vector.
// Options: "max", "mean", "first", "final"
pool_type = "auto" // Type of pooling to reduce sequences of vectors into a single vector.
// Options: "auto", "max", "mean", "first", "final"
// "auto" uses "first" for plain BERT (with no sent_enc), "final" for plain
// XLNet and GPT, and "max" in all other settings.
span_classifier_loss_fn = "softmax" // Classifier loss function. Used only in some tasks (notably
// span-related tasks), not mlp/fancy_mlp. Currently supports
// sigmoid and softmax.
Expand Down
2 changes: 0 additions & 2 deletions config/examples/copa_bert.conf
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ do_full_eval = 1

// Typical BERT base setup
input_module = bert-base-uncased
tokenizer = bert-base-uncased
transfer_paradigm = finetune
classifier = log_reg
pool_type = first
optimizer = bert_adam
lr = 0.00001
sent_enc = none
Expand Down
4 changes: 1 addition & 3 deletions config/examples/stilts_example.conf
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ batch_size = 24
write_preds = "val,test"

//BERT-specific parameters
bert_embeddings_mode = "top"
pool_type = "first"
pytorch_transformers_embedding_mode = "top"
sep_embs_for_skip = 1
sent_enc = "none"
classifier = log_reg // following BERT paper
Expand All @@ -34,6 +33,5 @@ patience = 20
max_vals = 10000
transfer_paradigm = "finetune"

tokenizer = "bert-base-uncased"
input_module = "bert-base-uncased"

5 changes: 2 additions & 3 deletions config/superglue-bert.conf
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@ exp_name = "bert-large-cased"
// Data and preprocessing settings
max_seq_len = 256 // Mainly needed for MultiRC, to avoid over-truncating
// But not 512 as that is really hard to fit in memory.
tokenizer = "bert-large-cased"

// Model settings
input_module = "bert-large-cased"
bert_embeddings_mode = "top"
pool_type = "first"
pytorch_transformers_embedding_mode = "top"
pair_attn = 0 // shouldn't be needed but JIC
s2s = {
attention = none
Expand Down
6 changes: 6 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ dependencies:
- ftfy==5.4.1
- spacy==2.0.11

# Warning: jiant currently depends on *both* pytorch_pretrained_bert > 0.6 _and_
sleepinyourhat marked this conversation as resolved.
Show resolved Hide resolved
# pytorch_transformers > 1.0. These are the same package, though the name changed between
# these two versions. AllenNLP requires 0.6 to support the BertAdam optimizer, and jiant
# directly requires 1.0 to support XLNet and WWM-BERT.
# This AllenNLP issue is relevant: https://github.com/allenai/allennlp/issues/3067
- pytorch-transformers==1.0.0
4 changes: 2 additions & 2 deletions gcp/config/jiant_paths.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ export JIANT_PROJECT_PREFIX="$HOME/exp"

# pre-downloaded ELMo models
export ELMO_SRC_DIR="/nfs/jiant/share/elmo"
# cache for BERT models
export PYTORCH_PRETRAINED_BERT_CACHE="/nfs/jiant/share/bert_cache"
# cache for BERT etc. models
export PYTORCH_PRETRAINED_BERT_CACHE="/nfs/jiant/share/pytorch_transformers_cache"
# word embeddings
export WORD_EMBS_FILE="/nfs/jiant/share/wiki-news-300d-1M.vec"

1 change: 0 additions & 1 deletion gcp/kubernetes/run_batch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,3 @@ jsonnet -S -o "${YAML_FILE}" \
##
# Create the Kubernetes pod; this will actually launch the job.
kubectl ${KUBECTL_MODE} -f "${YAML_FILE}"

6 changes: 3 additions & 3 deletions gcp/kubernetes/templates/jiant_env.libsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
nfs_exp_dir: "/nfs/jiant/exp",

# Name of pre-built Docker image, accessible from Kubernetes.
gcr_image: "gcr.io/google.com/jiant-stilts/jiant-conda:v0",
gcr_image: "gcr.io/google.com/jiant-stilts/jiant-conda:v2",

# Default location for glue_data
jiant_data_dir: "/nfs/jiant/share/glue_data",
# Path to ELMO cache.
elmo_src_dir: "/nfs/jiant/share/elmo",
# Path to BERT model cache; should be writable by Kubernetes workers.
bert_cache_path: "/nfs/jiant/share/bert_cache",
# Path to BERT etc. model cache; should be writable by Kubernetes workers.
pytorch_transformers_cache_path: "/nfs/jiant/share/pytorch_transformers_cache",
# Path to default word embeddings file
word_embs_file: "/nfs/jiant/share/wiki-news-300d-1M.vec",
}
2 changes: 1 addition & 1 deletion gcp/kubernetes/templates/run_batch.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function(job_name, command, project_dir, uid, fsgroup,
},
{
name: "PYTORCH_PRETRAINED_BERT_CACHE",
sleepinyourhat marked this conversation as resolved.
Show resolved Hide resolved
value: jiant_env.bert_cache_path,
value: jiant_env.pytorch_transformers_cache_path
},
{
name: "ELMO_SRC_DIR",
Expand Down
Empty file removed jiant/bert/__init__.py
Empty file.
Loading