diff --git a/README.md b/README.md index 047e1b7686..af5fe8b01e 100644 --- a/README.md +++ b/README.md @@ -1,241 +1,163 @@

- +

- Support Ukraine - MIT License + MIT License Latest Release Build Status Documentation Status - CicleCI Status

+# Self-Supervised Neural Machine Translation -------------------------------------------------------------------------------- - -Fairseq(-py) is a sequence modeling toolkit that allows researchers and -developers to train custom models for translation, summarization, language -modeling and other text generation tasks. - -We provide reference implementations of various sequence modeling papers: - -
List of implemented papers

- -* **Convolutional Neural Networks (CNN)** - + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md) - + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) - + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) - + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) - + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -* **LightConv and DynamicConv models** - + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) -* **Long Short-Term Memory (LSTM) networks** - + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015) -* **Transformer (self-attention) networks** - + Attention Is All You Need (Vaswani et al., 2017) - + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) - + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) - + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) - + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) - + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md) - + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md) - + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) - + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) - + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) - + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md ) - + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) - + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) - + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) - + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) - + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md) - + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md) - + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) - + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md) - + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979) - + [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430) - + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027) - + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084) - + [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680) - + [VideoCLIP: Contrastive Pre-training for Zero-shot Video-Text Understanding (Xu et. al., 2021)](https://arxiv.org/pdf/2109.14084.pdf) - + [VLM: Task-agnostic Video-Language Model Pre-training for Video Understanding (Xu et. al., 2021)](https://aclanthology.org/2021.findings-acl.370.pdf) - + [NormFormer: Improved Transformer Pretraining with Extra Normalization (Shleifer et. al, 2021)](examples/normformer/README.md) -* **Non-autoregressive Transformers** - + Non-Autoregressive Neural Machine Translation (Gu et al., 2017) - + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018) - + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019) - + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019) - + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) -* **Finetuning** - + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md) - -

- -### What's New: -* June 2022 [Released code for wav2vec-U 2.0 from Towards End-to-end Unsupervised Speech Recognition (Liu, et al., 2022)](examples/wav2vec/unsupervised/README.md) -* May 2022 [Integration with xFormers](https://github.com/facebookresearch/xformers) -* December 2021 [Released Direct speech-to-speech translation code](examples/speech_to_speech/README.md) -* October 2021 [Released VideoCLIP and VLM models](examples/MMPT/README.md) -* October 2021 [Released multilingual finetuned XLSR-53 model](examples/wav2vec/README.md) -* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming). -* July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md) -* July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md) -* June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md) -* May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md) -* March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md) -* February 2021 [Added LASER training code](examples/laser/README.md) -* December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md) -* December 2020: [GottBERT model and code released](examples/gottbert/README.md) -* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework - * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) -* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0) -* October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) -* October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) -* October 2020: [Added CRISS models and code](examples/criss/README.md) - -
Previous updates

- -* September 2020: [Added Linformer code](examples/linformer/README.md) -* September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md) -* August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) -* August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) -* July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) -* May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) -* April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) -* April 2020: [Quant-Noise code released](examples/quant_noise/README.md) -* April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) -* March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) -* February 2020: [mBART model and code released](examples/mbart/README.md) -* February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german) -* December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0) -* November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example) -* November 2019: [CamemBERT model and code released](examples/camembert/README.md) -* November 2019: [BART model and code released](examples/bart/README.md) -* November 2019: [XLM-R models and code released](examples/xlmr/README.md) -* September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) -* August 2019: [WMT'19 models released](examples/wmt19/README.md) -* July 2019: fairseq relicensed under MIT license -* July 2019: [RoBERTa models and code released](examples/roberta/README.md) -* June 2019: [wav2vec models and code released](examples/wav2vec/README.md) - -

- -### Features: - -* multi-GPU training on one machine or across multiple machines (data and model parallel) -* fast generation on both CPU and GPU with multiple search algorithms implemented: - + beam search - + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) - + sampling (unconstrained, top-k and top-p/nucleus) - + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018) -* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU -* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) -* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers -* [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration -* [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md) -* [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md) - -We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) -with a convenient `torch.hub` interface: - -``` python -en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model') -en2de.translate('Hello world', beam=5) -# 'Hallo Welt' -``` - -See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) -and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples. +This is the code used for the paper *Self-Supervised Neural Machine Translation*, which describes a joint parallel data extraction and NMT training approach. It is based on a May 2019 copy of the [Fairseq-py](https://github.com/pytorch/fairseq) repository. Be aware that it is therefore not up-to-date with current changes in the original Fairseq(-py) code. # Requirements and Installation -* [PyTorch](http://pytorch.org/) version >= 1.10.0 -* Python version >= 3.8 +All the requirements are listed in `environment.yml` and can be installed using `conda env create -f environment.yml` + +* [PyTorch](http://pytorch.org/) version >= 1.13.1 +* The code has been tested on Python version = 3.8.16 * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) -* **To install fairseq** and develop locally: +* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library with the `--cuda_ext` and `--deprecated_fused_adam` options -``` bash -git clone https://github.com/pytorch/fairseq -cd fairseq -pip install --editable ./ -# on MacOS: -# CFLAGS="-stdlib=libc++" pip install --editable ./ +## Instructions -# to install the latest stable release (0.10.x) -# pip install fairseq -``` +### Data Preparation + +1. Extract original and translated data from [here](https://zenodo.org/record/5596238#.Y2ObSezMJEJ). -* **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: +2. Preprocess (e.g. using [Moses scripts](https://github.com/moses-smt/mosesdecoder/tree/master/scripts)) and apply BPE encoding. -``` bash -git clone https://github.com/NVIDIA/apex -cd apex -pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \ - --global-option="--deprecated_fused_adam" --global-option="--xentropy" \ - --global-option="--fast_multihead_attn" ./ +### On a high-level, to run SSNMT with pretraining: + - **Pretraining** + - Tokenise and preprocess the data (Europarl) + - combine preprocessed Europarl training data with preprocessed MOTRA training data (or not) + - apply BPE - 10k merge operations (~10.3k vocab) for EN-ALL and 11k for DE-ALL. + - apply fairseq-preprocess to binarise the data + - run fairseq-train for bart-style pre-taining + - **SSNMT** + - Tokenise and preprocess the data (MOTRA) + - apply the learned BPE codes from pretraining on MOTRA train-test-dev + - binarise the data using fairseq-preprocess + - load the pretrained model checkpoint and finetune over `translation_from_pretrained_bart` task + + Note that, the tokenization and BPE [Byte-Pair Encoding](https://github.com/rsennrich/subword-nmt) (BPE) should remain consistent for the data used for DAE pretraining, LM training and finetuning and for the Style-Transfer model. + +### Preprocessing Data for Style Transfer +An example for preprocessing the data before training +``` +cd fairseq_cli +python3 preprocess.py --destdir /netscratch/anonymous/datasets/motra-sst/ppd_w_europarl-motra-10k_no_dups/en_es_de/unsup_setup/ \ + --source-lang tr \ + --target-lang og \ + --trainpref /netscratch/anonymous/datasets/motra-preprocessed/en_es_de/train/bpe --validpref /netscratch/anonymous/datasets/motra-preprocessed/en_de/dev/europarl-motra-10k-no-dups/bpe \ + --testpref /netscratch/anonymous/datasets/motra-preprocessed/en_de/test/europarl-motra-10k-no-dups/bpe \ + --srcdict /netscratch/anonymous/datasets/data-bin/europarl-motra/subword-nmt-10k/europarl/dict.txt \ + --tgtdict /netscratch/anonymous/datasets/data-bin/europarl-motra/subword-nmt-10k/europarl/dict.txt \ + --dataset-impl raw \ + --workers 60 ``` +### Train +Use `traincomp.py` to train the system. An example run on how to train SSNMT for Translationese-to-Original Style Transfer using Joint Training is shown below. -* **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow` -* If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` - as command line options to `nvidia-docker run` . +``` +python3 traincomp.py /netscratch/anonymous/datasets/motra-sst/ppd_w_europarl-motra-10k_no_dups/en_es_de/unsup_setup/ \ + --arch transformer \ + --share-all-embeddings --checkpoint-activations \ + --share-decoder-input-output-embed \ + --encoder-embed-dim 512 \ + --decoder-embed-dim 512 \ + --task translation_from_pretrained_bart --langs ", " \ + --update-freq 2 \ + --lr 0.0003 \ + --criterion unsupervised_augmented_label_smoothed_cross_entropy \ + --label-smoothing 0.1 --start-unsup 1200 \ + --dropout 0.2 \ + --weight-decay 0.0001 \ + --optimizer adam \ + --adam-betas '(0.9, 0.9995)' \ + --clip-norm 0.0 \ + --write-dual \ + --lr-scheduler inverse_sqrt \ + --warmup-updates 2000 \ + --dataset-impl raw \ + --decoder-learned-pos --encoder-learned-pos \ + --max-sentences 160 --retrieval intersect \ + --max-source-positions 512 \ + --max-target-positions 512 \ + --skip-invalid-size-inputs-valid-test \ + --max-epoch 30 --keep-best-checkpoints 10 --patience 10 \ + --comp-epochs 30 --save-interval-updates 2 \ + --comparable --margin ratio \ + --verbose --faiss --index ivf \ + --representations dual --faiss-output /netscratch/anonymous/logs/from_en_all_bs@en_newLoss_htr.txt \ + --no-base --comp-log /netscratch/anonymous/logs/en_all_bs@en_newLoss_htr/ \ + --comparable-data Comparable/tr_og.list \ + --sim-measure margin --save-dir /netscratch/anonymous/checkpoints/sst/en_all_bs_no_th@newLoss_htr/ \ + --finetune-from-model /netscratch/anonymous/checkpoints/subword-nmt-10k-bpe-transformer_gpu4_cpu50/checkpoint_best.pt \ + --threshold 0 \ + --wandb-project motra_no_thresh_unsup_en \ + --num-workers 0 \ + --log-format json |tee /netscratch/anonymous/logs/train_en_all_bs@en_newLoss_htr.log -# Getting Started +``` -The [full documentation](https://fairseq.readthedocs.io/) contains instructions -for getting started, training new models and extending fairseq with new model -types and tasks. +Run `train.py -h` for more information. -# Pre-trained models and examples +### Evaluation -We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, -as well as example training and evaluation commands. +All the scripts for evaluation can be found under the `evaluation/` folder -* [Translation](examples/translation/README.md): convolutional and transformer models are available -* [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available +1. Get the style-transferred outputs a) once with --remove-bpe and b) once without (to compute average perplexities using Fairseq's pretrained TransformerLM). +``` +python3 generate.py /netscratch/anonymous/datasets/data-bin/europarl-motra/subword-nmt-10k/europarl/test_bal/ \ +--task translation \ +--path /checkpoint_best.pt \ +--results-path \ +--beam 5 --source-lang tr --target-lang og --dataset-impl raw +``` +2. Generate intermediate data files. +``` +python evaluation/gen_test_data.py --file /netscratch/anonymous/results/generations/unsup/motra-old/712551/generate-test.txt --out_dir /netscratch/anonymous/datasets/motra-preprocessed/en_de/test/unsup-generated/ --name pred_712551.tsv -We also have more detailed READMEs to reproduce results from specific papers: +# combine og file with pred file +cat /netscratch/anonymous/datasets/motra-preprocessed/en_de/test/og.tsv /netscratch/anonymous/datasets/motra-preprocessed/en_de/test/unsup-generated/pred_712551.tsv > /netscratch/anonymous/datasets/motra-preprocessed/en_de/test/gen_tsvs/gen_712551.tsv -* [XLS-R: Self-supervised Cross-lingual Speech Representation Learning at Scale (Babu et al., 2021)](examples/wav2vec/xlsr/README.md) -* [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md) -* [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md) -* [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md) -* [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md) -* [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md) -* [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md) -* [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md) -* [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md) -* [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md) -* [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) -* [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) -* [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) -* [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) -* [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md) -* [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) -* [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel) -* [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md) -* [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md) -* [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md) -* [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md) +# shuffle the test file +shuf -o /netscratch/anonymous/datasets/motra-preprocessed/en_de/test/gen_tsvs/gen_712551.tsv < /netscratch/anonymous/datasets/motra-preprocessed/en_de/test/gen_tsvs/gen_712551.tsv -# Join the fairseq community +python evaluation/extract_ref_hyp.py --file /netscratch/anonymous/results/generations/unsup/motra-old/712551/generate-test.txt --name 712551.tsv -* Twitter: https://twitter.com/fairseq -* Facebook page: https://www.facebook.com/groups/fairseq.users -* Google group: https://groups.google.com/forum/#!forum/fairseq-users +python new/fairseq/evaluation/gen_fsq_ppl_data.py --file /netscratch/anonymous/results/generations/unsup/motra-old/712551_ppl/generate-test.txt --out_dir /netscratch/anonymous/test_perplexity/ --exp 712551 +``` -# License +3. Evaluate LM perplexity -fairseq(-py) is MIT-licensed. -The license applies to the pre-trained models as well. +Note: copy dict.txt from the preprocessed FAIRSEQ_DATA to -# Citation +``` +python3 eval_lm.py /netscratch/anonymous/test_perplexity/712551/ --path /netscratch/anonymous/checkpoints/transformer_lm_en_finetuned/checkpoint_best.pt --quiet --output-word-stats --gen-subset test --max-sentences 500 --skip-invalid-size-inputs-valid-test --dataset-impl raw --fp16 --sample-break-mode eos --context-window 50 +``` + +4. Meausre BERT-Score +``` +python3 evaluation/compute_bertscore.py --file /netscratch/anonymous/datasets/motra-preprocessed/en_de/test/src_hyp/712551.tsv --model roberta-base +``` -Please cite as: +5. Run Translationese Classifier +``` +python3 evaluation/binary_classification.py --model /netscratch/anonymous/checkpoints/binaryClassification_balanced/ --test /netscratch/anonymous/datasets/motra-preprocessed/en_de/test/gen_tsvs/gen_712551.tsv +``` -``` bibtex -@inproceedings{ott2019fairseq, - title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling}, - author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli}, - booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations}, - year = {2019}, -} +6. Run Qualitative Analysis +``` +python3 evaluation/qualitative_analysis.py --file /netscratch/anonymous/datasets/motra-preprocessed/en_de/test/src_hyp/712551.tsv ``` + + + + + +![Model](fairseq.gif) diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000..b7867cf606 --- /dev/null +++ b/environment.yml @@ -0,0 +1,251 @@ +name: fsq012 +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - binutils_impl_linux-64=2.33.1=he6710b0_7 + - binutils_linux-64=2.33.1=h9595d00_15 + - blas=1.0=mkl + - brotlipy=0.7.0=py38h27cfd23_1003 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2023.5.7=hbcca054_0 + - certifi=2023.5.7=pyhd8ed1ab_0 + - cffi=1.15.1=py38h5eee18b_3 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cryptography=39.0.1=py38h9ce1e76_0 + - cuda=11.7.1=0 + - cuda-cccl=11.7.91=0 + - cuda-command-line-tools=11.7.1=0 + - cuda-compiler=11.7.1=0 + - cuda-cudart=11.7.99=0 + - cuda-cudart-dev=11.7.99=0 + - cuda-cuobjdump=11.7.91=0 + - cuda-cupti=11.7.101=0 + - cuda-cuxxfilt=11.7.91=0 + - cuda-demo-suite=12.1.55=0 + - cuda-documentation=12.1.55=0 + - cuda-driver-dev=11.7.99=0 + - cuda-gdb=12.1.55=0 + - cuda-libraries=11.7.1=0 + - cuda-libraries-dev=11.7.1=0 + - cuda-memcheck=11.8.86=0 + - cuda-nsight=12.1.55=0 + - cuda-nsight-compute=12.1.0=0 + - cuda-nvcc=11.7.99=0 + - cuda-nvdisasm=12.1.55=0 + - cuda-nvml-dev=11.7.91=0 + - cuda-nvprof=12.1.55=0 + - cuda-nvprune=11.7.91=0 + - cuda-nvrtc=11.7.99=0 + - cuda-nvrtc-dev=11.7.99=0 + - cuda-nvtx=11.7.91=0 + - cuda-nvvp=12.1.55=0 + - cuda-runtime=11.7.1=0 + - cuda-sanitizer-api=12.1.55=0 + - cuda-toolkit=11.7.1=0 + - cuda-tools=11.7.1=0 + - cuda-visual-tools=11.7.1=0 + - cudatoolkit=11.7.0=hd8887f6_10 + - faiss=1.7.2=py38cuda112h5d0fea0_0_cuda + - faiss-gpu=1.7.2=h788eb59_3 + - ffmpeg=4.3=hf484d3e_0 + - flit-core=3.6.0=pyhd3eb1b0_0 + - freetype=2.12.1=h4a9f257_0 + - gcc_impl_linux-64=7.3.0=habb00fd_1 + - gcc_linux-64=7.3.0=h553295d_15 + - gds-tools=1.6.0.25=0 + - giflib=5.2.1=h5eee18b_3 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - gxx_impl_linux-64=7.3.0=hdf63c60_1 + - gxx_linux-64=7.3.0=h553295d_15 + - idna=3.4=py38h06a4308_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - jpeg=9e=h5eee18b_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.33.1=h53a641e_7 + - lerc=3.0=h295c915_0 + - libblas=3.9.0=12_linux64_mkl + - libcublas=11.10.3.66=0 + - libcublas-dev=11.10.3.66=0 + - libcufft=10.7.2.124=h4fbf590_0 + - libcufft-dev=10.7.2.124=h98a8f43_0 + - libcufile=1.6.0.25=0 + - libcufile-dev=1.6.0.25=0 + - libcurand=10.3.2.56=0 + - libcurand-dev=10.3.2.56=0 + - libcusolver=11.4.0.1=0 + - libcusolver-dev=11.4.0.1=0 + - libcusparse=11.7.4.91=0 + - libcusparse-dev=11.7.4.91=0 + - libdeflate=1.17=h5eee18b_0 + - libfaiss=1.7.2=cuda112hc9ed507_0_cuda + - libfaiss-avx2=1.7.2=cuda112h1234567_0_cuda + - libffi=3.4.2=h6a678d5_6 + - libgcc=7.2.0=h69d50b8_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.2=h7f8727e_0 + - liblapack=3.9.0=12_linux64_mkl + - libnpp=11.7.4.75=0 + - libnpp-dev=11.7.4.75=0 + - libnvjpeg=11.8.0.2=0 + - libnvjpeg-dev=11.8.0.2=0 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.5.0=h6a678d5_2 + - libunistring=0.9.10=h27cfd23_0 + - libwebp=1.2.4=h11a3e52_1 + - libwebp-base=1.2.4=h5eee18b_1 + - lz4-c=1.9.4=h6a678d5_0 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.1=py38hd3c417c_0 + - mkl_random=1.2.2=py38h51133e4_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - nsight-compute=2023.1.0.15=0 + - numpy-base=1.23.5=py38h31eccc5_0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1t=h7f8727e_0 + - pillow=9.4.0=py38h6a678d5_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=23.0.0=py38h06a4308_0 + - pysocks=1.7.1=py38h06a4308_0 + - python=3.8.16=h7a1cb2a_2 + - python_abi=3.8=2_cp38 + - pytorch=1.13.1=py3.8_cuda11.7_cudnn8.5.0_0 + - pytorch-cuda=11.7=h67b0de4_1 + - pytorch-mutex=1.0=cuda + - readline=8.2=h5eee18b_0 + - requests=2.28.1=py38h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.40.1=h5082296_0 + - tk=8.6.12=h1ccaba5_0 + - torchvision=0.14.1=py38_cu117 + - typing_extensions=4.4.0=py38h06a4308_0 + - urllib3=1.26.14=py38h06a4308_0 + - xz=5.2.10=h5eee18b_1 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.2=ha4553b6_0 + - pip: + - accelerate==0.19.0 + - aiohttp==3.8.4 + - aiosignal==1.3.1 + - antlr4-python3-runtime==4.8 + - apex==0.1 + - appdirs==1.4.4 + - async-timeout==4.0.2 + - attrs==22.2.0 + - bert-score==0.3.13 + - bitarray==2.7.3 + - blis==0.7.9 + - catalogue==2.0.8 + - click==8.1.3 + - colorama==0.4.6 + - confection==0.0.4 + - contourpy==1.0.7 + - cycler==0.11.0 + - cymem==2.0.7 + - cython==0.29.33 + - datasets==2.10.1 + - de-dep-news-trf==3.5.0 + - deepspeed==0.8.2 + - dill==0.3.6 + - docker-pycreds==0.4.0 + - easynmt==2.0.2 + - en-core-web-trf==3.5.0 + - evaluate==0.4.0 + - fairscale==0.4.13 + - fairseq==0.12.2 + - fastbpe==0.1.0 + - fasttext==0.9.2 + - filelock==3.9.1 + - fonttools==4.39.4 + - frozenlist==1.3.3 + - fsspec==2023.3.0 + - gensim==4.3.1 + - gitdb==4.0.10 + - gitpython==3.1.31 + - hjson==3.1.0 + - huggingface-hub==0.13.2 + - hydra-core==1.0.7 + - importlib-resources==5.12.0 + - jinja2==3.1.2 + - joblib==1.2.0 + - kiwisolver==1.4.4 + - langcodes==3.3.0 + - langdetect==1.0.9 + - lxml==4.9.2 + - markupsafe==2.1.3 + - matplotlib==3.7.1 + - multidict==6.0.4 + - multiprocess==0.70.14 + - murmurhash==1.0.9 + - ninja==1.11.1 + - nltk==3.8.1 + - numpy==1.24.2 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cudnn-cu11==8.5.0.96 + - omegaconf==2.0.6 + - packaging==23.0 + - pandas==1.5.3 + - pathtools==0.1.2 + - pathy==0.10.1 + - pip==23.1.2 + - portalocker==2.7.0 + - preshed==3.0.8 + - protobuf==4.22.1 + - psutil==5.9.4 + - py-cpuinfo==9.0.0 + - pyarrow==11.0.0 + - pybind11==2.10.4 + - pydantic==1.10.6 + - pyparsing==3.0.9 + - python-dateutil==2.8.2 + - pytz==2022.7.1 + - pyyaml==6.0 + - regex==2022.10.31 + - responses==0.18.0 + - sacrebleu==2.3.1 + - sacremoses==0.0.53 + - scikit-learn==1.2.1 + - scipy==1.10.1 + - sentencepiece==0.1.97 + - sentry-sdk==1.25.0 + - setproctitle==1.3.2 + - setuptools==67.8.0 + - smart-open==6.3.0 + - smmap==5.0.0 + - spacy==3.5.3 + - spacy-alignments==0.9.0 + - spacy-legacy==3.0.12 + - spacy-loggers==1.0.4 + - spacy-transformers==1.2.4 + - srsly==2.4.6 + - tabulate==0.9.0 + - thinc==8.1.10 + - threadpoolctl==3.1.0 + - tokenizers==0.13.2 + - torch==1.13.1 + - torchaudio==0.13.1 + - tqdm==4.64.1 + - transformers==4.27.0 + - typer==0.7.0 + - typing-extensions==4.5.0 + - wandb==0.15.3 + - wasabi==1.1.2 + - wheel==0.40.0 + - xxhash==3.2.0 + - yarl==1.8.2 + - zipp==3.15.0 +prefix: /netscratch/jalota/miniconda3/envs/fsq012 diff --git a/evaluation/binary_classification.py b/evaluation/binary_classification.py new file mode 100644 index 0000000000..9733472b7f --- /dev/null +++ b/evaluation/binary_classification.py @@ -0,0 +1,207 @@ +import torch +from datasets import load_dataset, load_metric, ClassLabel +from transformers import AutoTokenizer, BertForSequenceClassification +import argparse +import pandas as pd +import numpy as np +from pathlib import Path +import csv +from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer +from datasets import disable_caching, Dataset +disable_caching() +import logging +logging.disable(logging.INFO) + +tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', do_lower_case=False, use_fast=True) # True +model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2) + +labels = ClassLabel(names=['0', '1']) + +def preprocess_function(examples): + result = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512) + result['labels'] = [labels.str2int(str(label)) if label is not None else None for label in examples["label"]] + + return result + +def tokenize_text(examples): + result = tokenizer(str(examples["text"]),truncation=True, max_length=512, padding='max_length', return_overflowing_tokens=True) + + sample_map = result.pop("overflow_to_sample_mapping") + for key, values in examples.items(): + result[key] = [values[i] for i in sample_map] + return result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='run binary classifer') + parser.add_argument("--train", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/train/train_bt_bal.tsv") # based on *.tok.norm.true.txt - equal examples in both files! + parser.add_argument("--dev", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/dev/dev_bt.tsv") # based on translated.tok.norm.true.txt and original.tok.norm.true.txt -- equal examples in both files! + parser.add_argument("--test", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/test/test_bt_bal.tsv") # based on translated.tok.norm.true.txt and original.tok.norm.true.txt - equal examples in both files! + parser.add_argument("--model", default=None) + parser.add_argument("--out_dir", default="/netscratch/anonymous/results/binaryClassification_balanced_bt_og") + args = parser.parse_args() + # https://discuss.huggingface.co/t/using-trainer-at-inference-time/9378/7 + + print(args.test) + print(args.model) + + if args.model: + test_df = pd.read_csv(args.test, delimiter="\t", names=['text', 'label'], quoting=csv.QUOTE_NONE) + test_dataset = Dataset.from_pandas(test_df) + else: + dataset = load_dataset("csv", delimiter="\t", column_names=['text', 'label'], data_files={"train": args.train, "test": args.test, "dev": args.dev}) # streaming=True + batch_size = 16 + metric_name = "accuracy" # "f1" + metric = load_metric(metric_name) + + def compute_metrics(eval_pred): + predictions, labels = eval_pred + predictions = np.argmax(predictions, axis=1) + return metric.compute(predictions=predictions, references=labels) + + if args.model is None: + # print(dataset["train"]['text']) + encoded_dataset = dataset.map(lambda x: tokenizer(x["text"], truncation=True, padding='max_length', max_length=512), batched=True, batch_size=2000) # for streaming dataset which is an IterableDataset + # encoded_dataset = dataset.map(preprocess_function, batched=True, num_proc=30) + print("done with encoding") + + args = TrainingArguments( + args.out_dir, + evaluation_strategy = "epoch", + save_strategy = "epoch", + learning_rate=2e-5, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + num_train_epochs=3, + max_steps=1000, + weight_decay=0.01, + load_best_model_at_end=True, + metric_for_best_model=metric_name, + save_total_limit = 2, + push_to_hub=False,) + + trainer = Trainer( + model, + args, + train_dataset=encoded_dataset["train"], + eval_dataset=encoded_dataset["dev"], + tokenizer=tokenizer, + compute_metrics=compute_metrics + ) + + trainer.train() + trainer.evaluate() + print(trainer.predict(encoded_dataset["test"])) + savepath = f"{args.out_dir}/saved_model/" + Path(savepath).mkdir(parents=True, exist_ok=True) + trainer.save_model(savepath) + # "/netscratch/anonymous/checkpoints/binaryClassification_balanced_bt_og/" + + else: + Path(args.out_dir).mkdir(parents=True, exist_ok=True) + # for motra -- uncomment the line below with lambda x + # encoded_dataset = test_dataset.map(lambda x: tokenizer(str(x["text"]), truncation=True, padding='max_length', max_length=512), batched=True, batch_size=2000) + # encoded_dataset = dataset.map(preprocess_function, batched=True) + encoded_dataset = test_dataset.map(tokenize_text, batched=True, batch_size=100) + encoded_dataset = encoded_dataset.filter(lambda example: example['label'] is not None) + #dataset["test"].map(preprocess_function, batched=True) + # print(f"encoded_dataset['test']['label']: {encoded_dataset['test']['label']}") + model = BertForSequenceClassification.from_pretrained(args.model) + + # arguments for Trainer + test_args = TrainingArguments( + output_dir = args.out_dir, + do_train = False, + do_predict = True, + per_device_eval_batch_size = batch_size, + dataloader_drop_last = False + ) + + # init trainer + trainer = Trainer( + model = model, + args = test_args, + compute_metrics = compute_metrics) + + test_results = trainer.predict(encoded_dataset) # ["test"] + print(test_results) + + + +# i think not . +# Mr President . +# Thank you . +# Thank you for your attention . +# That is wrong . + +# ! +# But it is not . +# co-rapporteur . +# i ask for your support . +# i do not believe so . +# i do not know . +# i just wanted to point that out . +# It does not . +# It is as simple as that . +# It is not . +# i welcome that . +# Let me give you an example . +# Let me now turn to the amendments . +# Mr President . +# Mr President , on a point of order . +# Mr President , on a point of order . +# No ! +# No . +# No . +# Nothing . +# Our position is quite clear . +# So far , so good . +# Thank you . +# Thank you . +# Thank you , Commissioner . +# Thank you for your attention . +# Thank you for your cooperation . +# Thank you , Mr President . +# Thank you very much . +# Thank you very much . +# Thank you very much , Mr President . +# That is going too far . +# That is my first point . +# That is my first point . +# That is not acceptable . +# That is not acceptable . +# That is not correct . +# That is not the case . +# That is not the case . +# That is right and proper . +# That is the first issue . +# That is the objective . +# That is the reality . +# That is very important . +# That is what this is about . +# The Commission is the guardian of the Treaties . +# The Commission welcomes this . +# The list goes on . +# There is still much to be done , however . +# The same applies to the European Union . +# This is completely unacceptable . +# This is not acceptable . +# This is unacceptable . +# This is unacceptable . +# This is unacceptable . +# We all know that . +# We disagree . +# We know that . +# We know that . +# We should not forget that . +# Why ? +# Why ? +# Why ? +# Why ? +# Why ? +# Why ? +# Why ? +# Why is that the case ? +# Why is this so ? +# Why is this so important ? +# Why not ? diff --git a/evaluation/compute_bertscore.py b/evaluation/compute_bertscore.py new file mode 100644 index 0000000000..cb9c53d06d --- /dev/null +++ b/evaluation/compute_bertscore.py @@ -0,0 +1,33 @@ +import evaluate +import argparse +import pandas as pd + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='compute bertscore given a tsv file with src and hypothesis') + parser.add_argument("--file", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/test/src_hyp/699528.txt") + parser.add_argument("--batch-size", default=128, type=int) + parser.add_argument("--model", default='roberta-base') + args = parser.parse_args() + print(args.model) + print(args.file) + count = 0 + c = 0 + df = pd.read_csv(args.file, sep="\t", names=['source', 'hypothesis'], header=0) + for index, row in df.iterrows(): + if row['source'] == row['hypothesis']: + count += 1 + else: + if c < 10: + print(f"row['source']: {row['source']}, row['hypothesis]: {row['hypothesis']}") + c += 1 + print(f"count: {count}") + print(f"test set size: {len(df)}") + + bertscore = evaluate.load("bertscore") + results = bertscore.compute(predictions=df.hypothesis.tolist(), references=df.source.tolist(), model_type=args.model, lang='en') + + print(f"average precision: {sum(results['precision'])/len(results['precision'])}") + print(f"average recall: {sum(results['recall'])/len(results['recall'])}") + print(f"average f1: {sum(results['f1'])/len(results['f1'])}") + \ No newline at end of file diff --git a/evaluation/compute_perplexity.py b/evaluation/compute_perplexity.py new file mode 100644 index 0000000000..e1e5e0701d --- /dev/null +++ b/evaluation/compute_perplexity.py @@ -0,0 +1,19 @@ +import evaluate +import argparse +import pandas as pd + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='compute perplexity of generated sentences') + parser.add_argument("--file", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/test/unsup-generated/pred_no_th_699528.tsv") + parser.add_argument("--batch-size", default=64, type=int) + parser.add_argument("--model", default='/netscratch/anonymous/checkpoints/gpt2-finetuned-motra/') + args = parser.parse_args() + print(args.model) + print(args.file) + + df = pd.read_csv(args.file, sep="\t", names=['text', 'label']) + + perplexity = evaluate.load("perplexity", module_type="measurement") + ppl_results = perplexity.compute(data=df['text'].tolist(), model_id=args.model, batch_size=args.batch_size, add_start_token=True) + print(f"perplexity: {ppl_results['mean_perplexity']}") \ No newline at end of file diff --git a/evaluation/extract_ref_hyp.py b/evaluation/extract_ref_hyp.py new file mode 100644 index 0000000000..b4c17de5a6 --- /dev/null +++ b/evaluation/extract_ref_hyp.py @@ -0,0 +1,47 @@ +import os +import argparse +import pandas as pd +from pathlib import Path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='extract reference and hypothesis from model generation') + parser.add_argument("--file", default="/netscratch/anonymous/results/generations/unsup/motra-old/699517/generate-test.txt") + parser.add_argument("--out_dir", default="/netscratch/jalota/datasets/motra-preprocessed/en_de/test/src_hyp/") + parser.add_argument("--name", default="699517.tsv") + args = parser.parse_args() + contains_dup = False + # if "bt_test" in args.file: + # contains_dup = True + + # gen_modifiedComparable_translated_test.txt + # gen_no_threshold.txt + # gen_w_threshold_translated_test.txt + + Path(args.out_dir).mkdir(parents=True, exist_ok=True) + + srcs = [] + hyps = [] + + with open(args.file, encoding="utf-8") as f: + lines = f.readlines() + for i, line in enumerate(lines): + if line.startswith("H-"): + line = line.split() + line = " ".join(line[2:]) + hyps.append(line) + elif line.startswith("S-"): + line = line.split() + line = " ".join(line[1:]) + srcs.append(line) + else: + continue + print(len(srcs), len(hyps)) + df = pd.DataFrame( + { + 'source': srcs, + 'hypothesis': hyps + } + ) + df.to_csv(args.out_dir+args.name, sep='\t', index=False) + \ No newline at end of file diff --git a/evaluation/gen_fsq_ppl_data.py b/evaluation/gen_fsq_ppl_data.py new file mode 100644 index 0000000000..9896dcb7a5 --- /dev/null +++ b/evaluation/gen_fsq_ppl_data.py @@ -0,0 +1,70 @@ +import os +import argparse +import pandas as pd +from pathlib import Path + + +if __name__ == "__main__": + """ + Run: sed -i '1,53d' gen.txt to remove logger outputs before passing the generated file. + sed '1,56d' gen.txt > new_gen.txt + """ + parser = argparse.ArgumentParser(description='generate test data for binary classification from fairseq-generate output') + parser.add_argument("--file", default="/home/jalota/gen_w_threshold_translated_test.txt") + parser.add_argument("--out_dir", default="/netscratch/jalota/test_perplexity/") + parser.add_argument("--name", default="test") + parser.add_argument("--exp", default="712684") + args = parser.parse_args() + contains_dup = False + + path = args.out_dir + args.exp + + Path(path).mkdir(parents=True, exist_ok=True) + + with open(args.file, encoding="utf-8") as f: + lines = f.readlines() + with open(f"{path}/{args.name}", "w") as of: + if not contains_dup: + count = 0 + for i, line in enumerate(lines): + if line.startswith("H-"): + line = line.split() + line = " ".join(line[2:]) + # tr = lines[i-2].split() + # tr = " ".join(tr[1:]) + # if tr.strip() == "!" or tr.strip() == "co-rapporteur ." or tr.strip() == "Thank you very much for your attention .": + # print(tr) + # continue + of.write(f"{line}") + of.write("\n") + count += 1 + else: + continue + print(count) + else: + i = 0 + bt2og_like = dict() + while i < len(lines): + if lines[i].startswith("T-"): + tr = lines[i].split() + tr = " ".join(tr[1:]) + i += 2 + if i < len(lines) and lines[i].startswith("D-"): + og_like = lines[i].split() + og_like = " ".join(og_like[2:]) + + if tr not in bt2og_like: + bt2og_like[tr] = og_like + i += 1 + ogl_list = bt2og_like.values() + print(f"len ogl_list: {len(ogl_list)}") + for ogl in ogl_list: + of.write(f"{ogl}\t1") + of.write("\n") + + + + + + + diff --git a/evaluation/gen_test_data.py b/evaluation/gen_test_data.py new file mode 100644 index 0000000000..36888a6c53 --- /dev/null +++ b/evaluation/gen_test_data.py @@ -0,0 +1,74 @@ +import os +import argparse +import pandas as pd +from pathlib import Path + + +if __name__ == "__main__": + """ + Run: sed -i '1,53d' gen.txt to remove logger outputs before passing the generated file. + sed '1,56d' gen.txt > new_gen.txt + """ + parser = argparse.ArgumentParser(description='generate test data for binary classification from fairseq-generate output') + parser.add_argument("--file", default="/home/anonymous/gen_w_threshold_translated_test.txt") + parser.add_argument("--out_dir", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/test/generated/") + parser.add_argument("--name", default="pred_test.tsv") + args = parser.parse_args() + contains_dup = False + # if "bt_test" in args.file: + # contains_dup = True + + # gen_modifiedComparable_translated_test.txt + # gen_no_threshold.txt + # gen_w_threshold_translated_test.txt + + Path(args.out_dir).mkdir(parents=True, exist_ok=True) + + with open(args.file, encoding="utf-8") as f: + lines = f.readlines() + with open(args.out_dir+args.name, "w") as of: + if not contains_dup: + count = 0 + for i, line in enumerate(lines): + if line.startswith("H-"): + line = line.split() + line = " ".join(line[2:]) + # tr = lines[i-2].split() + # tr = " ".join(tr[1:]) + # if tr.strip() == "!" or tr.strip() == "co-rapporteur ." or tr.strip() == "Thank you very much for your attention .": + # print(tr) + # continue + # if len(line.split()) < 510: + of.write(f"{line}\t1") + of.write("\n") + count += 1 + else: + continue + print(count) + else: + i = 0 + bt2og_like = dict() + while i < len(lines): + if lines[i].startswith("T-"): + tr = lines[i].split() + tr = " ".join(tr[1:]) + i += 2 + if i < len(lines) and lines[i].startswith("D-"): + og_like = lines[i].split() + og_like = " ".join(og_like[2:]) + + if tr not in bt2og_like: + bt2og_like[tr] = og_like + i += 1 + ogl_list = bt2og_like.values() + print(f"len ogl_list: {len(ogl_list)}") + for ogl in ogl_list: + of.write(f"{ogl}\t1") + of.write("\n") + + + + + + + diff --git a/evaluation/qualitative_analysis.py b/evaluation/qualitative_analysis.py new file mode 100644 index 0000000000..5e6a2d6813 --- /dev/null +++ b/evaluation/qualitative_analysis.py @@ -0,0 +1,69 @@ +import argparse +import spacy +from typing import List +import pandas as pd + +allowed_postags=['NOUN', 'ADJ', 'VERB', 'ADV'] # or any other types +nlp = spacy.load('en_core_web_trf') +# nlp = spacy.load('de_dep_news_trf') + +def token_filter(token): + return (token.pos_ in allowed_postags) & (not (token.is_punct | token.is_space | + token.is_stop)) # | len(token.text) <= 2 + +def type_token_ratio(text): + if not text.strip(): + raise ValueError + tokens = text.split() + types = set(tokens) + return len(types) / len(tokens) + +def lexical_density(all_docs: List[str]): + content_words = 0 + total_words = 0 + all_docs = [str(doc) for doc in all_docs] + + for doc in nlp.pipe(all_docs): + content_toks = [token.lemma_ for token in doc if token_filter(token)] + # print(content_toks) + content_words += len(content_toks) + total_words += len(doc) + + return (content_words/total_words)*100 + +def length_variety(src, hyp): + return abs(len(str(src))-len(str(hyp))) / len(str(src)) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='run qualitative analysis') + parser.add_argument("--file", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/test/src_hyp/699528.txt") + parser.add_argument("--translated", action='store_true') + args = parser.parse_args() + print(args.file) + + if not args.translated: + df = pd.read_csv(args.file, sep="\t", names=['source', 'hypothesis'], header=0) + + df['ttr'] = df.apply(lambda row : type_token_ratio(str(row['hypothesis'])), axis = 1) + + df['lv'] = df.apply(lambda row : length_variety(row['source'], row['hypothesis']), axis = 1) + + print(f"AVG TTR: {df.loc[:, 'ttr'].mean()}") + print(f"AVG lexical density: {lexical_density(df['hypothesis'].tolist())}") + print(f"AVG length variety: {df.loc[:, 'lv'].mean()}") + + else: + df = pd.read_csv(args.file, sep="\t", names=['text', 'label']) + + df['ttr'] = df.apply(lambda row : type_token_ratio(row['text']), axis = 1) + + print(f"AVG TTR: {df.loc[:, 'ttr'].mean()}") + print(f"AVG lexical density: {lexical_density(df['text'].tolist())}") + + + + + + + + diff --git a/evaluation/remove_duplicates.py b/evaluation/remove_duplicates.py new file mode 100644 index 0000000000..bac91c92cf --- /dev/null +++ b/evaluation/remove_duplicates.py @@ -0,0 +1,70 @@ +import os +import argparse +import pandas as pd +from pathlib import Path +import numpy as np + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='generate test data for binary classification from fairseq-generate output') + parser.add_argument("--file", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/test/generated/gen_test_59030.tsv") + parser.add_argument("--og_file", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/test/test.tsv") + parser.add_argument("--out_dir", default="/netscratch/anonymous/datasets/motra-preprocessed/en_de/test/generated/") + args = parser.parse_args() + + seen = set() + dup = set() + uniq = [] + df = pd.read_csv(args.file, sep='\t', skipinitialspace=True, header=None, names=['text', 'label']) + # df.sort(columns=[0], inplace=True) + df.sort_values(by='text', inplace=True) + for _, row in df.iterrows(): + if row[0] in seen and row[1] == 0: + print(row[0]) + seen.remove(row[0]) + dup.add(row[0]) + elif row[0] in seen and row[1] ==1: + print(f"label 1: {row[0]}") + else: + seen.add(row[0]) + + for _, row in df.iterrows(): + if row[0] not in dup: + uniq.append(row) + + print(len(uniq)) + df2 = pd.DataFrame(uniq) + print(df2.head()) + df2.to_csv(args.out_dir+"no_dup_gen_test_59030.tsv", header=None, index=False, sep='\t') + + print(f"num examples in each label: {df2.groupby('label').size()}") + + print(len(dup)) + + df_og = pd.read_csv(args.og_file, sep='\t', skipinitialspace=True, header=None, names=['text', 'label']) + fin_test = [] + count = 0 + for _, row in df_og.iterrows(): + if row[0] not in dup: # removes original sent. found in dup + fin_test.append(row) + else: + count +=1 + + print(f"count: {count}") + dff = pd.DataFrame(fin_test) + to_remove = np.random.choice(dff[dff['label']==1].index,size=count,replace=False) # remove equal number of translated + dff.drop(to_remove, inplace=True) + print(len(dff)) + dff.to_csv(args.out_dir+"modified_test_59030.tsv", header=None, index=False, sep='\t') + print(f"num examples in each label: {dff.groupby('label').size()}") + + + + +################# +""""" +dev.tsv -> original.tsv translated.tsv (comparable) +test.tsv -> original.tsv translated.tsv (comparable) +""""" + + + diff --git a/evaluation/run_glue_no_trainer.py b/evaluation/run_glue_no_trainer.py new file mode 100644 index 0000000000..fedbd39cda --- /dev/null +++ b/evaluation/run_glue_no_trainer.py @@ -0,0 +1,652 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning a 🤗 Transformers model for sequence classification on GLUE.""" +import argparse +import json +import logging +import math +import os +import random +from pathlib import Path + +import datasets +import evaluate +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from datasets import load_dataset +from huggingface_hub import Repository, create_repo +from torch.utils.data import DataLoader +from tqdm.auto import tqdm +from datasets import Features,Value,ClassLabel + +import transformers +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + PretrainedConfig, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.30.0.dev0") + +logger = get_logger(__name__) + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") + parser.add_argument( + "--task_name", + type=str, + default=None, + help="The name of the glue task to train on.", + choices=list(task_to_keys.keys()), + ) + parser.add_argument( + "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) + parser.add_argument( + "--max_length", + type=int, + default=128, + help=( + "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," + " sequences shorter will be padded if `--pad_to_max_length` is passed." + ), + ) + parser.add_argument( + "--pad_to_max_length", + action="store_true", + help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), + ) + parser.add_argument( + "--ignore_mismatched_sizes", + action="store_true", + help="Whether or not to enable to load a pretrained model whose head dimensions are different.", + ) + args = parser.parse_args() + + # Sanity checks + if args.task_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a task name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def main(): + args = parse_args() + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_glue_no_trainer", args) + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + accelerator = ( + Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator() + ) + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). + + # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the + # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named + # label if at least two columns are provided. + + # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this + # single column. You can easily tweak this behavior (see below) + + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.task_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset("glue", args.task_name) + else: + # Loading the dataset from local csv or json file. + data_files = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1] + raw_datasets = load_dataset(extension, data_files=data_files, delimiter="\t") + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Labels + if args.task_name is not None: + is_regression = args.task_name == "stsb" + if not is_regression: + label_list = raw_datasets["train"].features["label"].names + num_labels = len(label_list) + else: + num_labels = 1 + else: + # Trying to have good defaults here, don't hesitate to tweak to your needs. + is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] + if is_regression: + num_labels = 1 + else: + # A useful fast method: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique + label_list = raw_datasets["train"].unique("label") + label_list.sort() # Let's sort it for determinism + num_labels = len(label_list) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained(args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name, cache_dir="/netscratch/anonymous/datasets/hf-cache/") + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, cache_dir="/netscratch/anonymous/datasets/hf-cache/") + model = AutoModelForSequenceClassification.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ignore_mismatched_sizes=args.ignore_mismatched_sizes, + cache_dir="/netscratch/anonymous/datasets/hf-cache/" + ) + + # Preprocessing the datasets + if args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[args.task_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] + if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: + sentence1_key, sentence2_key = "sentence1", "sentence2" + else: + if len(non_label_column_names) >= 2: + sentence1_key, sentence2_key = non_label_column_names[:2] + else: + sentence1_key, sentence2_key = non_label_column_names[0], None + + # Some models have set the order of the labels to use, so let's make sure we do use it. + label_to_id = None + if ( + model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + and args.task_name is not None + and not is_regression + ): + # Some have all caps in their config, some don't. + label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + if sorted(label_name_to_id.keys()) == sorted(label_list): + logger.info( + f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " + "Using it!" + ) + label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} + else: + logger.warning( + "Your model seems to have been trained with labels, but they don't match the dataset: ", + f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}." + "\nIgnoring the model labels as a result.", + ) + elif args.task_name is None and not is_regression: + label_to_id = {v: i for i, v in enumerate(label_list)} + + if label_to_id is not None: + model.config.label2id = label_to_id + model.config.id2label = {id: label for label, id in config.label2id.items()} + elif args.task_name is not None and not is_regression: + model.config.label2id = {l: i for i, l in enumerate(label_list)} + model.config.id2label = {id: label for label, id in config.label2id.items()} + + padding = "max_length" if args.pad_to_max_length else False + + # features = Features({ 'label': ClassLabel(names=['0', '1']), 'text': Value('string')}) + # num_labels = features['label'].num_classes + labels = ClassLabel(names=['0', '1']) + + def preprocess_function(examples): + # Tokenize the texts + logger.info(f'num examples: {len(examples)}') + # examples = examples.filter(lambda example: example['label'] is not None) + logger.info(f'num examples after filtering: {len(examples)}') + + texts = ( + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True) + + if "label" in examples: + # if label_to_id is not None: + # # Map labels to IDs (not necessary for GLUE tasks) + # result["labels"] = [label_to_id[l] for l in examples["label"]] + # else: + # In all cases, rename the column to labels because the model will expect that. + # logger.info(f"examples['label']: {examples['label']}") + result["labels"] = [labels.str2int(str(label)) if label is not None else None for label in examples["label"]] + #features["label"].str2int(examples["label"]) #examples["label"] + return result + + with accelerator.main_process_first(): + processed_datasets = raw_datasets.map( + preprocess_function, + batched=True, + remove_columns=raw_datasets["train"].column_names, + desc="Running tokenizer on dataset", + ) + logger.info(f"processed_datasets.column_names: {processed_datasets.column_names}") + + # processed_datasets = processed_datasets.filter(lambda example: example['label'] is not None) + + train_dataset = processed_datasets["train"] + eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] + train_dataset = train_dataset.filter(lambda example: example['labels'] is not None) + eval_dataset = eval_dataset.filter(lambda example: example['labels'] is not None) + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + if args.pad_to_max_length: + # If padding was already done ot max length, we use the default data collator that will just convert everything + # to tensors. + data_collator = default_data_collator + else: + # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of + # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple + # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=(8 if accelerator.use_fp16 else None)) + + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size + ) + eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = int(checkpointing_steps) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value + accelerator.init_trackers("glue_no_trainer", experiment_config) + + # Get the metric function + if args.task_name is not None: + metric = evaluate.load("glue", args.task_name) + else: + metric = evaluate.load("accuracy") + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + starting_epoch = 0 + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": + accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") + accelerator.load_state(args.resume_from_checkpoint) + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] + dirs.sort(key=os.path.getctime) + path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + else: + resume_step = int(training_difference.replace("step_", "")) + starting_epoch = resume_step // len(train_dataloader) + resume_step -= starting_epoch * len(train_dataloader) + + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + if args.with_tracking: + total_loss = 0 + if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: + # We skip the first `n` batches in the dataloader when resuming from a checkpoint + active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) + else: + active_dataloader = train_dataloader + for step, batch in enumerate(active_dataloader): + outputs = model(**batch) + loss = outputs.loss + # We keep track of the loss at each epoch + if args.with_tracking: + total_loss += loss.detach().float() + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + progress_bar.update(1) + completed_steps += 1 + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps }" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + if completed_steps >= args.max_train_steps: + break + + model.eval() + samples_seen = 0 + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() + predictions, references = accelerator.gather((predictions, batch["labels"])) + # If we are in a multiprocess environment, the last batch has duplicates + if accelerator.num_processes > 1: + if step == len(eval_dataloader) - 1: + predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] + references = references[: len(eval_dataloader.dataset) - samples_seen] + else: + samples_seen += references.shape[0] + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + logger.info(f"epoch {epoch}: {eval_metric}") + + if args.with_tracking: + accelerator.log( + { + "accuracy" if args.task_name is not None else "glue": eval_metric, + "train_loss": total_loss.item() / len(train_dataloader), + "epoch": epoch, + "step": completed_steps, + }, + step=completed_steps, + ) + + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True + ) + + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + if args.with_tracking: + accelerator.end_training() + + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + if args.task_name == "mnli": + # Final evaluation on mismatched validation set + eval_dataset = processed_datasets["validation_mismatched"] + eval_dataloader = DataLoader( + eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size + ) + eval_dataloader = accelerator.prepare(eval_dataloader) + + model.eval() + for step, batch in enumerate(eval_dataloader): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + metric.add_batch( + predictions=accelerator.gather(predictions), + references=accelerator.gather(batch["labels"]), + ) + + eval_metric = metric.compute() + logger.info(f"mnli-mm: {eval_metric}") + + if args.output_dir is not None: + all_results = {f"eval_{k}": v for k, v in eval_metric.items()} + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump(all_results, f) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/semi_supervised/1gpu_bart_run_pretrain.sh b/examples/semi_supervised/1gpu_bart_run_pretrain.sh new file mode 100755 index 0000000000..8c621d6681 --- /dev/null +++ b/examples/semi_supervised/1gpu_bart_run_pretrain.sh @@ -0,0 +1,99 @@ +#!/bin/bash + +# We jump into the submission dir +cd ${SLURM_SUBMIT_DIR} +MASTER=`echo $SLURM_JOB_NODELIST | cut -d"," -f1 | sed 's/[][]//g' | cut -d "-" -f 1,2` +NUM_GPUS=1; # number of gpu +MKL_NUM_THREADS=1 +NUMEXPR_NUM_THREADS=1 +export OMP_NUM_THREADS=1 +export USE_OPENMP=1 # prevents openblas to override OMP_NUM_THREADS + +# But if you are using conda (uncomment the lines below) +. /netscratch/jalota/miniconda3/etc/profile.d/conda.sh +conda activate fsq012 +python3 --version +# conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia +# conda install -c conda-forge cudatoolkit==11.7.0 +export CUDA_HOME=/usr/local/cuda +nvcc --version +var=$(which nvcc) +echo "var: ${var}" +echo "cuda home: ${CUDA_HOME}" +nvidia-smi + +echo "ngpus: ${NUM_GPUS}" +echo "SLURM_CPUS_PER_TASK: ${SLURM_CPUS_PER_TASK}" +echo "master: ${MASTER}" + +# # Step 1. preprocess the data! +# python3 preprocess.py --dir /netscratch/jalota/datasets/europarl/ --out /netscratch/jalota/datasets/europarl-ppd/ +# python3 preprocess.py --dir /netscratch/jalota/datasets/motra/en_de/dev --out /netscratch/jalota/datasets/motra-preprocessed/en_de/dev/ +# python3 preprocess.py --dir /netscratch/jalota/datasets/motra/en_de/test --out /netscratch/jalota/datasets/motra-preprocessed/en_de/test/ +# python3 preprocess.py --dir /netscratch/jalota/datasets/motra/en_es/train --out /netscratch/jalota/datasets/motra-preprocessed/en_es/train/ +# python3 preprocess.py --dir /netscratch/jalota/datasets/motra/en_es/dev --out /netscratch/jalota/datasets/motra-preprocessed/en_es/dev/ +# python3 preprocess.py --dir /netscratch/jalota/datasets/motra/en_es/test --out /netscratch/jalota/datasets/motra-preprocessed/en_es/test/ + +# # Step 2. run preprocess.sh +# ./preprocess.sh +# cat /netscratch/jalota/datasets/europarl-ppd/europarl.txt | sacremoses -l en -j 4 normalize -c -q -d tokenize truecase -m /netscratch/jalota/eu.truemodel > /netscratch/jalota/datasets/europarl-ppd/europarl.tok.txt + +# in="/netscratch/jalota/datasets/europarl-ppd/europarl.tok.txt" +# train="$in.train" +# test="$in.test" +# awk -v train="$train" -v test="$test" '{if(rand()<0.9) {print > train} else {print > test}}' $in +# Step 3. Apply bpe and then add style-labels +# ./subword.sh + +which nvcc + +./subword-nmt/learn_bpe.py -s 8000 < /netscratch/jalota/datasets/europarl-motra/europarl-motra-train.txt > /netscratch/jalota/datasets/europarl-motra/codes.txt +./subword-nmt/apply_bpe.py -c /netscratch/jalota/datasets/europarl-motra/codes.txt < /netscratch/jalota/datasets/europarl-motra/europarl-motra-train.txt > /netscratch/jalota/datasets/europarl-motra/bpe/europarl.train.bpe +./subword-nmt/apply_bpe.py -c /netscratch/jalota/datasets/europarl-motra/codes.txt < /netscratch/jalota/datasets/europarl-ppd/europarl.valid > /netscratch/jalota/datasets/europarl-motra/bpe/europarl.valid.bpe +./subword-nmt/apply_bpe.py -c /netscratch/jalota/datasets/europarl-motra/codes.txt < /netscratch/jalota/datasets/europarl-ppd/europarl.test > /netscratch/jalota/datasets/europarl-motra/bpe/europarl.test.bpe + +cd fb_fsq/fairseq/fairseq_cli/ +python3 preprocess.py \ + --only-source \ + --trainpref /netscratch/jalota/datasets/europarl-motra/bpe/europarl.train.bpe \ + --validpref /netscratch/jalota/datasets/europarl-motra/bpe/europarl.valid.bpe \ + --testpref /netscratch/jalota/datasets/europarl-motra/bpe/europarl.test.bpe \ + --destdir /netscratch/jalota/datasets/data-bin/europarl-motra/subword-nmt/europarl \ + --workers 60 + +export CUDA_VISIBLE_DEVICES=0,1,2,3 +DATA=/netscratch/jalota/datasets/data-bin/subword-nmt/europarl +export PYTORCH_NO_CUDA_MEMORY_CACHING=1 +export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:16024 +export CUDA_LAUNCH_BLOCKING=1 +cd .. +# torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} \ +# python3 -m torch.distributed.launch \ +# --nproc_per_node=${NUM_GPUS} \ +# --nnodes=$SLURM_JOB_NUM_NODES \ +# --master_addr=127.0.0.1 \ +# --master_port=5905 \ +# torchrun --nnodes=$NUM_NODES \ +# --nproc_per_node=${NUM_GPUS} --max_restarts=3 --rdzv_id=$JOB_ID \ +# --rdzv_backend=c10d --rdzv_endpoint=$HOST_NODE_ADDR:5905 \ +torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} \ +train.py $DATA \ +--mask 0.3 --tokens-per-sample 512 --fp16 --fp16-init-scale 4 \ +--total-num-update 500000 --max-update 500000 --checkpoint-activations \ +--warmup-updates 10000 --task denoising --save-interval 1 \ +--max-source-positions 512 \ +--max-target-positions 512 \ +--arch transformer --optimizer adam --lr-scheduler polynomial_decay \ +--lr 0.0004 --dropout 0.1 --criterion cross_entropy --max-tokens 8048 \ +--weight-decay 0.01 --attention-dropout 0.1 --share-all-embeddings \ +--clip-norm 0.1 --skip-invalid-size-inputs-valid-test --log-format json \ +--log-interval 1000 --save-interval-updates 5000 --keep-interval-updates 1 \ +--update-freq 16 --seed 4 --distributed-world-size $NUM_GPUS \ +--keep-best-checkpoints 10 \ +--mask-length span-poisson --replace-length 1 --encoder-learned-pos \ +--decoder-learned-pos --rotate 0.0 --mask-random 0.1 --save-dir /netscratch/jalota/checkpoints/subword-nmt-bpe-transformer_gpu1_cpu12 \ +--permute-sentences 1 --insert 0.0 --poisson-lambda 3.5 \ +--dataset-impl mmap --num-workers ${SLURM_CPUS_PER_TASK} + +# --bpe subword_nmt --optimizer cpu_adam --cpu-offload --ddp-backend fully_sharded + diff --git a/examples/semi_supervised/multi-gpu_bart_run_pretrain.sh b/examples/semi_supervised/multi-gpu_bart_run_pretrain.sh new file mode 100755 index 0000000000..fd3bc38199 --- /dev/null +++ b/examples/semi_supervised/multi-gpu_bart_run_pretrain.sh @@ -0,0 +1,95 @@ +#!/bin/bash + +# We jump into the submission dir +cd ${SLURM_SUBMIT_DIR} +MASTER=`echo $SLURM_JOB_NODELIST | cut -d"," -f1 | sed 's/[][]//g' | cut -d "-" -f 1,2` +NUM_GPUS=4; # number of gpu +MKL_NUM_THREADS=1 +NUMEXPR_NUM_THREADS=1 +export OMP_NUM_THREADS=1 +export USE_OPENMP=1 # prevents openblas to override OMP_NUM_THREADS + +# But if you are using conda (uncomment the lines below) +. /netscratch/jalota/miniconda3/etc/profile.d/conda.sh +conda activate fsq012 +python3 --version +# conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia +# conda install -c conda-forge cudatoolkit==11.7.0 +export CUDA_HOME=/usr/local/cuda +nvcc --version +var=$(which nvcc) +echo "var: ${var}" +echo "cuda home: ${CUDA_HOME}" +nvidia-smi + +echo "ngpus: ${NUM_GPUS}" +echo "SLURM_CPUS_PER_TASK: ${SLURM_CPUS_PER_TASK}" +echo "master: ${MASTER}" + +# # Step 1. preprocess the data! +# python3 preprocess.py --dir /netscratch/jalota/datasets/europarl/ --out /netscratch/jalota/datasets/europarl-ppd/ +# python3 preprocess.py --dir /netscratch/jalota/datasets/motra/en_de/dev --out /netscratch/jalota/datasets/motra-preprocessed/en_de/dev/ +# python3 preprocess.py --dir /netscratch/jalota/datasets/motra/en_de/test --out /netscratch/jalota/datasets/motra-preprocessed/en_de/test/ +# python3 preprocess.py --dir /netscratch/jalota/datasets/motra/en_es/train --out /netscratch/jalota/datasets/motra-preprocessed/en_es/train/ +# python3 preprocess.py --dir /netscratch/jalota/datasets/motra/en_es/dev --out /netscratch/jalota/datasets/motra-preprocessed/en_es/dev/ +# python3 preprocess.py --dir /netscratch/jalota/datasets/motra/en_es/test --out /netscratch/jalota/datasets/motra-preprocessed/en_es/test/ + +# # Step 2. run preprocess.sh +# ./preprocess.sh +# cat /netscratch/jalota/datasets/europarl-ppd/europarl.txt | sacremoses -l en -j 4 normalize -c -q -d tokenize truecase -m /netscratch/jalota/eu.truemodel > /netscratch/jalota/datasets/europarl-ppd/europarl.tok.txt + +# in="/netscratch/jalota/datasets/europarl-ppd/europarl.tok.txt" +# train="$in.train" +# test="$in.test" +# awk -v train="$train" -v test="$test" '{if(rand()<0.9) {print > train} else {print > test}}' $in +# Step 3. Apply bpe and then add style-labels +# ./subword.sh + +which nvcc + +# ./subword-nmt/learn_bpe.py -s 8000 < /netscratch/jalota/datasets/europarl-motra/europarl-motra-train.txt > /netscratch/jalota/datasets/europarl-motra/codes.txt +# ./subword-nmt/apply_bpe.py -c /netscratch/jalota/datasets/europarl-motra/codes.txt < /netscratch/jalota/datasets/europarl-motra/europarl-motra-train.txt > /netscratch/jalota/datasets/europarl-motra/bpe/europarl.train.bpe +# ./subword-nmt/apply_bpe.py -c /netscratch/jalota/datasets/europarl-motra/codes.txt < /netscratch/jalota/datasets/europarl-ppd/europarl.valid > /netscratch/jalota/datasets/europarl-motra/bpe/europarl.valid.bpe +# ./subword-nmt/apply_bpe.py -c /netscratch/jalota/datasets/europarl-motra/codes.txt < /netscratch/jalota/datasets/europarl-ppd/europarl.test > /netscratch/jalota/datasets/europarl-motra/bpe/europarl.test.bpe + +# cd fb_fsq/fairseq/fairseq_cli/ +# python3 preprocess.py \ +# --only-source \ +# --trainpref /netscratch/jalota/datasets/europarl-motra/bpe/europarl.train.bpe \ +# --validpref /netscratch/jalota/datasets/europarl-motra/bpe/europarl.valid.bpe \ +# --testpref /netscratch/jalota/datasets/europarl-motra/bpe/europarl.test.bpe \ +# --destdir /netscratch/jalota/datasets/data-bin/europarl-motra/subword-nmt/europarl \ +# --workers 60 + +export CUDA_VISIBLE_DEVICES=0,1,2,3 +DATA=/netscratch/jalota/datasets/data-bin/subword-nmt/europarl +export PYTORCH_NO_CUDA_MEMORY_CACHING=1 +export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:16024 +export CUDA_LAUNCH_BLOCKING=1 +# cd .. +cd fb_fsq/fairseq/ +# torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} \ + +# torchrun --nnodes=$NUM_NODES \ +# --nproc_per_node=${NUM_GPUS} --max_restarts=3 --rdzv_id=$JOB_ID \ +# --rdzv_backend=c10d --rdzv_endpoint=$HOST_NODE_ADDR:5905 \ +torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} \ +train.py $DATA \ +--mask 0.3 --tokens-per-sample 512 --fp16 --fp16-init-scale 4 \ +--total-num-update 500000 --max-update 500000 --checkpoint-activations \ +--warmup-updates 10000 --task denoising --save-interval 1 \ +--max-source-positions 512 \ +--max-target-positions 512 \ +--arch transformer --optimizer adam --lr-scheduler polynomial_decay \ +--lr 0.0004 --dropout 0.1 --criterion cross_entropy --max-tokens 16048 \ +--weight-decay 0.01 --attention-dropout 0.1 --share-all-embeddings \ +--clip-norm 0.1 --skip-invalid-size-inputs-valid-test --log-format json \ +--log-interval 1000 --save-interval-updates 5000 --keep-interval-updates 1 \ +--update-freq 8 --seed 4 --distributed-world-size $NUM_GPUS \ +--keep-best-checkpoints 10 \ +--mask-length span-poisson --replace-length 1 --encoder-learned-pos \ +--decoder-learned-pos --rotate 0.0 --mask-random 0.1 --save-dir /netscratch/jalota/checkpoints/subword-nmt-bpe-transformer_gpu4_cpu56 \ +--permute-sentences 1 --insert 0.0 --poisson-lambda 3.5 \ +--dataset-impl mmap --num-workers ${SLURM_CPUS_PER_TASK} + +# --bpe subword_nmt --optimizer cpu_adam --cpu-offload --ddp-backend fully_sharded diff --git a/examples/semi_supervised/pretrain_sbatch.sh b/examples/semi_supervised/pretrain_sbatch.sh new file mode 100755 index 0000000000..405da11425 --- /dev/null +++ b/examples/semi_supervised/pretrain_sbatch.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +#SBATCH --nodes=1 # Number of nodes or servers. See: http://koeln.kl.dfki.de:3000/d/slurm-resources/resources?orgId=1&refresh=15s +#SBATCH --ntasks-per-node=1 # Number of task in each node, we want 1 +#SBATCH --cpus-per-task=48 # We want 4 cores for this job. +#SBATCH --mem-per-cpu=8gb # each core to have 16 Gb RAM +#SBATCH --gres=gpu:4 # We want 4 GPUs in each node for this job. +#SBATCH --time=10-0:00 # Run this task no longer than 10 days. +#SBATCH --job-name=v2_pretrain_bart_thrd1_cpu2 +#SBATCH --partition=V100-32GB,RTXA6000,RTX3090,A100-40GB,A100-80GB,A100-PCI,V100-16GB,RTX2080Ti +#SBATCH --output=logs/v2_pretrain_transformer_gpu1_thrd1_cpu12%A.logs + +echo "#############################" +date +echo "Current dir: " ${SLURM_SUBMIT_DIR} +echo "Hostname: `hostname`" + +# Print the task details. +echo "Job ID: ${SLURM_JOBID}" +echo "SLURM array task ID: ${SLURM_ARRAY_TASK_ID}" +echo "Node list: ${SLURM_JOB_NODELIST}" +echo "Cluster name: ${SLURM_CLUSTER_NAME}" +echo "Partition name: ${SLURM_JOB_PARTITION}" +echo "num nodes: ${SLURM_JOB_NUM_NODES}" +echo "Using: `which python`" +echo -e "#############################\n" + +NGPUS=4; # number of gpu +NCPUS_PER_TASK=48; # number of cpu per task +# MEM=50000 # memory increase this if needed + +srun -v \ +--container-mounts=/netscratch/jalota:/netscratch/jalota,/ds:/ds:ro,"`pwd`":"`pwd`",/home/jalota/:/home/jalota/ \ +--container-image=/netscratch/$USER/containers/cuda:11.7.0-devel-ubuntu22.04.sqsh \ +--container-workdir="`pwd`" --no-container-remap-root --gpu-bind=none \ +--gpus=$NGPUS \ +bash {1gpu|multi-gpu}_bart_run_pretrain.sh + +#--cpus-per-task=$NCPUS_PER_TASK +# /netscratch/$USER/containers/cuda:11.7.1-devel-ubuntu22.04 +# /netscratch/$USER/containers/cuda:11.7.0-devel-ubuntu18.04.sqsh +# --container-save=/netscratch/$USER/containers/jalota_pytorch22.12.sqsh \ +# /netscratch/enroot/nvcr.io_nvidia_pytorch_22.12-py3.sqsh -n1 --mem=$MEM +# RTXA6000,V100-32GB,RTX3090 # Run this only in these mentioned GPUs. If you don't have any choice over GPUs, remove this parameter. \ No newline at end of file diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_unsupervised_loss.py b/fairseq/criterions/label_smoothed_cross_entropy_with_unsupervised_loss.py new file mode 100644 index 0000000000..1bd56ae7d3 --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_unsupervised_loss.py @@ -0,0 +1,643 @@ +""" +Classes and methods used for multi-task tuning on Supervised style-transfer and unsupervised LM and cosine similarity losses. +""" + +import torch.nn.functional as F +import torch.nn as nn +from fairseq import criterions +from dataclasses import dataclass, field +import torch +import numpy as np +from torch.nn import MSELoss, CosineSimilarity +import math +from fairseq import metrics, utils +from fairseq.criterions import register_criterion +from torch.nn.functional import gumbel_softmax +from torch.distributions import Categorical +from torch.utils.data import Dataset +from fairseq.data import LMContextWindowDataset, MonolingualDataset +import evaluate +import random +from fairseq.lm_perplexity import LanguageModel, LanguageModelValidation + +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, +) +import logging, os, sys +from fairseq.data.data_utils import collate_tokens +from evaluate import load +from fairseq.scoring.perplexity import Perplexity + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq.criterion.UnsupervisedAugmentedCrossEntropyLoss") + + +class torchDataset(Dataset): + def __init__(self, data_list): + self.data_list = data_list + + def __getitem__(self, index): + return self.data_list[index] + + def __len__(self): + return len(self.data_list) + +def cross_entropy(pred, soft_targets): + # logger.info(f"pred.size(): {pred.size()}") + # logger.info(f"soft_targets.size(): {soft_targets.size()}") + logsoftmax = nn.LogSoftmax(dim=1) + return torch.mean(torch.sum(- soft_targets * logsoftmax(pred), 1)) + + +def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): + if target.dim() == lprobs.dim() - 1: + target = target.unsqueeze(-1) + nll_loss = -lprobs.gather(dim=-1, index=target) + smooth_loss = -lprobs.sum(dim=-1, keepdim=True) + if ignore_index is not None: + pad_mask = target.eq(ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + else: + nll_loss = nll_loss.squeeze(-1) + smooth_loss = smooth_loss.squeeze(-1) + if reduce: + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + eps_i = epsilon / (lprobs.size(-1) - 1) + loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss + return loss, nll_loss + + +@dataclass +class UnsupervisedAugmentedLabelSmoothedCrossEntropyCriterionConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + lm_weight: float = field( + default=0.5, + metadata={"help": "weight fot per-word LM entropy."}, + ) + cosine_weight: float = field( + default=0.5, + metadata={"help": "weight for cosine similarity loss."}, + ) + unsupervised_weight: str = field( + default=0.5, + metadata={"help": "unsupervised loss weightage"}, + ) + supervised_weight: str = field( + default=0.5, + metadata={"help": "supervised loss weightage"}, + ) + pretrained_lm: str = field( + default="/netscratch/jalota/checkpoints/transformer_en_hansard/", + metadata={ + "help": "pretrained fairseq LM model to evaluate PPL during unsupervised training." + }, + ) + pretrained_lm_dict_path: str = field( + default="/netscratch/jalota/datasets/data-bin/canadianHansard/lm/", + metadata={ + "help": "dict path for pretrained fairseq LM model to evaluate PPL during unsupervised training." + }, + ) + lm_context_window: int = field( + default=5, metadata={"help": "context window size for evaluating PPL"} + ) + bertscore_model: str = field( + default="roberta-base", + metadata={ + "help": "which model to use for evaluating semantic similarity. for EN: roberta-base, DE: t5-base" + }, + ) + + +@register_criterion( + "unsupervised_augmented_label_smoothed_cross_entropy", + dataclass=UnsupervisedAugmentedLabelSmoothedCrossEntropyCriterionConfig, +) +class UnsupervisedAugmentedLabelSmoothedCrossEntropyCriterion( + LabelSmoothedCrossEntropyCriterion +): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + lm_weight=1, + cosine_weight=1, + unsupervised_weight=1, + supervised_weight=1, + bertscore_model='roberta-base', + lm_context_window=5, + pretrained_lm_dict_path="/netscratch/jalota/datasets/data-bin/canadianHansard/lm/", + pretrained_lm="/netscratch/jalota/checkpoints/transformer_en_hansard/", + tau_gumbel_softmax=0.1, + hard_gumbel_softmax=False, + eps_gumbel_softmax=1e-10, + soft_bert_score=False + ): + # 'microsoft/deberta-v3-base' t5-base + # roberta-base for EN + super().__init__( + task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy + ) + self.lm_weight = torch.tensor(1) + self.cosine_weight = torch.tensor(1) + self.unsupervised_weight = torch.tensor(0.3) + self.supervised_weight = torch.tensor(0.7) + self.perplexity = Perplexity() + self.cosine_sim = CosineSimilarity() + self.mse_loss = MSELoss(reduction='mean') + self.bertscore_model = bertscore_model + + self.tau_gumbel_softmax = tau_gumbel_softmax + self.hard_gumbel_softmax = hard_gumbel_softmax + self.eps_gumbel_softmax = eps_gumbel_softmax + self.pretrained_lm = pretrained_lm + self.pretrained_lm_dict_path = pretrained_lm_dict_path + self.lm_context_window = lm_context_window + + # self.bert_scorer = BERTScorer(self.bert_model, soft_bert_score=soft_bert_score) # , device='cpu') + # self.pad_token_id = self.bert_scorer._tokenizer.convert_tokens_to_ids('[PAD]') + # hansard: /netscratch/jalota/checkpoints/transformer_en_hansard/ + # hansard_data: /netscratch/jalota/datasets/data-bin/canadianHansard/lm/ + # de: /netscratch/jalota/checkpoints/transformer_lm_de_finetuned/ + # de_data: /netscratch/jalota/datasets/motra-sst/de/unsup_setup_raw/lm_finetuning/ + self.bertscore = evaluate.load("bertscore") + self.lm = LanguageModel(path=self.pretrained_lm,tgt_dict=task.tgt_dict,data_name_or_path=self.pretrained_lm_dict_path) + self.val_lm = LanguageModelValidation(path=self.pretrained_lm,tgt_dict=task.tgt_dict, context_window=self.lm_context_window,data_name_or_path=self.pretrained_lm_dict_path) + # /netscratch/jalota/datasets/motra-sst/de/unsup_setup_raw/lm_finetuning/ + # DE: /netscratch/jalota/checkpoints/transformer_lm_de_finetuned/ + # EN: /netscratch/jalota/checkpoints/transformer_lm_en_finetuned/ + # data_name_or_path='/netscratch/jalota/datasets/motra-sst/ppd_w_europarl-motra-10k_no_dups/en_es_de/unsup_setup/lm_finetune/' + + #load("perplexity", module_type="measurement") + + def forward(self, model, sample, seqeunce_generator=None, tgt_dict=None,reduce=True, unsup=False, src_dict=None, train=True, only_unsupervised=False): + + logging_output = {} + loss = 0.0 + sample_size_set = False + # only_unsupervised = False + if train and not only_unsupervised: + net_output = model(**sample["sup"]["net_input"]) + loss_sum, nll_loss_sum = self.compute_loss(model, net_output, sample["sup"], reduce=reduce) + sample_size = ( + sample['sup']["target"].size(0) if self.sentence_avg else sample['sup']["ntokens"] + ) + # logger.info(f'sample["sup"]["net_input"]["prev_output_tokens"]: {sample["sup"]["net_input"]["prev_output_tokens"]}') + ## take the mean of loss and nll_loss here and convert them from log base e to 2 + loss = loss_sum / sample_size / math.log(2) + nll_loss = nll_loss_sum / sample['sup']["ntokens"] / math.log(2) + # NOTE: + # # we don't need to use sample_size/ntokens as denominator for the gradient + # # here sample_size & ntokens are just used for logging + sample_size = 1 + sample_size_set = True + if unsup: + loss = self.supervised_weight * loss + + logging_output = { + "loss" : loss.data, + "nll_loss": nll_loss.data, + "ntokens": 1, + "nsentences": sample['sup']["target"].size(0), + "sample_size": sample_size, + } + # sample['sup']["ntokens"] + if unsup: + if train: + if not only_unsupervised: + sample = sample['unsup'] + # in case of eval, dataset is not RoundRobin, thus 'sample' can be used directly, & is not an OrderedDict! + + def decode(toks, src=False, escape_unk=False): + if src: + s = src_dict.string(toks, bpe_symbol="subword_nmt",) + return s.replace("", "").rstrip() + + return tgt_dict.string( + toks, + bpe_symbol="subword_nmt", + ).replace("", "").rstrip() + + with torch.no_grad(): + if any(sample["net_input"]["src_lengths"]) > 510: + logger.info(f'sample["net_input"]["src_lengths"]: {sample["net_input"]["src_lengths"]}') + gen_out = seqeunce_generator.generate( + [model], sample, prefix_tokens=None, constraints=None) + + # logger.info(f"gen_out: {gen_out}") + hyps, hyps_tok = [], [] + for i in range(len(gen_out)): + # s = decode(gen_out[i][0]["tokens"]).strip() + # if len(s) > 0: + # hyps_tok.append(s) + hyps.append(gen_out[i][0]["tokens"]) + + msize = max(v.size(0) for v in hyps) + msize = msize if msize <= 512 else 512 + # logger.info(f"msize: {msize}") + + hyps = collate_tokens(hyps, src_dict.pad(), src_dict.eos(), move_eos_to_beginning=False, left_pad=False, pad_to_length=512,pad_to_bsz=None) + + batch_size = len(hyps) + + if not train: + # calculate bertscore and PPL straight-away! + refs_list = [] + hyps_tok = [] + refs = sample['net_input']['src_tokens'] + for i in range(len(refs)): + s = decode(refs[i]).strip() + hs = decode(gen_out[i][0]["tokens"]).strip() + if len(s.split()) > 2 and len(hs.split()) > 1: + hyps_tok.append(hs) + refs_list.append(s) + + # refs_list.append(s) + + # logger.info(f"len(refs_list): {len(refs_list)}") + # logger.info(f"len(hyps_tok): {len(hyps_tok)}") + + # logger.info(f"refs_list: {refs_list}") + # logger.info(f"hyps_tok: {hyps_tok}") + + sim_loss, _ = self.compute_bertLoss(hyps_tok, refs_list) + + # ppl_results = self.perplexity.compute(data=hyps_tok, model_id='/netscratch/jalota/checkpoints/gpt2-finetuned-motra/', batch_size=len(hyps_tok), add_start_token=True) + hyps_cpu, gen_sizes = [], [] + for h in hyps: + # if h.size(0) <= 512: + hyps_cpu.append(h.cpu()) + gen_sizes.append(msize) + + # hyps = [h.cpu() for h in hyps] + # logger.info(f"len(hyps_cpu): {len(hyps_cpu)}") + # logger.info(f"gen_sizes: {gen_sizes}") + + genData = torchDataset(data_list=hyps_cpu) + # gen_sizes = [msize for _ in range(len(genData))] + gen_data = MonolingualDataset(genData, gen_sizes, src_vocab=tgt_dict, fixed_pad_length=512) + + ppl_results = self.val_lm.get_lm_perplexity(gen_data, batch_size) + + # logger.info(f"ppl: {ppl_results['mean_perplexity']}") + # gpt2-finetuned-motra-de-40epochs/ - DE + # gpt2-finetuned-motra/ - EN + + mean_per_word_entropy = ppl_results['loss'] + # math.log2(ppl_results['mean_perplexity']) + + unsupervised_loss = 1.0 * sim_loss + 1.0 * mean_per_word_entropy + loss += self.unsupervised_weight * unsupervised_loss + logging_output["loss"] = loss + logging_output["sim_loss"] = sim_loss + logging_output["mean_per_word_entropy"] = mean_per_word_entropy + logging_output["lm_ppl"] = ppl_results['perplexity'] + logging_output["unsupervised_loss"] = unsupervised_loss + + else: + # use the hyps to create prev_output_tokens + # a shifted version of hyps for feeding the + # previous output token(s) into the next decoder step + sample = self.prepare_second_pass_input(sample, tgt_dict, src_dict, hyps) + + # logger.info(f"sample: {sample}") + + # logger.info(f"enable gradient") + # second-pass through the decoder in training mode + with torch.enable_grad(): + net_output = model(**sample['net_input']) + lprobs = model.get_normalized_probs(net_output, log_probs=True) + + gsm_samples = gumbel_softmax(lprobs, tau=self.tau_gumbel_softmax, hard=self.hard_gumbel_softmax,eps=self.eps_gumbel_softmax, dim=-1) + + lm_out = self.lm.get_lm_out_from_decoder_inp(gsm_samples) + + lm_loss = self.compute_loss(model, lm_out, gsm_samples, unsup=unsup, reduce=reduce) + + sample_size = gsm_samples.size()[1] + ## take the mean of loss and nll_loss here and convert them from log base e to 2 + lm_loss = lm_loss / math.log(2) + + sim_loss = self.get_similarity_loss(model, gsm_samples, sample, src_dict.pad()) + + unsupervised_loss = 1.0 * sim_loss + 1.0 * lm_loss + + loss += self.unsupervised_weight * unsupervised_loss + logging_output["loss"] = loss.data + logging_output["lm_loss"] = lm_loss.data + # logging_output["lm_nll_loss"] = lm_nll_loss.data + logging_output["unsupervised_loss"] = unsupervised_loss.data + logging_output["sim_loss"] = sim_loss.data + logging_output["unsup_nsentences"] = 1 + + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + + if not sample_size_set: + sample_size= 1 + + return loss, sample_size, logging_output + + + def get_similarity_loss(self, model, preds_tensor, sample, pad_token_id): + + emb_matrix = model.encoder.embed_tokens.weight + + # get bert embeddings from tensor + batch_size, max_seq_len, vocab_size = preds_tensor.size() + emb_size = emb_matrix.size()[-1] + + with torch.autocast("cuda"): + preds_tensor_embs = torch.mm(preds_tensor.contiguous().view(-1, vocab_size), emb_matrix) + preds_tensor_embs = preds_tensor_embs.view(-1, max_seq_len, emb_size) + + # logger.info(f"preds_tensor_embs: {preds_tensor_embs.dtype}") + + with torch.no_grad(): + source_emb = model.encoder.forward(sample['net_input']['src_tokens']) + preds_enc_emb = model.encoder.forward(preds_tensor_embs) + + source_sent_repr = torch.sum(source_emb['encoder_out'][0], dim=0) + output_sent_repr = torch.sum(preds_enc_emb['encoder_out'][0], dim=0) + target_labels = torch.ones(source_sent_repr.shape[0], dtype=source_sent_repr.dtype).cuda() + #cosineLoss = torch.nn.CosineEmbeddingLoss(reduction='mean') + # cos_sim_loss = cosineLoss(source_sent_repr, output_sent_repr, target_labels) + cosine_out = self.cosine_sim(source_sent_repr, output_sent_repr) + # logger.info(f"cosine_out: {cosine_out}") + # similarity_labels = torch.FloatTensor(np.array([1]*len(source_sent_repr)), dtype=source_sent_repr.dtype).cuda() + # if similarity_labels is not None: + similarity_loss = self.mse_loss(cosine_out, target_labels.view(-1)) + # logger.info(f'cos_sim_loss: {cos_sim_loss}') + # logger.info(f"similarity_loss: {similarity_loss}") + return similarity_loss + + def compute_loss(self, model, net_output, sample, unsup=False, reduce=True): + if not unsup: + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + ) + return loss, nll_loss + else: + lm_out = net_output + decoder_out = sample + lm_probs = model.get_normalized_probs(lm_out, log_probs=True) + if self.ignore_prefix_size > 0: + # lprobs: B x T x C + lm_probs = lm_probs[:, self.ignore_prefix_size :, :].contiguous() + decoder_out = decoder_out[:, self.ignore_prefix_size :].contiguous() + lprobs = lm_probs + target = decoder_out # as per the eqn in the paper Yang et. al. 2019 + return cross_entropy(lm_probs, decoder_out) + + + def prepare_second_pass_input(self, sample, tgt_dict, src_dict, hyps): + prev_output_tokens = collate_tokens( + hyps, + tgt_dict.pad(), + tgt_dict.eos(), + left_pad=False, + move_eos_to_beginning=True, + pad_to_length=512, + pad_to_multiple=1 + ) + # logger.info(f"prev_output_tokens: {prev_output_tokens}") + + # logger.info(f"tgt_dict.eos():{tgt_dict.eos()}") + + src_lengths = sample["net_input"]["src_lengths"] + src_tokens = sample["net_input"]["src_tokens"] + # logger.info(f"src_lengths: {src_lengths}") + + # sort by descending src lengths + src_lengths, sort_order = src_lengths.sort(descending=True) + + sample['id'] = sample['id'].index_select(0, sort_order) + sample["net_input"]["prev_output_tokens"] = prev_output_tokens.index_select(0, sort_order) + sample["net_input"]["src_lengths"] = src_lengths + sample["net_input"]["src_tokens"] = src_tokens.index_select(0, sort_order) + + return sample + + def compute_bertLoss(self, preds_list, refs_list, reduce=True): + # logger.info(f"len(refs_list): {len(refs_list)}") + # logger.info(f"len(preds_list): {len(preds_list)}") + results = self.bertscore.compute(predictions=preds_list, references=refs_list, model_type=self.bertscore_model) + avg_f1 = sum(results['f1'])/len(results['f1']) + bert_loss = 1-avg_f1 + return bert_loss, avg_f1 + + def compute_cosineSimilarityLoss(self, model, sample, hyps, train): + # with torch.no_grad(): + if not train: + with torch.no_grad(): + source_emb = model.encoder.forward(sample['net_input']['src_tokens'].cuda()) + gen_out_emb = model.encoder.forward(hyps) + + source_sent_repr = torch.sum(source_emb['encoder_out'][0], dim=0) + + output_sent_repr = torch.sum(gen_out_emb['encoder_out'][0], dim=0).cuda() + target_labels = torch.ones(source_sent_repr.shape[0], dtype=source_sent_repr.dtype).cuda() + #cosineLoss = torch.nn.CosineEmbeddingLoss(reduction='mean') + # cos_sim_loss = cosineLoss(source_sent_repr, output_sent_repr, target_labels) + cosine_out = self.cosine_sim(source_sent_repr, output_sent_repr) + # similarity_labels = torch.FloatTensor(np.array([1]*len(source_sent_repr)), dtype=source_sent_repr.dtype).cuda() + # if similarity_labels is not None: + similarity_loss = self.mse_loss(cosine_out, target_labels.view(-1)) + # logger.info(f'cos_sim_loss: {cos_sim_loss}') + return similarity_loss #cos_sim_loss + else: + source_emb = model.encoder.forward(sample['net_input']['src_tokens'].cuda()) + gen_out_emb = model.encoder.forward(hyps) + + source_sent_repr = torch.sum(source_emb['encoder_out'][0], dim=0) + # logger.info(f"source_sent_repr: {source_sent_repr}") + + output_sent_repr = torch.sum(gen_out_emb['encoder_out'][0], dim=0).cuda() + + # logger.info(f"output_sent_repr: {output_sent_repr}") + target_labels = torch.ones(source_sent_repr.shape[0], dtype=source_sent_repr.dtype).cuda() + # cosineLoss = torch.nn.CosineEmbeddingLoss(reduction='mean') + # cos_sim_loss = cosineLoss(source_sent_repr, output_sent_repr, target_labels) + cosine_out = self.cosine_sim(source_sent_repr, output_sent_repr) + # similarity_labels = torch.FloatTensor(np.array([1]*len(source_sent_repr)), dtype=source_sent_repr.dtype).cuda() + similarity_loss = self.mse_loss(cosine_out, target_labels.view(-1)) + + return similarity_loss #cos_sim_loss + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + # super().reduce_metrics(logging_outputs) + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 1) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 1) for log in logging_outputs) + sim_loss = sum(log.get("sim_loss", 0) for log in logging_outputs) + mean_per_word_entropy = sum(log.get("mean_per_word_entropy", 0) for log in logging_outputs) + unsupervised_loss = sum(log.get("unsupervised_loss", 0) for log in logging_outputs) + unsup_nsentences = sum(log.get("unsup_nsentences", 1) for log in logging_outputs) + lm_loss = sum(log.get("lm_loss", 0) for log in logging_outputs) + lm_ppl = sum(log.get("lm_ppl", 0) for log in logging_outputs) + # lm_nll_loss = sum(log.get("lm_nll_loss", 0) for log in logging_outputs) + + metrics.log_scalar( + "loss", loss_sum / sample_size, sample_size, round=3 + ) # loss and nll_loss are already in base 2! + metrics.log_scalar( + "nll_loss", nll_loss_sum / ntokens, ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + # metrics.log_derived( + # "lm_ppl", lm_ppl, unsup_nsentences, + # ) + metrics.log_scalar( + "sim_loss", sim_loss, unsup_nsentences, round=3 + ) + metrics.log_scalar( + "lm_loss", lm_loss / sample_size, sample_size, round=3 + ) + # metrics.log_scalar( + # "lm_nll_loss", lm_nll_loss / ntokens, ntokens, round=3 + # ) + # metrics.log_derived( + # "lm_ppl", lambda meters: utils.get_perplexity(meters["lm_nll_loss"].avg) + # ) + metrics.log_scalar( + "mean_per_word_entropy", mean_per_word_entropy, unsup_nsentences, round=3 + ) + metrics.log_scalar( + "unsupervised_loss", unsupervised_loss, unsup_nsentences, round=3 + ) + total = utils.item(sum(log.get("total", 0) for log in logging_outputs)) + if total > 0: + metrics.log_scalar("total", total) + n_correct = utils.item( + sum(log.get("n_correct", 0) for log in logging_outputs) + ) + metrics.log_scalar("n_correct", n_correct) + metrics.log_derived( + "accuracy", + lambda meters: round( + meters["n_correct"].sum * 100.0 / meters["total"].sum, 3 + ) + if meters["total"].sum > 0 + else float("nan"), + ) + + # average_entropy = 0.0 + # refs=sample['net_input']['src_tokens'] + # rows, cols = refs.size() + # # logger.info(f"refs_size: {refs.size()}") + # # logger.info(f"gsm_samples: {gsm_samples.size()}") + # refs_list = [] + # preds_list = [] + + # for i in range(rows): + # ref_sentence = [] + # # pred_sentence = [] + # for j in range(cols): + # ref_word = model.decoder.dictionary.__getitem__(refs[i, j].cpu().detach().numpy()) + # # pred_word = model.decoder.dictionary.__getitem__(gsm_samples[i, j].argmax().cpu().detach().numpy()) + # # prob_entropy = Categorical(gsm_samples[i,j,:]).entropy().cpu().detach().numpy() + + # if refs[i, j] != tgt_dict.pad(): + # # average_entropy += prob_entropy + # if ref_word != '' or '' not in ref_word: + # ref_sentence.append(ref_word) + # # if pred_word != '' or '' not in pred_word or pred_word != '': + # # pred_sentence.append(pred_word) + # refs_list.append(" ".join(ref_sentence).replace("@@ ", "").replace("", "").rstrip()) + # # preds_list.append(" ".join(pred_sentence).replace("@@ ", "").replace("", "").replace("", "").rstrip()) + # # average_entropy = average_entropy / (rows*cols) + + # rows, cols, _ = gsm_samples.size() + # for i in range(rows): + # pred_sentence = [] + # for j in range(cols): + # pred_word = model.decoder.dictionary.__getitem__(gsm_samples[i, j].argmax().cpu().detach().numpy()) + + # if pred_word != '' or '' not in pred_word or pred_word != '': + # pred_sentence.append(pred_word) + # preds_list.append(" ".join(pred_sentence).replace("@@ ", "").replace("", "").replace("", "").rstrip()) + # # average_entropy = average_entropy / (rows*cols) + + # # inds = random.sample(range(len(refs_list)), 2) + # # rr = [refs_list[i] for i in inds] + # # rp = [preds_list[i] for i in inds] + # # logger.info(f"ref_list: {rr}") + # # logger.info(f"pred_list: {rp}") + # # logger.info(f"avg_entropy: {average_entropy}") + + # bert_loss, avg_f1 = self.compute_bertLoss(preds_list, refs_list) + + # ppl_results = self.perplexity.compute(data=preds_list, model_id='/netscratch/jalota/checkpoints/gpt2-finetuned-motra/', batch_size=len(preds_list), add_start_token=True) + + # if not train: + # logger.info(f"sample: {sample}") + # with torch.no_grad(): + # net_output = model(**sample['net_input']) + + # lprobs = model.get_normalized_probs(net_output, log_probs=True) + + # gsm_samples = gumbel_softmax(lprobs, tau=self.tau_gumbel_softmax, hard=self.hard_gumbel_softmax,eps=self.eps_gumbel_softmax, dim=-1) + + # # gen_out = seqeunce_generator.generate( + # # [model], sample, prefix_tokens=None, constraints=None) + # else: + # # if cosine: + # # gen_out = seqeunce_generator.generate( + # # [model], sample, prefix_tokens=None, constraints=None) + # # else: + # net_output = model(**sample['net_input']) + + # lprobs = model.get_normalized_probs(net_output, log_probs=True) + + # gsm_samples = gumbel_softmax(lprobs, tau=self.tau_gumbel_softmax, hard=self.hard_gumbel_softmax,eps=self.eps_gumbel_softmax, dim=-1) + + # logger.info(f"hyps: {hyps.size()}") + # logger.info(f"shape sample['net_input']['src_tokens']: {sample['net_input']['src_tokens'].size()}") + # logger.info(f"hyps[0]: {hyps[0]}") + + # encoder_out = getattr(net_output, "encoder_out") + # logger.info(f"hyps_tok: {hyps_tok}") + # /netscratch/jalota/checkpoints/gpt2-finetuned/ + # /netscratch/jalota/checkpoints/gpt2-finetuned-motra/ + # if cosine: + # cos_sim_loss = self.compute_cosineSimilarityLoss(model, sample, preds_list, train) + # hyps, hyps_tok = [], [] + # for i in range(len(gen_out)): + # s = decode(gen_out[i][0]["tokens"]).strip() + # if len(s) > 0: + # hyps_tok.append(s) + # hyps.append(gen_out[i][0]["tokens"]) + # # [h.clone().detach() for h in hyps] + # hyps = collate_tokens(hyps, src_dict.pad(), src_dict.eos(), left_pad=False, pad_to_length=None,pad_to_bsz=None) + + # cos_sim_loss = self.compute_cosineSimilarityLoss(model, sample, hyps, train) + # if torch.isnan(cos_sim_loss): + # logger.info(f"hyps: {hyps}") + # logger.info(f"sample['net_input']['src_tokens']: {sample['net_input']['src_tokens']}") + # cos_sim_loss = torch.tensor(1e-10) + + # ppl_results = self.perplexity.compute(data=hyps_tok, model_id='/netscratch/jalota/checkpoints/gpt2-finetuned-motra/', batch_size=len(hyps_tok), add_start_token=True) # {'perplexities': [], 'mean_perplexity': float_value } + # mean_per_word_entropy = math.log2(ppl_results['mean_perplexity']) diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 9a19cc3c18..d02733fc32 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -47,7 +47,11 @@ def collate_tokens( ): """Convert a list of 1d tensors into a padded 2d tensor.""" size = max(v.size(0) for v in values) + if size < 512 and pad_to_length is not None: + pad_to_length = size size = size if pad_to_length is None else max(size, pad_to_length) + if size >= 512: + logger.info(f"size!: {size}") if pad_to_multiple != 1 and size % pad_to_multiple != 0: size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) @@ -102,6 +106,7 @@ def load_indexed_dataset( raise e dataset_impl_k = dataset_impl + # print(f'dataset_impl_k: {dataset_impl_k}') if dataset_impl_k is None: dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k) dataset = indexed_dataset.make_dataset( @@ -110,6 +115,7 @@ def load_indexed_dataset( fix_lua_indexing=True, dictionary=dictionary, ) + # print(f"dataset: {dataset}") if dataset is None: break logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k)) @@ -159,6 +165,7 @@ def collect_filtered(function, iterable, filtered): def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False): + # logger.info(f"size_fn: {size_fn}") def compare_leq(a, b): return a <= b if not isinstance(a, tuple) else max(a) <= b @@ -209,6 +216,8 @@ def filter_by_size(indices, dataset, max_positions, raise_exception=False): "Use `FairseqDataset::filter_indices_by_size` instead.", stacklevel=2, ) + # logger.info(f"max_positions: {max_positions}") + # logger.info(f"dataset: {dataset}") if isinstance(max_positions, float) or isinstance(max_positions, int): if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray): ignored = indices[dataset.sizes[indices] > max_positions].tolist() @@ -353,6 +362,7 @@ def batch_by_size( max_sentences, bsz_mult, ) + #logger.info(f"b: {b}") if bsz_mult > 1 and len(b[-1]) % bsz_mult != 0: b = b[:-1] diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index a488265137..c538db9750 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -591,7 +591,7 @@ def __init__(self, iterable, chunk_size, skip_remainder_batch=False): ) else: total_num_itrs = int(math.ceil(len(iterable) / float(chunk_size))) - logger.info(f"grouped total_num_itrs = {total_num_itrs}") + # logger.info(f"grouped total_num_itrs = {total_num_itrs}") itr = _chunk_iterator(iterable, chunk_size, skip_remainder_batch) super().__init__( diff --git a/fairseq/data/lm_context_window_dataset.py b/fairseq/data/lm_context_window_dataset.py index 1a945927cf..91fb10d634 100644 --- a/fairseq/data/lm_context_window_dataset.py +++ b/fairseq/data/lm_context_window_dataset.py @@ -67,7 +67,15 @@ def collater(self, samples) -> Dict: if extra > 0: self.prev_tokens = self.prev_tokens[extra:] pads = np.full(self.context_window - len(self.prev_tokens), pad) + # if toks[i].get_device() != -1: + # toks[i] = toks[i].cpu().data # move the tensor to cpu + # print(f"self.prev_tokens: {self.prev_tokens}") + # print(f"toks[i]: {toks[i]}") new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads]) + # print(f"new_toks[i]: {new_toks[i]}") + # print(f"tgt[i]: {tgt[i]}") + tgt[i] = tgt[i].cpu().data #.numpy() + # print(f"tgt[i]: {tgt[i]}") new_tgt[ i, len(self.prev_tokens) : len(self.prev_tokens) + len(tgt[i]) ] = tgt[i] diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index 54fd583b64..82015aa023 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -6,9 +6,16 @@ import numpy as np import torch +from typing import Callable, Dict, List from . import FairseqDataset, data_utils +def uniform_sampler(x, k=50000): + # Sample from uniform distribution + if len(x) < k: + return x + return np.random.choice(x, k,replace=False).item() + def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None): if len(samples) == 0: return {} @@ -83,6 +90,8 @@ def __init__( pad_to_bsz=None, src_lang_idx=None, tgt_lang_idx=None, + perform_sampling=False, + num_samples=1000, ): self.dataset = dataset self.sizes = np.array(sizes) @@ -95,6 +104,11 @@ def __init__( self.pad_to_bsz = pad_to_bsz self.src_lang_idx = src_lang_idx self.tgt_lang_idx = tgt_lang_idx + # if sampling_func is None: + # sampling_func = uniform_sampler + self.sampling_func = uniform_sampler + self.perform_sampling = perform_sampling + self.num_samples = num_samples # num samples to sample if perform_sampling is True! assert targets is None or all( t in {"self", "future", "past"} for t in targets @@ -217,14 +231,27 @@ def collater(self, samples): target sentence of shape `(bsz, tgt_len)`. Padding will appear on the right. """ + if self.perform_sampling: + selected_samples = self.sampling_func(samples, self.num_samples) + else: + selected_samples = samples + return collate( - samples, + selected_samples, self.vocab.pad(), self.vocab.eos(), self.fixed_pad_length, self.pad_to_bsz, ) + # return collate( + # samples, + # self.vocab.pad(), + # self.vocab.eos(), + # self.fixed_pad_length, + # self.pad_to_bsz, + # ) + def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" @@ -251,3 +278,23 @@ def supports_prefetch(self): def prefetch(self, indices): self.dataset.prefetch(indices) + + def filter_indices_by_size(self, indices, max_sizes): + """Filter a list of sample indices. Remove those that are longer + than specified in max_sizes. + + Args: + indices (np.array): original array of sample indices + max_sizes (int or list[int] or tuple[int]): max sample size, + can be defined separately for src and tgt (then list or tuple) + + Returns: + np.array: filtered sample array + list: list of removed indices + """ + if not isinstance(max_sizes, int): + max_sizes = max_sizes[0] + + return data_utils.filter_by_size( + indices=indices, dataset=self.dataset, max_positions=max_sizes + ), [] diff --git a/fairseq/data/round_robin_zip_datasets.py b/fairseq/data/round_robin_zip_datasets.py index 2cb7447ea9..a3f706fd82 100644 --- a/fairseq/data/round_robin_zip_datasets.py +++ b/fairseq/data/round_robin_zip_datasets.py @@ -5,14 +5,16 @@ import logging from collections import OrderedDict -from typing import Dict, Sequence - +from typing import Dict, Sequence, Callable, Dict, List import numpy as np -from . import FairseqDataset, LanguagePairDataset +from . import FairseqDataset, LanguagePairDataset, MonolingualDataset logger = logging.getLogger(__name__) +def uniform_sampler(x): + # Sample from uniform distribution + return np.random.choice(x, 1).item() class RoundRobinZipDatasets(FairseqDataset): """Zip multiple :class:`~fairseq.data.FairseqDataset` instances together. @@ -27,7 +29,8 @@ class RoundRobinZipDatasets(FairseqDataset): this instance to pass-through batches from *datasets[eval_key]*. """ - def __init__(self, datasets, eval_key=None): + def __init__(self, datasets, eval_key=None, + sampling_func: Callable[[List], int] = None,): super().__init__() if isinstance(datasets, dict): datasets = OrderedDict(datasets) @@ -38,6 +41,9 @@ def __init__(self, datasets, eval_key=None): self.datasets = datasets self.eval_key = eval_key + if sampling_func is None: + sampling_func = uniform_sampler + self.sampling_func = sampling_func self.longest_dataset_key = max(datasets, key=lambda k: len(datasets[k])) self.longest_dataset = datasets[self.longest_dataset_key] @@ -72,6 +78,9 @@ def collater(self, samples): if len(samples) == 0: return None if self.eval_key is None: + # selected_key = self.sampling_func(list(self.datasets.keys())) + # selected_samples = [sample[selected_key] for sample in samples] + return OrderedDict( [ (key, dataset.collater([sample[key] for sample in samples])) @@ -121,6 +130,8 @@ def filter_indices_by_size(self, indices, max_positions=None): def _deep_until_language_pair(dataset): if isinstance(dataset, LanguagePairDataset): return dataset + if isinstance(dataset, MonolingualDataset): + return dataset if hasattr(dataset, "tgt_dataset"): return _deep_until_language_pair(dataset.tgt_dataset) if hasattr(dataset, "dataset"): diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index edd3485fd7..289bd76c30 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -541,7 +541,7 @@ class DatasetConfig(FairseqDataclass): }, ) batch_size_valid: Optional[int] = field( - default=II("dataset.batch_size"), + default=500, # II("dataset.batch_size") metadata={ "help": "batch size of the validation batch (defaults to --batch-size)", "argparse_alias": "--max-sentences-valid", @@ -644,6 +644,16 @@ class ComparableConfig(FairseqDataclass): comparable: bool = field( default=False, metadata={"help": 'Use comparable data during training.'} ) + temp: str = field( + default="exp0", + metadata={"help": """temporary folder in /netscratch/jalota/pickle/ where internal representations will be saved"""} + ) + only_unsupervised: bool = field( + default=False, metadata={"help": 'runs only unsupervised training and and validation without self-supervision'} + ) + max_sentences: int = field( + default=1200, metadata={"help": 'Number of sentences in a batch'} + ) sim_measure: str = field( default="margin", metadata={"help": """Similarity measure to be used for extrtacting @@ -653,12 +663,16 @@ class ComparableConfig(FairseqDataclass): default=float('-inf'), metadata={"help": "Decision threshold for keeping a similar pair."} ) + use_threshold: bool = field( + default=False, + metadata={"help": """When passed, threshold is used along with the dual representation filtering"""} + ) threshold_dynamics: str = field( default="static", metadata={"help": """Set threshold dynamics. Options: [static|grow|decay]"""} ) comp_example_limit: int = field( - default=float('inf'), + default=500000000, metadata={"help": """Limit number of training samples from comparable data."""} ) @@ -668,7 +682,7 @@ class ComparableConfig(FairseqDataclass): comparable training directly."""} ) comp_log: str = field( - default=None, + default='', metadata={ "help":"""Path where comparable rejected/accepted pairs will be logged.""" } @@ -686,7 +700,7 @@ class ComparableConfig(FairseqDataclass): } ) comparable_data: str = field( - default=None, + default='', metadata={ "help":"""Path to comparable data list.""" } @@ -747,24 +761,24 @@ class ComparableConfig(FairseqDataclass): } ) test_data: str = field( - default=None, + default='', metadata={ "help":"""Test data that should be excluded from training.""" } ) - use_bt: str = field( + use_bt: bool = field( default=False, metadata={ "help":"""Apply backtranslation to non-match sentences.""" } ) trans_opts: str = field( - default=None, + default='', metadata={ "help":"""Translator options comfiguration file.""" } ) - filter_bt: str = field( + filter_bt: bool = field( default=False, metadata={ "help":"""Filter backtranslations using SSNMT.""" @@ -777,7 +791,7 @@ class ComparableConfig(FairseqDataclass): } ) mono: str = field( - default=None, + default='', metadata={ "help":"""Path to list of monolingual corpora.""" } @@ -795,7 +809,7 @@ class ComparableConfig(FairseqDataclass): } ) vocab_list: str = field( - default=None, + default='', metadata={ "help":"""Path to list of vocabulary files used for substitution.""" } @@ -837,7 +851,7 @@ class ComparableConfig(FairseqDataclass): } ) retrieval: str = field( - default='max', + default='intersect', metadata={ "help":"Retrieval strategy ['fwd', 'bwd', 'max', 'intersect']" } @@ -1359,6 +1373,7 @@ class FairseqConfig(FairseqDataclass): generation: GenerationConfig = GenerationConfig() eval_lm: EvalLMConfig = EvalLMConfig() interactive: InteractiveConfig = InteractiveConfig() + comparable: ComparableConfig = ComparableConfig() model: Any = MISSING task: Any = None criterion: Any = None @@ -1367,5 +1382,4 @@ class FairseqConfig(FairseqDataclass): scoring: Any = None bpe: Any = None tokenizer: Any = None - ema: EMAConfig = EMAConfig() - comparable: ComparableConfig = ComparableConfig() + ema: EMAConfig = EMAConfig() \ No newline at end of file diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index f6467d5f40..8c5dec9202 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -156,6 +156,8 @@ def get_kwargs_from_dc( for k in dataclass_instance._get_all_attributes(): field_name = argparse_name(dataclass_instance._get_name(k)) field_type = dataclass_instance._get_type(k) + # print(f"field_name: {field_name}") + # print(f"field_type: {field_type}") if field_name is None: continue elif inspect.isclass(field_type) and issubclass(field_type, FairseqDataclass): diff --git a/fairseq/lm_perplexity/__init__.py b/fairseq/lm_perplexity/__init__.py new file mode 100644 index 0000000000..79ab08f1a4 --- /dev/null +++ b/fairseq/lm_perplexity/__init__.py @@ -0,0 +1,2 @@ +from .lm import * +from .compute_nll_loss import * \ No newline at end of file diff --git a/fairseq/lm_perplexity/compute_nll_loss.py b/fairseq/lm_perplexity/compute_nll_loss.py new file mode 100644 index 0000000000..6deced036c --- /dev/null +++ b/fairseq/lm_perplexity/compute_nll_loss.py @@ -0,0 +1,360 @@ +""" +Measures the perplexity of sentences using a trained language model. +""" +import logging +import math +import os +import sys +from argparse import Namespace +from typing import Iterable, List, Optional + +import torch +from omegaconf import DictConfig +from fairseq.data import LMContextWindowDataset, MonolingualDataset +import fairseq +from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter +from fairseq.sequence_scorer import SequenceScorer +from fairseq.models.transformer import TransformerModel +from fairseq.data.data_utils import batch_by_size,filter_by_size +from fairseq.data.iterators import EpochBatchIterator +from fairseq import utils + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("lm_perplexity.compute_nll_loss") + +def eval_lm( + model: fairseq.models.FairseqModel, + source_dictionary: fairseq.data.Dictionary, + batch_iterator: Iterable, + post_process: Optional[str] = None, + output_word_probs: bool = False, + output_word_stats: bool = False, + target_dictionary: Optional[fairseq.data.Dictionary] = None, + softmax_batch: int = 0, + remove_bos_token: bool = False, + device: Optional[torch.device] = None, +): + """ + Args: + models (List[~fairseq.models.FairseqModel]): list of models to + evaluate. Models are essentially `nn.Module` instances, but + must be compatible with fairseq's `SequenceScorer`. + source_dictionary (~fairseq.data.Dictionary): dictionary for + applying any relevant post processing or outputing word + probs/stats. + batch_iterator (Iterable): yield batches of data + post_process (Optional[str]): post-process text by removing BPE, + letter segmentation, etc. Valid options can be found in + fairseq.data.utils.post_process, although not all options + are implemented here. + output_word_probs (Optional[bool]): output words and their + predicted log probabilities + output_word_stats (Optional[bool]): output word statistics such + as word count and average probability + target_dictionary (Optional[~fairseq.data.Dictionary]): output + dictionary (defaults to *source_dictionary*) + softmax_batch (Optional[bool]): if BxT is more than this, will + batch the softmax over vocab to this amount of tokens, in + order to fit into GPU memory + remove_bos_token (Optional[bool]): if True, confirm that the + first token is the beginning-of-sentence symbol (according + to the relevant dictionary) and remove it from the output + device (Optional[torch.device]): device to use for evaluation + (defaults to device of first model parameter) + """ + if target_dictionary is None: + target_dictionary = source_dictionary + if device is None: + device = next(model.parameters()).device #next(models[0].parameters()).device + + gen_timer = StopwatchMeter() + scorer = SequenceScorer(target_dictionary, softmax_batch) + + score_sum = 0.0 + count = 0 + + if post_process is not None: + if post_process in {"subword_nmt", "@@ "}: + bpe_cont = post_process.rstrip() + bpe_toks = { + i + for i in range(len(source_dictionary)) + if source_dictionary[i].endswith(bpe_cont) + } + else: + raise NotImplementedError( + f"--post-process={post_process} is not implemented" + ) + bpe_len = len(bpe_cont) + else: + bpe_toks = None + bpe_len = 0 + + word_stats = dict() + models = [model] + + for sample in batch_iterator: + if "net_input" not in sample: + continue + + sample = utils.move_to_cuda(sample, device=device) + + gen_timer.start() + hypos = scorer.generate(models, sample) + # logger.info(f"hypos: {hypos}") + gen_timer.stop(sample["ntokens"]) + + for i, hypos_i in enumerate(hypos): + # logger.info(f"hypos_i: {hypos_i}") + hypo = hypos_i[0] + sample_id = sample["id"][i] + # logger.info(f"hypo: {hypo}") + tokens = hypo["tokens"] + tgt_len = tokens.numel() + pos_scores = hypo["positional_scores"].float() + # logger.info(f"target_dictionary.bos(): {target_dictionary.bos()}") + # logger.info(f"remove_bos_token: {remove_bos_token}") + + if torch.any(hypo["positional_scores"].isnan()): + continue + + if remove_bos_token: + assert hypo["tokens"][0].item() == target_dictionary.bos() + tokens = tokens[1:] + pos_scores = pos_scores[1:] + + skipped_toks = 0 + if bpe_toks is not None: + for i in range(tgt_len - 1): + if tokens[i].item() in bpe_toks: + skipped_toks += 1 + pos_scores[i + 1] += pos_scores[i] + pos_scores[i] = 0 + + inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf")) + if inf_scores.any(): + # logger.info( + # "skipping tokens with inf scores:", + # target_dictionary.string(tokens[inf_scores.nonzero()]), + # ) + pos_scores = pos_scores[(~inf_scores).nonzero()] + score_sum += pos_scores.sum().cpu() + count += pos_scores.numel() - skipped_toks + + if output_word_probs or output_word_stats: + w = "" + word_prob = [] + is_bpe = False + for i in range(len(tokens)): + w_ind = tokens[i].item() + w += source_dictionary[w_ind] + if bpe_toks is not None and w_ind in bpe_toks: + w = w[:-bpe_len] + is_bpe = True + else: + word_prob.append((w, pos_scores[i].item())) + + next_prob = None + ind = i + 1 + while ind < len(tokens): + if pos_scores[ind].item() != 0: + next_prob = pos_scores[ind] + break + ind += 1 + + word_stats.setdefault(w, WordStat(w, is_bpe)).add( + pos_scores[i].item(), next_prob + ) + is_bpe = False + w = "" + if output_word_probs: + logger.info( + str(int(sample_id)) + + " " + + ( + "\t".join( + "{} [{:2f}]".format(x[0], x[1]) for x in word_prob + ) + ) + ) + + avg_nll_loss = ( + -score_sum / count / math.log(2) if count > 0 else 0 + ) # convert to base 2 + # logger.info( + # "Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format( + # gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0 + # ) + # ) + + if output_word_stats: + for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): + logger.info(ws) + + return { + "loss": avg_nll_loss, + "perplexity": 2**avg_nll_loss, + } + + +class WordStat(object): + def __init__(self, word, is_bpe): + self.word = word + self.is_bpe = is_bpe + self.log_prob = 0 + self.next_word_prob = 0 + self.count = 0 + self.missing_next_words = 0 + + def add(self, log_prob, next_word_prob): + """increments counters for the sum of log probs of current word and next + word (given context ending at current word). Since the next word might be at the end of the example, + or it might be not counted because it is not an ending subword unit, + also keeps track of how many of those we have seen""" + if next_word_prob is not None: + self.next_word_prob += next_word_prob + else: + self.missing_next_words += 1 + self.log_prob += log_prob + self.count += 1 + + def __str__(self): + return "{}\t{}\t{}\t{}\t{}\t{}".format( + self.word, + self.count, + self.log_prob, + self.is_bpe, + self.next_word_prob, + self.count - self.missing_next_words, + ) + +class LanguageModelValidation: + def __init__( + self, + path='/netscratch/jalota/checkpoints/transformer_lm_en_finetuned/', + checkpoint_file='checkpoint_best.pt', + data_name_or_path='/netscratch/jalota/datasets/motra-sst/ppd_w_europarl-motra-10k_no_dups/en_es_de/unsup_setup/lm_finetune/', + device=None, + tgt_dict=None, + context_window=5, + tokens_per_sample=512 + + ): + self.context_window = context_window + self.tokens_per_sample = tokens_per_sample + self.tgt_dict = tgt_dict + if context_window > 0: + # reduce tokens per sample by the required context window size + tokens_per_sample -= context_window + + # Load ensemble + obj = TransformerModel.from_pretrained( + path, + checkpoint_file, + data_name_or_path, + ) + self.model = obj.models[0] + + self.model.half() ## use fp16 + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + else: + device = device + self.model.to(device) + + # logger.info( + # "num. model params: {:,}".format(sum(p.numel() for p in self.model.parameters())) + # ) + + def eval_lm_dataloader(self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 0, + data_buffer_size: int = 10, + # ensures that every evaluated token has access to a context of at least + # this size, if possible + context_window: int = 0, + ): + + # logger.info(f"len(dataset): {len(dataset)}") + + if context_window > 0: + dataset = LMContextWindowDataset( + dataset=dataset, + tokens_per_sample=self.tokens_per_sample, + context_window=self.context_window, + pad_idx=self.tgt_dict.pad(), + ) + + # logger.info(f"len(LMdataset): {len(dataset)}") + + indices = dataset.ordered_indices() + # logger.info(f"indices: {indices}") + + indices = filter_by_size(indices, dataset, max_positions=512, raise_exception=False) + + batch_sampler = batch_by_size(indices, dataset.num_tokens, max_sentences=80, required_batch_size_multiple=8) + + itrs = EpochBatchIterator(dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=23, epoch=0, num_workers=0) + return itrs.next_epoch_itr(shuffle=False) + + # return self.get_batch_iterator( + # dataset=dataset, + # max_tokens=max_tokens, + # max_sentences=batch_size, + # max_positions=max_positions, + # ignore_invalid_inputs=True, + # num_shards=num_shards, + # shard_id=shard_id, + # num_workers=num_workers, + # data_buffer_size=data_buffer_size, + # ).next_epoch_itr(shuffle=False) + + def get_lm_perplexity(self, dataset, batch_size): + + dataset = dataset + batch_size = batch_size + # Load dataset splits + + itr = self.eval_lm_dataloader( + dataset=dataset, + max_tokens=36000, + batch_size=batch_size, + max_positions=utils.resolve_max_positions(self.model.max_positions() + ), + context_window=self.context_window, + ) + # *[model.max_positions() for model in models] + + itr = progress_bar.progress_bar( + itr, log_format='json', + log_interval=100, + default_log_format='tqdm', + ) + + results = eval_lm( + model=self.model, + source_dictionary=self.tgt_dict, + batch_iterator=itr, + target_dictionary=self.tgt_dict, + ) + + # logger.info( + # "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( + # results["loss"], results["perplexity"] + # ) + # ) + + return results \ No newline at end of file diff --git a/fairseq/lm_perplexity/lm.py b/fairseq/lm_perplexity/lm.py new file mode 100644 index 0000000000..10cba24701 --- /dev/null +++ b/fairseq/lm_perplexity/lm.py @@ -0,0 +1,162 @@ +from fairseq.models.transformer import TransformerModel +from fairseq.models.transformer_lm import TransformerLanguageModel +from fairseq.tasks.language_modeling import LanguageModelingTask +import torch +import logging +import os, sys +import torch.nn as nn +from fairseq import tasks, checkpoint_utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from argparse import Namespace + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq.lm_perplexity.lm") + +# task_args= Namespace(task="language_modeling", data="/netscratch/jalota/datasets/motra-sst/ppd_w_europarl-motra-10k_no_dups/en_es_de/unsup_setup/lm_finetune/", tokens_per_sample=512, output_dictionary_size= -1, dataset_impl='mmap', future_target=False, self_target=False, past_target=False, ) + +# model_args = Namespace(tokens_per_sample=512, data="/netscratch/jalota/datasets/motra-sst/ppd_w_europarl-motra-10k_no_dups/en_es_de/unsup_setup/lm_finetune/", arch="transformer_lm", activation_fn='relu', dropout=0.1, attention_dropout= 0.0, activation_dropout= 0.0, relu_dropout= 0.0, decoder_embed_dim= 512, decoder_output_dim= 512, decoder_input_dim= 512, decoder_ffn_embed_dim= 2048, decoder_layers= 6, decoder_attention_heads= 8, decoder_normalize_before= False, no_decoder_final_norm= False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0.0, adaptive_softmax_factor=4.0, no_token_positional_embeddings=False, share_decoder_input_output_embed=True, character_embeddings= False, character_filters='[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]', character_embedding_dim=4, char_embedder_highway_layers=2, adaptive_input= False, adaptive_input_factor= 4.0, adaptive_input_cutoff= None, tie_adaptive_weights=False, tie_adaptive_proj= False, decoder_learned_pos= False, layernorm_embedding= False, no_scale_embedding= False, checkpoint_activations= False, offload_activations= False, decoder_layerdrop= 0.0, decoder_layers_to_keep= None, quant_noise_pq= 0.0, quant_noise_pq_block_size=8, quant_noise_scalar= 0.0, min_params_to_wrap=100000000, base_layers= 0, base_sublayers=1, base_shuffle=1, scale_fc=False, scale_attn=False, scale_heads=False, scale_resids= False, decoder_xformers_att_config=None, add_bos_token= False, max_target_positions= None, tpu= False) +# /netscratch/jalota/datasets/motra-sst/ppd_w_europarl-motra-10k_no_dups/en_es_de/unsup_setup/lm_finetune/ +# /netscratch/jalota/datasets/motra-sst/de/unsup_setup_raw/lm_finetuning/ +class LanguageModel: + """ + Transformer LanguageModel to compute perplexity or cross entropy. + """ + def __init__( + self, + path=None, + checkpoint_file='checkpoint_best.pt', + data_name_or_path='/netscratch/jalota/datasets/motra-sst/ppd_w_europarl-motra-10k_no_dups/en_es_de/unsup_setup/lm_finetune/', + device = None, + tgt_dict=None + ): + # /netscratch/jalota/datasets/motra-sst/de/unsup_setup_raw/lm_finetuning/ + obj = TransformerModel.from_pretrained( + path, + checkpoint_file, + data_name_or_path, + ) + self._model = obj.models[0] + # logger.info("self._model.type: {self._model.__class__.__name__}") + # print(self._model) + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device + + self._emb_matrix = None + # we want to keep the weights of the pretrained model frozen + for name, param in self._model.named_parameters(): + # logger.info(f"name: {name}") + if 'embed_tokens.weight' in name: + emb_matrix = param + self._emb_matrix = emb_matrix + param.requires_grad = False + + old_num_tokens, old_embedding_dim = self._emb_matrix.size() + + # build new embeddings: https://huggingface.co/transformers/v2.11.0/_modules/transformers/modeling_utils.html#PreTrainedModel.resize_token_embeddings + new_emb_matrix = torch.nn.Embedding(len(tgt_dict), old_embedding_dim).requires_grad_(False) + + # new_emb_matrix = init_weights(new_emb_matrix) + nn.init.normal_(new_emb_matrix.weight, mean=0, std=old_embedding_dim**-0.5) + + # print(f"new_emb_matrix.weight: {new_emb_matrix.weight}") + # print(f"self._emb_matrix.weight: {self._emb_matrix}") + + # Copy token embeddings from the previous weights + num_tokens_to_copy = min(old_num_tokens, len(tgt_dict)) + # logger.info(f"num_tokens_to_copy: {num_tokens_to_copy}") + new_emb_matrix.weight[:num_tokens_to_copy, :] = self._emb_matrix[:num_tokens_to_copy, :] + + self._emb_matrix = new_emb_matrix.weight + + self._model.decoder.embed_tokens.weight = new_emb_matrix.weight + self._model.decoder.output_projection.weight = new_emb_matrix.weight + + # logger.info(f"self._model.decoder.embed_tokens.weight: {self._model.decoder.embed_tokens.weight.size()}") + + # logger.info(f"self._emb_matrix.size(): {self._emb_matrix.size()}") + + self._model.to(self.device) + self._emb_matrix.to(self.device) + + + def get_lm_out_from_decoder_inp(self, preds, batch_size=64, verbose=True): + """ + Args: + - :param: `preds` (torch.tensor BxTxE): predicted logits + + Return: + Return: + - :param: 'lm_out' (torch.tensor BxTxE): output of LM when fed predicted logits from the decoder + """ + + return self.get_lm_output( + preds, + verbose=verbose, + device=self.device, + batch_size=batch_size, + ) + + def get_lm_output(self, + preds, verbose=False, device="cuda:0", batch_size=64): + """ + Args: + - :param: `preds` (torch.tensor BxTxE): predicted logits + """ + return self.get_lm_out_from_tensor( + preds, device=device) + + + def get_lm_out_from_tensor(self, + preds_tensor, + device="cuda:0"): + """ + Compute LM embedding in batches. + + Args: + - :param: `preds_tensor` (torch.tensor) : preds tensor. + - :param: `device` (str): device to use, e.g. 'cpu' or 'cuda' + """ + batch_size, max_seq_len, vocab_size = preds_tensor.size() + + emb_size = self._emb_matrix.size()[-1] + + preds_tensor = preds_tensor.to(device) + self._emb_matrix = self._emb_matrix.to(device) + + preds_tensor_embs = torch.mm(preds_tensor.contiguous().view(-1, vocab_size), self._emb_matrix) + # logger.info("preds_tensor_embs.size(): {preds_tensor_embs.size()}") + preds_tensor_embs = preds_tensor_embs.view(-1, max_seq_len, emb_size) + + # logger.info(f"model.__class__.__name__: {self._model.__class__.__name__}") + + lm_out = self._model(preds_tensor_embs) + + # logger.info(f"lm_out: {lm_out}") + return lm_out + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 11cf6ee530..64158c97f4 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -58,6 +58,8 @@ def build_model(cfg: FairseqDataclass, task, from_checkpoint=False): model = None model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None) + print(cfg) + if not model_type and len(cfg) == 1: # this is hit if config object is nested in directory that is named after model type diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 65ead9dcf2..d4d25b1d4f 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -79,7 +79,7 @@ def get_normalized_probs_scriptable( log_probs: bool, sample: Optional[Dict[str, Tensor]] = None, ): - """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" + """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" if hasattr(self, "decoder"): return self.decoder.get_normalized_probs(net_output, log_probs, sample) elif torch.is_tensor(net_output): diff --git a/fairseq/models/transformer/transformer_decoder.py b/fairseq/models/transformer/transformer_decoder.py index c22e5625d4..3f855fcd56 100644 --- a/fairseq/models/transformer/transformer_decoder.py +++ b/fairseq/models/transformer/transformer_decoder.py @@ -279,7 +279,9 @@ def extract_features_scriptable( - the decoder's features of shape `(batch, tgt_len, embed_dim)` - a dictionary with any model-specific outputs """ - bs, slen = prev_output_tokens.size() + # print(f"prev_output_tokens.size(): {prev_output_tokens.size()}") + if len(prev_output_tokens.size()) == 2: + bs, slen = prev_output_tokens.size() if alignment_layer is None: alignment_layer = self.num_layers - 1 @@ -290,22 +292,28 @@ def extract_features_scriptable( if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: padding_mask = encoder_out["encoder_padding_mask"][0] - # embed positions positions = None - if self.embed_positions is not None: - positions = self.embed_positions( - prev_output_tokens, incremental_state=incremental_state - ) - if incremental_state is not None: - prev_output_tokens = prev_output_tokens[:, -1:] - if positions is not None: - positions = positions[:, -1:] + if len(prev_output_tokens.size()) == 2: + + # embed positions + if self.embed_positions is not None: + positions = self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) - # Prevent torchscript exporting issue for dynamic quant embedding - prev_output_tokens = prev_output_tokens.contiguous() - # embed tokens and positions - x = self.embed_scale * self.embed_tokens(prev_output_tokens) + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # Prevent torchscript exporting issue for dynamic quant embedding + prev_output_tokens = prev_output_tokens.contiguous() + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + else: + x = self.embed_scale * prev_output_tokens if self.quant_noise is not None: x = self.quant_noise(x) diff --git a/fairseq/models/transformer/transformer_encoder.py b/fairseq/models/transformer/transformer_encoder.py index 17369ec8dc..08ebf57158 100644 --- a/fairseq/models/transformer/transformer_encoder.py +++ b/fairseq/models/transformer/transformer_encoder.py @@ -61,6 +61,7 @@ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False): self.max_source_positions = cfg.max_source_positions self.embed_tokens = embed_tokens + # print(f"embed_tokens: {embed_tokens}") self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) @@ -120,18 +121,26 @@ def forward_embedding( self, src_tokens, token_embedding: Optional[torch.Tensor] = None ): # embed tokens and positions + # print("inside forward embedding") if token_embedding is None: + # print(f"src_tokens: {src_tokens}") + # print(f"src_tokens.size(): {src_tokens.size()}") + # print(f"self.embed_tokens(src_tokens): {self.embed_tokens(src_tokens)}") token_embedding = self.embed_tokens(src_tokens) x = embed = self.embed_scale * token_embedding + # print(f"type(x): {type(x)}") + # print(f"self.embed_positions: {self.embed_positions}") + if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) + # print("added embed_positions") if self.layernorm_embedding is not None: x = self.layernorm_embedding(x) x = self.dropout_module(x) if self.quant_noise is not None: x = self.quant_noise(x) return x, embed - + def forward( self, src_tokens, @@ -201,20 +210,26 @@ def forward_scriptable( Only populated if *return_all_hiddens* is True. """ # compute padding mask - encoder_padding_mask = src_tokens.eq(self.padding_idx) - has_pads = ( - torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any() - ) - # Torchscript doesn't handle bool Tensor correctly, so we need to work around. - if torch.jit.is_scripting(): - has_pads = torch.tensor(1) if has_pads else torch.tensor(0) + has_pads = torch.tensor(0) + if len(src_tokens.size()) == 2: + encoder_padding_mask = src_tokens.eq(self.padding_idx) + has_pads = ( + torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any() + ) + # Torchscript doesn't handle bool Tensor correctly, so we need to work around. + if torch.jit.is_scripting(): + has_pads = torch.tensor(1) if has_pads else torch.tensor(0) - x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) + x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) + + else: + x = src_tokens # BxTxemb_dim # account for padding while computing the representation - x = x * ( - 1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x) - ) + if len(src_tokens.size()) == 2 and has_pads: + x = x * ( + 1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x) + ) # B x T x C -> T x B x C x = x.transpose(0, 1) @@ -249,20 +264,25 @@ def forward_scriptable( # `forward` so we use a dictionary instead. # TorchScript does not support mixed values so the values are all lists. # The empty list is equivalent to None. - src_lengths = ( - src_tokens.ne(self.padding_idx) - .sum(dim=1, dtype=torch.int32) - .reshape(-1, 1) - .contiguous() - ) + + if len(src_tokens.size()) == 2: + src_lengths = ( + src_tokens.ne(self.padding_idx) + .sum(dim=1, dtype=torch.int32) + .reshape(-1, 1) + .contiguous() + ) + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask], # B x T + "encoder_embedding": [encoder_embedding], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "fc_results": fc_results, # List[T x B x C] + "src_tokens": [], + "src_lengths": [src_lengths], + } return { "encoder_out": [x], # T x B x C - "encoder_padding_mask": [encoder_padding_mask], # B x T - "encoder_embedding": [encoder_embedding], # B x T x C - "encoder_states": encoder_states, # List[T x B x C] - "fc_results": fc_results, # List[T x B x C] - "src_tokens": [], - "src_lengths": [src_lengths], } @torch.jit.export diff --git a/fairseq/options.py b/fairseq/options.py index 4cf5d6b859..2fb799c552 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -21,6 +21,7 @@ InteractiveConfig, OptimizationConfig, EMAConfig, + ComparableConfig ) from fairseq.dataclass.utils import gen_parser_from_dataclass @@ -42,6 +43,8 @@ def get_training_parser(default_task="translation"): add_optimization_args(parser) add_checkpoint_args(parser) add_comparable_args(parser) + print("parser") + add_generation_args(parser) add_ema_args(parser) return parser @@ -310,9 +313,12 @@ def add_preprocess_args(parser): return parser -def add_comparable_args(parser, train=True, gen=False): - group = parser.add_argument_group('comparable_data') +def add_comparable_args(parser): + print(f"add comp args") + group = parser.add_argument_group('comparable') gen_parser_from_dataclass(group, ComparableConfig()) + print(group) + return group # group.add_argument('--comparable', action="store_true", default=False, # help='Use comparable data during training.') # group.add_argument('--sim_measure', '-sim_measure', default="margin", @@ -402,7 +408,7 @@ def add_comparable_args(parser, train=True, gen=False): # group.add_argument('--faiss-output', default='/netscratch/jalota/logs', help='faiss alignment output') # group.add_argument('--faiss-use-gpu', default=False, action='store_true', help='whether to store the index and perform search on GPU') # group.add_argument('--index', default='flat', choices=['flat', 'ivf', 'pq'], help="which faiss index to use.") - return group + # return group def add_dataset_args(parser, train=False, gen=False): diff --git a/fairseq/scoring/perplexity.py b/fairseq/scoring/perplexity.py new file mode 100644 index 0000000000..31dc822adf --- /dev/null +++ b/fairseq/scoring/perplexity.py @@ -0,0 +1,193 @@ +# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Perplexity Metric.""" + +import datasets +import numpy as np +import torch +from torch.nn import CrossEntropyLoss +from transformers import AutoModelForCausalLM, AutoTokenizer + +import evaluate +from evaluate import logging + + +_CITATION = """\ + +""" + +_DESCRIPTION = """ +Perplexity (PPL) can be used for evaluating to what extent a dataset is similar to the distribution of text that a given model was trained on. +It is defined as the exponentiated average negative log-likelihood of a sequence. + +For more information, see https://huggingface.co/docs/transformers/perplexity +""" + +_KWARGS_DESCRIPTION = """ +Args: + model_id (str): model used for calculating Perplexity + NOTE: Perplexity can only be calculated for causal language models. + This includes models such as gpt2, causal variations of bert, + causal versions of t5, and more (the full list can be found + in the AutoModelForCausalLM documentation here: + https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM ) + + data (list of str): input data, each separate text snippet + is one list entry. + batch_size (int): the batch size to run texts through the model. Defaults to 16. + add_start_token (bool): whether to add the start token to the texts, + so the perplexity can include the probability of the first word. Defaults to True. + device (str): device to run on, defaults to 'cuda' when available +Returns: + perplexity: dictionary containing the perplexity scores for the texts + in the input list, as well as the mean perplexity. If one of the input texts is + longer than the max input length of the model, then it is truncated to the + max length for the perplexity computation. +Examples: + Example 1: + >>> perplexity = evaluate.load("perplexity", module_type="measurement") + >>> data = ["lorem ipsum", "Happy Birthday!", "Bienvenue"] + >>> results = perplexity.compute(model_id='gpt2', + ... add_start_token=False, + ... data=data) # doctest:+ELLIPSIS + >>> print(list(results.keys())) + ['perplexities', 'mean_perplexity'] + >>> print(round(results["mean_perplexity"], 2)) + 78.22 + >>> print(round(results["perplexities"][0], 2)) + 11.11 + + Example 2: + >>> from datasets import load_dataset + >>> perplexity = evaluate.load("perplexity", module_type="measurement") + >>> data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP + >>> data = [s for s in data if s!=''] + >>> results = perplexity.compute(model_id='gpt2', + ... data=data) + >>> print(list(results.keys())) + ['perplexities', 'mean_perplexity'] + >>> print(round(results["mean_perplexity"], 2)) # doctest: +SKIP + 60.35 + >>> print(round(results["perplexities"][0], 2)) # doctest: +SKIP + 81.12 +""" + + +@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) +class Perplexity(evaluate.EvaluationModule): + def _info(self): + return evaluate.EvaluationModuleInfo( + module_type="measurement", + description=_DESCRIPTION, + citation=_CITATION, + inputs_description=_KWARGS_DESCRIPTION, + features=datasets.Features( + { + "data": datasets.Value("string"), + } + ), + reference_urls=["https://huggingface.co/docs/transformers/perplexity"], + ) + + def _compute(self, data, model_id, batch_size: int = 16, add_start_token: bool = True, device=None): + + if device is not None: + assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu." + if device == "gpu": + device = "cuda" + else: + device = "cuda" if torch.cuda.is_available() else "cpu" + + model = AutoModelForCausalLM.from_pretrained(model_id) + model = model.to(device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = 'left' + + # if batch_size > 1 (which generally leads to padding being required), and + # if there is not an already assigned pad_token, assign an existing + # special token to also be the padding token + if tokenizer.pad_token is None and batch_size > 1: + existing_special_tokens = list(tokenizer.special_tokens_map_extended.values()) + # check that the model already has at least one special token defined + assert ( + len(existing_special_tokens) > 0 + ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1." + # assign one of the special tokens to also be the pad token + tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]}) + + if add_start_token: + # leave room for token to be added: + assert ( + tokenizer.bos_token is not None + ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" + max_tokenized_len = model.config.max_length - 1 + else: + max_tokenized_len = model.config.max_length + + encodings = tokenizer( + data, + add_special_tokens=False, + padding=True, + truncation=True, + max_length=max_tokenized_len, + return_tensors="pt", + return_attention_mask=True, + ).to(device) + + encoded_texts = encodings["input_ids"] + attn_masks = encodings["attention_mask"] + + # check that each input is long enough: + if add_start_token: + assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long." + else: + assert torch.all( + torch.ge(attn_masks.sum(1), 2) + ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings." + + ppls = [] + loss_fct = CrossEntropyLoss(reduction='none') + + # for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)): + for start_index in range(0, len(encoded_texts), batch_size): + end_index = min(start_index + batch_size, len(encoded_texts)) + encoded_batch = encoded_texts[start_index:end_index] + attn_mask = attn_masks[start_index:end_index] + + if add_start_token: + bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device) + encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) + attn_mask = torch.cat( + [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1 + ) + + labels = encoded_batch + + with torch.no_grad(): + out_logits = model(encoded_batch, attention_mask=attn_mask).logits + + shift_logits = out_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + shift_attention_mask_batch = attn_mask[..., 1:].contiguous() + + perplexity_batch = torch.exp2( + (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1) + / shift_attention_mask_batch.sum(1) + ) + + ppls += perplexity_batch.tolist() + + return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)} \ No newline at end of file diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 7481e0f4bf..0c24586f4a 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -165,7 +165,7 @@ def dataset(self, split): return self.datasets[split] def filter_indices_by_size( - self, indices, dataset, max_positions=None, ignore_invalid_inputs=False + self, indices, dataset, max_positions=None, ignore_invalid_inputs=True ): """ Filter examples that are too large @@ -307,8 +307,8 @@ def make_batches(dataset, epoch): reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True) persistent_workers = getattr(self.cfg, "persistent_workers", True) rebuild_batches = getattr(self.cfg, "rebuild_batches", False) - logger.info(f"reuse_dataloader = {reuse_dataloader}") - logger.info(f"rebuild_batches = {rebuild_batches}") + # logger.info(f"reuse_dataloader = {reuse_dataloader}") + # logger.info(f"rebuild_batches = {rebuild_batches}") if rebuild_batches: logger.info("batches will be rebuilt for each epoch") @@ -529,6 +529,10 @@ def train_step( with torch.autograd.profiler.record_function("forward"): with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): loss, sample_size, logging_output = criterion(model, sample) + + # add unsupervised criterion + + if ignore_grad: loss *= 0 with torch.autograd.profiler.record_function("backward"): diff --git a/fairseq/tasks/translation.py b/fairseq/tasks/translation.py index 79752279a6..44ac4b47d4 100644 --- a/fairseq/tasks/translation.py +++ b/fairseq/tasks/translation.py @@ -11,9 +11,10 @@ from typing import Optional from argparse import Namespace from omegaconf import II - +import torch import numpy as np from fairseq import metrics, utils +from fairseq.optim.amp_optimizer import AMPOptimizer from fairseq.data import ( AppendTokenDataset, ConcatDataset, @@ -28,7 +29,8 @@ from fairseq.data.indexed_dataset import get_available_dataset_impl from fairseq.dataclass import ChoiceEnum, FairseqDataclass from fairseq.tasks import FairseqTask, register_task - +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D EVAL_BLEU_ORDER = 4 @@ -143,6 +145,9 @@ def split_exists(split, src, tgt, lang, data_path): tgt_dataset, tgt_dict.index("[{}]".format(tgt)) ) eos = tgt_dict.index("[{}]".format(tgt)) + print(f"tgt_dict: {tgt_dict}") + # src_dict.add_symbol("") + # tgt_dict.add_symbol("") align_dataset = None if load_alignments: @@ -227,7 +232,18 @@ class TranslationConfig(FairseqDataclass): "dataset.dataset_impl" ) required_seq_len_multiple: int = II("dataset.required_seq_len_multiple") - + unsup_gen_args: Optional[str] = field( + default="{}", + metadata={ + "help": 'generation args for Unsupervised Learning, e.g., \'{"beam": 4, "lenpen": 0.6}\', as JSON string' + }, + ) + start_unsup: int = field( + default=400, metadata={"help": "start unsupervised training from xx updates onwards."} + ) + only_unsupervised: bool = field( + default=False, metadata={"help": "whether to perform only unsupervised training"} + ) # options for reporting BLEU during validation eval_bleu: bool = field( default=False, metadata={"help": "evaluation with BLEU scores"} @@ -285,6 +301,8 @@ def __init__(self, cfg: TranslationConfig, src_dict, tgt_dict): super().__init__(cfg) self.src_dict = src_dict self.tgt_dict = tgt_dict + self.start_unsup = cfg.start_unsup + self.only_unsupervised = cfg.only_unsupervised @classmethod def setup_task(cls, cfg: TranslationConfig, **kwargs): @@ -366,7 +384,11 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None) ) def build_model(self, cfg, from_checkpoint=False): - model = super().build_model(cfg, from_checkpoint) + model = super().build_model(cfg.model, from_checkpoint) + # gen_args = json.loads(cfg.generation) + self.sequence_generator = self.build_generator( + [model], cfg.generation + ) if self.cfg.eval_bleu: detok_args = json.loads(self.cfg.eval_bleu_detok_args) self.tokenizer = encoders.build_tokenizer( @@ -379,6 +401,53 @@ def build_model(self, cfg, from_checkpoint=False): ) return model + # def train_step( + # self, sample, model, criterion, optimizer, update_num, ignore_grad=False + # ): + # """ + # Do forward and backward, and return the loss as computed by *criterion* + # for the given *model* and *sample*. + + # Args: + # sample (dict): the mini-batch. The format is defined by the + # :class:`~fairseq.data.FairseqDataset`. + # model (~fairseq.models.BaseFairseqModel): the model + # criterion (~fairseq.criterions.FairseqCriterion): the criterion + # optimizer (~fairseq.optim.FairseqOptimizer): the optimizer + # update_num (int): the current update + # ignore_grad (bool): multiply loss by 0 if this is set to True + + # Returns: + # tuple: + # - the loss + # - the sample size, which is used as the denominator for the + # gradient + # - logging outputs to display while training + # """ + # model.train() + # model.set_num_updates(update_num) + # unsup = False + # # logger.info(f"self.only_unsupervised: {self.only_unsupervised}") + # if update_num >= self.start_unsup: # warm-up updates before unsupervised training starts - half of warm updates set in the config + # unsup = True + # with torch.autograd.profiler.record_function("forward"): + # with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): + # loss, sample_size, logging_output = criterion(model, sample, self.sequence_generator, self.tgt_dict, unsup=unsup, src_dict=self.src_dict, only_unsupervised=self.only_unsupervised) + + # if ignore_grad: + # loss *= 0 + # with torch.autograd.profiler.record_function("backward"): + # optimizer.backward(loss) + # # plot_grad_flow(model.named_parameters().cpu()) + # return loss, sample_size, logging_output + + + # def valid_step(self, sample, model, criterion): + # model.eval() + # with torch.no_grad(): + # loss, sample_size, logging_output = criterion(model, sample, self.sequence_generator, self.tgt_dict, unsup=True, src_dict=self.src_dict, train=False) + # return loss, sample_size, logging_output + def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) if self.cfg.eval_bleu: diff --git a/fairseq/tasks/translation_from_pretrained_bart.py b/fairseq/tasks/translation_from_pretrained_bart.py index 0fd7a5b29f..edfffd7130 100644 --- a/fairseq/tasks/translation_from_pretrained_bart.py +++ b/fairseq/tasks/translation_from_pretrained_bart.py @@ -6,10 +6,15 @@ import torch from fairseq import utils from fairseq.data import LanguagePairDataset +import json +import logging +from fairseq.data import encoders +from argparse import Namespace from . import register_task from .translation import TranslationTask, load_langpair_dataset +logger = logging.getLogger(__name__) @register_task("translation_from_pretrained_bart") class TranslationFromPretrainedBARTTask(TranslationTask): @@ -51,11 +56,15 @@ def add_args(parser): def __init__(self, args, src_dict, tgt_dict): super().__init__(args, src_dict, tgt_dict) - self.langs = args.langs.split(",") + # self.langs = args.langs.split(",") + logger.info(f"len(src_dict): {len(src_dict)}") + logger.info(f"len(tgt_dict): {len(tgt_dict)}") for d in [src_dict, tgt_dict]: - for l in self.langs: - d.add_symbol("[{}]".format(l)) + # for l in self.langs: + # d.add_symbol("[{}]".format(l)) d.add_symbol("") + logger.info(f"len(src_dict): {len(src_dict)}") + logger.info(f"len(tgt_dict): {len(tgt_dict)}") def load_dataset(self, split, epoch=1, combine=False, **kwargs): """Load a given dataset split. @@ -63,12 +72,15 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): Args: split (str): name of the split (e.g., train, valid, test) """ - paths = utils.split_paths(self.args.data) + paths = utils.split_paths(self.cfg.data) assert len(paths) > 0 data_path = paths[(epoch - 1) % len(paths)] # infer langcode - src, tgt = self.args.source_lang, self.args.target_lang + src, tgt = self.cfg.source_lang, self.cfg.target_lang + + logger.info(f"len(self.src_dict): {len(self.src_dict)}") + logger.info(f"len(self.tgt_dict): {len(self.tgt_dict)}") self.datasets[split] = load_langpair_dataset( data_path, @@ -78,47 +90,49 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): tgt, self.tgt_dict, combine=combine, - dataset_impl=self.args.dataset_impl, - upsample_primary=self.args.upsample_primary, - left_pad_source=self.args.left_pad_source, - left_pad_target=self.args.left_pad_target, - max_source_positions=getattr(self.args, "max_source_positions", 1024), - max_target_positions=getattr(self.args, "max_target_positions", 1024), - load_alignments=self.args.load_alignments, - prepend_bos=getattr(self.args, "prepend_bos", False), + dataset_impl=self.cfg.dataset_impl, + upsample_primary=self.cfg.upsample_primary, + left_pad_source=self.cfg.left_pad_source, + left_pad_target=self.cfg.left_pad_target, + max_source_positions=getattr(self.cfg, "max_source_positions", 1024), + max_target_positions=getattr(self.cfg, "max_target_positions", 1024), + load_alignments=self.cfg.load_alignments, + prepend_bos=getattr(self.cfg, "prepend_bos", False), append_source_id=True, ) - def build_generator(self, models, args, **unused): - if getattr(args, "score_reference", False): - from fairseq.sequence_scorer import SequenceScorer - - return SequenceScorer( - self.target_dictionary, - eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), - ) - else: - from fairseq.sequence_generator import SequenceGenerator - - return SequenceGenerator( - models, - self.target_dictionary, - beam_size=getattr(args, "beam", 5), - max_len_a=getattr(args, "max_len_a", 0), - max_len_b=getattr(args, "max_len_b", 200), - min_len=getattr(args, "min_len", 1), - normalize_scores=(not getattr(args, "unnormalized", False)), - len_penalty=getattr(args, "lenpen", 1), - unk_penalty=getattr(args, "unkpen", 0), - temperature=getattr(args, "temperature", 1.0), - match_source_len=getattr(args, "match_source_len", False), - no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), - eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), - ) + # def build_generator(self, models, args, **unused): + # if getattr(cfg, "score_reference", False): + # from fairseq.sequence_scorer import SequenceScorer + + # return SequenceScorer( + # self.target_dictionary, + # eos=self.tgt_dict.index("[{}]".format(self.cfg.target_lang)), + # ) + # else: + # from fairseq.sequence_generator import SequenceGenerator + + # return SequenceGenerator( + # models, + # self.target_dictionary, + # beam_size=getattr(args, "beam", 5), + # max_len_a=getattr(args, "max_len_a", 0), + # max_len_b=getattr(args, "max_len_b", 200), + # min_len=getattr(args, "min_len", 1), + # normalize_scores=(not getattr(args, "unnormalized", False)), + # len_penalty=getattr(args, "lenpen", 1), + # unk_penalty=getattr(args, "unkpen", 0), + # temperature=getattr(args, "temperature", 1.0), + # match_source_len=getattr(args, "match_source_len", False), + # no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + # eos=self.tgt_dict.index("[{}]".format(self.args.target_lang)), + # ) def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): src_lang_id = self.source_dictionary.index("[{}]".format(self.args.source_lang)) source_tokens = [] + logger.info(f"self.source_dictionary: {len(self.source_dictionary)}") + for s_t in src_tokens: s_t = torch.cat([s_t, s_t.new(1).fill_(src_lang_id)]) source_tokens.append(s_t) @@ -130,3 +144,17 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None) constraints=constraints, ) return dataset + + # def build_model(self, cfg, from_checkpoint=False): + # model = super().build_model(cfg, from_checkpoint) + # if self.cfg.eval_bleu: + # detok_args = json.loads(self.cfg.eval_bleu_detok_args) + # self.tokenizer = encoders.build_tokenizer( + # Namespace(tokenizer=self.cfg.eval_bleu_detok, **detok_args) + # ) + + # gen_args = json.loads(self.cfg.eval_bleu_args) + # self.sequence_generator = self.build_generator( + # [model], Namespace(**gen_args) + # ) + # return model diff --git a/fairseq_cli/Comparable4.py b/fairseq_cli/Comparable4.py index ce3609f2e2..ea500c4144 100644 --- a/fairseq_cli/Comparable4.py +++ b/fairseq_cli/Comparable4.py @@ -1,7 +1,6 @@ """ Classes and methods used for training and extraction of parallel pairs from a comparable dataset. -Author: Alabi Jesujoba """ import tracemalloc #import gc @@ -9,14 +8,17 @@ import itertools import random import faiss +import faiss.contrib.torch_utils import numpy as np from collections import defaultdict import torch import time +from pathlib import Path from fairseq.data import ( MonolingualDataset, LanguagePairDataset ) +from tqdm import tqdm from fairseq.data.data_utils import load_indexed_dataset,numpy_seed,batch_by_size,filter_by_size from fairseq.data.iterators import EpochBatchIterator, GroupedIterator from fairseq import ( @@ -25,8 +27,25 @@ from fairseq.logging import meters, metrics, progress_bar from omegaconf import DictConfig, OmegaConf import argparse -import os +import os, sys from typing import Any, Callable, Dict, List, Optional, Tuple +import logging +from fairseq.trainer import Trainer +from fairseq.distributed import utils as distributed_utils +torch.manual_seed(10) +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') +# debug_mode = "error" +# torch.cuda.set_sync_debug_mode(debug_mode) + +# We need to setup root logger before importing any fairseq libraries. +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.comparable") def get_src_len(src, use_gpu, device=""): if use_gpu: @@ -37,19 +56,6 @@ def get_src_len(src, use_gpu, device=""): else: return torch.tensor([src.size(0)]) -# def indexPhraseData(phrases, dictionary, append_eos, reverse_order): -# tokens_list = [] -# sizes = [] -# for line in phrases: -# # self.lines.append(line.strip('\n')) -# tokens = dictionary.encode_line( -# line, add_if_not_exist=False, -# append_eos=append_eos, reverse_order=reverse_order, -# ).long() -# tokens_list.append(tokens) -# sizes.append(len(tokens)) -# return tokens_list, sizes - #this method is to remove spaces added within strings when dict.string is used. #it removed remove spaces between characters and consecutive spaces def removeSpaces(s): @@ -91,61 +97,6 @@ def read_vocabulary(vocab_file, threshold=20): return vocabulary -# class PhraseBank(): -# """ -# Class that saves the sentence pairs from which we want to extract phrases -# Args: -# candidate(tuple(src,tgt,score)) -# args(argparse.Namespace): option object - -# """ - -# def __init__(self, tasks, phrase_length): -# self.tasks = tasks -# self.sourcesent = set() -# self.targetsent = set() -# self.phrase_length = phrase_length -# self.lsrc = [] -# self.ltgt = [] -# self.nlp_src = None -# self.nlp_tgt = None -# '''self.use_gpu = False -# if args.cpu == False: -# self.use_gpu = True -# else: -# self.use_gpu = False -# ''' - -# def add_example(self, src, tgt): -# """ Add an example from a batch to the PairBank (self.pairs). -# Args: -# src(torch.Tensor): src sequence (size(seq)) -# tgt(torch.Tensor): tgt sequence(size(tgt)) -# fields(list(str)): list of keys of fields -# """ -# # Get example from src/tgt and remove original padding -# self.sourcesent.add(str(src)) -# self.targetsent.add(str(tgt)) - -# def getexamples(self): -# return self.sourcesent, self.targetsent - -# def getexampleslen(self): -# return len(self.sourcesent), len(self.targetsent) - -# def remove_from_phrase_candidates(self, seq, side): -# hash_key = hash(str(seq)) -# # print(len(self.bt_candidates)) -# if side == 'src': -# self.lsrc.extend([x for x in self.sourcesent if x[0] == hash_key]) -# self.sourcesent = set([x for x in self.sourcesent if x[0] != hash_key]) -# elif side == 'tgt': -# self.ltgt.extend([x for x in self.targetsent if x[0] == hash_key]) -# self.targetsent = set([x for x in self.targetsent if x[0] != hash_key]) -# # print(len(self.bt_candidates)) -# # print('........') -# return None - def convert2string(self, side): lstString = [] if side == 'src': @@ -177,28 +128,22 @@ def setLang(self, s, t): self.s = s self.t = t - # def extractPhrasesSNL(self, sentences, side='src'): - # if side == 'src': - # #phrases = [list(set(extract_phrase(self.nlp_src.parse(x), 'NP'))) for x in sentences] - # #phrases = [noun_phrases(self.client_src,x,_annotators="tokenize,ssplit,pos,lemma,parse") for x in sentences] - # phrases = [noun_phrases(self.client_src,x) for x in sentences] #,_annotators="tokenize,ssplit,pos,lemma,parse" - # elif side == 'tgt': - # #phrases = [list(set(extract_phrase(self.nlp_tgt.parse(x), 'NP'))) for x in sentences] - # phrases = [noun_phrases(self.client_tgt,x) for x in sentences] #,_annotators="tokenize,ssplit,pos,lemma,parse" - - - # phrases = list(itertools.chain(*phrases)) - # if side == 'src': - # return ["<"+self.t+"> "+self.srcbpe.process_line(item) for item in phrases if len(item.split()) >= self.phrase_length] - # elif side == 'tgt': - # #print("From target", ["<"+self.s+"> "+self.tgtbpe.process_line(item) for item in phrases if len(item.split()) >= self.phrase_length] ) - # return ["<"+self.s+"> "+self.tgtbpe.process_line(item) for item in phrases if len(item.split()) >= self.phrase_length] - def resetData(self): self.sourcesent = set() self.targetsent = set() +# class RejectedBank(): +# """ +# Class that saves and prepares rejected monostylistic SRC sentences and their resulting +# batches. +# Args: +# batch_size(int): number of examples in a batch +# opt(argparse.Namespace): option object +# """ +# pass + + class PairBank(): """ Class that saves and prepares parallel pairs and their resulting @@ -208,15 +153,15 @@ class PairBank(): opt(argparse.Namespace): option object """ - def __init__(self, batcher, args): + def __init__(self, batcher, cfg): self.pairs = [] self.index_memory = set() - self.batch_size = args.max_sentences + self.batch_size = cfg.dataset.batch_size #max_sentences self.batcher = batcher self.use_gpu = False self.mps = False self.cuda = False - if args.cpu == False: + if cfg.common.cpu == False: self.use_gpu = True if torch.backends.mps.is_available(): self.mps = True @@ -225,9 +170,11 @@ def __init__(self, batcher, args): self.cuda = True else: self.use_gpu = False - self.update_freq = args.update_freq + self.update_freq = cfg.optimization.update_freq self.explen = self.batch_size * self.update_freq[-1] + def __len__(self): + return len(self.pairs) def removePadding(side): """ Removes original padding from a sequence. @@ -354,11 +301,12 @@ def __init__(self, dataset, src, tgt, src_length, tgt_length, index): class BatchCreator(): - def __init__(self, task, args): + def __init__(self, task, cfg, trainer): self.task = task - self.args = args + self.cfg = cfg + self.trainer = trainer - def create_batch(self, src_examples, tgt_examples, src_lengths, tgt_lengths, no_target=False): + def create_batch(self, src_examples, tgt_examples, src_lengths, tgt_lengths, no_target=False, shard_batch_itr=False): """ Creates a batch object from previously extracted parallel data. Args: src_examples(list): list of src sequence tensors @@ -376,19 +324,19 @@ def create_batch(self, src_examples, tgt_examples, src_lengths, tgt_lengths, no_ pairData = LanguagePairDataset( src_examples, src_lengths, self.task.src_dict, tgt_examples, tgt_lengths, self.task.tgt_dict, - left_pad_source=self.args.left_pad_source, - left_pad_target=self.args.left_pad_target, - max_source_positions=self.args.max_source_positions, - max_target_positions=self.args.max_target_positions, + left_pad_source=self.cfg.task.left_pad_source, + left_pad_target=self.cfg.task.left_pad_target ) + # max_source_positions=self.cfg.task.max_source_positions, + # max_target_positions=self.cfg.task.max_target_positions, - with numpy_seed(self.args.seed): + with numpy_seed(self.cfg.common.seed): indices = pairData.ordered_indices() - batch_sampler = batch_by_size(indices, pairData.num_tokens, max_sentences=self.args.max_sentences, - required_batch_size_multiple=self.args.required_batch_size_multiple, ) - itrs = EpochBatchIterator(dataset=pairData, collate_fn=pairData.collater, batch_sampler=batch_sampler, - seed=self.args.seed, epoch=0, num_workers=self.args.num_workers) + batch_sampler = batch_by_size(indices, pairData.num_tokens, + max_sentences=self.cfg.comparable.max_sentences, required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, ) + itrs = EpochBatchIterator(dataset=pairData, collate_fn=pairData.collater, + batch_sampler=batch_sampler, seed=self.cfg.common.seed, epoch=0, num_workers=self.cfg.dataset.num_workers,num_shards=self.trainer.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.trainer.data_parallel_rank if shard_batch_itr else 0,) indices = None return itrs @@ -401,6 +349,119 @@ def knn(x, y, k, use_gpu, index='flat'): large query batch, large index: GPU is typically faster ''' return knnGPU(x, y, k, index) if use_gpu else knnCPU(x, y, k, index) + +def knn_v2(x, y, k, use_gpu, index='flat'): + ''' + small query batch, small index: CPU is typically faster + small query batch, large index: GPU is typically faster + large query batch, small index: could go either way + large query batch, large index: GPU is typically faster + string_factory = "L2norm,OPQ16_64,IVF30000_HNSW32,PQ16" # Flat + ''' + return knnGPU_v2(x, y, k, index) if use_gpu else knnCPU(x, y, k, index) + +def knnGPU_v2(x, y, k, index='flat', faiss_verbose=True, batch_size=100000, train_size=1170000, use_float16=True): + ngpus = faiss.get_num_gpus() + dim = 512 #x.shape[1] + string_factory = "OPQ16_64,IVF30000,PQ16" + logger.info(f"index string_factory: {string_factory}") + #"OPQ32,IMI2x8,PQ32" "OPQ32,IVF256,PQ32" + #"PCA64,IVF30000,Flat" + #"IVF200,Flat" # -- works for < 1M indices + #"OPQ16_64,IVF30000,PQ16" # OPQ is a CPU-based vector transform + # IVF30000,PQ16 + idx = faiss.index_factory(dim, string_factory, faiss.METRIC_INNER_PRODUCT) + # https://www.pinecone.io/learn/composite-indexes/ + ivf = faiss.extract_index_ivf(idx) + res = faiss.StandardGpuResources() + res.setDefaultNullStreamAllDevices() + co = faiss.GpuMultipleClonerOptions() + co.shard = True + co.verbose = True + co.indicesOptions = faiss.INDICES_CPU + co.common_ivf_quantizer = False + co.usePrecomputed = False + co.useFloat16 = use_float16 + co.useFloat16CoarseQuantizer = True + if ngpus == 1: + gpu_idx = faiss.index_cpu_to_gpu(res, 0, idx) + else: # multiple gpus + res_list = [res for _ in range(ngpus)] + gpu_idx = faiss.index_cpu_to_gpu_multiple_py(resources=res_list, index=idx, co=co) + # https://github.com/KevinMusgrave/pytorch-metric-learning/issues/491 + # inputs to gpu_idx have to be on cpu and inside the function, they would be moved back to gpu! + + # logger.info("Created faiss index of type {}".format(type(gpu_idx))) + faiss_verbose = None + + # Set verbosity level + if faiss_verbose is not None: + if hasattr(gpu_idx, "index") and gpu_idx.index is not None: + gpu_idx.index.verbose = faiss_verbose + if hasattr(gpu_idx, "quantizer") and gpu_idx.quantizer is not None: + gpu_idx.quantizer.verbose = faiss_verbose + if hasattr(gpu_idx, "clustering_index") and gpu_idx.clustering_index is not None: + gpu_idx.clustering_index.verbose = faiss_verbose + + # Train + # logger.info("training index") + + ys, xs = [], [] + + for i in tqdm(range(0, len(y), batch_size)): + yb = torch.stack(list(y[i : i + batch_size]), dim=0) + yb = yb.type(torch.float32).cpu() # convert to float32 to move to cpu! + yb = torch.nn.functional.normalize(yb, p=2, dim=1) + # logger.info(f"yb.size: {yb.size()}") + ys.append(yb) + + y = torch.cat(ys, dim=0) + + train_vecs = y + if train_size is not None: + # train_vecs = y[:train_size].cpu() + with torch.no_grad(): + indices = torch.tensor(random.sample(range(y.size()[0]), train_size)) + indices = torch.tensor(indices) + train_vecs = train_vecs[indices] + logger.info("Training the index with {} randomly-sampled vectors".format(len(train_vecs))) + gpu_idx.train(train_vecs) + + # Add vectors + logger.info("Adding {} vectors to the faiss index".format(len(y))) + gpu_batch_size=200000 + for i in tqdm(range(0, len(y), gpu_batch_size)): + vecs = y[i : i + batch_size] + vecs = vecs.type(torch.float32).cpu() # convert to float32 to move to cpu! + gpu_idx.add(vecs) + + # send batched queries to the full faiss index + # size of sim and inds should be equal to len(y) + logger.info("Querying {} vectors to the faiss index".format(len(x))) + sim, ind = [], [] + faiss.GpuParameterSpace().set_index_parameter(ivf, "nprobe", 50) + # faiss.ParameterSpace().set_index_parameter(ivf, "nprobe", 50) + for i in tqdm(range(0, len(x), gpu_batch_size)): + xb = torch.stack(list(x[i : i + gpu_batch_size]), dim=0) + xb = xb.type(torch.float32).cpu() # convert to float32 to move to cpu! + xb = torch.nn.functional.normalize(xb, p=2, dim=1) + xs.append(xb) + bsim, bind = gpu_idx.search(xb, k) # x[i : i + batch_size].cpu() + sim.append(bsim) + # print(f"len(sim): {len(sim)}") + ind.append(bind) + + # logger.info(f"concat results..") + # logger.info(f"sim[0].size(): {sim[0].size()}") + rsim = torch.cat(sim, dim=0).cpu() # along the rows + rind = torch.cat(ind, dim=0).cpu() + x = torch.cat(xs, dim=0) + # logger.info(f"similarity results size: {rsim.size()}") + + # logger.info(f"type(rsim): {type(rsim)} type(rind): {type(rsim)}") + + return rsim, rind, x, y + def knnCPU(x, y, k, index='flat'): start=time.time() @@ -411,13 +472,16 @@ def knnCPU(x, y, k, index='flat'): if index == 'ivf': # quantizer = faiss.IndexFlatIP(dim) # idx = faiss.IndexIVFFlat(quantizer, dim, nlist) - idx = faiss.index_factory(dim, "IVF100,Flat", faiss.METRIC_INNER_PRODUCT) + idx = faiss.index_factory(dim, "IVF200,Flat", faiss.METRIC_INNER_PRODUCT) + idx.train(y) + # print(f"idx.is_trained: {idx.is_trained}") 40000 + elif index == 'hnsw': + idx = faiss.index_factory(dim, "PCA64,IVF30000_HNSW32,Flat", faiss.METRIC_INNER_PRODUCT) idx.train(y) - # print(f"idx.is_trained: {idx.is_trained}") elif index =='pq': # quantizer = faiss.IndexFlatIP(dim) # idx = faiss.IndexIVFPQ(quantizer, dim, nlist, m, bits) - idx = faiss.index_factory(dim, "IVF100,PQ16", faiss.METRIC_INNER_PRODUCT) + idx = faiss.index_factory(dim, "IVF200,PQ16", faiss.METRIC_INNER_PRODUCT) idx.train(y) else: idx = faiss.IndexFlatIP(dim) @@ -428,9 +492,10 @@ def knnCPU(x, y, k, index='flat'): # print(f"sim[:3]: {sim[:3]}") # print(f"ind: {ind}") # print(f"time taken to build the index: {time.time()-start} secs") - return sim, ind + return sim, ind, x, y -def knnGPU(x, y, k, index='flat', mem=5*1024*1024*1024): +def knnGPU(x, y, k, index='flat', mem=48*1024*1024* + 1024): # d = srcRep.shape[1] # print(f"d: {d}") @@ -445,15 +510,22 @@ def knnGPU(x, y, k, index='flat', mem=5*1024*1024*1024): # for each sub-vector set. # 3. In the vector of sub-vecs, replace each sub-vec with the ID of its nearest set-specific centroid # ''' - # m = 8 # number of centroid IDs in final compressed vectors - # bits = 8 # number of bits in each centroid - # nlist = 100 # how many cells + # https://github.com/facebookresearch/LASER/blob/main/source/mine_bitexts.py + print(f"faiss.get_num_gpus(): {faiss.get_num_gpus()}") + ngpus = faiss.get_num_gpus() dim = x.shape[1] + m = 8 # number of centroid IDs in final compressed vectors + bits = 8 # number of bits in each centroid + nlist = 100 # how many cells + res = faiss.StandardGpuResources() + res.setDefaultNullStreamAllDevices() + co = faiss.GpuClonerOptions() batch_size = mem // (dim*4) print(f"batch_size: {batch_size}") if batch_size > x.shape[0]: - batch_size = x.shape[0] // 5 + batch_size = 64 #x.shape[0] // 10000 print(f"batch_size: {batch_size}") + sim = np.zeros((x.shape[0], k), dtype=np.float32) ind = np.zeros((x.shape[0], k), dtype=np.int64) for xfrom in range(0, x.shape[0], batch_size): @@ -461,8 +533,31 @@ def knnGPU(x, y, k, index='flat', mem=5*1024*1024*1024): bsims, binds = [], [] for yfrom in range(0, y.shape[0], batch_size): yto = min(yfrom + batch_size, y.shape[0]) # to_trg_ind - # print('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto)) - idx = faiss.IndexFlatIP(dim) + logger.info('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto)) + if index == 'ivf': # below 1M vectors + idx = faiss.index_factory(dim, "IVF1200,Flat", faiss.METRIC_INNER_PRODUCT) + idx = faiss.index_cpu_to_gpu(res, 0, index=idx) + idx.train(y) + elif index =='pq': + # quantizer = faiss.IndexFlatIP(dim) + # idx = faiss.IndexIVFPQ(quantizer, dim, nlist, m, bits) + idx = faiss.index_factory(dim, "IVF1200,PQ16", faiss.METRIC_INNER_PRODUCT) + idx = faiss.index_cpu_to_all_gpus(index=idx, co=co) + idx.train(y) + elif index == 'hnsw': # for 1M-10M vectors + idx = faiss.index_factory(dim, "PCA64,IVF30000_HNSW32,Flat", faiss.METRIC_INNER_PRODUCT) + idx_ivf = faiss.extract_index_ivf(idx) + res = [faiss.StandardGpuResources() for _ in range(ngpus)] + #clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatIP(idx_ivf.d), co, ngpu=1) + # clustering_index = faiss.index_cpu_to_gpu(res, 0, faiss.IndexFlatIP(idx_ivf.d)) + # index_cpu_to_gpu_multiple_py(resources, index, co=None, gpus=None) + #clustering_index = faiss.index_cpu_to_all_gpus(res=res, index=faiss.IndexFlatIP(idx_ivf.d), co=co) + clustering_index = faiss.index_cpu_to_gpu_multiple_py(resources=res, index=faiss.IndexFlatIP(idx_ivf.d), co=co) + idx_ivf.clustering_index = clustering_index + idx.train(y) + else: + idx = faiss.IndexFlatIP(dim) + idx = faiss.index_cpu_to_all_gpus(index=idx, co=co) # quantizer = faiss.IndexFlatL2(d) # idx = faiss.IndexIVFFlat(quantizer, d, nlist) # #idx = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits) @@ -474,7 +569,6 @@ def knnGPU(x, y, k, index='flat', mem=5*1024*1024*1024): # idx.nprobe = 1 # to increase the search scope # # large nprobe values = slower but more accurate search - idx = faiss.index_cpu_to_all_gpus(idx) idx.add(y[yfrom:yto]) # added trg_batch = batch_size to the index bsim, bind = idx.search(x[xfrom:xto], min(k, yto-yfrom)) # find k nearest neighbours for the batched queries bsims.append(bsim) @@ -500,8 +594,6 @@ def score_candidates(x, y, candidate_inds, fwd_mean, bwd_mean, margin, verbose=F for j in range(scores.shape[1]): k = candidate_inds[i, j] scores[i, j] = score(x[i], y[k], fwd_mean[i], bwd_mean[k], margin) - # print(f"x[i]: {x[i]}, y[k]: {y[k]} fwd_mean[i]: {fwd_mean[i]}, bwd_mean[k]: {bwd_mean[k]}") - # print(f"scores[i, j] : {scores[i, j]}") return scores @@ -536,11 +628,14 @@ class Comparable(): def __init__(self, model, trainer, task, cfg): self.sim_measure = cfg.comparable.sim_measure self.threshold = cfg.comparable.threshold + self.use_threshold = cfg.comparable.use_threshold self.model_name = cfg.comparable.model_name self.save_dir = cfg.comparable.save_dir self.use_phrase = cfg.comparable.use_phrase - #self.model = trainer.get_model().encoder + self.model = model #trainer.get_model().encoder self.usepos = cfg.comparable.usepos + self.temp = cfg.comparable.temp + Path(f"/netscratch/jalota/pickle/{self.temp}/").mkdir(parents=True, exist_ok=True) # print("Use positional encoding = ", self.usepos) self.trainer = trainer # print(f"self.trainer: {self.trainer}") @@ -548,7 +643,7 @@ def __init__(self, model, trainer, task, cfg): self.encoder = self.trainer.get_model().encoder # print(f"self.encoder: {self.encoder}") self.batch_size = cfg.comparable.max_sentences - self.batcher = BatchCreator(task, cfg) + self.batcher = BatchCreator(task, cfg, trainer) self.similar_pairs = PairBank(self.batcher, cfg) self.accepted = 0 self.accepted_limit = 0 @@ -556,6 +651,7 @@ def __init__(self, model, trainer, task, cfg): self.total = 0 self.cfg = cfg self.comp_log = cfg.comparable.comp_log + Path(self.comp_log).mkdir(parents=True, exist_ok=True) self.cove_type = cfg.comparable.cove_type self.update_freq = cfg.optimization.update_freq self.k = cfg.comparable.k #20 #cfg.comparable.k @@ -623,12 +719,14 @@ def write_sentence(self, src, tgt, status, score=None): elif status == 'embed_only': with open(self.embed_file, 'a', encoding='utf8') as f: f.write(out) + # return out elif status == 'hidden_only': with open(self.hidden_file, 'a', encoding='utf8') as f: f.write(out) + # return out return None - def extract_parallel_sents(self, candidates, candidate_pool, phrasese=False): + def extract_parallel_sents(self, candidates, candidate_pool, phrasese=False, use_threshold=False): """ Extracts parallel sentences from candidates and adds them to the PairBank (secondary filter). @@ -651,50 +749,45 @@ def extract_parallel_sents(self, candidates, candidate_pool, phrasese=False): self.write_sentence(candidate[0], candidate[1], 'hidden_only', candidate[2]) continue - '''if self.no_swaps: - swap = False - # Swap src-tgt direction randomly - else: - swap = np.random.randint(2) - if swap: - src = candidate[1] - tgt = candidate[0] - else: - src = candidate[0] - tgt = candidate[1]''' - - src = candidate[0] - tgt = candidate[1] - score = candidate[2] - - # Apply threshold (single-representation systems only) - if score >= self.threshold: - # print("Score is greater than threshold") - # Check if no maximum of allowed unique accepted pairs reached - # if self.similar_pairs.no_limit_reached(src, tgt): - # Add to PairBank - self.similar_pairs.add_example(src, tgt) - self.write_sentence(removePadding(src), removePadding(tgt), 'accepted', score) - self.accepted += 1 - if self.symmetric: - self.similar_pairs.add_example(tgt, src) - # self.write_sentence(tgt, src, 'accepted', score) - - - # if self.use_phrase and phrasese is False: - # print("checking phrases to remove.......") - # src_rm = removePadding(src) - # self.phrases.remove_from_phrase_candidates(src_rm, 'src') - # tgt_rm = removePadding(tgt) - # self.phrases.remove_from_phrase_candidates(tgt_rm, 'tgt') - # # write the accepted phrases to file - # if self.use_phrase and phrasese is True and self.args.write_phrase: - # self.write_sentence(removePadding(src), removePadding(tgt), 'phrase', score) - else: - # print("threshold not met!!!") - self.declined += 1 - self.total += 1 - + + elif self.in_candidate_pool(candidate, candidate_pool) and not use_threshold: + src = candidate[0] + tgt = candidate[1] + score = candidate[2] + + self.similar_pairs.add_example(src, tgt) + self.write_sentence(removePadding(src), removePadding(tgt), 'accepted', score) + self.accepted += 1 + if self.symmetric: + self.similar_pairs.add_example(tgt, src) + self.write_sentence(tgt, src, 'accepted', score) + self.total += 1 + + elif use_threshold or self.in_candidate_pool(candidate, candidate_pool): + # Apply threshold (single-representation systems only) + + src = candidate[0] + tgt = candidate[1] + score = candidate[2] + + if score >= self.threshold: + # print("Score is greater than threshold") + # Check if no maximum of allowed unique accepted pairs reached + # if self.similar_pairs.no_limit_reached(src, tgt): + # Add to PairBank + self.similar_pairs.add_example(src, tgt) + self.write_sentence(removePadding(src), removePadding(tgt), 'accepted', score) + self.accepted += 1 + if self.symmetric: + self.similar_pairs.add_example(tgt, src) + self.write_sentence(tgt, src, 'accepted', score) + else: + # print("threshold not met!!!") + self.declined += 1 + self.total += 1 + else: # doesnt match thresold or not in candidate-pool + continue + return None def write_embed_only(self, candidates, cand_embed): @@ -714,6 +807,209 @@ def write_embed_only(self, candidates, cand_embed): score = candidate[2] self.write_sentence(src, tgt, 'embed_only', score) + def faiss_sent_scoring_v2(self, src_sents, tgt_sents): + """ Score source and target combinations. + Args: + src_sents(list(tuple(torch.Tensor...))): + list of src sentences in their sequential and semantic representation + tgt_sents(list(tuple(torch.Tensor...))): list of tgt sentences + Returns: + src2tgt(dict(dict(float))): dictionary mapping a src to a tgt and their score + tgt2src(dict(dict(float))): dictionary mapping a tgt to a src and their score + similarities(list(float)): list of cosine similarities + scores(list(float)): list of scores + """ + start = time.time() + + srcSent, srcRep = zip(*src_sents) + # print(f"srcSent: {srcSent}") + tgtSent, tgtRep = zip(*tgt_sents) + # print(f"tgtSent: {tgtSent}") + + # print("faiss sent scoring") + + if self.faiss_use_gpu: + # https://github.com/facebookresearch/faiss/wiki/Faiss-on-the-GPU + ngpus = faiss.get_num_gpus() + logger.info(f"number of GPUs: {ngpus}") + + # srcSent2ind = {sent:i for i, sent in enumerate(srcSent)} + # tgtSent2ind = {sent:i for i, sent in enumerate(tgtSent)} + + # logger.info(f"len srcRep: {len(srcRep)}") + # logger.info(f"srcRep: {srcRep}") + + + + # x = torch.stack(list(srcRep), dim=0) #torch.cat(srcRep, dim=1) # concat along the rows + # y = torch.stack(list(tgtRep), dim=0) + if self.faiss_use_gpu: + x, y = srcRep, tgtRep + + # logger.info(f"len x: {x.size()}") + # logger.info(f"x: {x}") + else: + x= np.asarray([rep.detach().cpu().numpy() for rep in srcRep]) + y= np.asarray([rep.detach().cpu().numpy() for rep in tgtRep]) + + # print(f"normalising x.dtype : {x.dtype}") + faiss.normalize_L2(x) + faiss.normalize_L2(y) + # logger.info("done torch normalizing") + # logger.info(f"x.size(): {x.size()}") + # https://discuss.pytorch.org/t/how-to-normalize-embedding-vectors/1209/9 + # x = torch.nn.functional.normalize(x, p=2, dim=1) + # y = torch.nn.functional.normalize(y, p=2, dim=1) + #F.normalize(x, p=2, dim=1) + + candidates = [] + + # torch.from_numpy(a) + + # calculate knn in both directions + if self.retrieval != 'bwd': + if self.verbose: + print(' - perform {:d}-nn source against target'.format(self.k)) + x2y_sim, x2y_ind, x, y = knn_v2(x, y, self.k, self.faiss_use_gpu, self.index) + if self.faiss_use_gpu: + x2y_sim = x2y_sim.numpy() #.detach().cpu().numpy() + x2y_ind = x2y_ind.numpy() # .detach().cpu() + x2y_mean = x2y_sim.mean(axis=1) + #x2y_mean = torch.mean(x2y_sim, 1) + + # print(f"x2y_sim.shape: {x2y_sim.shape}") + # print(f"x2y_ind.shape: {x2y_ind.shape}") + + if self.retrieval != 'fwd': + if self.verbose: + print(' - perform {:d}-nn target against source'.format(self.k)) + y2x_sim, y2x_ind, _, _ = knn_v2(y, x, self.k, self.faiss_use_gpu, self.index) + # logger.info(f"type(y2x_sim): {type(y2x_sim)}, type(y2x_ind): {type(y2x_ind)}") + if self.faiss_use_gpu: + y2x_sim = y2x_sim.numpy() # .detach().cpu() + y2x_ind = y2x_ind.numpy() # .detach().cpu().numpy() + y2x_mean = y2x_sim.mean(axis=1) + # y2x_mean = torch.mean(y2x_sim, 1) + + # margin function + if self.margin == 'absolute': + margin = lambda a, b: a + elif self.margin == 'distance': + margin = lambda a, b: a - b + else: # args.margin == 'ratio': + margin = lambda a, b: a / b + + # print(f"margin: {margin}") + + fout = open(self.faiss_output, mode='w', encoding='utf8', errors='surrogateescape') + + src_inds=list(range(len(srcSent))) + trg_inds=list(range(len(tgtSent))) + + if self.mode == 'search': + if self.verbose: + print(' - Searching for closest sentences in target') + print(' - writing alignments to {:s}'.format(self.faiss_output)) + scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin, self.verbose) + best = x2y_ind[np.arange(x.shape[0]), scores.argmax(axis=1)] + + print(f"best: {best}") + + nbex = x.shape[0] + ref = np.linspace(0, nbex-1, nbex).astype(int) # [0, nbex) + err = nbex - np.equal(best.reshape(nbex), ref).astype(int).sum() + print(' - errors: {:d}={:.2f}%'.format(err, 100*err/nbex)) + for i in src_inds: + print(tgtSent[best[i]], file=fout) + + elif self.mode == 'score': + for i, j in zip(src_inds, trg_inds): + s = score(x[i], y[j], x2y_mean[i], y2x_mean[j], margin) + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + + elif self.mode == 'mine': + if self.verbose: + logger.info(' - mining for parallel data') + fwd_scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin, self.verbose) + bwd_scores = score_candidates(y, x, y2x_ind, y2x_mean, x2y_mean, margin, self.verbose) + fwd_best = x2y_ind[np.arange(x.shape[0]), fwd_scores.argmax(axis=1)] + # print(f"fwd_best: {fwd_best}") + bwd_best = y2x_ind[np.arange(y.shape[0]), bwd_scores.argmax(axis=1)] + # print(f"bwd_best: {bwd_best}") + if self.verbose: + logger.info(' - writing alignments to {:s}'.format(self.faiss_output)) + if self.threshold > 0: + logger.info(' - with threshold of {:f}'.format(self.threshold)) + if self.retrieval == 'fwd': + for i, j in enumerate(fwd_best): + s = fwd_scores[i].max() + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(fwd_scores[i].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) + if self.retrieval == 'bwd': + for j, i in enumerate(bwd_best): + s = bwd_scores[j].max() + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(bwd_scores[j].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) + if self.retrieval == 'intersect': + for i, j in enumerate(fwd_best): + if bwd_best[j] == i: + s = fwd_scores[i].max() + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(fwd_scores[i].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) + if self.retrieval == 'max': + indices = np.stack((np.concatenate((np.arange(x.shape[0]), bwd_best)), + np.concatenate((fwd_best, np.arange(y.shape[0])))), axis=1) + scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1))) + seen_src, seen_trg = set(), set() + for i in np.argsort(-scores): + src_ind, trg_ind = indices[i] + if not src_ind in seen_src and not trg_ind in seen_trg: + seen_src.add(src_ind) + seen_trg.add(trg_ind) + if scores[i] > self.threshold: + s = scores[i] + src = srcSent[src_ind] + tgt = tgtSent[trg_ind] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(scores[i], srcSent[src_ind], tgtSent[trg_ind], sep='\t', file=fout) + candidates.append((srcSent[src_ind], tgtSent[trg_ind], scores[i])) + + fout.close() + logger.info(f"time taken by faiss sent scoring: {time.time()-start} seconds.") + # logger.info(f"num candidates: {len(candidates)}") + return candidates + def faiss_sent_scoring(self, src_sents, tgt_sents): """ Score source and target combinations. @@ -728,8 +1024,18 @@ def faiss_sent_scoring(self, src_sents, tgt_sents): scores(list(float)): list of scores """ start = time.time() + srcSent, srcRep = zip(*src_sents) + # print(f"srcSent: {srcSent}") tgtSent, tgtRep = zip(*tgt_sents) + # print(f"tgtSent: {tgtSent}") + + print("faiss sent scoring") + + if self.faiss_use_gpu: + # https://github.com/facebookresearch/faiss/wiki/Faiss-on-the-GPU + ngpus = faiss.get_num_gpus() + logger.info(f"number of GPUs: {ngpus}") # srcSent2ind = {sent:i for i, sent in enumerate(srcSent)} # tgtSent2ind = {sent:i for i, sent in enumerate(tgtSent)} @@ -737,9 +1043,12 @@ def faiss_sent_scoring(self, src_sents, tgt_sents): x= np.asarray([rep.detach().cpu().numpy() for rep in srcRep]) y= np.asarray([rep.detach().cpu().numpy() for rep in tgtRep]) - # print(f"x : {x}") + print(f"normalising x.dtype : {x.dtype}") faiss.normalize_L2(x) faiss.normalize_L2(y) + + logger.info("done faiss normalizing") + candidates = [] # torch.from_numpy(a) @@ -750,14 +1059,15 @@ def faiss_sent_scoring(self, src_sents, tgt_sents): print(' - perform {:d}-nn source against target'.format(self.k)) x2y_sim, x2y_ind = knn(x, y, min(y.shape[0], self.k), self.faiss_use_gpu, self.index) x2y_mean = x2y_sim.mean(axis=1) - # print(f"x2y_sim.shape: {x2y_sim.shape}") - # print(f"x2y_ind.shape: {x2y_ind.shape}") + # logger.info(f"x2y_sim.shape: {x2y_sim.shape}") + logger.info(f"x2y_ind.shape: {x2y_ind.shape}") if self.retrieval != 'fwd': if self.verbose: print(' - perform {:d}-nn target against source'.format(self.k)) y2x_sim, y2x_ind = knn(y, x, min(x.shape[0], self.k), self.faiss_use_gpu, self.index) y2x_mean = y2x_sim.mean(axis=1) + logger.info(f"y2x_ind.shape: {y2x_ind.shape}") # margin function if self.margin == 'absolute': @@ -781,7 +1091,7 @@ def faiss_sent_scoring(self, src_sents, tgt_sents): scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin, self.verbose) best = x2y_ind[np.arange(x.shape[0]), scores.argmax(axis=1)] - # print(f"best: {best}") + print(f"best: {best}") nbex = x.shape[0] ref = np.linspace(0, nbex-1, nbex).astype(int) # [0, nbex) @@ -825,6 +1135,7 @@ def faiss_sent_scoring(self, src_sents, tgt_sents): removeSpaces(' '.join(tgt_words)), s) print(out, file=fout) # print(fwd_scores[i].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) if self.retrieval == 'bwd': for j, i in enumerate(bwd_best): s = bwd_scores[j].max() @@ -836,6 +1147,7 @@ def faiss_sent_scoring(self, src_sents, tgt_sents): removeSpaces(' '.join(tgt_words)), s) print(out, file=fout) # print(bwd_scores[j].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) if self.retrieval == 'intersect': for i, j in enumerate(fwd_best): if bwd_best[j] == i: @@ -848,6 +1160,7 @@ def faiss_sent_scoring(self, src_sents, tgt_sents): removeSpaces(' '.join(tgt_words)), s) print(out, file=fout) # print(fwd_scores[i].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) if self.retrieval == 'max': indices = np.stack((np.concatenate((np.arange(x.shape[0]), bwd_best)), np.concatenate((fwd_best, np.arange(y.shape[0])))), axis=1) @@ -871,7 +1184,8 @@ def faiss_sent_scoring(self, src_sents, tgt_sents): candidates.append((srcSent[src_ind], tgtSent[trg_ind], scores[i])) fout.close() - print(f"time taken by faiss sent scoring: {time.time()-start} seconds.") + logger.info(f"time taken by faiss sent scoring: {time.time()-start} seconds.") + # logger.info(f"num candidates: {len(candidates)}") return candidates def score_sents(self, src_sents, tgt_sents): @@ -1000,7 +1314,9 @@ def get_article_coves(self, article, representation='memory', mean=False, side= list of sentences in their sequential (seq) and semantic representation (cove) """ sents = [] - # print("inside get_article_coves") + print("inside get_article_coves") + # print(f"self.cfg.task.arch: ") + # print(f"{self.cfg.task.arch}") #for k in article:#tqdm(article): # print("next(article)") id = 0 @@ -1008,13 +1324,16 @@ def get_article_coves(self, article, representation='memory', mean=False, side= # print(f"len(article): {len(article)}") for k in article: # print("inside article!") + # print(f"k['net_input']['src_tokens']: {k['net_input']['src_tokens']}") + # print(f"self.cfg.model.arch: {self.cfg.model.arch}") # print(f"article id: {id}") # if id == 3013: # print("skipping 3013") # continue - # print(f"k['net_input']['src_tokens']: {k['net_input']['src_tokens']}") + # print(f"k['net_input']['src_tokens']: {k['net_input']['src_tokens']}") sent_repr = None - if self.args.modeltype == "lstm": # if the model architecture is LSTM + if self.cfg.model.arch == "lstm": # if the model architecture is LSTM + # replaced self.cfg.task.arch lengths = k['net_input']['src_lengths'] texts = k['net_input']['src_tokens'] ordered_len, ordered_idx = lengths.sort(0, descending=True) @@ -1040,7 +1359,7 @@ def get_article_coves(self, article, representation='memory', mean=False, side= sent_repr = torch.mean(hidden_embed, dim=0) else: sent_repr = torch.sum(hidden_embed, dim=0) - elif self.args.modeltype == "transformer": + elif self.cfg.model.arch == "transformer": # print("In the transformer representation") if representation == 'memory': with torch.no_grad(): @@ -1054,21 +1373,26 @@ def get_article_coves(self, article, representation='memory', mean=False, side= k['net_input']['src_lengths'].to(self.mps_device)) elif self.use_gpu and self.cuda: # print("going into encoder forward") - encoderOut = self.encoder.forward(k['net_input']['src_tokens'].cuda(), - k['net_input']['src_lengths'].cuda()) + encoderOut = self.encoder.forward(k['net_input']['src_tokens'].cuda(), k['net_input']['src_lengths'].cuda()) + # print("got encoderOut") else: encoderOut = self.encoder.forward(k['net_input']['src_tokens'], k['net_input']['src_lengths']) - hidden_embed = getattr(encoderOut, 'encoder_out') # T x B x C + # print(f"encoderOut: {encoderOut}") + # print(f"len(encoderOut['encoder_out']): {len(encoderOut['encoder_out'])}") + hidden_embed = encoderOut['encoder_out'][0] + # hidden_embed = getattr(encoderOut, 'encoder_out') # T x B x C + # print(f"hidden_embed: {hidden_embed}") if mean: sent_repr = torch.mean(hidden_embed, dim=0) else: sent_repr = torch.sum(hidden_embed, dim=0) elif representation == 'embed': + # print(f"k['net_input']['src_tokens']: {k['net_input']['src_tokens']}") + # print(f"k['net_input']['src_lengths']: {k['net_input']['src_lengths']}") + # print("going into encoder forward emb") + # print(f"self.usepos: {self.usepos}") with torch.no_grad(): - # print(f"k['net_input']['src_tokens']: {k['net_input']['src_tokens']}") - # print(f"k['net_input']['src_lengths']: {k['net_input']['src_lengths']}") - # print("going into encoder forward emb") if self.usepos: if self.use_gpu and self.mps: input_emb,_ = self.encoder.forward_embedding(k['net_input']['src_tokens'].to(self.mps_device)) @@ -1084,6 +1408,7 @@ def get_article_coves(self, article, representation='memory', mean=False, side= else: _, input_emb = self.encoder.forward_embedding(k['net_input']['src_tokens']) # print(f"type(input_emb): {type(input_emb)}") + # print(f"self.cuda: {self.cuda}") if self.mps: input_emb = input_emb.to(self.mps_device) @@ -1091,26 +1416,22 @@ def get_article_coves(self, article, representation='memory', mean=False, side= input_emb = input_emb.cuda() #input_emb = getattr(encoderOut, 'encoder_embedding') # B x T x C - # print(f"type(input_emb): {type(input_emb)}") + # print(f"input_emb.size(): {input_emb.size()}") input_emb = input_emb.transpose(0, 1) if mean: sent_repr = torch.mean(input_emb, dim=0) else: sent_repr = torch.sum(input_emb, dim=0) - if self.args.modeltype == "transformer": + if self.cfg.model.arch == "transformer": # print(f"inside modeltype == transformer") - # print(f"k['net_input']['src_tokens'][i]: {k['net_input']['src_tokens'][i]}") - # print(f"rang(i): {range(k['net_input']['src_tokens'].shape[0])}") + for i in range(k['net_input']['src_tokens'].shape[0]): - #print(f"i : {i}") + # print(f"i : {i}") + # print(f"k['net_input']['src_tokens'][i]: {k['net_input']['src_tokens'][i]}") + # print(f"rang(i): {range(k['net_input']['src_tokens'].shape[0])}") sents.append((k['net_input']['src_tokens'][i], sent_repr[i])) - if side == 'src' and use_phrase is True: - st = removePadding(k['net_input']['src_tokens'][i]) - self.phrases.sourcesent.add((hash(str(st)), st)) - elif side == 'tgt' and use_phrase is True: - st = removePadding(k['net_input']['src_tokens'][i]) - self.phrases.targetsent.add((hash(str(st)), st)) - elif self.args.modeltype == "lstm": + + elif self.cfg.model.arch == "lstm": for i in range(texts.shape[0]): sents.append((texts[i], sent_repr[i])) # print(f"finishing {id}") @@ -1204,7 +1525,7 @@ def filter_candidates(self, src2tgt, tgt2src, second=False): candidates = list(src_tgt_max & tgt_src_max) return candidates # [(src_x, tgt_y, score_xy)] - def _get_iterator(self, sent, dictn, max_position, epoch, fix_batches_to_gpus=False): + def _get_iterator(self, sent, dictn, max_position, epoch, fix_batches_to_gpus=False, shard_batch_itr=False,disable_iterator_cache=False): """ Creates an iterator object from a text file. Args: @@ -1213,20 +1534,25 @@ def _get_iterator(self, sent, dictn, max_position, epoch, fix_batches_to_gpus=Fa data_iter(.EpochIterator): iterator object """ # get indices ordered by example size - with numpy_seed(self.args.seed): + with numpy_seed(self.cfg.common.seed): indices = sent.ordered_indices() # filter out examples that are too large max_positions = (max_position) if max_positions is not None: indices = filter_by_size(indices, sent, max_positions, raise_exception=(not True), ) # create mini-batches with given size constraints - max_sentences = self.args.max_sentences # 30 - batch_sampler = batch_by_size(indices, sent.num_tokens, max_sentences=max_sentences, - required_batch_size_multiple=self.args.required_batch_size_multiple, ) + print(f"self.cfg.comparable.max_sentences: {self.cfg.comparable.max_sentences}") + max_sentences = self.cfg.comparable.max_sentences # 30 + print(f"max_sentences: {max_sentences}") + print(f"self.cfg.dataset.num_workers: {self.cfg.dataset.num_workers}") + print(f"sent.num_tokens: {sent.num_tokens}") + + batch_sampler = batch_by_size(indices, sent.num_tokens, max_sentences=max_sentences, required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, ) # print(f"tuple(batch_sampler): {tuple(batch_sampler)}") - itrs = EpochBatchIterator(dataset=sent, collate_fn=sent.collater, batch_sampler=batch_sampler, seed=self.args.seed,num_workers=self.args.num_workers, epoch=epoch) + itrs = EpochBatchIterator(dataset=sent, collate_fn=sent.collater, batch_sampler=batch_sampler, seed=self.cfg.common.seed,num_workers=self.cfg.dataset.num_workers, epoch=epoch, num_shards=self.trainer.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.trainer.data_parallel_rank if shard_batch_itr else 0,) #data_iter = itrs.next_epoch_itr(shuffle=False, fix_batches_to_gpus=fix_batches_to_gpus) # print(f"itrs.state_dict: {itrs.state_dict()}") + # print(f"itrs: {itrs}") # print(f"itrs.n(): {itrs.n()}") # print(f"itrs.first_batch(): {itrs.first_batch()}") # print(f"next(itrs)") @@ -1254,149 +1580,24 @@ def get_cove(self, memory, ex, mean=False): return cove def getdata(self, articles): + logger.info(f"self.cfg.dataset.dataset_impl: raw") trainingSetSrc = load_indexed_dataset(articles[0], self.task.src_dict, - dataset_impl=self.args.dataset_impl, combine=False, + dataset_impl='raw', combine=False, default='cached') trainingSetTgt = load_indexed_dataset(articles[1], self.task.tgt_dict, - dataset_impl=self.args.dataset_impl, combine=False, + dataset_impl='raw', combine=False, default='cached') + logger.info(f"trainingSetSrc: {trainingSetSrc}") # print("read the text file ")self.args.data + # convert the read files to Monolingual dataset to make padding easy - src_mono = MonolingualDataset(dataset=trainingSetSrc, sizes=trainingSetSrc.sizes, - src_vocab=self.task.src_dict, - tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False) - tgt_mono = MonolingualDataset(dataset=trainingSetTgt, sizes=trainingSetTgt.sizes, - src_vocab=self.task.tgt_dict, - tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False) - + src_mono = MonolingualDataset(dataset=trainingSetSrc, sizes=trainingSetSrc.sizes, src_vocab=self. task.src_dict, tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False) + tgt_mono = MonolingualDataset(dataset=trainingSetTgt, sizes=trainingSetTgt.sizes, src_vocab=self.task.tgt_dict, tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False) del trainingSetSrc, trainingSetTgt # print("Monolingual data") # print(f"src_mono.num_tokens(1): {src_mono.num_tokens(1)}") # print(f"tgt_mono.num_tokens(1): {tgt_mono.num_tokens(1)}") return src_mono, tgt_mono - # def extract_phrase_train(self, srcPhrase, tgtPhrase, epoch): - # src_sents = [] - # tgt_sents = [] - # src_embeds = [] - # tgt_embeds = [] - # # load the dataset from the files for both source and target - # src_indexed, src_sizes = indexPhraseData(srcPhrase, dictionary = self.task.src_dict, append_eos = True, reverse_order = False) - # tgt_indexed, tgt_sizes = indexPhraseData(tgtPhrase, dictionary = self.task.tgt_dict, append_eos=True, reverse_order=False) - # #print(src_indexed) - - # src_mono = MonolingualDataset(dataset=src_indexed, sizes=src_sizes, - # src_vocab=self.task.src_dict, - # tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False) - # tgt_mono = MonolingualDataset(dataset=tgt_indexed, sizes=tgt_sizes, - # src_vocab=self.task.tgt_dict, - # tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False) - - # # Prepare iterator objects for current src/tgt document - # src_article = self._get_iterator(src_mono, dictn=self.task.src_dict, - # max_position=self.args.max_source_positions, epoch=epoch, - # fix_batches_to_gpus=False) - # tgt_article = self._get_iterator(tgt_mono, dictn=self.task.tgt_dict, - # max_position=self.args.max_target_positions, epoch=epoch, - # fix_batches_to_gpus=False) - # # Get sentence representations - # #try: - # if self.representations == 'embed-only': - # # print("Using Embeddings only for representation") - # # C_e - # itr_src = src_article._get_iterator_for_epoch(epoch=epoch, shuffle=True) - # itr_tgt = tgt_article._get_iterator_for_epoch(epoch=epoch, shuffle=True) - # src_sents += self.get_article_coves(itr_src, representation='embed', mean=False, side='src') - # tgt_sents += self.get_article_coves(itr_tgt, representation='embed', mean=False, side='tgt') - # else: - # # C_e and C_h - - # it1 = src_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) - # src_embeds += self.get_article_coves(it1, representation='embed', mean=False, side='src', - # use_phrase=self.use_phrase) - # it1 = src_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) - # src_sents += self.get_article_coves(it1, representation='memory', mean=False, side='src') - - # it3 = tgt_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) - # tgt_embeds += self.get_article_coves(it3, representation='embed', mean=False, side='tgt', - # use_phrase=self.use_phrase) - # it3 = tgt_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) - # tgt_sents += self.get_article_coves(it3, representation='memory', mean=False, side='tgt') - - # # return - # '''except: - # # Skip document pair in case of errors - # print("error") - # src_sents = [] - # tgt_sents = [] - # src_embeds = [] - # tgt_embeds = []''' - # # free resources for Gabbage, not necessary tho - # #src_mono.dataset.tokens_list = None - # #src_mono.dataset.sizes = None - # #src_mono.sizes = None - # #tgt_mono.sizes = None - # del tgt_mono - # del src_mono - # #print('source = ',src_sents[0][0]) - # #print('target = ', tgt_sents[0][0]) - # # Score src and tgt sentences - - # #try: - # src2tgt, tgt2src, similarities, scores = self.score_sents(src_sents, tgt_sents) - # # src2tgt = { "dis a src sent": {"dis a tg": 0.2, "dis s a TRG": 0.6, "dis": 0.12} } - # # this score could be from margin / cosine similarity - # # similarities containes only sim scores (useless var) - # # scores is a useless var - - # '''except: - # # print('Error occurred in: {}\n'.format(article_pair), flush=True) - # print(src_sents, flush=True) - # print(tgt_sents, flush=True) - # src_sents = [] - # tgt_sents = [] - # return''' - # # print("source 2 target ", src2tgt) - # # Keep statistics - # # epoch_similarities += similarities - # # epoch_scores += scores - # src_sents = [] - # tgt_sents = [] - - # #try: - # if self.representations == 'dual': # means fwd and bwd - # # For dual representation systems, filter C_h... - # candidates = self.filter_candidates(src2tgt, tgt2src, second=self.second) - # # Filter candidates (primary filter), such that only those which are top candidates in - # # both src2tgt and tgt2src direction pass. - # # ...and C_e - # comparison_pool, cand_embed = self.get_comparison_pool(src_embeds, - # tgt_embeds) - # print("The number of candidates from Phrases = ", len(candidates)) - # src_embeds = [] - # tgt_embeds = [] - # if self.write_dual: - # # print("writing the sentences to file....") - # self.write_embed_only(candidates, cand_embed) - # else: - # print("Using Embedings only for Filtering ......") - # # Filter C_e or C_h for single representation system - # candidates = self.filter_candidates(src2tgt, tgt2src) - # comparison_pool = None - # '''except: - # # Skip document pair in case of errors - # print("Error Occured!!!!") - # # print('Error occured in: {}\n'.format(article_pair), flush=True) - # src_embeds = [] - # tgt_embeds = [] - # return''' - - - # # Extract parallel samples (secondary filter) - # phrasese = True - # self.extract_parallel_sents(candidates, comparison_pool, phrasese) - # return None - def extract_and_train(self, comparable_data_list, epoch): tracemalloc.start() @@ -1434,6 +1635,11 @@ def extract_and_train(self, comparable_data_list, epoch): print(f"on article {ap}") cur_article += 1 articles = article_pair.split(' ') + # if ap >= 1: + # src_sents = torch.load(f"/netscratch/jalota/pickle/{self.temp}/src_sents.pt") + # src_embeds = torch.load(f"/netscratch/jalota/pickle/{self.temp}/src_embeds.pt") + # tgt_sents = torch.load(f"/netscratch/jalota/pickle/{self.temp}/tgt_sents.pt") + # tgt_embeds = torch.load(f"/netscratch/jalota/pickle/{self.temp}/tgt_embeds.pt") # print(f"articles: {articles}") # print(f"len(articles): {len(articles)}") # Discard malaligned documents @@ -1442,45 +1648,51 @@ def extract_and_train(self, comparable_data_list, epoch): #load the dataset from the files for both source and target src_mono, tgt_mono = self.getdata(articles) # Prepare iterator objects for current src/tgt document - # print(f"self.task.src_dict: {self.task.src_dict}") - # print(f"self.args.max_source_positions: {self.args.max_source_positions}") - # print(f"get iterator") - src_article = self._get_iterator(src_mono, dictn=self.task.src_dict, - max_position=self.args.max_source_positions, epoch=epoch, fix_batches_to_gpus=False) - tgt_article = self._get_iterator(tgt_mono, dictn=self.task.tgt_dict, - max_position=self.args.max_target_positions, epoch=epoch, fix_batches_to_gpus=False) + print(f"self.task.src_dict: {self.task.src_dict}") + print(f"self.cfg.max_source_positions: {self.cfg.task.max_source_positions}") + print(f"get iterator") + src_article = self._get_iterator(src_mono, dictn=self.task.src_dict, max_position=self.cfg.task.max_source_positions, epoch=epoch, fix_batches_to_gpus=False) + tgt_article = self._get_iterator(tgt_mono, dictn=self.task.tgt_dict, max_position=self.cfg.task.max_target_positions, epoch=epoch, fix_batches_to_gpus=False) # Get sentence representations try: if self.representations == 'embed-only': - # print("Using Embeddings only for representation") + print("Using Embeddings only for representation") # C_e itr_src = src_article._get_iterator_for_epoch(epoch=epoch, shuffle=True) itr_tgt = tgt_article._get_iterator_for_epoch(epoch=epoch, shuffle=True) - # print(f"src article, rep=embed") + print(f"src article, rep=embed") src_sents += self.get_article_coves(itr_src, representation='embed', mean=False) # print(f"tgt article, rep=embed") + # torch.save(src_sents, "/netscratch/jalota/pickle/exp2/src_sents.pt") tgt_sents += self.get_article_coves(itr_tgt, representation='embed', mean=False) + # torch.save(tgt_sents, "/netscratch/jalota/pickle/exp2/tgt_sents.pt") else: # C_e and C_h '''it1, it2 = itertools.tee(src_article) it3, it4 = itertools.tee(tgt_article)''' - # print(f"src article, rep=embed") - it1 = src_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) - src_embeds += self.get_article_coves(it1, representation='embed', mean=False, side='src', - use_phrase=self.use_phrase) - # print(f"src article, rep=memory") - it1 = src_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) - src_sents += self.get_article_coves(it1, representation='memory', mean=False, side='src') - - # print(f"tgt article, rep=embed") - it3 = tgt_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) - tgt_embeds += self.get_article_coves(it3, representation='embed', mean=False, side='tgt', - use_phrase=self.use_phrase) - # print(f"tgt article, rep=memory") - it3 = tgt_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) - tgt_sents += self.get_article_coves(it3, representation='memory', mean=False, side='tgt') - + print(f"src article, rep=embed") + it = src_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) + src_embeds += self.get_article_coves(it, representation='embed', mean=False, side='src', use_phrase=self.use_phrase) + # torch.save(src_embeds, f"/netscratch/jalota/pickle/{self.temp}/src_embeds.pt") + print(f"src article, rep=memory") + # del src_embeds + it = src_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) + src_sents += self.get_article_coves(it, representation='memory', mean=False, side='src') + # torch.save(src_sents, f"/netscratch/jalota/pickle/{self.temp}/src_sents.pt") + # del src_sents + print(f"tgt article, rep=embed") + it = tgt_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) + tgt_embeds += self.get_article_coves(it, representation='embed', mean=False, side='tgt', use_phrase=self.use_phrase) + # torch.save(tgt_embeds, f"/netscratch/jalota/pickle/{self.temp}/tgt_embeds.pt") + # del tgt_embeds + print(f"tgt article, rep=memory") + it = tgt_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) + print(f"got it4") + tgt_sents += self.get_article_coves(it, representation='memory', mean=False, side='tgt') + # torch.save(tgt_sents, f"/netscratch/jalota/pickle/{self.temp}/tgt_sents.pt") + # del tgt_sents + print(f"done with tgt article, rep=memory") #return except: #Skip document pair in case of errors @@ -1497,10 +1709,13 @@ def extract_and_train(self, comparable_data_list, epoch): del tgt_mono del src_mono + # src_sents = torch.load(f"/netscratch/jalota/pickle/{self.temp}/src_sents.pt") + # tgt_sents = torch.load(f"/netscratch/jalota/pickle/{self.temp}/tgt_sents.pt") + if len(src_sents) < 15 or len(tgt_sents) < 15: #print("Length LEss tahn 15") continue - # print("Proceeding") + print("Proceeding") # Score src and tgt sentences print("In all we have got ", len(src_sents), "source sentences and ", len(tgt_sents), "target") @@ -1508,21 +1723,29 @@ def extract_and_train(self, comparable_data_list, epoch): # get src2gt , tgt2src try: - print(f"self.faiss: {self.faiss}") + logger.info(f"self.faiss: {self.faiss}") if self.faiss: - candidates = self.faiss_sent_scoring(src_sents, tgt_sents) - # print(f"done with faiss scoring of src sents and tgt sents") - candidates_embed = self.faiss_sent_scoring(src_embeds, tgt_embeds) - # print(f"done with faiss scoring of src embeds and tgt embeds") + candidates = self.faiss_sent_scoring_v2(src_sents, tgt_sents) + logger.info(f"done with faiss scoring of src sents and tgt sents") + del src_sents + del tgt_sents + + # src_embeds = torch.load(f"/netscratch/jalota/pickle/{self.temp}/src_embeds.pt") + # tgt_embeds = torch.load(f"/netscratch/jalota/pickle/{self.temp}/tgt_embeds.pt") + + candidates_embed = self.faiss_sent_scoring_v2(src_embeds, tgt_embeds) + logger.info(f"num candidates = {len(candidates)}") + logger.info(f"num candidates_embed = {len(candidates_embed)}") + logger.info(f"done with faiss scoring of src embeds and tgt embeds") embed_comparison_pool = set_embed = set([hash((str(c[0]), str(c[1]))) for c in candidates_embed]) # candidates : [(src_sent_x, tgt_sent_y, score_xy)] - # print(f"made embed_comparison_pool") + # logger.info(f"made embed_comparison_pool") if self.write_dual: #print("writing the sentences to file....") self.write_embed_only(candidates, candidates_embed) # Extract parallel samples (secondary filter) - # print(f"starting to extract parallel sents") - self.extract_parallel_sents(candidates, embed_comparison_pool) + logger.info(f"starting to extract parallel sents") + self.extract_parallel_sents(candidates, embed_comparison_pool, use_threshold=self.use_threshold) else: src2tgt, tgt2src, similarities, scores = self.score_sents(src_sents, tgt_sents) # src2tgt = { "dis a src sent": {"dis a tg": 0.2, "dis s a TRG": 0.6, "dis": 0.12} } @@ -1580,12 +1803,12 @@ def extract_and_train(self, comparable_data_list, epoch): continue # Extract parallel samples (secondary filter) - self.extract_parallel_sents(candidates, comparison_pool) + self.extract_parallel_sents(candidates, comparison_pool, use_threshold=self.use_threshold) # if phrase extraction is to be used - print("pair bank = ",len((self.similar_pairs.pairs))) + logger.info(f"pair bank = {len(self.similar_pairs.pairs)}") # Train on extracted sentences - end_of_epoch = self.train(epoch) + self.train(epoch) if not self.faiss: del src2tgt, tgt2src #gc.collect() @@ -1594,18 +1817,14 @@ def extract_and_train(self, comparable_data_list, epoch): snapshot = tracemalloc.take_snapshot() top_stats = snapshot.statistics('lineno') - if len((self.similar_pairs.pairs)) > 0: - print("batching and training") - end_of_epoch = self.train(epoch, last=True) + if len(self.similar_pairs.pairs) > 0: + logger.info("batching and training") + self.train(epoch, last=True) self.accepted_file.close() if self.use_phrase == True: self.accepted_phrase.close() - # log end-of-epoch stats - #stats = get_training_stats(metrics.get_smoothed_values('train')) - #self.progress.print(stats, tag='train', step=num_updates) - # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format(epoch)) num_updates = self.trainer.get_num_updates() @@ -1614,6 +1833,7 @@ def extract_and_train(self, comparable_data_list, epoch): # reset epoch-level meters metrics.reset_meters('train') + end_of_epoch = True return num_updates, end_of_epoch ''' @metrics.aggregate('train') @@ -1641,19 +1861,16 @@ def trainRest(self, epoch): @metrics.aggregate('train') def train(self, epoch, last=False): # Check if enough parallel sentences were collected + # is_epoch_end = False if last is False: while self.similar_pairs.contains_batch(): # print("IT has batch.....") # try: itrs = self.similar_pairs.yield_batch() - itr = itrs.next_epoch_itr(shuffle=True, fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus) - itr = GroupedIterator(itr, self.update_freq[-1], skip_remainder_batch=cfg.optimization.skip_remainder_batch) - if cfg.common.tpu: + itr = itrs.next_epoch_itr(shuffle=True, fix_batches_to_gpus=self.cfg.distributed_training.fix_batches_to_gpus) + itr = GroupedIterator(itr, self.update_freq[-1], skip_remainder_batch=self.cfg.optimization.skip_remainder_batch) + if self.cfg.common.tpu: itr = utils.tpu_data_loader(itr) - - # self.progress = progress_bar.build_progress_bar( - # self.cfg, itr, epoch, no_progress_bar='simple', - # ) epoch=epoch_itr.epoch, self.progress = progress_bar.progress_bar( itr, log_format=self.cfg.common.log_format, @@ -1691,9 +1908,9 @@ def train(self, epoch, last=False): else False ), ) - self.progress.update_config(_flatten_config(self.cfg) - logger.info("Start iterating over samples") - for i, samples in enumerate(progress): + self.progress.update_config(_flatten_config(self.cfg)) + # logger.info(f"Start iterating over samples") + for i, samples in enumerate(self.progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i ): @@ -1709,20 +1926,14 @@ def train(self, epoch, last=False): # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters('train_inner') - # end_of_epoch = not itr.has_next() - # if log_output is None: - # continue - # log mid-epoch stats - # stats = get_training_stats(metrics.get_smoothed_values('train_inner')) - # self.progress.print(stats, tag='train_inner', step=num_updates) - # self.progress.log(stats, tag='train_inner', step=num_updates) - # metrics.reset_meters('train_inner') + # end_of_epoch = not itr.has_next() + # is_epoch_end = end_of_epoch else: # numberofex = self.similar_pairs.get_num_examples() itrs = self.similar_pairs.yield_batch() - itr = itrs.next_epoch_itr(shuffle=True, fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus) - itr = GroupedIterator(itr, self.update_freq[-1], skip_remainder_batch=cfg.optimization.skip_remainder_batch) - if cfg.common.tpu: + itr = itrs.next_epoch_itr(shuffle=True, fix_batches_to_gpus=self.cfg.distributed_training.fix_batches_to_gpus) + itr = GroupedIterator(itr, self.update_freq[-1], skip_remainder_batch=self.cfg.optimization.skip_remainder_batch) + if self.cfg.common.tpu: itr = utils.tpu_data_loader(itr) self.progress = progress_bar.progress_bar( itr, @@ -1761,9 +1972,9 @@ def train(self, epoch, last=False): else False ), ) - self.progress.update_config(_flatten_config(self.cfg) + self.progress.update_config(_flatten_config(self.cfg)) logger.info("Start iterating over samples") - for i, samples in enumerate(progress): + for i, samples in enumerate(self.progress): with metrics.aggregate('train_inner'): log_output = self.trainer.train_step(samples) num_updates = self.trainer.get_num_updates() @@ -1774,8 +1985,9 @@ def train(self, epoch, last=False): self.progress.print(stats, tag='train_inner', step=num_updates) self.progress.log(stats, tag='train_inner', step=num_updates) metrics.reset_meters('train_inner') - end_of_epoch = not itr.has_next() - return end_of_epoch + # end_of_epoch = not itr.has_next() + # is_epoch_end = end_of_epoch + # return is_epoch_end #end_of_epoch def validate(self, epoch, subsets): @@ -1794,8 +2006,9 @@ def validate(self, epoch, subsets): itr = self.trainer.get_valid_iterator(subset).next_epoch_itr( shuffle=False, set_dataset_epoch=False # use a fixed valid set ) - if cfg.common.tpu: + if self.cfg.common.tpu: itr = utils.tpu_data_loader(itr) + # print(f"self.cfg.distributed_training: {self.cfg.distributed_training}") progress = progress_bar.progress_bar( itr, log_format=self.cfg.common.log_format, @@ -1842,82 +2055,28 @@ def validate(self, epoch, subsets): # log validation stats # only tracking the best metric on the 1st validation subset - tracking_best = subset_idx == 0 + # logger.info(f"subset_idx: {subset_idx}") + tracking_best = True + #subset_idx == 0 stats = get_valid_stats(self.cfg, self.trainer, agg.get_smoothed_values(), tracking_best) - if hasattr(task, "post_validate"): - task.post_validate(self.trainer.get_model(), stats, agg) + if hasattr(self.task, "post_validate"): + # logger.info(f"post_validate = True") + self.task.post_validate(self.trainer.get_model(), stats, agg) + # logger.info(f"stats in {subset} subset: {stats}") progress.print(stats, tag=subset, step=self.trainer.get_num_updates()) + # logger.info(f"self.cfg.checkpoint.best_checkpoint_metric: {self.cfg.checkpoint.best_checkpoint_metric}") + # logger.info(f"stats[self.cfg.checkpoint.best_checkpoint_metric]: {stats[self.cfg.checkpoint.best_checkpoint_metric]}") + valid_losses.append(stats[self.cfg.checkpoint.best_checkpoint_metric]) return valid_losses - # if self.args.fixed_validation_seed is not None: - # # set fixed seed for every validation - # utils.set_torch_seed(self.args.fixed_validation_seed) - - # valid_losses = [] - # for subset in subsets: - # # print(f"subset: {subset}") - # # Initialize data iterator - # itr = self.task.get_batch_iterator( - # dataset=self.task.dataset(subset), - # max_tokens=self.args.max_tokens_valid, - # max_sentences=self.args.max_sentences_valid, - # max_positions=utils.resolve_max_positions( - # self.task.max_positions(), - # self.trainer.get_model().max_positions(), - # ), - # ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test, - # required_batch_size_multiple=self.args.required_batch_size_multiple, - # seed=self.args.seed, - # num_shards=self.args.distributed_world_size, - # shard_id=self.args.distributed_rank, - # num_workers=self.args.num_workers, - # ).next_epoch_itr(shuffle=False) - # progress = progress_bar.build_progress_bar( - # self.args, itr, epoch, - # prefix='valid on \'{}\' subset'.format(subset), - # no_progress_bar='simple' - # ) - - # # create a new root metrics aggregator so validation metrics - # # don't pollute other aggregators (e.g., train meters) - # with metrics.aggregate(new_root=True) as agg: - # for sample in progress: - # # print(f"sample: {sample}") - # self.trainer.valid_step(sample) - - # # log validation stats - # stats = get_valid_stats(self.args, self.trainer, agg.get_smoothed_values()) - # progress.print(stats, tag=subset, step=self.trainer.get_num_updates()) - - # # print(f"self.args.best_checkpoint_metric: {self.args.best_checkpoint_metric}") - - # valid_losses.append(stats[self.args.best_checkpoint_metric]) - # return valid_losses - def save_comp_chkp(self, epoch): dirs = self.save_dir + '/' + self.model_name + '_' + str(epoch) + self.src + "-" + self.tgt + ".pt" self.trainer.save_checkpoint(dirs, {"train_iterator": {"epoch": epoch}}) -# def get_valid_stats(cfg, trainer, stats): -# if 'nll_loss' in stats and 'ppl' not in stats: -# stats['ppl'] = utils.get_perplexity(stats['nll_loss']) -# stats['num_updates'] = trainer.get_num_updates() -# # print(f"stats['num_updates']: {stats['num_updates']}") -# # print(f"hasattr(checkpoint_utils.save_checkpoint, 'best'): {hasattr(checkpoint_utils.save_checkpoint, 'best')}") -# if hasattr(checkpoint_utils.save_checkpoint, 'best'): -# key = 'best_{0}'.format(args.best_checkpoint_metric) -# # print(f"key: {key}") -# # print(f"args.best_checkpoint_metric: {args.best_checkpoint_metric}") -# best_function = max if args.maximize_best_checkpoint_metric else min -# stats[key] = best_function( -# checkpoint_utils.save_checkpoint.best, -# stats[args.best_checkpoint_metric], -# ) -# return stats def get_valid_stats( cfg: DictConfig, diff --git a/fairseq_cli/Comparable_unsup.py b/fairseq_cli/Comparable_unsup.py new file mode 100644 index 0000000000..67f8ddb0fc --- /dev/null +++ b/fairseq_cli/Comparable_unsup.py @@ -0,0 +1,2203 @@ +""" +Classes and methods used for training and extraction of parallel pairs +from a comparable dataset. +""" +import tracemalloc +#import gc +import re +import itertools +import random +import faiss +import faiss.contrib.torch_utils +from pathlib import Path +import numpy as np +from collections import OrderedDict, defaultdict +import torch +import pandas as pd +import time +from tqdm import tqdm +from torch.utils.data import Dataset +from fairseq.data import ( + MonolingualDataset, + LanguagePairDataset, + BacktranslationDataset, + ConcatDataset, + RoundRobinZipDatasets +) +from fairseq.data.data_utils import load_indexed_dataset,numpy_seed,batch_by_size,filter_by_size +from fairseq.data.iterators import EpochBatchIterator, GroupedIterator +from fairseq import ( + checkpoint_utils, utils +) +from fairseq.logging import meters, metrics, progress_bar +from omegaconf import DictConfig, OmegaConf +import argparse +import os, sys +from typing import Any, Callable, Dict, List, Optional, Tuple +import logging +from fairseq.trainer import Trainer +from fairseq.sequence_generator import SequenceGenerator +from fairseq.distributed import utils as distributed_utils +from itertools import cycle +torch.manual_seed(10) + +# We need to setup root logger before importing any fairseq libraries. +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.comparable") + +def get_src_len(src, use_gpu, device=""): + if use_gpu: + if device=="mps": + return torch.tensor([src.size(0)], device="mps") #.cuda() + else: + return torch.tensor([src.size(0)]).cuda() + else: + return torch.tensor([src.size(0)]) + +#this method is to remove spaces added within strings when dict.string is used. +#it removed remove spaces between characters and consecutive spaces +def removeSpaces(s): + k = re.sub(' (?! )',"",s) + k = re.sub(' +', ' ', k) + return k + +# get noun phrases with tregex using stanza +def noun_phrases(_client, _text): + pattern = "NP" + matches = _client.tregex(_text, pattern) + s = "\n".join(["\t" + sentence[match_id]['spanString'] for sentence in matches['sentences'] for match_id in sentence]) + phrases = [x.strip() for x in s.split('\n\t')] + return phrases + +def extract_phrase(tree_str, label): + phrases = [] + trees = Tree.fromstring(tree_str) + for tree in trees: + for subtree in tree.subtrees(): + if subtree.label() == label: + t = subtree + t = ' '.join(t.leaves()) + phrases.append(t) + + return phrases + +def read_vocabulary(vocab_file, threshold=20): + """read vocabulary file produced by get_vocab.py, and filter according to frequency threshold. + """ + vocabulary = set() + + for line in vocab_file: + word, freq = line.strip('\r\n ').split(' ') + freq = int(freq) + if threshold == None or freq >= threshold: + vocabulary.add(word) + + return vocabulary + + + def convert2string(self, side): + lstString = [] + if side == 'src': + lstString = [removeSpaces(' '.join(self.tasks.src_dict.string(x[1], bpe_symbol='@@ '))).replace("",'').strip() for x in self.sourcesent] + elif side == 'tgt': + lstString = [removeSpaces(' '.join(self.tasks.tgt_dict.string(x[1], bpe_symbol='@@ '))).replace("",'').strip() for x in self.targetsent] + self.resetData() + return lstString + + def resetSource(self): + self.sourcesent = set() + + def resetTarget(self): + self.targetsent = set() + + def setparsers(self, nlp_src, nlp_tgt): + self.nlp_src = nlp_src + self.nlp_tgt = nlp_tgt + + def setclients(self, nlp_src, nlp_tgt): + self.client_src = nlp_src + self.client_tgt = nlp_tgt + + def bpe(self, src, tgt): + self.srcbpe = src + self.tgtbpe = tgt + + def setLang(self, s, t): + self.s = s + self.t = t + + def resetData(self): + self.sourcesent = set() + self.targetsent = set() + + +class torchDataset(Dataset): + def __init__(self, data_list): + self.data_list = data_list + + def __getitem__(self, index): + return self.data_list[index] + + def __len__(self): + return len(self.data_list) + + +class PairBank(): + """ + Class that saves and prepares parallel pairs and their resulting + batches. + Args: + batch_size(int): number of examples in a batch + opt(argparse.Namespace): option object + """ + + def __init__(self, batcher, cfg): + self.pairs = [] + self.index_memory = set() + self.batch_size = cfg.dataset.batch_size + self.sizes = [] + self.srcs = [] + self.tgts = [] + self.src_lens = [] + self.tgt_lens = [] + #max_sentences + self.batcher = batcher + self.use_gpu = False + self.mps = False + self.cuda = False + if cfg.common.cpu == False: + self.use_gpu = True + if torch.backends.mps.is_available(): + self.mps = True + self.mps_device = torch.device("mps") + else: + self.cuda = True + else: + self.use_gpu = False + self.update_freq = cfg.optimization.update_freq + self.explen = self.batch_size * self.update_freq[-1] + + def __len__(self): + return len(self.pairs) + + def removePadding(side): + """ Removes original padding from a sequence. + Args: + side(torch.Tensor): src/tgt sequence (size(seq)) + Returns: + side(torch.Tensor): src/tgt sequence without padding + NOTE: This only works as long as PAD_ID==1! + """ + # Get indexes of paddings in sequence + padding_idx = (side == 1).nonzero() + # If there is any padding, cut sequence from first occurence of a pad + if padding_idx.size(0) != 0: + first_pad = padding_idx.data.tolist()[0][0] + side = side[:first_pad] + return side + + def add_example(self, src, tgt): + """ Add an example from a batch to the PairBank (self.pairs). + Args: + src(torch.Tensor): src sequence (size(seq)) + tgt(torch.Tensor): tgt sequence(size(tgt)) + fields(list(str)): list of keys of fields + """ + # Get example from src/tgt and remove original padding + src = PairBank.removePadding(src) + tgt = PairBank.removePadding(tgt) + self.srcs.append(src) + self.tgts.append(tgt) + self.src_lens.append(src.size(0)) + self.tgt_lens.append(tgt.size(0)) + return None + + def contains_batch(self): + """Check if enough parallel pairs found to create a batch. + """ + return (len(self.pairs) >= self.explen) + + def no_limit_reached(self, src, tgt): + """ Check if no assigned limit of unique src-tgt pairs is reached. + Args: + src(torch.Tensor): src sequence (size(seq)) + tgt(torch.Tensor): tgt sequence(size(tgt)) + """ + # src = PairBank.removePadding(src) + # tgt = PairBank.removePadding(tgt) + return (hash((str(src), str(tgt))) in self.index_memory or len(self.index_memory) < self.limit) + + def get_num_examples(self): + """Returns batch size if no maximum number of extracted parallel data + used for training is met. Otherwise returns number of examples that can be yielded + without exceeding that maximum. + """ + if len(self.pairs) < self.explen: + return len(self.pairs) + return self.explen + + def yield_batch(self): + """ Prepare and yield a new batch from self.pairs. + Returns: + batch(fairseq.data.LanguagePairDataset): batch of extracted parallel data + """ + src_examples = [] + tgt_examples = [] + src_lengths = [] + tgt_lengths = [] + indices = [] + num_examples = self.get_num_examples() + + # Get as many examples as needed to fill a batch or a given limit + random.shuffle(self.pairs) + for ex in range(num_examples): + example = self.pairs.pop() # removes the pair from the list!! + src_len = example.src_length.item() + tgt_len = example.tgt_length.item() + # print(f"example.src_length: {src_len}") + src_examples.append(example.src) + tgt_examples.append(example.tgt) + src_lengths.append(src_len) # example.src_length + tgt_lengths.append(tgt_len) # example.tgt_length + indices.append(example.index) + + dataset = None + # fields = CompExample.get_fields() + batch = self.batcher.create_batch(src_examples, tgt_examples, src_lengths, tgt_lengths) + # enumerate to yield batch here + return batch + + +class CompExample(): + """ + Class that stores the information of one parallel data example. + Args: + dataset(fairseq.data): dataset object + src(torch.Tensor): src sequence (size(seq)) + tgt(torch.Tensor): tgt sequence (size(seq)) + src_length(torch.Tensor): the length of the src sequence (size([])) + index(torch.Tensor): the index of the example in dataset + """ + # These should be the same for all examples (else: consistency problem) + _dataset = None + + def __init__(self, dataset, src, tgt, src_length, tgt_length, index): + self.src = src + self.tgt = tgt + self.src_length = src_length + self.tgt_length = tgt_length + self.index = index + + if CompExample._dataset == None: + CompExample._dataset = dataset + + def to_dict(self): + return { + 'index': self.index, + 'src': self.src, + 'tgt': self.tgt, + 'src_length': self.src_length, + 'tgt_length': self.tgt_length + } + + +class BatchCreator(): + def __init__(self, task, cfg): + self.task = task + self.cfg = cfg + + def create_batch(self, src_examples, tgt_examples, src_lengths, tgt_lengths, no_target=False): + """ Creates a batch object from previously extracted parallel data. + Args: + src_examples(list): list of src sequence tensors + tgt_examples(list): list of tgt sequence tensors + src_lenths(list): list of the lengths of each src sequence + tgt_lenths(list): list of the lengths of each tgt sequence + indices(list): list of indices of example instances in dataset + dataset(fairseq.data): dataset object + Returns: + batch(fairseq.data.LanguagePairDataset): batch object + """ + # print(f"src_lengths type: {type(src_lengths)}") + # src_lengths = src_lengths.detach().cpu().numpy() + # tgt_lengths = tgt_lengths.detach().cpu().numpy() + pairData = LanguagePairDataset( + src_examples, src_lengths, self.task.src_dict, + tgt_examples, tgt_lengths, self.task.tgt_dict, + left_pad_source=self.cfg.task.left_pad_source, + left_pad_target=self.cfg.task.left_pad_target, + ) + # max_source_positions=self.cfg.task.max_source_positions, + # max_target_positions=self.cfg.task.max_target_positions, + + with numpy_seed(self.cfg.common.seed): + indices = pairData.ordered_indices() + + batch_sampler = batch_by_size(indices, pairData.num_tokens, + max_sentences=self.cfg.comparable.max_sentences, required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, ) + itrs = EpochBatchIterator(dataset=pairData, collate_fn=pairData.collater, + batch_sampler=batch_sampler, seed=self.cfg.common.seed, epoch=0, num_workers=self.cfg.dataset.num_workers) + indices = None + return itrs + + +def knn_v2(x, y, k, use_gpu, index='flat'): + ''' + small query batch, small index: CPU is typically faster + small query batch, large index: GPU is typically faster + large query batch, small index: could go either way + large query batch, large index: GPU is typically faster + string_factory = "L2norm,OPQ16_64,IVF30000_HNSW32,PQ16" # Flat + ''' + return knnGPU_v2(x, y, k, index) if use_gpu else knnCPU(x, y, k, index) + + +def knnGPU_v2(x, y, k, index='flat', faiss_verbose=True, batch_size=100000, train_size=1170000, use_float16=True): + ngpus = faiss.get_num_gpus() + dim = 512 #x.shape[1] + string_factory = "OPQ16_64,IVF30000,PQ16" + logger.info(f"index string_factory: {string_factory}") + #"OPQ32,IMI2x8,PQ32" "OPQ32,IVF256,PQ32" + #"PCA64,IVF30000,Flat" + #"IVF200,Flat" # -- works for < 1M indices + #"OPQ16_64,IVF30000,PQ16" # OPQ is a CPU-based vector transform + # IVF30000,PQ16 + idx = faiss.index_factory(dim, string_factory, faiss.METRIC_INNER_PRODUCT) + # https://www.pinecone.io/learn/composite-indexes/ + ivf = faiss.extract_index_ivf(idx) + res = faiss.StandardGpuResources() + res.setDefaultNullStreamAllDevices() + co = faiss.GpuMultipleClonerOptions() + co.shard = True + co.verbose = True + co.indicesOptions = faiss.INDICES_CPU + co.common_ivf_quantizer = False + co.usePrecomputed = False + co.useFloat16 = use_float16 + co.useFloat16CoarseQuantizer = True + if ngpus == 1: + gpu_idx = faiss.index_cpu_to_gpu(res, 0, idx) + else: # multiple gpus + res_list = [res for _ in range(ngpus)] + gpu_idx = faiss.index_cpu_to_gpu_multiple_py(resources=res_list, index=idx, co=co) + # https://github.com/KevinMusgrave/pytorch-metric-learning/issues/491 + # inputs to gpu_idx have to be on cpu and inside the function, they would be moved back to gpu! + + # logger.info("Created faiss index of type {}".format(type(gpu_idx))) + faiss_verbose = None + + # Set verbosity level + if faiss_verbose is not None: + if hasattr(gpu_idx, "index") and gpu_idx.index is not None: + gpu_idx.index.verbose = faiss_verbose + if hasattr(gpu_idx, "quantizer") and gpu_idx.quantizer is not None: + gpu_idx.quantizer.verbose = faiss_verbose + if hasattr(gpu_idx, "clustering_index") and gpu_idx.clustering_index is not None: + gpu_idx.clustering_index.verbose = faiss_verbose + + # Train + # logger.info("training index") + + ys, xs = [], [] + + for i in tqdm(range(0, len(y), batch_size)): + yb = torch.stack(list(y[i : i + batch_size]), dim=0) + yb = yb.type(torch.float32).cpu() # convert to float32 to move to cpu! + yb = torch.nn.functional.normalize(yb, p=2, dim=1) + # logger.info(f"yb.size: {yb.size()}") + ys.append(yb) + + y = torch.cat(ys, dim=0) + + train_vecs = y + if train_size is not None: + # train_vecs = y[:train_size].cpu() + with torch.no_grad(): + indices = torch.tensor(random.sample(range(y.size()[0]), train_size)) + indices = torch.tensor(indices) + train_vecs = train_vecs[indices] + logger.info("Training the index with {} randomly-sampled vectors".format(len(train_vecs))) + gpu_idx.train(train_vecs) + + # Add vectors + logger.info("Adding {} vectors to the faiss index".format(len(y))) + gpu_batch_size=200000 + for i in tqdm(range(0, len(y), gpu_batch_size)): + vecs = y[i : i + batch_size] + vecs = vecs.type(torch.float32).cpu() # convert to float32 to move to cpu! + gpu_idx.add(vecs) + + # send batched queries to the full faiss index + # size of sim and inds should be equal to len(y) + logger.info("Querying {} vectors to the faiss index".format(len(x))) + sim, ind = [], [] + faiss.GpuParameterSpace().set_index_parameter(ivf, "nprobe", 50) + # faiss.ParameterSpace().set_index_parameter(ivf, "nprobe", 50) + for i in tqdm(range(0, len(x), gpu_batch_size)): + xb = torch.stack(list(x[i : i + gpu_batch_size]), dim=0) + xb = xb.type(torch.float32).cpu() # convert to float32 to move to cpu! + xb = torch.nn.functional.normalize(xb, p=2, dim=1) + xs.append(xb) + bsim, bind = gpu_idx.search(xb, k) # x[i : i + batch_size].cpu() + sim.append(bsim) + # print(f"len(sim): {len(sim)}") + ind.append(bind) + + # logger.info(f"concat results..") + # logger.info(f"sim[0].size(): {sim[0].size()}") + rsim = torch.cat(sim, dim=0).cpu() # along the rows + rind = torch.cat(ind, dim=0).cpu() + x = torch.cat(xs, dim=0) + # logger.info(f"similarity results size: {rsim.size()}") + + # logger.info(f"type(rsim): {type(rsim)} type(rind): {type(rsim)}") + + return rsim, rind, x, y + + +def knn(x, y, k, use_gpu, index='flat'): + ''' + small query batch, small index: CPU is typically faster + small query batch, large index: GPU is typically faster + large query batch, small index: could go either way + large query batch, large index: GPU is typically faster + ''' + return knnGPU(x, y, k, index) if use_gpu else knnCPU(x, y, k, index) + +def knnCPU(x, y, k, index='flat'): + start=time.time() + dim = x.shape[1] + m = 8 # number of centroid IDs in final compressed vectors + bits = 8 # number of bits in each centroid + nlist = 100 # how many cells + if index == 'ivf': + # quantizer = faiss.IndexFlatIP(dim) + # idx = faiss.IndexIVFFlat(quantizer, dim, nlist) + idx = faiss.index_factory(dim, "IVF200,Flat", faiss.METRIC_INNER_PRODUCT) + idx.train(y) + # print(f"idx.is_trained: {idx.is_trained}") 40000 + elif index == 'hnsw': + idx = faiss.index_factory(dim, "IVF40000_HNSW32,Flat", faiss.METRIC_INNER_PRODUCT) + idx.train(y) + elif index =='pq': + # quantizer = faiss.IndexFlatIP(dim) + # idx = faiss.IndexIVFPQ(quantizer, dim, nlist, m, bits) + idx = faiss.index_factory(dim, "IVF200,PQ16", faiss.METRIC_INNER_PRODUCT) + idx.train(y) + else: + idx = faiss.IndexFlatIP(dim) + + # print(f"num embeddings indexed: {idx.ntotal}") + idx.add(y) + sim, ind = idx.search(x, k) + # print(f"sim[:3]: {sim[:3]}") + # print(f"ind: {ind}") + # print(f"time taken to build the index: {time.time()-start} secs") + return sim, ind, x, y + +def knnGPU(x, y, k, index='flat', mem=8*1024*1024*1024): + # d = srcRep.shape[1] + # print(f"d: {d}") + + # 1. take a query vector xq 2. identify the cell it belongs to + # 3. use IndexFlat2 to search btw query vector & all other vectors + # belonging to that specific cell + # ''' + # PQ = Product Quantization. IVF reduces the scope of our search, PQ approximates + # distance/similarity calculation. + # 1. split OG vector into several subvectors. + # 2. for each set of subvector, perform a clustering operation - creating multiple centroids + # for each sub-vector set. + # 3. In the vector of sub-vecs, replace each sub-vec with the ID of its nearest set-specific centroid + # ''' + # https://github.com/facebookresearch/LASER/blob/main/source/mine_bitexts.py + print(faiss.get_num_gpus()) + dim = x.shape[1] + m = 8 # number of centroid IDs in final compressed vectors + bits = 8 # number of bits in each centroid + nlist = 100 # how many cells + res = faiss.StandardGpuResources() + co = faiss.GpuClonerOptions() + batch_size = mem // (dim*4) + print(f"batch_size: {batch_size}") + if batch_size > x.shape[0]: + batch_size = x.shape[0] // 5 + print(f"batch_size: {batch_size}") + + sim = np.zeros((x.shape[0], k), dtype=np.float32) + ind = np.zeros((x.shape[0], k), dtype=np.int64) + for xfrom in range(0, x.shape[0], batch_size): + xto = min(xfrom + batch_size, x.shape[0]) # to_src_ind + bsims, binds = [], [] + for yfrom in range(0, y.shape[0], batch_size): + yto = min(yfrom + batch_size, y.shape[0]) # to_trg_ind + # print('{}-{} -> {}-{}'.format(xfrom, xto, yfrom, yto)) + if index == 'ivf': # below 1M vectors + idx = faiss.index_factory(dim, "IVF1200,Flat", faiss.METRIC_INNER_PRODUCT) + idx = faiss.index_cpu_to_gpu(res, 0, idx, co) + idx.train(y) + elif index =='pq': + # quantizer = faiss.IndexFlatIP(dim) + # idx = faiss.IndexIVFPQ(quantizer, dim, nlist, m, bits) + idx = faiss.index_factory(dim, "IVF1200,PQ16", faiss.METRIC_INNER_PRODUCT) + idx = faiss.index_cpu_to_gpu(res, 0, idx, co) + idx.train(y) + elif index == 'hnsw': # for 1M-10M vectors + idx = faiss.index_factory(dim, "IVF40000_HNSW32,Flat", faiss.METRIC_INNER_PRODUCT) + idx_ivf = faiss.extract_index_ivf(idx) + clustering_index = faiss.index_cpu_to_all_gpus(res, 0, faiss.IndexFlatIP(idx_ivf.d), co) + idx_ivf.clustering_index = clustering_index + idx.train(y) + else: + idx = faiss.IndexFlatIP(dim) + idx = faiss.index_cpu_to_all_gpus(idx) + # quantizer = faiss.IndexFlatL2(d) + # idx = faiss.IndexIVFFlat(quantizer, d, nlist) + # #idx = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits) + # idx.train(srcRep) + # print(f"idx.is_trained: {idx.is_trained}") + # idx.add(srcRep) + # print(f"num embeddings indexed: {idx.ntotal}") + + # idx.nprobe = 1 # to increase the search scope + # # large nprobe values = slower but more accurate search + + idx.add(y[yfrom:yto]) # added trg_batch = batch_size to the index + bsim, bind = idx.search(x[xfrom:xto], min(k, yto-yfrom)) # find k nearest neighbours for the batched queries + bsims.append(bsim) + binds.append(bind + yfrom) + del idx + bsims = np.concatenate(bsims, axis=1) + binds = np.concatenate(binds, axis=1) + aux = np.argsort(-bsims, axis=1) + for i in range(xfrom, xto): + for j in range(k): + sim[i, j] = bsims[i-xfrom, aux[i-xfrom, j]] + ind[i, j] = binds[i-xfrom, aux[i-xfrom, j]] + return sim, ind + +def score(x, y, fwd_mean, bwd_mean, margin): + return margin(x.dot(y), (fwd_mean + bwd_mean) / 2) + +def score_candidates(x, y, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False): + if verbose: + logger.info(' - scoring {:d} candidates'.format(x.shape[0])) + scores = np.zeros(candidate_inds.shape) + for i in range(scores.shape[0]): + for j in range(scores.shape[1]): + k = candidate_inds[i, j] + scores[i, j] = score(x[i], y[k], fwd_mean[i], bwd_mean[k], margin) + # print(f"x[i]: {x[i]}, y[k]: {y[k]} fwd_mean[i]: {fwd_mean[i]}, bwd_mean[k]: {bwd_mean[k]}") + # print(f"scores[i, j] : {scores[i, j]}") + return scores + + +def _flatten_config(cfg: DictConfig): + config = OmegaConf.to_container(cfg) + # remove any legacy Namespaces and replace with a single "args" + namespace = None + for k, v in list(config.items()): + if isinstance(v, argparse.Namespace): + namespace = v + del config[k] + if namespace is not None: + config["args"] = vars(namespace) + return config + + +class Comparable(): + """ + Class that controls the extraction of parallel sentences and manages their + storage and training. + Args: + model(:py:class:'fairseq.models'): + translation model used for extraction and training + trainer(:obj:'fairseq.trainer'): + trainer that controlls the training process + fields(dict): fields and vocabulary + logger(logging.RootLogger): + logger that reports information about extraction and training + opt(argparse.Namespace): option object + """ + + def __init__(self, model, trainer, task, cfg): + self.sim_measure = cfg.comparable.sim_measure + self.threshold = cfg.comparable.threshold + self.use_threshold = cfg.comparable.use_threshold + self.model_name = cfg.comparable.model_name + self.save_dir = cfg.comparable.save_dir + self.use_phrase = cfg.comparable.use_phrase + #self.model = trainer.get_model().encoder + self.model = model + self.usepos = cfg.comparable.usepos + # print("Use positional encoding = ", self.usepos) + self.trainer = trainer + # print(f"self.trainer: {self.trainer}") + self.task = self.trainer.task + self.encoder = self.trainer.get_model().encoder + # print(f"self.encoder: {self.encoder}") + self.batch_size = cfg.dataset.batch_size + # cfg.comparable.max_sentences + self.batcher = BatchCreator(task, cfg) + self.similar_pairs = PairBank(self.batcher, cfg) + self.unsup_itr = None + self.accepted = 0 + self.accepted_limit = 0 + self.declined = 0 + self.total = 0 + self.cfg = cfg + self.comp_log = cfg.comparable.comp_log + self.cove_type = cfg.comparable.cove_type + self.update_freq = cfg.optimization.update_freq + self.k = cfg.comparable.k #20 #cfg.comparable.k + self.trainstep = 0 + self.second = cfg.comparable.second + self.representations = cfg.comparable.representations + # self.task = task + self.write_dual = cfg.comparable.write_dual + self.no_swaps = cfg.comparable.no_swaps + self.symmetric = cfg.comparable.symmetric + self.add_noise = cfg.comparable.add_noise + self.use_bt = cfg.comparable.use_bt + self.stats = None + self.progress = None + self.src, self.tgt = "tr", "og" #args.source_lang, args.target_lang + self.use_gpu = False + self.mps = False + self.cuda = False + self.mps_device = None + self.log_interval = cfg.common.log_interval #5 + self.margin = cfg.comparable.margin + self.verbose = cfg.comparable.verbose + self.mode = cfg.comparable.mode + self.faiss = cfg.comparable.faiss + self.retrieval = cfg.comparable.retrieval + self.faiss_use_gpu = cfg.comparable.faiss_use_gpu + self.faiss_output = cfg.comparable.faiss_output + self.index=cfg.comparable.index + self.only_unsupervised = cfg.comparable.only_unsupervised + Path(self.comp_log).mkdir(parents=True, exist_ok=True) + # print(f"args.cpu: {args.cpu}") + if cfg.common.cpu == False: + self.use_gpu = True + if torch.backends.mps.is_available(): + self.mps = True + self.mps_device = torch.device("mps") + self.div = 2 * torch.tensor(self.k).to(self.mps_device) #.cuda() + else: + self.div = 2 * torch.tensor(self.k).cuda() + self.cuda = True + else: + self.use_gpu = False + self.div = 2 * torch.tensor(self.k) #, device="mps") #.cuda() + + def getstring(self, vec, dict): + words = dict.string(vec) + return removeSpaces(' '.join(words)) + + def write_sentence(self, src, tgt, status, score=None): + """ + Writes an accepted parallel sentence candidate pair to a file. + Args: + src(torch.tensor): src sequence (size(seq)) + tgt(torch.tensor): tgt sequence (size(seq)) + status(str): ['accepted', 'accepted-limit', 'rejected'] + score(float): score of the sentence pair + """ + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\tstatus: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), score, status) + if 'accepted' in status: + self.accepted_file.write(out) + # print(out) + elif 'phrase' in status: + self.accepted_phrase.write(out) + elif status == 'embed_only': + with open(self.embed_file, 'a', encoding='utf8') as f: + f.write(out) + elif status == 'hidden_only': + with open(self.hidden_file, 'a', encoding='utf8') as f: + f.write(out) + return None + + def extract_parallel_sents(self, candidates, candidate_pool, phrasese=False, use_threshold=False): + """ + Extracts parallel sentences from candidates and adds them to the + PairBank (secondary filter). + Args: + candidates(list): list of src, tgt pairs (C_h) # memory reps + candidates(list(tuple(torch.Tensor...)): list of src-tgt candidates + candidate_pool(list(hash)): list of hashed C_e candidates + """ + # print("extract parallel") + for candidate in candidates: + # candidate_pair = hash((str(candidate[0]), str(candidate[1]))) + # For dual representation systems... + # print("Dual representation checking") + if candidate_pool: + # ...skip C_h pairs not in C_e (secondary filter) + if self.in_candidate_pool(candidate, candidate_pool) == False: + self.declined += 1 + self.total += 1 + if self.write_dual: + self.write_sentence(candidate[0], candidate[1], + 'hidden_only', candidate[2]) + continue + + elif self.in_candidate_pool(candidate, candidate_pool) and not use_threshold: + src = candidate[0] + tgt = candidate[1] + score = candidate[2] + + self.similar_pairs.add_example(src, tgt) + self.write_sentence(removePadding(src), removePadding(tgt), 'accepted', score) + self.accepted += 1 + if self.symmetric: + self.similar_pairs.add_example(tgt, src) + self.write_sentence(tgt, src, 'accepted', score) + self.total += 1 + + elif use_threshold or self.in_candidate_pool(candidate, candidate_pool): + # Apply threshold (single-representation systems only) + src = candidate[0] + tgt = candidate[1] + score = candidate[2] + + if score >= self.threshold: + # print("Score is greater than threshold") + # Check if no maximum of allowed unique accepted pairs reached + # if self.similar_pairs.no_limit_reached(src, tgt): + # Add to PairBank + self.similar_pairs.add_example(src, tgt) + self.write_sentence(removePadding(src), removePadding(tgt), 'accepted', score) + self.accepted += 1 + if self.symmetric: + self.similar_pairs.add_example(tgt, src) + self.write_sentence(tgt, src, 'accepted', score) + else: + # print("threshold not met!!!") + self.declined += 1 + self.total += 1 + else: # doesnt match thresold or not in candidate-pool + continue + + return None + + def write_embed_only(self, candidates, cand_embed): + """ Writes C_e scores to file (if --write-dual is set). + Args: + candidates(list): list of src, tgt pairs (C_h) # memory reps + cand_embed(list): list of src, tgt pairs (C_e) # embed reps + """ + candidate_pool = set([hash((str(c[0]), str(c[1]))) for c in candidates]) + + for candidate in cand_embed: + candidate_pair = hash((str(candidate[0]), str(candidate[1]))) + # Write statistics only if C_e pair not in C_h + if candidate_pair not in candidate_pool: + src = candidate[0] + tgt = candidate[1] + score = candidate[2] + self.write_sentence(src, tgt, 'embed_only', score) + + + def faiss_sent_scoring(self, src_sents, tgt_sents): + """ Score source and target combinations. + Args: + src_sents(list(tuple(torch.Tensor...))): + list of src sentences in their sequential and semantic representation + tgt_sents(list(tuple(torch.Tensor...))): list of tgt sentences + Returns: + src2tgt(dict(dict(float))): dictionary mapping a src to a tgt and their score + tgt2src(dict(dict(float))): dictionary mapping a tgt to a src and their score + similarities(list(float)): list of cosine similarities + scores(list(float)): list of scores + """ + start = time.time() + + srcSent, srcRep = zip(*src_sents) + # print(f"srcSent: {srcSent}") + tgtSent, tgtRep = zip(*tgt_sents) + # print(f"tgtSent: {tgtSent}") + + print("faiss sent scoring") + + if self.faiss_use_gpu: + # https://github.com/facebookresearch/faiss/wiki/Faiss-on-the-GPU + ngpus = faiss.get_num_gpus() + logger.info(f"number of GPUs: {ngpus}") + + # srcSent2ind = {sent:i for i, sent in enumerate(srcSent)} + # tgtSent2ind = {sent:i for i, sent in enumerate(tgtSent)} + + x= np.asarray([rep.detach().cpu().numpy() for rep in srcRep]) + y= np.asarray([rep.detach().cpu().numpy() for rep in tgtRep]) + + print(f"normalising x.dtype : {x.dtype}") + faiss.normalize_L2(x) + faiss.normalize_L2(y) + + logger.info("done faiss normalizing") + + candidates = [] + + # torch.from_numpy(a) + + # calculate knn in both directions + if self.retrieval != 'bwd': + if self.verbose: + print(' - perform {:d}-nn source against target'.format(self.k)) + x2y_sim, x2y_ind = knn(x, y, min(y.shape[0], self.k), self.faiss_use_gpu, self.index) + x2y_mean = x2y_sim.mean(axis=1) + # print(f"x2y_sim.shape: {x2y_sim.shape}") + # print(f"x2y_ind.shape: {x2y_ind.shape}") + + if self.retrieval != 'fwd': + if self.verbose: + print(' - perform {:d}-nn target against source'.format(self.k)) + y2x_sim, y2x_ind = knn(y, x, min(x.shape[0], self.k), self.faiss_use_gpu, self.index) + y2x_mean = y2x_sim.mean(axis=1) + + # margin function + if self.margin == 'absolute': + margin = lambda a, b: a + elif self.margin == 'distance': + margin = lambda a, b: a - b + else: # args.margin == 'ratio': + margin = lambda a, b: a / b + + # print(f"margin: {margin}") + + fout = open(self.faiss_output, mode='w', encoding='utf8', errors='surrogateescape') + + src_inds=list(range(len(srcSent))) + trg_inds=list(range(len(tgtSent))) + + if self.mode == 'search': + if self.verbose: + print(' - Searching for closest sentences in target') + print(' - writing alignments to {:s}'.format(self.faiss_output)) + scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin, self.verbose) + best = x2y_ind[np.arange(x.shape[0]), scores.argmax(axis=1)] + + print(f"best: {best}") + + nbex = x.shape[0] + ref = np.linspace(0, nbex-1, nbex).astype(int) # [0, nbex) + err = nbex - np.equal(best.reshape(nbex), ref).astype(int).sum() + print(' - errors: {:d}={:.2f}%'.format(err, 100*err/nbex)) + for i in src_inds: + print(tgtSent[best[i]], file=fout) + + elif self.mode == 'score': + for i, j in zip(src_inds, trg_inds): + s = score(x[i], y[j], x2y_mean[i], y2x_mean[j], margin) + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + + elif self.mode == 'mine': + if self.verbose: + logger.info(' - mining for parallel data') + fwd_scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin, self.verbose) + bwd_scores = score_candidates(y, x, y2x_ind, y2x_mean, x2y_mean, margin, self.verbose) + fwd_best = x2y_ind[np.arange(x.shape[0]), fwd_scores.argmax(axis=1)] + # print(f"fwd_best: {fwd_best}") + bwd_best = y2x_ind[np.arange(y.shape[0]), bwd_scores.argmax(axis=1)] + # print(f"bwd_best: {bwd_best}") + if self.verbose: + logger.info(' - writing alignments to {:s}'.format(self.faiss_output)) + if self.threshold > 0: + logger.info(' - with threshold of {:f}'.format(self.threshold)) + if self.retrieval == 'fwd': + for i, j in enumerate(fwd_best): + s = fwd_scores[i].max() + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(fwd_scores[i].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) + if self.retrieval == 'bwd': + for j, i in enumerate(bwd_best): + s = bwd_scores[j].max() + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(bwd_scores[j].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) + if self.retrieval == 'intersect': + for i, j in enumerate(fwd_best): + if bwd_best[j] == i: + s = fwd_scores[i].max() + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(fwd_scores[i].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) + if self.retrieval == 'max': + indices = np.stack((np.concatenate((np.arange(x.shape[0]), bwd_best)), + np.concatenate((fwd_best, np.arange(y.shape[0])))), axis=1) + scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1))) + seen_src, seen_trg = set(), set() + for i in np.argsort(-scores): + src_ind, trg_ind = indices[i] + if not src_ind in seen_src and not trg_ind in seen_trg: + seen_src.add(src_ind) + seen_trg.add(trg_ind) + if scores[i] > self.threshold: + s = scores[i] + src = srcSent[src_ind] + tgt = tgtSent[trg_ind] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(scores[i], srcSent[src_ind], tgtSent[trg_ind], sep='\t', file=fout) + candidates.append((srcSent[src_ind], tgtSent[trg_ind], scores[i])) + + fout.close() + logger.info(f"time taken by faiss sent scoring: {time.time()-start} seconds.") + logger.info(f"num candidates: {len(candidates)}") + return candidates + + def score_sents(self, src_sents, tgt_sents): + """ Score source and target combinations. + Args: + src_sents(list(tuple(torch.Tensor...))): + list of src sentences in their sequential and semantic representation + tgt_sents(list(tuple(torch.Tensor...))): list of tgt sentences + Returns: + src2tgt(dict(dict(float))): dictionary mapping a src to a tgt and their score + tgt2src(dict(dict(float))): dictionary mapping a tgt to a src and their score + similarities(list(float)): list of cosine similarities + scores(list(float)): list of scores + """ + src2tgt = defaultdict(dict) + tgt2src = defaultdict(dict) + similarities = [] + scores = [] + + srcSent, srcRep = zip(*src_sents) + tgtSent, tgtRep = zip(*tgt_sents) + + #print("At the point of unzipping the list of tuple....................") + #unzip the list ot tiples to have two lists of equal length each sent, repre + + #print("Stacking the representations to cuda....................") + #stack the representation list into a tensor and use that to compute the similarity + if self.mps: + srcRp=torch.stack(srcRep).to(self.mps_device) #.cuda() + tgtRp=torch.stack(tgtRep).to(self.mps_device) #.cuda() + elif self.cuda: + srcRp=torch.stack(srcRep).cuda() + tgtRp=torch.stack(tgtRep).cuda() + else: + srcRp = torch.stack(srcRep) + tgtRp = torch.stack(tgtRep) + + # print(f"tgtRp: {tgtRp}") + # print(f"self.sim_measure: {self.sim_measure}") + + # Return cosine similarity if that is the scoring function + if self.sim_measure == 'cosine': + matx = self.sim_matrix(srcRp, tgtRp) + # print(f"going into double loop") + for i in range(len(srcSent)): + for j in range(len(tgtSent)): + #print(f"i: {i}, j: {j}") + if srcSent[i][0] == tgtSent[j][0]: + continue + src2tgt[srcSent[i]][tgtSent[j]] = matx[i][j].tolist() # for each sent in SRC -> every TGT is assigned a score + tgt2src[tgtSent[j]][srcSent[i]] = matx[i][j].tolist() + # src2tgt = { "dis a src sent": {"dis a tg": 0.2, "dis s a TRG": 0.6, "dis": 0.12} } + similarities.append(matx[i][j].tolist()) + return src2tgt, tgt2src, similarities, similarities + else: + sim_mt, sumDistSource, sumDistTarget = self.sim_matrix(srcRp, tgtRp) + # sim_mt, nearestSrc, nearestTgt = self.sim_matrix(srcRp, tgtRp) + # sumDistSource = torch.sum(nearestSrc, 1).cuda() /self.div + # sumDistTarget = torch.sum(nearestTgt, 0).cuda() /self.div + # print(f"sumDistSource device: {sumDistSource.get_device()}") + # print(f"sim_mt: {sim_mt}") + + # print(f"going into double loop") + for i in range(len(srcSent)): # m + for j in range(len(tgtSent)): # n + #print(f"i: {i}, j: {j}") + if srcSent[i][0] == tgtSent[j][0]: + continue + # assign margin scores + tgt2src[tgtSent[j]][srcSent[i]] = src2tgt[srcSent[i]][tgtSent[j]] = sim_mt[i][j].tolist() / (sumDistSource[i].tolist() + sumDistTarget[j].tolist()) + #tgt2src[tgtSent[j]][srcSent[i]] = sim_mt[i][j].tolist() / (sumDistTarget[j].tolist() + sumDistSource[i].tolist()) + similarities.append(sim_mt[i][j].tolist()) + + # Get list of scores for statistics + '''for src in list(src2tgt.keys()): + scores += list(src2tgt[src].values())''' + # print(f"finished with the double loop. going out of score_sents.") + return src2tgt, tgt2src, similarities, scores + + def sim_matrix(self, a, b, eps=1e-8): + """ + added eps for numerical stability + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)).detach().cpu() + # print(f"sim_mt: {sim_mt}") + # print(f"sim_mt shape in sim_matrix: {sim_mt.shape}") + del a_n, b_n, a_norm, b_norm + if self.sim_measure == 'cosine': + return sim_mt.cuda() + # print(f"self.k: {self.k}") + # print("nearestSrc") + # print(torch.topk(sim_mt, self.k, dim=1, largest=True, sorted=False, out=None)) + nearestSrc = torch.topk(sim_mt, self.k, dim=1, largest=True, sorted=False, out=None) + #sumDistSource = torch.sum(nearestSrc[0], 1) + # print(f"nearestSrc: {nearestSrc}") + # print("nearestTgt") + nearestTgt = torch.topk(sim_mt, self.k, dim=0, largest=True, sorted=False, out=None) + #sumDistTarget = torch.sum(nearestTgt[0], 0) + # print(f"nearestTgt: {nearestTgt}") + # print(f"device self.div: {self.div.get_device()}") + sim_mt = sim_mt.cuda() + # print(f"after sim_mt: {sim_mt}") + # return sim_mt, nearestSrc[0], nearestTgt[0] + c = torch.sum(nearestSrc[0], 1)/self.div.detach().cpu() + d = torch.sum(nearestTgt[0], 0)/self.div.detach().cpu() + # print(f"torch.sum(nearestSrc[0], 1): {c.shape}") + # print(f"torch.sum(nearestTgt[0], 0): {d.shape}") + + return sim_mt , c.cuda(), d.cuda() + # return sim_mt, torch.sum(nearestSrc[0], 1)/self.div, torch.sum(nearestTgt[0], 0)/self.div + # return sim_mt, c, d + + + def get_article_coves(self, article, representation='memory', mean=False, side='phr', use_phrase=False): + """ Get representations (C_e or C_h) for sentences in a document. + Args: + article(inputters.OrderedIterator): iterator over sentences in document + representation(str): if 'memory', create C_h; if 'embed', create C_e + fast(boolean): if true, only look at first batch in article + mean(boolean): if true, use mean over time-step representations; else, sum + Returns: + sents(list(tuple(torch.Tensor...))): + list of sentences in their sequential (seq) and semantic representation (cove) + """ + sents = [] + # print("inside get_article_coves") + #for k in article:#tqdm(article): + # p:rint("next(article)") + id = 0 + # print(f"next(article): {next(article)}") + # print(f"len(article): {len(article)}") + # logger.info(f"torch.cuda.current_device(): {torch.cuda.current_device()}") + for k in article: + # print("inside article!") + # print(f"self.cfg.task.arch: {self.cfg.task.arch}") + # print(f"article id: {id}") + # if id == 3013: + # print("skipping 3013") + # continue + # print(f"k['net_input']['src_tokens']: {k['net_input']['src_tokens']}") + sent_repr = None + if self.cfg.task.arch == "lstm": # if the model architecture is LSTM + lengths = k['net_input']['src_lengths'] + texts = k['net_input']['src_tokens'] + ordered_len, ordered_idx = lengths.sort(0, descending=True) + if self.use_gpu and self.mps: + texts = texts[ordered_idx].to(self.mps_device) + ordered_len = ordered_len.to(self.mps_device) + elif self.use_gpu and self.cuda: + texts = texts[ordered_idx].cuda() + ordered_len = ordered_len.cuda() + else: + texts = texts[ordered_idx] + with torch.no_grad(): + output = self.encoder.forward(texts, ordered_len) # texts.cuda() + + if representation == 'memory': + sent_repr = output['encoder_out'][1].squeeze() + # print("In the lstm representation",sent_repr) + elif representation == 'embed': + # print("Collecting Embedding") + hidden_embed = output['encoder_out'][0] + # print(hidden_embed)tr + if mean: + sent_repr = torch.mean(hidden_embed, dim=0) + else: + sent_repr = torch.sum(hidden_embed, dim=0) + elif self.cfg.task.arch == "transformer": + # print("In the transformer representation") + if representation == 'memory': + with torch.no_grad(): + # print(f"k['net_input']['src_tokens']: {k['net_input']['src_tokens']}") + # print(f"k['net_input']['src_lengths']: {k['net_input']['src_lengths']}") + # encoderOut = self.encoder.forward(k['net_input']['src_tokens'].cuda(), + # k['net_input']['src_lengths'].cuda()) + if self.use_gpu and self.mps: + # print("going into encoder forward") + encoderOut = self.encoder.forward(k['net_input']['src_tokens'].to(self.mps_device), + k['net_input']['src_lengths'].to(self.mps_device)) + elif self.use_gpu and self.cuda: + # print("going into encoder forward") + encoderOut = self.encoder.forward(k['net_input']['src_tokens'].cuda(), k['net_input']['src_lengths'].cuda()) + # print("got encoderOut") + else: + encoderOut = self.encoder.forward(k['net_input']['src_tokens'], + k['net_input']['src_lengths']) + # print(f"encoderOut: {encoderOut}") + # print(f"len(encoderOut['encoder_out']): {len(encoderOut['encoder_out'])}") + hidden_embed = encoderOut['encoder_out'][0] + # hidden_embed = getattr(encoderOut, 'encoder_out') # T x B x C + # print(f"hidden_embed: {hidden_embed}") + if mean: + sent_repr = torch.mean(hidden_embed, dim=0) + else: + sent_repr = torch.sum(hidden_embed, dim=0) + elif representation == 'embed': + with torch.no_grad(): + # print(f"k['net_input']['src_tokens']: {k['net_input']['src_tokens']}") + # print(f"k['net_input']['src_lengths']: {k['net_input']['src_lengths']}") + # print("going into encoder forward emb") + # print(f"self.usepos: {self.usepos}") + if self.usepos: + if self.use_gpu and self.mps: + input_emb,_ = self.encoder.forward_embedding(k['net_input']['src_tokens'].to(self.mps_device)) + elif self.use_gpu and self.cuda: + input_emb,_ = self.encoder.forward_embedding(k['net_input']['src_tokens'].cuda()) + else: + input_emb,_ = self.encoder.forward_embedding(k['net_input']['src_tokens']) + else: + if self.use_gpu and self.mps: + _, input_emb = self.encoder.forward_embedding(k['net_input']['src_tokens'].to(self.mps_device)) # .to(self.mps_device) + elif self.use_gpu and self.cuda: + _, input_emb = self.encoder.forward_embedding(k['net_input']['src_tokens'].cuda()) + else: + _, input_emb = self.encoder.forward_embedding(k['net_input']['src_tokens']) + # print(f"type(input_emb): {type(input_emb)}") + # print(f"self.cuda: {self.cuda}") + + if self.mps: + input_emb = input_emb.to(self.mps_device) + if self.cuda: + input_emb = input_emb.cuda() + + #input_emb = getattr(encoderOut, 'encoder_embedding') # B x T x C + # print(f"input_emb.size(): {input_emb.size()}") + input_emb = input_emb.transpose(0, 1) + if mean: + sent_repr = torch.mean(input_emb, dim=0) + else: + sent_repr = torch.sum(input_emb, dim=0) + if self.cfg.task.arch == "transformer": + # print(f"inside modeltype == transformer") + + for i in range(k['net_input']['src_tokens'].shape[0]): + # print(f"i : {i}") + # print(f"k['net_input']['src_tokens'][i]: {k['net_input']['src_tokens'][i]}") + # print(f"rang(i): {range(k['net_input']['src_tokens'].shape[0])}") + sents.append((k['net_input']['src_tokens'][i], sent_repr[i])) + + elif self.cfg.task.arch == "lstm": + for i in range(texts.shape[0]): + sents.append((texts[i], sent_repr[i])) + # print(f"finishing {id}") + id += 1 + + # print(f"len(sents): {len(sents)}") + return sents + + def get_comparison_pool(self, src_embeds, tgt_embeds): + """ Perform scoring and filtering for C_e (in dual representation system) + Args: + src_embeds(list): list of source embeddings (C_e) + tgt_embeds(list): list of target embeddings (C_e) + Returns: + candidate_pool(set): set of hashed src-tgt C_e pairs + candidate_embed(list): list of src-tgt C_e pairs + """ + # Scoring + src2tgt_embed, tgt2src_embed, _, _ = self.score_sents(src_embeds, tgt_embeds) + # Filtering (primary filter) + print("candidate filtering") + candidates_embed = self.filter_candidates(src2tgt_embed, tgt2src_embed) + # candidates_embed: [(src_sent_x, tgt_sent_y, score_xy)] + # Filter candidates (primary filter), such that only those which are top candidates in + # both src2tgt and tgt2src direction pass. + # Create set of hashed pairs (for easy comparison in secondary filter) + set_embed = set([hash((str(c[0]), str(c[1]))) for c in candidates_embed]) + candidate_pool = set_embed # unique set of hashed (src_sent_x, tgt_sent_y) pairs + return candidate_pool, candidates_embed + + def in_candidate_pool(self, candidate, candidate_pool): + candidate_pair = hash((str(candidate[0]), str(candidate[1]))) + # For dual representation systems... + # ...skip C_h pairs not in C_e (secondary filter) + if candidate_pair in candidate_pool: + return True + return False + + def filter_candidates(self, src2tgt, tgt2src, second=False): + """ Filter candidates (primary filter), such that only those which are top candidates in + both src2tgt and tgt2src direction pass. + Args: + src2tgt(dict(dict(float))): mapping src sequence to tgt sequence and score + tgt2src(dict(dict(float))): mapping tgt sequence to src sequence and score + second(boolean): if true, also include second-best candidate for src2tgt direction + (medium permissibility mode only) + Returns: + candidates(list(tuple(torch.Tensor...)): list of src-tgt candidates + """ + src_tgt_max = set() + tgt_src_max = set() + src_tgt_second = set() + tgt_src_second = set() + i = 0 + + # For each src... + for src in list(src2tgt.keys()): + # print(f"src: {src}") + # sort the dict of dict based on sim scores + toplist = sorted(src2tgt[src].items(), key=lambda x: x[1], reverse=True) + # ... get the top scoring tgt + max_tgt = toplist[0] + # Get src, tgt and score + src_tgt_max.add((src, max_tgt[0], max_tgt[1])) + if second: + # If high permissibility mode, also get second-best tgt + second_tgt = toplist[1] + src_tgt_second.add((src, second_tgt[0], second_tgt[1])) + i += 1 + + # For each tgt... + i = 0 + for tgt in list(tgt2src.keys()): + # print(f"tgt {i}") + toplist = sorted(tgt2src[tgt].items(), key=lambda x: x[1], reverse=True) + # ... get the top scoring src + max_src = toplist[0] + tgt_src_max.add((max_src[0], tgt, max_src[1])) + i += 1 + + if second: + # Intersection as defined in medium permissibility mode + src_tgt = (src_tgt_max | src_tgt_second) & tgt_src_max + candidates = list(src_tgt) + return candidates + + # Intersection as defined in low permissibility + print("Length of s2t max",len(src_tgt_max)) + print("Length of t2s max", len(tgt_src_max)) + # print("Intersection = ",list(src_tgt_max & tgt_src_max)) + candidates = list(src_tgt_max & tgt_src_max) + return candidates # [(src_x, tgt_y, score_xy)] + + def _get_iterator(self, sent, dictn, max_position, epoch, fix_batches_to_gpus=False, shard_batch_itr=False): + """ + Creates an iterator object from a text file. + Args: + path(str): path to text file to process + Returns: + data_iter(.EpochIterator): iterator object + """ + # get indices ordered by example size + with numpy_seed(self.cfg.common.seed): + indices = sent.ordered_indices() + # filter out examples that are too large + max_positions = (max_position) + if max_positions is not None: + indices = filter_by_size(indices, sent, max_positions, raise_exception=(not True), ) + # create mini-batches with given size constraints + # print(f"self.cfg.comparable.max_sentences: {self.cfg.comparable.max_sentences}") + max_sentences = self.batch_size #self.cfg.comparable.max_sentences # 30 + # print(f"max_sentences: {max_sentences}") + # print(f"self.cfg.dataset.num_workers: {self.cfg.dataset.num_workers}") + # print(f"sent.num_tokens: {sent.num_tokens}") + + batch_sampler = batch_by_size(indices, sent.num_tokens, max_sentences=max_sentences, required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, ) + # print(f"tuple(batch_sampler): {tuple(batch_sampler)}") + itrs = EpochBatchIterator(dataset=sent, collate_fn=sent.collater, batch_sampler=batch_sampler, seed=self.cfg.common.seed,num_workers=self.cfg.dataset.num_workers, epoch=epoch, num_shards=self.trainer.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.trainer.data_parallel_rank if shard_batch_itr else 0,) + #data_iter = itrs.next_epoch_itr(shuffle=False, fix_batches_to_gpus=fix_batches_to_gpus) + # print(f"itrs.state_dict: {itrs.state_dict()}") + # print(f"itrs.n(): {itrs.n()}") + # print(f"itrs.first_batch(): {itrs.first_batch()}") + # print(f"next(itrs)") + # print(f"{next(itrs)}") + + return itrs + #return data_iter + #return data_loader + + def get_cove(self, memory, ex, mean=False): + """ Get sentence representation. + Args: + memory(torch.Tensor): hidden states or word embeddings of batch + ex(int): index of example in batch + mean(boolean): if true, take mean over time-steps; else, sum + Returns: + cove(torch.Tensor): sentence representation C_e or C_h + """ + # Get current example + seq_ex = memory[:, ex, :] + if self.cove_type == 'mean': + cove = torch.mean(seq_ex, dim=0) + else: + cove = torch.sum(seq_ex, dim=0) + return cove + + # def get_source_monolingual_data(self, articles): + # trainingSetSrc = load_indexed_dataset(articles[0], self.task.src_dict, + # dataset_impl='raw', combine=False, + # default='cached') + # src_mono = MonolingualDataset(dataset=trainingSetSrc, sizes=trainingSetSrc.sizes, + # src_vocab=self.task.src_dict, + # tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False) + # src_mono = MonolingualDataset(dataset=trainingSetSrc, sizes=trainingSetSrc.sizes, + # src_vocab=self.task.src_dict, + # tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False) + # return src_mono + + def get_unsupervised_data(self, articles): + trainingSetSrc = load_indexed_dataset(articles[0], self.task.src_dict, + dataset_impl='raw', combine=False, + default='cached') + src_mono = MonolingualDataset(dataset=trainingSetSrc, sizes=trainingSetSrc.sizes, + src_vocab=self.task.src_dict, + tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False, fixed_pad_length=512, perform_sampling=True, num_samples=40000) + del trainingSetSrc + # perform_sampling=True, num_samples=20000 + return src_mono + + + def getdata(self, articles): + # logger.info(f"self.cfg.dataset.dataset_impl: raw") + trainingSetSrc = load_indexed_dataset(articles[0], self.task.src_dict, + dataset_impl='raw', combine=False, + default='cached') + trainingSetTgt = load_indexed_dataset(articles[1], self.task.tgt_dict, + dataset_impl='raw', combine=False, + default='cached') + # logger.info(f"trainingSetSrc.sizes: {trainingSetSrc.sizes}") + # print("read the text file ")self.args.data + + # convert the read files to Monolingual dataset to make padding easy + src_mono = MonolingualDataset(dataset=trainingSetSrc, sizes=trainingSetSrc.sizes, + src_vocab=self.task.src_dict, + tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False, fixed_pad_length=512) + tgt_mono = MonolingualDataset(dataset=trainingSetTgt, sizes=trainingSetTgt.sizes, + src_vocab=self.task.tgt_dict, + tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False, fixed_pad_length=512) + + del trainingSetSrc, trainingSetTgt + # print("Monolingual data") + # print(f"src_mono.num_tokens(1): {src_mono.num_tokens(1)}") + # print(f"tgt_mono.num_tokens(1): {tgt_mono.num_tokens(1)}") + # logger.info(f"src_mono.sizes: {src_mono.sizes}") + return src_mono, tgt_mono + + def generate_output(self, + sample, + ): + def decode(toks, escape_unk=False): + s = self.task.tgt_dict.string( + toks.int().cpu(), + unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), + ) + return s + # bos_token=self.task.tgt_dict.eos() + gen_out = self.task.inference_step(self.task.sequence_generator, [self.model], sample, prefix_tokens=None) + hyps = [] + for i in range(len(gen_out)): + hyps.append(decode(gen_out[i][0]["tokens"])) + logger.info("example hypothesis: " + hyps[0]) + # self.task.src_dict.string(src) + logger.info(f"input reference: {self.task.src_dict.string(removePadding(sample['net_input']['src_tokens']))}") + + return gen_out + + + def unsupervised_training(self, comparable_data_list, epoch): + # tracemalloc.start() + """ + Trains with Mono-(lingual/stylistic) SRC data using model from the previous time-step + """ + translate_fn = self.generate_output + # Go through comparable data + with open(comparable_data_list, encoding='utf8') as c: + comp_list = c.read().split('\n') + cur_article = 0 + for ap, article_pair in enumerate(comp_list): + print(f"on article {ap}") + cur_article += 1 + articles = article_pair.split(' ') + + # Discard malaligned documents + if len(articles) != 2: + continue + + trainingSetSrc = load_indexed_dataset(articles[0], self.task.src_dict, dataset_impl='raw', combine=False, default='cached') + + logger.info(f"trainingSetSrc type: {type(trainingSetSrc)}") + + src_mono = MonolingualDataset(dataset=trainingSetSrc, sizes=trainingSetSrc.sizes, src_vocab=self.task.src_dict, tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False) + # logger.info(f"src mono type: {type(src_mono)}") + # in monolingualDataset, sample['target] is None. + + #src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated + #sentences. i.e. in TR-OG case, = self.task.tgt_dict = OG dict + # tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of + # sentences to be backtranslated. = self.task.src_dict = TR dict + + src_gen_data = BacktranslationDataset(tgt_dataset=src_mono, sizes=trainingSetSrc.sizes, src_dict=self.task.tgt_dict, tgt_dict=self.task.src_dict, backtranslation_fn=translate_fn, cuda=True) + + #1. get iterator from SRC-GEN data + logger.info("get iterator from SRC-GEN data") + data_itr = self._get_iterator(src_gen_data, dictn=self.task.src_dict, max_position=self.cfg.task.max_source_positions, epoch=epoch, fix_batches_to_gpus=False) + + #2. get ENC representations for both the instances in the criterion!! + logger.info("get ENC representations for both the instances") + + it1 = data_itr.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) + unsup_itr = GroupedIterator(it1, self.update_freq[-1], skip_remainder_batch=self.cfg.optimization.skip_remainder_batch) + + unsup_progress = progress_bar.progress_bar( + unsup_itr, + log_format=self.cfg.common.log_format, + log_file=self.cfg.common.log_file, + log_interval=self.log_interval, + epoch=epoch, + aim_repo=( + self.cfg.common.aim_repo + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + aim_run_hash=( + self.cfg.common.aim_run_hash + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + aim_param_checkpoint_dir=self.cfg.checkpoint.save_dir, + tensorboard_logdir=( + self.cfg.common.tensorboard_logdir + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + default_log_format=("tqdm" if not self.cfg.common.no_progress_bar else "simple"), + wandb_project=( + self.cfg.common.wandb_project + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(self.cfg.checkpoint.save_dir) + ), + azureml_logging=( + self.cfg.common.azureml_logging + if distributed_utils.is_master(self.cfg.distributed_training) + else False + ), + ) + unsup_progress.update_config(_flatten_config(self.cfg)) + logger.info(f"Start iterating over SRC-TGT' samples") + + + # for k in it1: + # print(f"k: {k}") + # break + + # 3. combine Supervised and Backtranslation data using concat datasets/Round-robin dataset - how will I make a distinction? -- think!! + # 4. have one criteria that computes supervised and unsupervised loss! + # 5. unsupervised loss can either be computed using model at the previous timestep. or it can be computed real-time like in semi-fst. just check if the generation happens per sample or over a batch of samples!! + # 6. valid_step already changed in translate_mod.py -- make this change to translate.py to compute unsupervised loss during validation and we are all set! + + # 3. compute cosine similarity loss + # 4. compute entopy-loss for generations + # 5. update gradients + # 6. In train_unsup_comp.py, change validation function and make it similar to unsup training. + + def faiss_sent_scoring_v2(self, src_sents, tgt_sents): + """ Score source and target combinations. + Args: + src_sents(list(tuple(torch.Tensor...))): + list of src sentences in their sequential and semantic representation + tgt_sents(list(tuple(torch.Tensor...))): list of tgt sentences + Returns: + src2tgt(dict(dict(float))): dictionary mapping a src to a tgt and their score + tgt2src(dict(dict(float))): dictionary mapping a tgt to a src and their score + similarities(list(float)): list of cosine similarities + scores(list(float)): list of scores + """ + start = time.time() + + srcSent, srcRep = zip(*src_sents) + # print(f"srcSent: {srcSent}") + tgtSent, tgtRep = zip(*tgt_sents) + # print(f"tgtSent: {tgtSent}") + + # print("faiss sent scoring") + + if self.faiss_use_gpu: + # https://github.com/facebookresearch/faiss/wiki/Faiss-on-the-GPU + ngpus = faiss.get_num_gpus() + logger.info(f"number of GPUs: {ngpus}") + + # srcSent2ind = {sent:i for i, sent in enumerate(srcSent)} + # tgtSent2ind = {sent:i for i, sent in enumerate(tgtSent)} + + # logger.info(f"len srcRep: {len(srcRep)}") + # logger.info(f"srcRep: {srcRep}") + + + + # x = torch.stack(list(srcRep), dim=0) #torch.cat(srcRep, dim=1) # concat along the rows + # y = torch.stack(list(tgtRep), dim=0) + if self.faiss_use_gpu: + x, y = srcRep, tgtRep + + # logger.info(f"len x: {x.size()}") + # logger.info(f"x: {x}") + else: + x= np.asarray([rep.detach().cpu().numpy() for rep in srcRep]) + y= np.asarray([rep.detach().cpu().numpy() for rep in tgtRep]) + + # print(f"normalising x.dtype : {x.dtype}") + faiss.normalize_L2(x) + faiss.normalize_L2(y) + # logger.info("done torch normalizing") + # logger.info(f"x.size(): {x.size()}") + # https://discuss.pytorch.org/t/how-to-normalize-embedding-vectors/1209/9 + # x = torch.nn.functional.normalize(x, p=2, dim=1) + # y = torch.nn.functional.normalize(y, p=2, dim=1) + #F.normalize(x, p=2, dim=1) + + candidates = [] + + # torch.from_numpy(a) + + # calculate knn in both directions + if self.retrieval != 'bwd': + if self.verbose: + logger.info(' - perform {:d}-nn source against target'.format(self.k)) + x2y_sim, x2y_ind, x, y = knn_v2(x, y, self.k, self.faiss_use_gpu, self.index) + if self.faiss_use_gpu: + x2y_sim = x2y_sim.numpy() #.detach().cpu().numpy() + x2y_ind = x2y_ind.numpy() # .detach().cpu() + x2y_mean = x2y_sim.mean(axis=1) + #x2y_mean = torch.mean(x2y_sim, 1) + + # print(f"x2y_sim.shape: {x2y_sim.shape}") + # print(f"x2y_ind.shape: {x2y_ind.shape}") + + if self.retrieval != 'fwd': + if self.verbose: + logger.info(' - perform {:d}-nn target against source'.format(self.k)) + y2x_sim, y2x_ind, _, _ = knn_v2(y, x, self.k, self.faiss_use_gpu, self.index) + # logger.info(f"type(y2x_sim): {type(y2x_sim)}, type(y2x_ind): {type(y2x_ind)}") + if self.faiss_use_gpu: + y2x_sim = y2x_sim.numpy() # .detach().cpu() + y2x_ind = y2x_ind.numpy() # .detach().cpu().numpy() + y2x_mean = y2x_sim.mean(axis=1) + # y2x_mean = torch.mean(y2x_sim, 1) + + # margin function + if self.margin == 'absolute': + margin = lambda a, b: a + elif self.margin == 'distance': + margin = lambda a, b: a - b + else: # args.margin == 'ratio': + margin = lambda a, b: a / b + + # print(f"margin: {margin}") + + fout = open(self.faiss_output, mode='w', encoding='utf8', errors='surrogateescape') + + src_inds=list(range(len(srcSent))) + trg_inds=list(range(len(tgtSent))) + + if self.mode == 'search': + if self.verbose: + print(' - Searching for closest sentences in target') + print(' - writing alignments to {:s}'.format(self.faiss_output)) + scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin, self.verbose) + best = x2y_ind[np.arange(x.shape[0]), scores.argmax(axis=1)] + + print(f"best: {best}") + + nbex = x.shape[0] + ref = np.linspace(0, nbex-1, nbex).astype(int) # [0, nbex) + err = nbex - np.equal(best.reshape(nbex), ref).astype(int).sum() + print(' - errors: {:d}={:.2f}%'.format(err, 100*err/nbex)) + for i in src_inds: + print(tgtSent[best[i]], file=fout) + + elif self.mode == 'score': + for i, j in zip(src_inds, trg_inds): + s = score(x[i], y[j], x2y_mean[i], y2x_mean[j], margin) + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + + elif self.mode == 'mine': + if self.verbose: + logger.info(' - mining for parallel data') + fwd_scores = score_candidates(x, y, x2y_ind, x2y_mean, y2x_mean, margin, self.verbose) + bwd_scores = score_candidates(y, x, y2x_ind, y2x_mean, x2y_mean, margin, self.verbose) + fwd_best = x2y_ind[np.arange(x.shape[0]), fwd_scores.argmax(axis=1)] + # print(f"fwd_best: {fwd_best}") + bwd_best = y2x_ind[np.arange(y.shape[0]), bwd_scores.argmax(axis=1)] + # print(f"bwd_best: {bwd_best}") + if self.verbose: + logger.info(' - writing alignments to {:s}'.format(self.faiss_output)) + if self.threshold > 0: + logger.info(' - with threshold of {:f}'.format(self.threshold)) + if self.retrieval == 'fwd': + for i, j in enumerate(fwd_best): + s = fwd_scores[i].max() + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(fwd_scores[i].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) + if self.retrieval == 'bwd': + for j, i in enumerate(bwd_best): + s = bwd_scores[j].max() + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(bwd_scores[j].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) + if self.retrieval == 'intersect': + for i, j in enumerate(fwd_best): + if bwd_best[j] == i: + s = fwd_scores[i].max() + src = srcSent[i] + tgt = tgtSent[j] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(fwd_scores[i].max(), srcSent[i], tgtSent[j], sep='\t', file=fout) + candidates.append((srcSent[i], tgtSent[j], s)) + if self.retrieval == 'max': + indices = np.stack((np.concatenate((np.arange(x.shape[0]), bwd_best)), + np.concatenate((fwd_best, np.arange(y.shape[0])))), axis=1) + scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1))) + seen_src, seen_trg = set(), set() + for i in np.argsort(-scores): + src_ind, trg_ind = indices[i] + if not src_ind in seen_src and not trg_ind in seen_trg: + seen_src.add(src_ind) + seen_trg.add(trg_ind) + if scores[i] > self.threshold: + s = scores[i] + src = srcSent[src_ind] + tgt = tgtSent[trg_ind] + src_words = self.task.src_dict.string(src) + tgt_words = self.task.tgt_dict.string(tgt) + out = 'src: {}\ttgt: {}\tsimilarity: {}\n'.format(removeSpaces(' '.join(src_words)), + removeSpaces(' '.join(tgt_words)), s) + print(out, file=fout) + # print(scores[i], srcSent[src_ind], tgtSent[trg_ind], sep='\t', file=fout) + candidates.append((srcSent[src_ind], tgtSent[trg_ind], scores[i])) + + fout.close() + logger.info(f"time taken by faiss sent scoring: {time.time()-start} seconds.") + # logger.info(f"num candidates: {len(candidates)}") + return candidates + + + def extract_and_train(self, comparable_data_list, epoch): + + tracemalloc.start() + """ Manages the alternating extraction of parallel sentences and training. + Args: + comparable_data_list(str): path to list of mapped documents + Returns: + train_stats(:obj:'onmt.Trainer.Statistics'): epoch loss statistics + """ + + self.accepted_file = open('{}_accepted-e{}.txt'.format(self.comp_log, epoch), 'w+', encoding='utf8') + if self.use_phrase == True: + self.accepted_phrase = open('{}_accepted_phrase-e{}.txt'.format(self.comp_log, epoch), 'w+', + encoding='utf8') + self.status_file = '{}_status-e{}.txt'.format(self.comp_log, epoch) + if self.write_dual: + self.embed_file = '{}_accepted_embed-e{}.txt'.format(self.comp_log, + epoch) + self.hidden_file = '{}_accepted_hidden-e{}.txt'.format(self.comp_log, + epoch) + + epoch_similarities = [] + epoch_scores = [] + src_sents = [] + tgt_sents = [] + src_embeds = [] + tgt_embeds = [] + + # Go through comparable data + with open(comparable_data_list, encoding='utf8') as c: + comp_list = c.read().split('\n') + #num_articles = len(comp_list) + unsup_data = None + cur_article = 0 + for ap, article_pair in enumerate(comp_list): + logger.info(f"on article {ap}") + cur_article += 1 + articles = article_pair.split(' ') + # print(f"articles: {articles}") + # print(f"len(articles): {len(articles)}") + # Discard malaligned documents + if len(articles) != 2: + continue + #load the dataset from the files for both source and target + src_mono, tgt_mono = self.getdata(articles) + # trainingSetSrc = load_indexed_dataset(articles[0], self.task.src_dict, + # dataset_impl='raw', combine=False, + # default='cached') + # src_mono = MonolingualDataset(dataset=trainingSetSrc, sizes=trainingSetSrc.sizes, + # src_vocab=self.task.src_dict, + # tgt_vocab=None, shuffle=False, add_eos_for_other_targets=False) + unsup_data = self.get_unsupervised_data(articles) + # Prepare iterator objects for current src/tgt document + # print(f"self.task.src_dict: {self.task.src_dict}") + # print(f"self.cfg.max_source_positions: {self.cfg.task.max_source_positions}") + # print(f"get iterator") + + src_article = self._get_iterator(src_mono, dictn=self.task.src_dict, max_position=self.cfg.task.max_source_positions, epoch=epoch, fix_batches_to_gpus=False) + tgt_article = self._get_iterator(tgt_mono, dictn=self.task.tgt_dict, max_position=self.cfg.task.max_target_positions, epoch=epoch, fix_batches_to_gpus=False) + + # re-use src_article for monolingual unsupervised training + # self.unsup_itr = src_article + + # Get sentence representations + try: + if self.representations == 'embed-only': + print("Using Embeddings only for representation") + # C_e + itr_src = src_article._get_iterator_for_epoch(epoch=epoch, shuffle=True) + itr_tgt = tgt_article._get_iterator_for_epoch(epoch=epoch, shuffle=True) + print(f"src article, rep=embed") + src_sents += self.get_article_coves(itr_src, representation='embed', mean=False) + # print(f"tgt article, rep=embed") + tgt_sents += self.get_article_coves(itr_tgt, representation='embed', mean=False) + else: + # C_e and C_h + '''it1, it2 = itertools.tee(src_article) + it3, it4 = itertools.tee(tgt_article)''' + logger.info(f"src article, rep=embed") + it1 = src_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) + src_embeds += self.get_article_coves(it1, representation='embed', mean=False, side='src', + use_phrase=self.use_phrase) + logger.info(f"src article, rep=memory") + it1 = src_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) + src_sents += self.get_article_coves(it1, representation='memory', mean=False, side='src') + + logger.info(f"tgt article, rep=embed") + it3 = tgt_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) + tgt_embeds += self.get_article_coves(it3, representation='embed', mean=False, side='tgt', + use_phrase=self.use_phrase) + logger.info(f"tgt article, rep=memory") + it3 = tgt_article.next_epoch_itr(shuffle=False, fix_batches_to_gpus=False) + tgt_sents += self.get_article_coves(it3, representation='memory', mean=False, side='tgt') + + #return + except: + #Skip document pair in case of errors + print("error") + src_sents = [] + tgt_sents = [] + src_embeds = [] + tgt_embeds = [] + continue + # src_mono.dataset.tokens_list = None + # src_mono.dataset.sizes = None + # src_mono.sizes = None + tgt_mono.sizes = None + del tgt_mono + # del src_mono + + if len(src_sents) < 15 or len(tgt_sents) < 15: + #print("Length LEss tahn 15") + continue + print("Proceeding") + # Score src and tgt sentences + print("In all we have got ", len(src_sents), "source sentences and ", len(tgt_sents), "target") + + try: + logger.info(f"self.faiss: {self.faiss}") + if self.faiss: + candidates = self.faiss_sent_scoring_v2(src_sents, tgt_sents) + logger.info(f"done with faiss scoring of src sents and tgt sents") + candidates_embed = self.faiss_sent_scoring_v2(src_embeds, tgt_embeds) + logger.info(f"num candidates = {len(candidates)}") + logger.info(f"num candidates_embed = {len(candidates_embed)}") + logger.info(f"done with faiss scoring of src embeds and tgt embeds") + embed_comparison_pool = set_embed = set([hash((str(c[0]), str(c[1]))) for c in candidates_embed]) + # candidates : [(src_sent_x, tgt_sent_y, score_xy)] + logger.info(f"made embed_comparison_pool") + if self.write_dual: + #print("writing the sentences to file....") + self.write_embed_only(candidates, candidates_embed) + # Extract parallel samples (secondary filter) + logger.info(f"starting to extract parallel sents") + self.extract_parallel_sents(candidates, embed_comparison_pool, use_threshold=self.use_threshold) + else: + src2tgt, tgt2src, similarities, scores = self.score_sents(src_sents, tgt_sents) + # src2tgt = { "dis a src sent": {"dis a tg": 0.2, "dis s a TRG": 0.6, "dis": 0.12} } + # this score could be from margin / cosine similarity + # similarities containes only sim scores (useless var) + # scores is a useless var + except Exception as e: + print('Error occurred in: {}\n'.format(article_pair), flush=True) + print(f"e: {e}") + print("src_sents") + # print(src_sents, flush=True) + print("tgt_sents") + # print(tgt_sents, flush=True) + src_sents = [] + tgt_sents = [] + continue + # print("source 2 target ", src2tgt) + # Keep statistics + #epoch_similarities += similarities + #epoch_scores += scores + src_sents = [] + tgt_sents = [] + + if not self.faiss: + try: + if self.representations == 'dual': + # For dual representation systems, filter C_h... + candidates = self.filter_candidates(src2tgt, tgt2src, second=self.second) + # candidates : [(src_sent_x, tgt_sent_y, score_xy)] + # candidates generated from memory representations + # Filter candidates (primary filter), such that only those which are top candidates in + # both src2tgt and tgt2src direction pass. + # ...and C_e + comparison_pool, cand_embed = self.get_comparison_pool(src_embeds,tgt_embeds) + # comparison_pool: unique set of hashed (src_sent_x, tgt_sent_y) pairs + # cand_embed: candidates generated from embedding representations + # [(src_sent_x, tgt_sent_y, score_xy)] + src_embeds = [] + tgt_embeds = [] + if self.write_dual: + #print("writing the sentences to file....") + self.write_embed_only(candidates, cand_embed) + else: + print("Using Embedings only for Filtering ......") + # Filter C_e or C_h for single representation system + candidates = self.filter_candidates(src2tgt, tgt2src) + comparison_pool = None + except: + # Skip document pair in case of errors + print("Error Occured!!!!") + print('Error occured in: {}\n'.format(article_pair), flush=True) + src_embeds = [] + tgt_embeds = [] + continue + + # Extract parallel samples (secondary filter) + self.extract_parallel_sents(candidates, comparison_pool, use_threshold=self.use_threshold) + # if phrase extraction is to be used + + logger.info(f"pair bank = {len(self.similar_pairs.srcs)}") + + src_data = torchDataset(data_list=self.similar_pairs.srcs) + tgt_data = torchDataset(data_list=self.similar_pairs.tgts) + pairData = LanguagePairDataset( + src_data, self.similar_pairs.src_lens, self.task.src_dict, + tgt_data, self.similar_pairs.tgt_lens, self.task.tgt_dict, + left_pad_source=self.cfg.task.left_pad_source, + left_pad_target=self.cfg.task.left_pad_target) + + self.concat_data = RoundRobinZipDatasets( + OrderedDict([('sup', pairData)] + [('unsup', unsup_data)]), + eval_key=None + ) + # indices = self.concat_data.ordered_indices() + + # self.concat_data.filter_indices_by_size(indices=indices, max_positions=512) + + self.train(epoch) + self.reset_pairbank() + + if not self.faiss: + del src2tgt, tgt2src + #gc.collect() + # Add to leaky code within python_script_being_profiled.py + + snapshot = tracemalloc.take_snapshot() + top_stats = snapshot.statistics('lineno') + + # if len(self.similar_pairs.pairs) > 0: + # # print("batching and training") + # logger.info("batching and training") + # self.train(epoch, last=True) + + self.accepted_file.close() + if self.use_phrase == True: + self.accepted_phrase.close() + + # log end-of-epoch stats + logger.info("end of epoch {} (average epoch stats below)".format(epoch)) + num_updates = self.trainer.get_num_updates() + stats = get_training_stats(metrics.get_smoothed_values('train')) + self.progress.print(stats, tag='train', step=num_updates) + + # reset epoch-level meters + metrics.reset_meters('train') + end_of_epoch = True + return num_updates, end_of_epoch + + def reset_pairbank(self): + self.similar_pairs.srcs = [] + self.similar_pairs.tgts = [] + self.similar_pairs.src_lens = [] + self.similar_pairs.tgt_lens = [] + ''' + @metrics.aggregate('train') + def trainRest(self, epoch): + itrs = self.similar_pairs.yield_batch() + itr = itrs.next_epoch_itr(shuffle=True, fix_batches_to_gpus=False) + itr = GroupedIterator(itr, 1) + self.progress = progress_bar.build_progress_bar( + self.args, itr, epoch, no_progress_bar='simple', + ) + for samples in self.progress: + log_output = self.trainer.train_step(samples) + num_updates = self.trainer.get_num_updates() + if log_output is None: + continue + # log mid-epoch stats + stats = get_training_stats(metrics.get_smoothed_values('train')) + self.progress.log(stats, tag='train', step=num_updates) + self.progress.print(stats, tag='train', step=num_updates) + + print("done") + #del itrs, itr + ''' + + @metrics.aggregate('train') + def train(self, epoch, itrs=None, last=False): + # Check if enough parallel sentences were collected + # is_epoch_end = False + if last is False: + if itrs is None: + itrs = self.task.get_batch_iterator(self.concat_data, max_sentences=self.batch_size, epoch=0, max_positions=self.cfg.task.max_source_positions) + itr = itrs.next_epoch_itr(shuffle=True, fix_batches_to_gpus=self.cfg.distributed_training.fix_batches_to_gpus) + itr = GroupedIterator(itr, self.update_freq[-1], skip_remainder_batch=self.cfg.optimization.skip_remainder_batch) + if self.cfg.common.tpu: + itr = utils.tpu_data_loader(itr) + self.progress = progress_bar.progress_bar( + itr, + log_format=self.cfg.common.log_format, + log_file=self.cfg.common.log_file, + log_interval=self.log_interval, + epoch=epoch, + aim_repo=( + self.cfg.common.aim_repo + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + aim_run_hash=( + self.cfg.common.aim_run_hash + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + aim_param_checkpoint_dir=self.cfg.checkpoint.save_dir, + tensorboard_logdir=( + self.cfg.common.tensorboard_logdir + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + default_log_format=("tqdm" if not self.cfg.common.no_progress_bar else "simple"), + wandb_project=( + self.cfg.common.wandb_project + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(self.cfg.checkpoint.save_dir) + ), + azureml_logging=( + self.cfg.common.azureml_logging + if distributed_utils.is_master(self.cfg.distributed_training) + else False + ), + ) + self.progress.update_config(_flatten_config(self.cfg)) + # logger.info(f"Start iterating over samples") + for i, samples in enumerate(self.progress): + with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( + "train_step-%d" % i + ): + log_output = self.trainer.train_step(samples) + if log_output is not None: # not OOM, overflow, ... + # log mid-epoch stats + num_updates = self.trainer.get_num_updates() + if num_updates % self.log_interval == 0: + stats = get_training_stats(metrics.get_smoothed_values("train_inner")) + self.progress.log(stats, tag="train_inner", step=num_updates) + + # reset mid-epoch stats after each log interval + # the end-of-epoch stats will still be preserved + metrics.reset_meters('train_inner') + + def validate(self, epoch, itr): + """Evaluate the model on the validation set(s) and return the losses.""" + + if self.cfg.dataset.fixed_validation_seed is not None: + # set fixed seed for every validation + utils.set_torch_seed(self.cfg.dataset.fixed_validation_seed) + + self.trainer.begin_valid_epoch(epoch) + valid_losses = [] + # for subset_idx, subset in enumerate(subsets): + # logger.info(' "{}" subset'.format(subset)) + + # Initialize data iterator + # itr = self.trainer.get_valid_iterator(subset).next_epoch_itr( + # shuffle=False, set_dataset_epoch=False # use a fixed valid set + # ) + if self.cfg.common.tpu: + itr = utils.tpu_data_loader(itr) + # print(f"self.cfg.distributed_training: {self.cfg.distributed_training}") + progress = progress_bar.progress_bar( + itr, + log_format=self.cfg.common.log_format, + log_interval=self.cfg.common.log_interval, + epoch=epoch, + prefix=f"valid on 'validation' subset", + aim_repo=( + self.cfg.common.aim_repo + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + aim_run_hash=( + self.cfg.common.aim_run_hash + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + aim_param_checkpoint_dir=self.cfg.checkpoint.save_dir, + tensorboard_logdir=( + self.cfg.common.tensorboard_logdir + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + default_log_format=("tqdm" if not self.cfg.common.no_progress_bar else "simple"), + wandb_project=( + self.cfg.common.wandb_project + if distributed_utils.is_master(self.cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(self.cfg.checkpoint.save_dir) + ), + ) + + # create a new root metrics aggregator so validation metrics + # don't pollute other aggregators (e.g., train meters) + with metrics.aggregate(new_root=True) as agg: + for i, sample in enumerate(progress): + if ( + self.cfg.dataset.max_valid_steps is not None + and i > self.cfg.dataset.max_valid_steps + ): + break + self.trainer.valid_step(sample) + + # log validation stats + # only tracking the best metric on the 1st validation subset + tracking_best = True #subset_idx == 0 + stats = get_valid_stats(self.cfg, self.trainer, agg.get_smoothed_values(), tracking_best) + + if hasattr(self.task, "post_validate"): + self.task.post_validate(self.trainer.get_model(), stats, agg) + + progress.print(stats, tag='valid', step=self.trainer.get_num_updates()) + + # logger.info(f"stats: {stats}") + + valid_losses.append(stats[self.cfg.checkpoint.best_checkpoint_metric]) + return valid_losses + + def save_comp_chkp(self, epoch): + dirs = self.save_dir + '/' + self.model_name + '_' + str(epoch) + self.src + "-" + self.tgt + ".pt" + self.trainer.save_checkpoint(dirs, {"train_iterator": {"epoch": epoch}}) + + +def get_valid_stats( + cfg: DictConfig, + trainer: Trainer, + stats: Dict[str, Any], + tracking_best: bool, +) -> Dict[str, Any]: + stats["num_updates"] = trainer.get_num_updates() + if tracking_best and hasattr(checkpoint_utils.save_checkpoint, "best"): + key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) + best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min + stats[key] = best_function( + checkpoint_utils.save_checkpoint.best, + stats[cfg.checkpoint.best_checkpoint_metric], + ) + # logger.info(f"stats in get_valid_stats: {stats}") + return stats + +def get_training_stats(stats): + if 'nll_loss' in stats and 'ppl' not in stats: + stats['ppl'] = utils.get_perplexity(stats['nll_loss']) + stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0) + return stats + +def removePadding(side): + """ Removes original padding from a sequence. + Args: + side(torch.Tensor): src/tgt sequence (size(seq)) + Returns: + side(torch.Tensor): src/tgt sequence without padding + NOTE: This only works as long as PAD_ID==1! + """ + # Get indexes of paddings in sequence + padding_idx = (side == 1).nonzero() + # If there is any padding, cut sequence from first occurence of a pad + if padding_idx.size(0) != 0: + first_pad = padding_idx.data.tolist()[0][0] + side = side[:first_pad] + return side diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index dbd1450a9e..52ffc8ca71 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -113,10 +113,13 @@ def eval_lm( gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample["ntokens"]) + logger.info(f"hypos: {hypos}") for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] + logger.info(f"hypos_i: {hypos_i}") sample_id = sample["id"][i] + logger.info(f"hypo: {hypo}") tokens = hypo["tokens"] tgt_len = tokens.numel() diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index b8757835d4..bd8feb23b0 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -80,14 +80,17 @@ def _main(cfg: DictConfig, output_file): use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Load dataset splits + # print(f"cfg.task: {cfg.task}") task = tasks.setup_task(cfg.task) # Set dictionaries try: src_dict = getattr(task, "source_dictionary", None) + logger.info(f"src: {len(src_dict)}") except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary + logger.info(f"task.target_dictionary: {len(task.target_dictionary)}") overrides = ast.literal_eval(cfg.common_eval.model_overrides) diff --git a/fairseq_cli/train_unsup_comp.py b/fairseq_cli/train_unsup_comp.py new file mode 100644 index 0000000000..8e6af4246d --- /dev/null +++ b/fairseq_cli/train_unsup_comp.py @@ -0,0 +1,1009 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Train a new model on one or across multiple GPUs. +""" + +import argparse +import logging +import math +import os +import sys +from typing import Any, Callable, Dict, List, Optional, Tuple +from fairseq.file_io import PathManager +from fairseq.dataclass.configs import CheckpointConfig +import logging +import ast +import collections + +# We need to setup root logger before importing any fairseq libraries. +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("fairseq_cli.traincomp") + +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf + +from fairseq import checkpoint_utils, options, quantization_utils, tasks, utils +from fairseq.data import data_utils, iterators, indexed_dataset, MonolingualDataset +from fairseq.data.plasma_utils import PlasmaStore +from fairseq.dataclass.configs import FairseqConfig +from fairseq.dataclass.initialize import add_defaults +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap +from fairseq.distributed import utils as distributed_utils +from fairseq.file_io import PathManager +from fairseq.logging import meters, metrics, progress_bar +from fairseq.model_parallel.megatron_trainer import MegatronTrainer +from fairseq.trainer import Trainer +# from fairseq_cli.Comparable4 import Comparable +from fairseq_cli.Comparable_unsup import Comparable + +def load_validation_data(data_path, src, tgt, src_dict, dataset_impl, split='valid', left_pad_source=True): + def split_exists(split, src, tgt, data_path): + filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, src)) + return indexed_dataset.dataset_exists(filename, impl=dataset_impl) + + if split_exists(split, src, tgt, data_path): + prefix = os.path.join(data_path, "{}.{}-{}.".format(split, src, tgt)) + + dataset = data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl) + + return MonolingualDataset(dataset=dataset, sizes=dataset.sizes, src_vocab=src_dict, tgt_vocab=None, shuffle=False,add_eos_for_other_targets=False) + + +def get_valid_iterator(cfg, dataset, trainer, task, disable_iterator_cache=False): + batch_iterator = task.get_batch_iterator( + dataset=dataset, + max_tokens=cfg.dataset.max_tokens_valid, + max_sentences=cfg.dataset.batch_size_valid, + max_positions=utils.resolve_max_positions( + task.max_positions(), + trainer.model.max_positions(), + ), + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=trainer.data_parallel_world_size, + shard_id=trainer.data_parallel_rank, + num_workers=cfg.dataset.num_workers, + # always pass a fixed "epoch" to keep validation data consistent + # across training epochs + epoch=1, + data_buffer_size=cfg.dataset.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + skip_remainder_batch=False, + ) + trainer.reset_dummy_batch(batch_iterator.first_batch) + return batch_iterator + +def main(cfg: FairseqConfig) -> None: + if isinstance(cfg, argparse.Namespace): + # print(f"convert namespace") + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) + # print(f"added user module") + add_defaults(cfg) + # print(f"added defaults") + + if ( + distributed_utils.is_master(cfg.distributed_training) + and "job_logging_cfg" in cfg + ): + # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) + logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) + + assert ( + cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None + ), "Must specify batch size either with --max-tokens or --batch-size" + metrics.reset() + + if cfg.common.log_file is not None: + handler = logging.FileHandler(filename=cfg.common.log_file) + logger.addHandler(handler) + + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) + + if distributed_utils.is_master(cfg.distributed_training): + checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) + + # Print args + logger.info(cfg) + + if cfg.checkpoint.write_checkpoints_asynchronously: + try: + import iopath # noqa: F401 + except ImportError: + logging.exception( + "Asynchronous checkpoint writing is specified but iopath is " + "not installed: `pip install iopath`" + ) + return + + # Setup task, e.g., translation, language modeling, etc. + task = tasks.setup_task(cfg.task) + # cfg.task.src_dict.add_symbol("") + # cfg.task.tgt_dict.add_symbol("") + + assert cfg.criterion, "Please specify criterion to train a model" + + # Build model and criterion + if cfg.distributed_training.ddp_backend == "fully_sharded": + with fsdp_enable_wrap(cfg.distributed_training): + model = fsdp_wrap(task.build_model(cfg.model)) + else: + model = task.build_model(cfg) #.model + criterion = task.build_criterion(cfg.criterion) + # generator = task.build_generator([model]) # SequenceGenerator object + logger.info(model) + logger.info("task: {}".format(task.__class__.__name__)) + logger.info("model: {}".format(model.__class__.__name__)) + logger.info("criterion: {}".format(criterion.__class__.__name__)) + # logger.info("generator: {}".format(generator.__class__.__name__)) + logger.info( + "num. shared model params: {:,} (num. trained: {:,})".format( + sum( + p.numel() for p in model.parameters() if not getattr(p, "expert", False) + ), + sum( + p.numel() + for p in model.parameters() + if not getattr(p, "expert", False) and p.requires_grad + ), + ) + ) + + logger.info( + "num. expert model params: {} (num. trained: {})".format( + sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)), + sum( + p.numel() + for p in model.parameters() + if getattr(p, "expert", False) and p.requires_grad + ), + ) + ) + + # Load valid dataset (we load training data below, based on the latest checkpoint) + # We load the valid dataset AFTER building the model + train_dataset = None + if not cfg.dataset.disable_validation: + data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) + + if cfg.comparable.comparable: + paths = utils.split_paths(cfg.task.data) + assert len(paths) > 0 + # logger.info(f"paths: {paths}") + src, tgt = cfg.task.source_lang, cfg.task.target_lang + data_path = paths[0] + # logger.info(f"data_path: {data_path}") + train_dataset = load_validation_data(data_path, src, tgt, src_dict=task.src_dict, dataset_impl=cfg.dataset.dataset_impl, split='train') + vaild_dataset = load_validation_data(data_path,src, tgt,src_dict=task.src_dict, dataset_impl=cfg.dataset.dataset_impl) + + elif cfg.dataset.combine_valid_subsets: + task.load_dataset("valid", combine=True, epoch=1) + else: + for valid_sub_split in cfg.dataset.valid_subset.split(","): + task.load_dataset(valid_sub_split, combine=False, epoch=1) + + # (optionally) Configure quantization + if cfg.common.quantization_config_path is not None: + quantizer = quantization_utils.Quantizer( + config_path=cfg.common.quantization_config_path, + max_epoch=cfg.optimization.max_epoch, + max_update=cfg.optimization.max_update, + ) + else: + quantizer = None + + # Build trainer + if cfg.common.model_parallel_size == 1: + logger.info("trainer") + trainer = Trainer(cfg, task, model, criterion, quantizer) + else: + logger.info("MegatronTrainer") + trainer = MegatronTrainer(cfg, task, model, criterion) + + logger.info( + "training on {} devices (GPUs/TPUs)".format( + cfg.distributed_training.distributed_world_size + ) + ) + logger.info( + "max tokens per device = {} and max sentences per device = {}".format( + cfg.dataset.max_tokens, + cfg.dataset.batch_size, + ) + ) + + # Load the latest checkpoint if one is available and restore the + # corresponding train iterator + # extra_state, epoch_itr = checkpoint_utils.load_checkpoint( + # cfg.checkpoint, + # trainer, + # # don't cache epoch iterators for sharded datasets + # disable_iterator_cache=task.has_sharded_data("train"), + # ) + extra_state, epoch = load_checkpoint(cfg, trainer) + if cfg.common.tpu: + import torch_xla.core.xla_model as xm + + xm.rendezvous("load_checkpoint") # wait for all workers + + # Train until the learning rate gets too small + max_epoch = cfg.optimization.max_epoch or math.inf + # max_update = cfg.optimization.max_update or math.inf + lr = trainer.get_lr() + + # TODO: a dry run on validation set to pin the memory + # valid_subsets = cfg.dataset.valid_subset.split(",") + if not cfg.dataset.disable_validation: + logger.info('begin dry-run validation on valid subset') + valid_itr = get_valid_iterator(cfg, vaild_dataset, trainer, task).next_epoch_itr( + shuffle=False, set_dataset_epoch=False # use a fixed valid set + ) + # for subset in valid_subsets: + # logger.info('begin dry-run validation on "{}" subset'.format(subset)) + # itr = trainer.get_valid_iterator(subset).next_epoch_itr( + # shuffle=False, set_dataset_epoch=False # use a fixed valid set + # ) + # if cfg.common.tpu: + # itr = utils.tpu_data_loader(itr) + # for _ in itr: + # pass + # TODO: end of dry run section + + train_meter = meters.StopwatchMeter() + train_meter.start() + + if cfg.comparable.comparable: + comp = Comparable(model, trainer, task, cfg) + + while epoch <= max_epoch: # _itr.next_epoch_idx + if lr <= cfg.optimization.stop_min_lr: + logger.info( + f"stopping training because current learning rate ({lr}) is smaller " + "than or equal to minimum learning rate " + f"(--stop-min-lr={cfg.optimization.stop_min_lr})" + ) + break + + if cfg.comparable.only_unsupervised: + unsup_itr = comp.task.get_batch_iterator(train_dataset, max_sentences=comp.batch_size, epoch=0) + + comp.trainer.begin_epoch(unsup_itr.epoch) + + comp.train(epoch=unsup_itr.epoch, itrs=unsup_itr) + end_of_epoch = True + + logger.info("end of epoch {} (average epoch stats below)".format(epoch)) + + num_updates = comp.trainer.get_num_updates() + + stats = get_training_stats(metrics.get_smoothed_values('train')) + comp.progress.print(stats, tag='train', step=num_updates) + # reset epoch-level meters + metrics.reset_meters('train') + else: + # train for one epoch + logger.info(f"begin epoch") + comp.task.begin_epoch(epoch, comp.trainer.get_model()) + + # Extract parallel data and train + num_updates, end_of_epoch = comp.extract_and_train(cfg.comparable.comparable_data, epoch) + + max_update = cfg.optimization.max_update or math.inf + should_stop = False + + if num_updates >= max_update: + should_stop = True + logger.info( + f"Stopping training due to " + f"num_updates: {num_updates} >= max_update: {max_update}" + ) + + training_time_hours = trainer.cumulative_training_time() / (60 * 60) + if ( + cfg.optimization.stop_time_hours > 0 + and training_time_hours > cfg.optimization.stop_time_hours + ): + should_stop = True + logger.info( + f"Stopping training due to " + f"cumulative_training_time: {training_time_hours} > " + f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)" + ) + + do_save = ( + (end_of_epoch and epoch % cfg.checkpoint.save_interval == 0) + or should_stop + or ( + cfg.checkpoint.save_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.checkpoint.save_interval_updates == 0 + and num_updates >= cfg.dataset.validate_after_updates + ) + ) + do_validate = ( + ( + (not end_of_epoch and do_save) # validate during mid-epoch saves + or (end_of_epoch and epoch % cfg.dataset.validate_interval == 0) + or should_stop + or ( + cfg.dataset.validate_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.dataset.validate_interval_updates == 0 + ) + ) + and not cfg.dataset.disable_validation + and num_updates >= cfg.dataset.validate_after_updates + ) + # epoch_itr. + # Validate + valid_losses = [None] + if do_validate: + valid_losses = comp.validate(epoch, valid_itr) + + valid_itr = get_valid_iterator(cfg, vaild_dataset, trainer, task).next_epoch_itr( + shuffle=False, set_dataset_epoch=False # use a fixed valid set + ) + + should_stop |= should_stop_early(cfg, valid_losses[0]) + + # Save checkpoint + if do_save or should_stop: + cp_path = save_checkpoint( + cfg.checkpoint, trainer, epoch, valid_losses[0] + ) + if cp_path is not None and hasattr(task, "post_save"): + task.post_save(cp_path, num_updates) + + if should_stop: + break + + # only use first validation loss to update the learning rate + lr = trainer.lr_step(epoch, valid_losses[0]) + epoch += 1 + + train_meter.stop() + logger.info("done training in {:.1f} seconds".format(train_meter.sum)) + + # ioPath implementation to wait for all asynchronous file writes to complete. + if cfg.checkpoint.write_checkpoints_asynchronously: + logger.info( + "ioPath PathManager waiting for all asynchronous checkpoint " + "writes to finish." + ) + PathManager.async_close() + logger.info("ioPath PathManager finished waiting.") + +def load_checkpoint(cfg, trainer, **passthrough_args): + """ + Load a checkpoint and restore the training iterator. + + *passthrough_args* will be passed through to + ``trainer.get_train_iterator``. + """ + # only one worker should attempt to create the required dir + reset_optimizer = cfg.checkpoint.reset_optimizer + reset_lr_scheduler = cfg.checkpoint.reset_lr_scheduler + # print(f"cfg.optimizer_overrides: {cfg.optimizer_overrides}") + optimizer_overrides = ast.literal_eval(cfg.checkpoint.optimizer_overrides) + reset_meters = cfg.checkpoint.reset_meters + reset_dataloader = cfg.checkpoint.reset_dataloader + + if cfg.distributed_training.distributed_rank == 0: + print(f"cfg.checkpoint.save_dir: {cfg.checkpoint.save_dir}") + os.makedirs(cfg.checkpoint.save_dir, exist_ok=True) + + if cfg.checkpoint.finetune_from_model is not None and ( + reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader + ): + raise ValueError( + "--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader" + ) + suffix = trainer.checkpoint_suffix + + if cfg.checkpoint.restore_file == "checkpoint_last.pt": + checkpoint_path = os.path.join(cfg.checkpoint.save_dir, "checkpoint_last{}.pt".format(suffix)) + first_launch = not PathManager.exists(checkpoint_path) + if first_launch and getattr(cfg.checkpoint, "continue_once", None) is not None: + checkpoint_path = cfg.checkpoint.continue_once + elif cfg.checkpoint.finetune_from_model is not None and first_launch: + # if there is no last checkpoint to restore, start the finetune from pretrained model + # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. + if PathManager.exists(cfg.checkpoint.finetune_from_model): + checkpoint_path = cfg.checkpoint.finetune_from_model + reset_optimizer = True + reset_lr_scheduler = True + reset_meters = True + reset_dataloader = True + logger.info( + f"loading pretrained model from {checkpoint_path}: " + "optimizer, lr scheduler, meters, dataloader will be reset" + ) + else: + raise ValueError( + f"--finetune-from-model {cfg.finetune_from_model} does not exist" + ) + elif suffix is not None: + checkpoint_path = cfg.checkpoint.restore_file.replace(".pt", suffix + ".pt") + else: + checkpoint_path = os.path.join(cfg.checkpoint.save_dir, cfg.checkpoint.restore_file) + + if cfg.checkpoint.restore_file != "checkpoint_last.pt" and cfg.checkpoint.finetune_from_model: + raise ValueError( + "--finetune-from-model and --restore-file (non-default value) " + "can not be specified together: " + str(cfg) + ) + + extra_state = trainer.load_checkpoint( + checkpoint_path, + reset_optimizer, + reset_lr_scheduler, + optimizer_overrides, + reset_meters=reset_meters, + ) + + # if ( + # extra_state is not None + # and "best" in extra_state + # and not args.reset_optimizer + # and not args.reset_meters + # ): + # save_checkpoint.best = extra_state["best"] + + if ( + extra_state is not None + and "best" in extra_state + and not reset_optimizer + and not reset_meters + ): + save_checkpoint.best = extra_state["best"] + + if extra_state is not None and not reset_dataloader: + # restore iterator from checkpoint + itr_state = extra_state["train_iterator"] + # epoch_itr = trainer.get_train_iterator( + # epoch=itr_state["epoch"], load_dataset=False, **passthrough_args + # ) + epoch = extra_state["train_iterator"]["epoch"] + 1 + # epoch_itr.load_state_dict(itr_state) + else: + epoch = 1 + # epoch_itr = trainer.get_train_iterator( + # epoch=1, load_dataset=False, **passthrough_args + # ) + + trainer.lr_step(epoch) + # trainer.lr_step(epoch_itr.epoch) + + return extra_state, epoch + # return extra_state, epoch_itr + + +def save_checkpoint(cfg: CheckpointConfig, trainer, epoch, val_loss): + from fairseq import meters + + # only one worker should attempt to create the required dir + if trainer.data_parallel_rank == 0: + os.makedirs(cfg.save_dir, exist_ok=True) + + prev_best = getattr(save_checkpoint, "best", val_loss) + if val_loss is not None: + best_function = max if cfg.maximize_best_checkpoint_metric else min + save_checkpoint.best = best_function(val_loss, prev_best) + + if cfg.no_save: + return None + + trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state + + if not trainer.should_save_checkpoint_on_current_rank: + if trainer.always_call_state_dict_during_save_checkpoint: + trainer.state_dict() + return None + + write_timer = meters.StopwatchMeter() + write_timer.start() + + # epoch = epoch_itr.epoch + # end_of_epoch = epoch_itr.end_of_epoch() + updates = trainer.get_num_updates() + + logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") + + def is_better(a, b): + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b + + suffix = trainer.checkpoint_suffix + checkpoint_conds = collections.OrderedDict() + checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( + # end_of_epoch and + not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 + ) + checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( + # not end_of_epoch and + cfg.save_interval_updates > 0 + and updates % cfg.save_interval_updates == 0 + ) + checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( + not hasattr(save_checkpoint, "best") + or is_better(val_loss, save_checkpoint.best) + ) + if val_loss is not None and cfg.keep_best_checkpoints > 0: + worst_best = getattr(save_checkpoint, "best", None) + chkpts = checkpoint_utils.checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix + ), + ) + if len(chkpts) > 0: + p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] + worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) + # add random digits to resolve ties + with data_utils.numpy_seed(epoch, updates, val_loss): + rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints) + + checkpoint_conds[ + "checkpoint.best_{}_{:.3f}{}{}.pt".format( + cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix + ) + ] = worst_best is None or is_better(val_loss, worst_best) + checkpoint_conds[ + "checkpoint_last{}.pt".format(suffix) + ] = not cfg.no_last_checkpoints + + extra_state = {"train_iterator": {"epoch": epoch}, "val_loss": val_loss} + if hasattr(save_checkpoint, "best"): + extra_state.update({"best": save_checkpoint.best}) + + checkpoints = [ + os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + ] + saved_cp = None + if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank: + saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state) + for cp in checkpoints[1:]: + if cfg.write_checkpoints_asynchronously: + # TODO[ioPath]: Need to implement a delayed asynchronous + # file copying/moving feature. + logger.warning( + f"ioPath is not copying {checkpoints[0]} to {cp} " + "since async write mode is on." + ) + else: + assert PathManager.copy( + checkpoints[0], cp, overwrite=True + ), f"Failed to copy {checkpoints[0]} to {cp}" + + write_timer.stop() + logger.info( + "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( + checkpoints[0], epoch, updates, val_loss, write_timer.sum + ) + ) + + # if ( + # # not end_of_epoch and + # cfg.keep_interval_updates > 0 + # and trainer.should_save_checkpoint_on_current_rank + # ): + # # remove old checkpoints; checkpoints are sorted in descending order + # if cfg.keep_interval_updates_pattern == -1: + # checkpoints = checkpoint_paths( + # cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + # ) + # else: + # checkpoints = checkpoint_paths( + # cfg.save_dir, + # pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), + # keep_match=True, + # ) + # checkpoints = [ + # x[0] + # for x in checkpoints + # if x[1] % cfg.keep_interval_updates_pattern != 0 + # ] + + # for old_chk in checkpoints[cfg.keep_interval_updates :]: + # if os.path.lexists(old_chk): + # os.remove(old_chk) + # elif PathManager.exists(old_chk): + # PathManager.rm(old_chk) + + if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_utils.checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_epochs :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank: + # only keep the best N checkpoints according to validation metric + checkpoints = checkpoint_utils.checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix + ), + ) + if not cfg.maximize_best_checkpoint_metric: + checkpoints = checkpoints[::-1] + for old_chk in checkpoints[cfg.keep_best_checkpoints :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + return saved_cp + + +def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: + # skip check if no validation was done in the current epoch + if valid_loss is None: + return False + if cfg.checkpoint.patience <= 0: + return False + + def is_better(a, b): + return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b + + prev_best = getattr(should_stop_early, "best", None) + if prev_best is None or is_better(valid_loss, prev_best): + should_stop_early.best = valid_loss + should_stop_early.num_runs = 0 + return False + else: + should_stop_early.num_runs += 1 + if should_stop_early.num_runs >= cfg.checkpoint.patience: + logger.info( + "early stop since valid performance hasn't improved for last {} runs".format( + cfg.checkpoint.patience + ) + ) + return True + else: + return False + + +@metrics.aggregate("train") +def train( + cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr +) -> Tuple[List[Optional[float]], bool]: + """Train the model for one epoch and return validation losses.""" + # Initialize data iterator + itr = epoch_itr.next_epoch_itr( + fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, + shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), + ) + update_freq = ( + cfg.optimization.update_freq[epoch_itr.epoch - 1] + if epoch_itr.epoch <= len(cfg.optimization.update_freq) + else cfg.optimization.update_freq[-1] + ) + itr = iterators.GroupedIterator( + itr, + update_freq, + skip_remainder_batch=cfg.optimization.skip_remainder_batch, + ) + if cfg.common.tpu: + itr = utils.tpu_data_loader(itr) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_file=cfg.common.log_file, + log_interval=cfg.common.log_interval, + epoch=epoch_itr.epoch, + aim_repo=( + cfg.common.aim_repo + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + aim_run_hash=( + cfg.common.aim_run_hash + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + aim_param_checkpoint_dir=cfg.checkpoint.save_dir, + tensorboard_logdir=( + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) + ), + azureml_logging=( + cfg.common.azureml_logging + if distributed_utils.is_master(cfg.distributed_training) + else False + ), + ) + progress.update_config(_flatten_config(cfg)) + + trainer.begin_epoch(epoch_itr.epoch) + + valid_subsets = cfg.dataset.valid_subset.split(",") + should_stop = False + num_updates = trainer.get_num_updates() + logger.info("Start iterating over samples") + for i, samples in enumerate(progress): + with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( + "train_step-%d" % i + ): + log_output = trainer.train_step(samples) + + if log_output is not None: # not OOM, overflow, ... + # log mid-epoch stats + num_updates = trainer.get_num_updates() + if num_updates % cfg.common.log_interval == 0: + stats = get_training_stats(metrics.get_smoothed_values("train_inner")) + progress.log(stats, tag="train_inner", step=num_updates) + + # reset mid-epoch stats after each log interval + # the end-of-epoch stats will still be preserved + metrics.reset_meters("train_inner") + + end_of_epoch = not itr.has_next() + valid_losses, should_stop = validate_and_save( + cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch + ) + + if should_stop: + break + + # log end-of-epoch stats + logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch)) + stats = get_training_stats(metrics.get_smoothed_values("train")) + progress.print(stats, tag="train", step=num_updates) + + # reset epoch-level meters + metrics.reset_meters("train") + return valid_losses, should_stop + + +def _flatten_config(cfg: DictConfig): + config = OmegaConf.to_container(cfg) + # remove any legacy Namespaces and replace with a single "args" + namespace = None + for k, v in list(config.items()): + if isinstance(v, argparse.Namespace): + namespace = v + del config[k] + if namespace is not None: + config["args"] = vars(namespace) + return config + + +def validate_and_save( + cfg: DictConfig, + trainer: Trainer, + task: tasks.FairseqTask, + epoch_itr, + valid_subsets: List[str], + end_of_epoch: bool, +) -> Tuple[List[Optional[float]], bool]: + num_updates = trainer.get_num_updates() + max_update = cfg.optimization.max_update or math.inf + + # Stopping conditions (and an additional one based on validation loss later + # on) + should_stop = False + if num_updates >= max_update: + should_stop = True + logger.info( + f"Stopping training due to " + f"num_updates: {num_updates} >= max_update: {max_update}" + ) + + training_time_hours = trainer.cumulative_training_time() / (60 * 60) + if ( + cfg.optimization.stop_time_hours > 0 + and training_time_hours > cfg.optimization.stop_time_hours + ): + should_stop = True + logger.info( + f"Stopping training due to " + f"cumulative_training_time: {training_time_hours} > " + f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)" + ) + + do_save = ( + (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) + or should_stop + or ( + cfg.checkpoint.save_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.checkpoint.save_interval_updates == 0 + and num_updates >= cfg.dataset.validate_after_updates + ) + ) + do_validate = ( + ( + (not end_of_epoch and do_save) # validate during mid-epoch saves + or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) + or should_stop + or ( + cfg.dataset.validate_interval_updates > 0 + and num_updates > 0 + and num_updates % cfg.dataset.validate_interval_updates == 0 + ) + ) + and not cfg.dataset.disable_validation + and num_updates >= cfg.dataset.validate_after_updates + ) + + # Validate + valid_losses = [None] + if do_validate: + valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) + + should_stop |= should_stop_early(cfg, valid_losses[0]) + + # Save checkpoint + if do_save or should_stop: + cp_path = checkpoint_utils.save_checkpoint( + cfg.checkpoint, trainer, epoch_itr, valid_losses[0] + ) + if cp_path is not None and hasattr(task, "post_save"): + task.post_save(cp_path, num_updates) + + return valid_losses, should_stop + + +def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: + stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) + return stats + + +def validate( + cfg: DictConfig, + trainer: Trainer, + task: tasks.FairseqTask, + epoch_itr, + subsets: List[str], +) -> List[Optional[float]]: + """Evaluate the model on the validation set(s) and return the losses.""" + + if cfg.dataset.fixed_validation_seed is not None: + # set fixed seed for every validation + utils.set_torch_seed(cfg.dataset.fixed_validation_seed) + + trainer.begin_valid_epoch(epoch_itr.epoch) + valid_losses = [] + for subset_idx, subset in enumerate(subsets): + logger.info('begin validation on "{}" subset'.format(subset)) + + # Initialize data iterator + itr = trainer.get_valid_iterator(subset).next_epoch_itr( + shuffle=False, set_dataset_epoch=False # use a fixed valid set + ) + if cfg.common.tpu: + itr = utils.tpu_data_loader(itr) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + epoch=epoch_itr.epoch, + prefix=f"valid on '{subset}' subset", + aim_repo=( + cfg.common.aim_repo + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + aim_run_hash=( + cfg.common.aim_run_hash + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + aim_param_checkpoint_dir=cfg.checkpoint.save_dir, + tensorboard_logdir=( + cfg.common.tensorboard_logdir + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + wandb_project=( + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) + ), + ) + + # create a new root metrics aggregator so validation metrics + # don't pollute other aggregators (e.g., train meters) + with metrics.aggregate(new_root=True) as agg: + for i, sample in enumerate(progress): + if ( + cfg.dataset.max_valid_steps is not None + and i > cfg.dataset.max_valid_steps + ): + break + trainer.valid_step(sample) + + # log validation stats + # only tracking the best metric on the 1st validation subset + tracking_best = subset_idx == 0 + stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values(), tracking_best) + + if hasattr(task, "post_validate"): + task.post_validate(trainer.get_model(), stats, agg) + + progress.print(stats, tag=subset, step=trainer.get_num_updates()) + + valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) + return valid_losses + + +def get_valid_stats( + cfg: DictConfig, + trainer: Trainer, + stats: Dict[str, Any], + tracking_best: bool, +) -> Dict[str, Any]: + stats["num_updates"] = trainer.get_num_updates() + if tracking_best and hasattr(checkpoint_utils.save_checkpoint, "best"): + key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) + best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min + stats[key] = best_function( + checkpoint_utils.save_checkpoint.best, + stats[cfg.checkpoint.best_checkpoint_metric], + ) + return stats + + +def cli_main( + modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None +) -> None: + # print(f"get parser") + parser = options.get_training_parser() + # print(f"parser: {parser}") + args = options.parse_args_and_arch(parser, modify_parser=modify_parser) + logger.info(f"args: {args}") + + cfg = convert_namespace_to_omegaconf(args) + + if cfg.common.use_plasma_view: + server = PlasmaStore(path=cfg.common.plasma_path) + logger.info( + f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}" + ) + + if cfg.common.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, main) + else: + distributed_utils.call_main(cfg, main) + + # if cfg.common.use_plasma_view: + # server.server.kill() + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq_cli/traincomp.py b/fairseq_cli/traincomp.py index 9b274db648..b12a92134f 100644 --- a/fairseq_cli/traincomp.py +++ b/fairseq_cli/traincomp.py @@ -13,6 +13,11 @@ import os import sys from typing import Any, Callable, Dict, List, Optional, Tuple +from fairseq.file_io import PathManager +from fairseq.dataclass.configs import CheckpointConfig +import logging +import ast +import collections # We need to setup root logger before importing any fairseq libraries. logging.basicConfig( @@ -21,10 +26,12 @@ level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) -logger = logging.getLogger("fairseq_cli.train") +logger = logging.getLogger("fairseq_cli.traincomp") import numpy as np import torch +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') from omegaconf import DictConfig, OmegaConf from fairseq import checkpoint_utils, options, quantization_utils, tasks, utils @@ -44,10 +51,13 @@ def main(cfg: FairseqConfig) -> None: if isinstance(cfg, argparse.Namespace): + print(f"convert namespace") cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) + print(f"added user module") add_defaults(cfg) + print(f"added defaults") if ( distributed_utils.is_master(cfg.distributed_training) @@ -86,20 +96,25 @@ def main(cfg: FairseqConfig) -> None: # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) + logger.info(f"model.arch: {cfg.model.arch}") + # cfg.task.src_dict.add_symbol("") + # cfg.task.tgt_dict.add_symbol("") assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion if cfg.distributed_training.ddp_backend == "fully_sharded": with fsdp_enable_wrap(cfg.distributed_training): - model = fsdp_wrap(task.build_model(cfg.model)) + model = fsdp_wrap(task.build_model(cfg)) # .model else: - model = task.build_model(cfg.model) + model = task.build_model(cfg) # .model criterion = task.build_criterion(cfg.criterion) + # generator = task.build_generator(cfg.generation) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) + # logger.info("generator: {}".format(generator.__class__.__name__)) logger.info( "num. shared model params: {:,} (num. trained: {:,})".format( sum( @@ -146,9 +161,12 @@ def main(cfg: FairseqConfig) -> None: # Build trainer if cfg.common.model_parallel_size == 1: + logger.info("trainer") trainer = Trainer(cfg, task, model, criterion, quantizer) else: + logger.info("MegatronTrainer") trainer = MegatronTrainer(cfg, task, model, criterion) + logger.info( "training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size @@ -163,18 +181,21 @@ def main(cfg: FairseqConfig) -> None: # Load the latest checkpoint if one is available and restore the # corresponding train iterator - extra_state, epoch_itr = checkpoint_utils.load_checkpoint( - cfg.checkpoint, - trainer, - # don't cache epoch iterators for sharded datasets - disable_iterator_cache=task.has_sharded_data("train"), - ) + # extra_state, epoch_itr = checkpoint_utils.load_checkpoint( + # cfg.checkpoint, + # trainer, + # # don't cache epoch iterators for sharded datasets + # disable_iterator_cache=task.has_sharded_data("train"), + # ) + extra_state, epoch = load_checkpoint(cfg, trainer) if cfg.common.tpu: import torch_xla.core.xla_model as xm xm.rendezvous("load_checkpoint") # wait for all workers + # Train until the learning rate gets too small max_epoch = cfg.optimization.max_epoch or math.inf + # max_update = cfg.optimization.max_update or math.inf lr = trainer.get_lr() # TODO: a dry run on validation set to pin the memory @@ -194,11 +215,10 @@ def main(cfg: FairseqConfig) -> None: train_meter = meters.StopwatchMeter() train_meter.start() - if cfg.comparable.comparable: comp = Comparable(model, trainer, task, cfg) - while epoch_itr.next_epoch_idx <= max_epoch: + while epoch <= max_epoch: # _itr.next_epoch_idx if lr <= cfg.optimization.stop_min_lr: logger.info( f"stopping training because current learning rate ({lr}) is smaller " @@ -208,13 +228,16 @@ def main(cfg: FairseqConfig) -> None: break # train for one epoch - comp.task.begin_epoch(epoch_itr.next_epoch_idx, comp.trainer.get_model()) + print(f"begin epoch") + comp.task.begin_epoch(epoch, comp.trainer.get_model()) + # _itr.next_epoch_idx # valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) # if should_stop: # break - print(f"epoch_itr.next_epoch_id: {epoch_itr.next_epoch_id}") - print(f"epoch_itr.epoch: {epoch_itr.epoch}") - num_updates, end_of_epoch = comp.extract_and_train(cfg.comparable.comparable_data, epoch_itr.next_epoch_idx) + # print(f"epoch_itr.next_epoch_id: {epoch_itr.next_epoch_id}") + # print(f"epoch_itr.epoch: {epoch_itr.epoch}") + # Extract parallel data and train + num_updates, end_of_epoch = comp.extract_and_train(cfg.comparable.comparable_data, epoch) #_itr.next_epoch_idx max_update = cfg.optimization.max_update or math.inf should_stop = False @@ -238,7 +261,7 @@ def main(cfg: FairseqConfig) -> None: ) do_save = ( - (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) + (end_of_epoch and epoch % cfg.checkpoint.save_interval == 0) or should_stop or ( cfg.checkpoint.save_interval_updates > 0 @@ -250,7 +273,7 @@ def main(cfg: FairseqConfig) -> None: do_validate = ( ( (not end_of_epoch and do_save) # validate during mid-epoch saves - or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) + or (end_of_epoch and epoch % cfg.dataset.validate_interval == 0) or should_stop or ( cfg.dataset.validate_interval_updates > 0 @@ -261,10 +284,12 @@ def main(cfg: FairseqConfig) -> None: and not cfg.dataset.disable_validation and num_updates >= cfg.dataset.validate_after_updates ) + # epoch_itr. # Validate valid_losses = [None] if do_validate: - valid_losses = comp.validate(epoch_itr.next_epoch_idx, valid_subsets) + valid_losses = comp.validate(epoch, valid_subsets) + # _itr.next_epoch_idx # if (not cfg.dataset.disable_validation # and cfg.checkpoint.save_interval_updates > 0 # and num_updates % cfg.checkpoint.save_interval_updates == 0 @@ -278,8 +303,8 @@ def main(cfg: FairseqConfig) -> None: # Save checkpoint if do_save or should_stop: - cp_path = checkpoint_utils.save_checkpoint( - cfg.checkpoint, trainer, epoch_itr, valid_losses[0] + cp_path = save_checkpoint( + cfg.checkpoint, trainer, epoch, valid_losses[0] ) if cp_path is not None and hasattr(task, "post_save"): task.post_save(cp_path, num_updates) @@ -288,15 +313,16 @@ def main(cfg: FairseqConfig) -> None: break # only use first validation loss to update the learning rate - lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) - - epoch_itr = trainer.get_train_iterator( - epoch_itr.next_epoch_idx, - # sharded data: get train iterator for next epoch - load_dataset=task.has_sharded_data("train"), - # don't cache epoch iterators for sharded datasets - disable_iterator_cache=task.has_sharded_data("train"), - ) + lr = trainer.lr_step(epoch, valid_losses[0]) + epoch += 1 + + # epoch_itr = trainer.get_train_iterator( + # epoch, + # # sharded data: get train iterator for next epoch + # load_dataset=task.has_sharded_data("train"), + # # don't cache epoch iterators for sharded datasets + # disable_iterator_cache=task.has_sharded_data("train"), + # ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) @@ -309,6 +335,272 @@ def main(cfg: FairseqConfig) -> None: PathManager.async_close() logger.info("ioPath PathManager finished waiting.") +def load_checkpoint(cfg, trainer, **passthrough_args): + """ + Load a checkpoint and restore the training iterator. + + *passthrough_args* will be passed through to + ``trainer.get_train_iterator``. + """ + # only one worker should attempt to create the required dir + reset_optimizer = cfg.checkpoint.reset_optimizer + reset_lr_scheduler = cfg.checkpoint.reset_lr_scheduler + # print(f"cfg.optimizer_overrides: {cfg.optimizer_overrides}") + optimizer_overrides = ast.literal_eval(cfg.checkpoint.optimizer_overrides) + reset_meters = cfg.checkpoint.reset_meters + reset_dataloader = cfg.checkpoint.reset_dataloader + + if cfg.distributed_training.distributed_rank == 0: + print(f"cfg.checkpoint.save_dir: {cfg.checkpoint.save_dir}") + os.makedirs(cfg.checkpoint.save_dir, exist_ok=True) + + if cfg.checkpoint.finetune_from_model is not None and ( + reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader + ): + raise ValueError( + "--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader" + ) + suffix = trainer.checkpoint_suffix + + if cfg.checkpoint.restore_file == "checkpoint_last.pt": + checkpoint_path = os.path.join(cfg.checkpoint.save_dir, "checkpoint_last{}.pt".format(suffix)) + first_launch = not PathManager.exists(checkpoint_path) + if first_launch and getattr(cfg.checkpoint, "continue_once", None) is not None: + checkpoint_path = cfg.checkpoint.continue_once + elif cfg.checkpoint.finetune_from_model is not None and first_launch: + # if there is no last checkpoint to restore, start the finetune from pretrained model + # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. + if PathManager.exists(cfg.checkpoint.finetune_from_model): + checkpoint_path = cfg.checkpoint.finetune_from_model + reset_optimizer = True + reset_lr_scheduler = True + reset_meters = True + reset_dataloader = True + logger.info( + f"loading pretrained model from {checkpoint_path}: " + "optimizer, lr scheduler, meters, dataloader will be reset" + ) + else: + raise ValueError( + f"--finetune-from-model {cfg.finetune_from_model} does not exist" + ) + elif suffix is not None: + checkpoint_path = cfg.checkpoint.restore_file.replace(".pt", suffix + ".pt") + else: + checkpoint_path = os.path.join(cfg.checkpoint.save_dir, cfg.checkpoint.restore_file) + + if cfg.checkpoint.restore_file != "checkpoint_last.pt" and cfg.checkpoint.finetune_from_model: + raise ValueError( + "--finetune-from-model and --restore-file (non-default value) " + "can not be specified together: " + str(cfg) + ) + + extra_state = trainer.load_checkpoint( + checkpoint_path, + reset_optimizer, + reset_lr_scheduler, + optimizer_overrides, + reset_meters=reset_meters, + ) + + # if ( + # extra_state is not None + # and "best" in extra_state + # and not args.reset_optimizer + # and not args.reset_meters + # ): + # save_checkpoint.best = extra_state["best"] + + if ( + extra_state is not None + and "best" in extra_state + and not reset_optimizer + and not reset_meters + ): + save_checkpoint.best = extra_state["best"] + + if extra_state is not None and not reset_dataloader: + # restore iterator from checkpoint + itr_state = extra_state["train_iterator"] + # epoch_itr = trainer.get_train_iterator( + # epoch=itr_state["epoch"], load_dataset=False, **passthrough_args + # ) + epoch = extra_state["train_iterator"]["epoch"] + 1 + # epoch_itr.load_state_dict(itr_state) + else: + epoch = 1 + # epoch_itr = trainer.get_train_iterator( + # epoch=1, load_dataset=False, **passthrough_args + # ) + + trainer.lr_step(epoch) + # trainer.lr_step(epoch_itr.epoch) + + return extra_state, epoch + # return extra_state, epoch_itr + + +def save_checkpoint(cfg: CheckpointConfig, trainer, epoch, val_loss): + from fairseq import meters + + # only one worker should attempt to create the required dir + if trainer.data_parallel_rank == 0: + os.makedirs(cfg.save_dir, exist_ok=True) + + prev_best = getattr(save_checkpoint, "best", val_loss) + if val_loss is not None: + best_function = max if cfg.maximize_best_checkpoint_metric else min + save_checkpoint.best = best_function(val_loss, prev_best) + + if cfg.no_save: + return None + + trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state + + if not trainer.should_save_checkpoint_on_current_rank: + if trainer.always_call_state_dict_during_save_checkpoint: + trainer.state_dict() + return None + + write_timer = meters.StopwatchMeter() + write_timer.start() + + # epoch = epoch_itr.epoch + # end_of_epoch = epoch_itr.end_of_epoch() + updates = trainer.get_num_updates() + + logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") + + def is_better(a, b): + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b + + suffix = trainer.checkpoint_suffix + checkpoint_conds = collections.OrderedDict() + checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( + # end_of_epoch and + not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 + ) + checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( + # not end_of_epoch and + cfg.save_interval_updates > 0 + and updates % cfg.save_interval_updates == 0 + ) + checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( + not hasattr(save_checkpoint, "best") + or is_better(val_loss, save_checkpoint.best) + ) + if val_loss is not None and cfg.keep_best_checkpoints > 0: + worst_best = getattr(save_checkpoint, "best", None) + chkpts = checkpoint_utils.checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix + ), + ) + if len(chkpts) > 0: + p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] + worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) + # add random digits to resolve ties + with data_utils.numpy_seed(epoch, updates, val_loss): + rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints) + + checkpoint_conds[ + "checkpoint.best_{}_{:.3f}{}{}.pt".format( + cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix + ) + ] = worst_best is None or is_better(val_loss, worst_best) + checkpoint_conds[ + "checkpoint_last{}.pt".format(suffix) + ] = not cfg.no_last_checkpoints + + extra_state = {"train_iterator": {"epoch": epoch}, "val_loss": val_loss} + if hasattr(save_checkpoint, "best"): + extra_state.update({"best": save_checkpoint.best}) + + checkpoints = [ + os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + ] + saved_cp = None + if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank: + saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state) + for cp in checkpoints[1:]: + if cfg.write_checkpoints_asynchronously: + # TODO[ioPath]: Need to implement a delayed asynchronous + # file copying/moving feature. + logger.warning( + f"ioPath is not copying {checkpoints[0]} to {cp} " + "since async write mode is on." + ) + else: + assert PathManager.copy( + checkpoints[0], cp, overwrite=True + ), f"Failed to copy {checkpoints[0]} to {cp}" + + write_timer.stop() + logger.info( + "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( + checkpoints[0], epoch, updates, val_loss, write_timer.sum + ) + ) + + # if ( + # # not end_of_epoch and + # cfg.keep_interval_updates > 0 + # and trainer.should_save_checkpoint_on_current_rank + # ): + # # remove old checkpoints; checkpoints are sorted in descending order + # if cfg.keep_interval_updates_pattern == -1: + # checkpoints = checkpoint_paths( + # cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) + # ) + # else: + # checkpoints = checkpoint_paths( + # cfg.save_dir, + # pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), + # keep_match=True, + # ) + # checkpoints = [ + # x[0] + # for x in checkpoints + # if x[1] % cfg.keep_interval_updates_pattern != 0 + # ] + + # for old_chk in checkpoints[cfg.keep_interval_updates :]: + # if os.path.lexists(old_chk): + # os.remove(old_chk) + # elif PathManager.exists(old_chk): + # PathManager.rm(old_chk) + + if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_utils.checkpoint_paths( + cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) + ) + for old_chk in checkpoints[cfg.keep_last_epochs :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank: + # only keep the best N checkpoints according to validation metric + checkpoints = checkpoint_utils.checkpoint_paths( + cfg.save_dir, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( + cfg.best_checkpoint_metric, suffix + ), + ) + if not cfg.maximize_best_checkpoint_metric: + checkpoints = checkpoints[::-1] + for old_chk in checkpoints[cfg.keep_best_checkpoints :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + elif PathManager.exists(old_chk): + PathManager.rm(old_chk) + + return saved_cp + def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: # skip check if no validation was done in the current epoch @@ -636,8 +928,11 @@ def get_valid_stats( def cli_main( modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None ) -> None: + print(f"get parser") parser = options.get_training_parser() + print(f"parser: {parser}") args = options.parse_args_and_arch(parser, modify_parser=modify_parser) + print(f"args: {args}") cfg = convert_namespace_to_omegaconf(args) @@ -647,7 +942,7 @@ def cli_main( f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}" ) - if args.profile: + if cfg.common.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): distributed_utils.call_main(cfg, main) diff --git a/scripts/clean_hansard.py b/scripts/clean_hansard.py new file mode 100644 index 0000000000..ccf211f758 --- /dev/null +++ b/scripts/clean_hansard.py @@ -0,0 +1,28 @@ +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='extract comparable corpus from Haifa Hansard') + parser.add_argument("--out", default="/netscratch/jalota/datasets/haifa-hansard/") + parser.add_argument("--inp", default="/netscratch/jalota/datasets/haifa-hansard/train/uniq_og") + parser.add_argument("--split", default='train') + parser.add_argument("--name", default="pp_uniq_og", help="without extension") + args = parser.parse_args() + clean = False # for MOTRA True for Hansard + count = 0 + + with open(args.inp) as f: + # with open(f"{args.out}{args.split}/{args.name}", "w") as fo: + with open(args.out, "w") as fo: + for line in f.readlines(): + wds = line.split() + if clean: + if len(wds) > 3 and len(wds) < 505: + fo.write(f"{line}") + else: + if len(wds) > 505: + count +=1 + print(len(wds)) + else: + fo.write(f"{line}") + + print(f"{count} lines have length > 505:") diff --git a/scripts/detokenize_bpe.py b/scripts/detokenize_bpe.py new file mode 100644 index 0000000000..c4c89344e5 --- /dev/null +++ b/scripts/detokenize_bpe.py @@ -0,0 +1,17 @@ +import argparse + +def detokenize_bpe_string(bpe_string): + return bpe_string.replace("@@ ", "") + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='extract comparable corpus from Haifa Hansard') + parser.add_argument("--out", default="/netscratch/jalota/datasets/haifa-hansard/test/") + parser.add_argument("--inp", default="/netscratch/jalota/datasets/haifa-hansard/fairseq-pp/unsup_setup/filtered/test.tr-og.og") + parser.add_argument("--name", default="new_filtered_og", help="without extension") + args = parser.parse_args() + + with open(args.inp) as f: + with open(f"{args.out}/{args.name}", "w") as fo: + # with open(args.out, "w") as fo: + for line in f.readlines(): + fo.write(f"{detokenize_bpe_string(line)}") \ No newline at end of file diff --git a/scripts/finetuning_requirements.txt b/scripts/finetuning_requirements.txt new file mode 100644 index 0000000000..0a652a9a79 --- /dev/null +++ b/scripts/finetuning_requirements.txt @@ -0,0 +1,7 @@ +accelerate >= 0.12.0 +torch >= 1.3 +datasets >= 1.8.0 +sentencepiece != 0.1.92 +protobuf +evaluate +scikit-learn \ No newline at end of file diff --git a/scripts/get_balanced_data.py b/scripts/get_balanced_data.py new file mode 100644 index 0000000000..587dd49881 --- /dev/null +++ b/scripts/get_balanced_data.py @@ -0,0 +1,86 @@ +import argparse +import random +from pathlib import Path + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--bpe", default="/netscratch/jalota/datasets/haifa-hansard/test/bpe-12k/") + parser.add_argument("--txt", default="/netscratch/jalota/datasets/haifa-hansard/test/") + parser.add_argument("--nsamples", default=4068, help="should be equal to the num samples in translated.txt", type=int) + args = parser.parse_args() + for txt_file in Path(args.txt).glob('*.txt'): + txt_file = str(txt_file) + # if 'bt_original' in txt_file: + # with open(txt_file) as f: + # bt_txt = f.readlines() + # elif 'original' in txt_file: + # with open(txt_file) as f: + # og_txt = f.readlines() + if 'translated' in txt_file: + with open(txt_file) as f: + tr_txt = f.readlines() + else: + continue + for txt_file in Path(args.bpe).glob('*.bpe'): + txt_file = str(txt_file) + # if 'bt_original' in txt_file: + # with open(txt_file) as f: + # bt_bpe = f.readlines() + # elif 'original' in txt_file: + # with open(txt_file) as f: + # og_bpe = f.readlines() + if 'translated' in txt_file: + with open(txt_file) as f: + tr_bpe = f.readlines() + else: + continue + + # print(f"len(og_bpe): {len(og_bpe)}") + # print(f"len(og_txt): {len(og_txt)}") + # print(f"len(bt_bpe): {len(bt_bpe)}") + # print(f"len(bt_txt): {len(bt_txt)}") + print(f"len(tr_txt): {len(tr_txt)}") + print(f"len(tr_bpe): {len(tr_bpe)}") + + # r = [random.randint(0, len(og_bpe)-1) for _ in range(args.nsamples)] + r = [random.randint(0, len(tr_bpe)-1) for _ in range(args.nsamples)] + # og_txt = [og_txt[i] for i in r] + # og_bpe = [og_bpe[i] for i in r] + # bt_bpe = [bt_bpe[i] for i in r] + # bt_txt = [bt_txt[i] for i in r] + tr_bpe = [tr_bpe[i] for i in r] + tr_txt = [tr_txt[i] for i in r] + # print(len(og_bpe)) + # print(og_bpe[0]) + # print(bt_bpe[0]) + print(tr_bpe[0]) + print(tr_txt[0]) + + # with open(f"{args.txt}bt_bal.txt", "w") as wf: + # for line in bt_txt: + # wf.write(line) + + # with open(f"{args.txt}og_bal.txt", "w") as wf: + # for line in og_txt: + # wf.write(line) + + # with open(f"{args.bpe}bt_bal.bpe", "w") as wf: + # for line in bt_bpe: + # wf.write(line) + + # with open(f"{args.bpe}og_bal.bpe", "w") as wf: + # for line in og_bpe: + # wf.write(line) + + with open(f"{args.txt}tr_bal.txt", "w") as wf: + for line in tr_txt: + wf.write(line) + + with open(f"{args.bpe}tr_bal.bpe", "w") as wf: + for line in tr_bpe: + wf.write(line) + + + + + diff --git a/scripts/get_bt_data.py b/scripts/get_bt_data.py new file mode 100644 index 0000000000..a32c9a110d --- /dev/null +++ b/scripts/get_bt_data.py @@ -0,0 +1,48 @@ +from easynmt import EasyNMT +m2pivot = EasyNMT('m2m_100_418M') # mbart50_en2m +pivot2m = EasyNMT('m2m_100_418M') # mbart50_m2en + +# process_pool = en2es.start_multi_process_pool(['cuda', 'cuda']) +# pp = es2en.start_multi_process_pool(['cuda', 'cuda']) +# /netscratch/jalota/datasets/haifa-hansard/dev/original.txt +with open("/netscratch/jalota/datasets/motra-preprocessed/de_all/train/og_bal.tok") as f: + lines = f.readlines() + # es_translations = en2es.translate(lines, target_lang='es') + count = 0 + es_trans = [] + with open("/netscratch/jalota/datasets/motra-preprocessed/de_all/train/de2en.txt", "w") as fo: + for translation in m2pivot.translate_stream(lines, show_progress_bar=True, chunk_size=80, source_lang='de', target_lang='en'): + count +=1 + print(count) + # es_trans.append(translation) + fo.write(translation) + # fo.write("\n") + #Do some warm-up + # en2es.translate_multi_process(process_pool, lines[:100], source_lang='en', target_lang='es', show_progress_bar=False) + # es_translations_multi_p = en2es.translate_multi_process(process_pool, lines, source_lang='en', target_lang='es', show_progress_bar=True) + # en2es.stop_multi_process_pool(process_pool) + + #Do some warm-up + # es2en.translate_multi_process(pp, es_translations_multi_p[:100], source_lang='es', target_lang='en', show_progress_bar=False) + # en_translations_multi_p = es2en.translate_multi_process(pp, es_translations_multi_p, source_lang='es', target_lang='en', show_progress_bar=True) + # es2en.stop_multi_process_pool(pp) + # en_trans = [] +with open("/netscratch/jalota/datasets/motra-preprocessed/de_all/train/de2en.txt") as f: + lines = f.readlines() + count = 0 + with open("/netscratch/jalota/datasets/motra-preprocessed/de_all/train/bt_de.txt", "w") as foo: + for translation in pivot2m.translate_stream(lines, show_progress_bar=True, chunk_size=80, source_lang='en', target_lang='de'): + count +=1 + print(count) + foo.write(translation) + # foo.write("\n") + # en_trans.append(translation) + # en_translations = es2en.translate(es_trans, target_lang='en') + +# --- Remove blank lines --- +# awk -i inplace NF bt_og.txt + +# with open("/netscratch/jalota/datasets/motra-preprocessed/en_es/test/bt_og.txt", "w") as f: +# for tr in en_trans: +# f.write(tr) +# f.write("\n") \ No newline at end of file diff --git a/scripts/get_bt_fairseq.py b/scripts/get_bt_fairseq.py new file mode 100644 index 0000000000..71e9e01e71 --- /dev/null +++ b/scripts/get_bt_fairseq.py @@ -0,0 +1,26 @@ +import torch + +de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en.single_model', tokenizer='moses', bpe='fastbpe') + +en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model', + tokenizer='moses', bpe='fastbpe') + +de2en.eval() +en2de.eval() + +en2de.cuda() +de2en.cuda() + +with open("/netscratch/jalota/datasets/motra-preprocessed/en_de/train/original.txt") as f: + lines = f.readlines() + out = [] + de = en2de.translate(lines) + print("done de translations") + bt = de2en.translate(de) + print("got round-trip translations") + # out.append(bt) + + with open("/netscratch/jalota/datasets/motra-preprocessed/en_de/train/bt_original.txt", "w") as foo: + for line in bt: + foo.write(line) + foo.write("\n") \ No newline at end of file diff --git a/scripts/get_comparable_hansard.py b/scripts/get_comparable_hansard.py new file mode 100644 index 0000000000..8c9af97a88 --- /dev/null +++ b/scripts/get_comparable_hansard.py @@ -0,0 +1,48 @@ +import pandas as pd +from pathlib import Path +import argparse + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='extract comparable corpus from Haifa Hansard') + parser.add_argument("--dir", default="/ds/text/corpora_translationese_research_rabinovich/hansard.EN-FR/committees/") + parser.add_argument("--out", default="/netscratch/jalota/datasets/haifa-hansard/") + #parser.add_argument("--fname", default="snli_train", help="without extension") + args = parser.parse_args() + + Path(args.out).mkdir(parents=True, exist_ok=True) + + for path in Path(args.dir).glob('devtest.*'): + print(path) + path = str(path) + + og_ids = set() + tr_ids = set() + names = ['dev1', 'dev2', 'test1', 'test2'] + + for name in names: + og = open(args.out+name+"_original2.txt", 'w') + tr = open(args.out+name+"_translated_fr2.txt", 'w') + + with open(path+"/"+name+".id") as idf: + ids = idf.readlines() + for i, line in enumerate(ids): + if 'LANGUAGE="EN"' in line: + og_ids.add(i) + else: + tr_ids.add(i) + print(len(tr_ids)) + print(len(og_ids)) + + + with open(path+"/"+name+".en.tok") as f: + for i, line in enumerate(f): + if i in og_ids: + og.write(line) + else: + tr.write(line) + + og.close() + tr.close() + + + diff --git a/scripts/gpt2_accelerate_finetuning.py b/scripts/gpt2_accelerate_finetuning.py new file mode 100644 index 0000000000..a9b0e26ee7 --- /dev/null +++ b/scripts/gpt2_accelerate_finetuning.py @@ -0,0 +1,89 @@ +import os +from accelerate import infer_auto_device_map, init_empty_weights +os.environ['TRANSFORMERS_CACHE'] = '/netscratch/jalota/hf-cache/' +from datasets import load_dataset +datasets = load_dataset("text", data_files={"train": "/netscratch/jalota/datasets/haifa-hansard/train/original.txt", "validation":"/netscratch/jalota/datasets/haifa-hansard/dev/original.txt"}) + +# print(datasets["train"][0]) +# print(datasets["train"][200]) +# print(datasets["train"][10]) +# print(datasets["train"][1000]) +# print(datasets["train"][50]) +# print(datasets["train"][100]) + +model_checkpoint = "gpt2" +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) + +def tokenize_function(examples): + return tokenizer(examples["text"]) + +tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"]) + +# print(tokenized_datasets["train"][1]) +block_size = 512 + +def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + +lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + batch_size=8000, + num_proc=12, +) + +# print(tokenizer.decode(lm_datasets["train"][1]["input_ids"])) + +from transformers import AutoConfig, AutoModelForCausalLM +import torch + +config = AutoConfig.from_pretrained(model_checkpoint) +with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + +model = AutoModelForCausalLM.from_pretrained(model_checkpoint, device_map="auto", offload_folder="offload", offload_state_dict = True, torch_dtype=torch.float16) + +device_map = infer_auto_device_map(model) + +from transformers import Trainer, TrainingArguments +model_name = model_checkpoint.split("/")[-1] + +training_args = TrainingArguments( + f"/netscratch/jalota/checkpoints/{model_name}-finetuned-CanadianHansardOriginals", + evaluation_strategy="epoch", + learning_rate=2e-5, + weight_decay=0.01, + push_to_hub=False, + num_train_epochs=5, +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=lm_datasets["train"], + eval_dataset=lm_datasets["validation"], +) + +print("started training") + +trainer.train() +import math +eval_results = trainer.evaluate() +print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}") + +trainer.save_model("/netscratch/jalota/checkpoints/huggingface-gpt2-3/") + + diff --git a/scripts/gpt2_finetuning.py b/scripts/gpt2_finetuning.py new file mode 100644 index 0000000000..e3b217b643 --- /dev/null +++ b/scripts/gpt2_finetuning.py @@ -0,0 +1,78 @@ +import os +os.environ['TRANSFORMERS_CACHE'] = '/netscratch/jalota/hf-cache/' +from datasets import load_dataset +datasets = load_dataset("text", data_files={"train": "/netscratch/jalota/datasets/haifa-hansard/train/original.txt", "validation":"/netscratch/jalota/datasets/haifa-hansard/dev/original.txt"}) + +# print(datasets["train"][0]) +# print(datasets["train"][200]) +# print(datasets["train"][10]) +# print(datasets["train"][1000]) +# print(datasets["train"][50]) +# print(datasets["train"][100]) + +model_checkpoint = "gpt2" +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) + +def tokenize_function(examples): + return tokenizer(examples["text"]) + +tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"]) + +# print(tokenized_datasets["train"][1]) +block_size = 512 + +def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + +lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + batch_size=8000, + num_proc=12, +) + +# print(tokenizer.decode(lm_datasets["train"][1]["input_ids"])) + +from transformers import AutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained(model_checkpoint) +from transformers import Trainer, TrainingArguments +model_name = model_checkpoint.split("/")[-1] +training_args = TrainingArguments( + f"/netscratch/jalota/checkpoints/{model_name}-finetuned-CanadianHansardOriginals", + evaluation_strategy="epoch", + learning_rate=2e-5, + weight_decay=0.01, + push_to_hub=False, + num_train_epochs=5, +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=lm_datasets["train"], + eval_dataset=lm_datasets["validation"], +) + +print("started training") + +trainer.train() +import math +eval_results = trainer.evaluate() +print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}") + +trainer.save_model("/netscratch/jalota/checkpoints/huggingface-gpt2-3/") + + diff --git a/scripts/pp_accepted_pairs.py b/scripts/pp_accepted_pairs.py new file mode 100644 index 0000000000..636b5831b5 --- /dev/null +++ b/scripts/pp_accepted_pairs.py @@ -0,0 +1,31 @@ +import argparse +import pandas as pd + +def read_file(path): + srcs = [] + tgts = [] + with open(path) as f: + lines = f.readlines() + for line in lines: + parts = line.split("\t") + for part in parts: + if part.startswith('src'): + instances = part.split() + srcs.append(instances[1].strip()) + elif part.startswith('tgt'): + instances = part.split() + tgts.append(instances[1].strip()) + else: + continue + return srcs, tgts + +if __name__ == '__main__': + # argparse the file name and outdir + src, tgt = read_file(path) + sdf = pd.DataFrame(src) + tdf = pd.DataFrame(tgt) + sdf['label'] = 1 + tdf['label'] = 0 + df = pd.concat([sdf, tdf], ignore_index=True) + df = df.sample(frac=1).reset_index(drop=True) + df.to_csv(args.out) diff --git a/scripts/run_clm_no_trainer.py b/scripts/run_clm_no_trainer.py new file mode 100644 index 0000000000..c0fb33980c --- /dev/null +++ b/scripts/run_clm_no_trainer.py @@ -0,0 +1,686 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) +on a text file or a dataset without using HuggingFace Trainer. + +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=text-generation +""" +# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. + +import argparse +import json +import logging +import math +import os +os.environ['TRANSFORMERS_CACHE'] = '/netscratch/jalota/hf-cache/' +import random +from itertools import chain +from pathlib import Path + +import datasets +import torch +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from datasets import load_dataset +from huggingface_hub import Repository, create_repo +from torch.utils.data import DataLoader +from tqdm import tqdm + +import transformers +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + SchedulerType, + default_data_collator, + get_scheduler, +) +from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.30.0.dev0") + +logger = get_logger(__name__) + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a causal language modeling task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) + parser.add_argument( + "--validation_split_percentage", + default=5, + help="The percentage of the train set used as validation set in case there's no validation split", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + parser.add_argument( + "--block_size", + type=int, + default=None, + help=( + "Optional input sequence length after tokenization. The training dataset will be truncated in block of" + " this size for training. Default to the model max input length for single sentence inputs (take into" + " account special tokens)." + ), + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument( + "--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using TXT files." + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`." + ) + parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--checkpointing_steps", + type=str, + default=None, + help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--with_tracking", + action="store_true", + help="Whether to enable experiment trackers for logging.", + ) + parser.add_argument( + "--report_to", + type=str, + default="all", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' + ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' + "Only applicable when `--with_tracking` is passed." + ), + ) + parser.add_argument( + "--low_cpu_mem_usage", + action="store_true", + help=( + "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." + "If passed, LLM loading time and RAM consumption will be benefited." + ), + ) + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, json or txt file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, json or txt file." + + if args.push_to_hub: + assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed." + + return args + + +def main(): + args = parse_args() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_clm_no_trainer", args) + + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers + # in the environment + accelerator_log_kwargs = {} + + if args.with_tracking: + accelerator_log_kwargs["log_with"] = args.report_to + accelerator_log_kwargs["logging_dir"] = args.output_dir + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + create_repo(repo_name, exist_ok=True, token=args.hub_token) + repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + accelerator.wait_for_everyone() + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[:{args.validation_split_percentage}%]", + ) + raw_datasets["train"] = load_dataset( + args.dataset_name, + args.dataset_config_name, + split=f"train[{args.validation_split_percentage}%:]", + ) + else: + data_files = {} + dataset_args = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + if extension == "txt": + extension = "text" + dataset_args["keep_linebreaks"] = not args.no_keep_linebreaks + raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args) + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{args.validation_split_percentage}%]", + **dataset_args, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{args.validation_split_percentage}%:]", + **dataset_args, + ) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name, cache_dir="/netscratch/jalota/datasets/hf-cache/") + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir="/netscratch/jalota/datasets/hf-cache/") + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer, cache_dir="/netscratch/jalota/datasets/hf-cache/") + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer, cache_dir="/netscratch/jalota/datasets/hf-cache/") + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if args.model_name_or_path: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + low_cpu_mem_usage=args.low_cpu_mem_usage, cache_dir="/netscratch/jalota/datasets/hf-cache/" + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForCausalLM.from_config(config) + + # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch + # on a small vocab and want a smaller embedding size, remove this test. + embedding_size = model.get_input_embeddings().weight.shape[0] + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + def tokenize_function(examples): + return tokenizer(examples[text_column_name]) + + with accelerator.main_process_first(): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" + " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" + " override this default with `--block_size xxx`." + ) + block_size = 1024 + else: + if args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(args.block_size, tokenizer.model_max_length) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + with accelerator.main_process_first(): + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=not args.overwrite_cache, + desc=f"Grouping texts in chunks of {block_size}", + ) + + train_dataset = lm_datasets["train"] + eval_dataset = lm_datasets["validation"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # DataLoaders creation: + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size + ) + eval_dataloader = DataLoader( + eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size + ) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "layer_norm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties. + if accelerator.distributed_type == DistributedType.TPU: + model.tie_weights() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Figure out how many steps we should save the Accelerator states + checkpointing_steps = args.checkpointing_steps + if checkpointing_steps is not None and checkpointing_steps.isdigit(): + checkpointing_steps = int(checkpointing_steps) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if args.with_tracking: + experiment_config = vars(args) + # TensorBoard cannot log Enums, need the raw value + experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value + accelerator.init_trackers("clm_no_trainer", experiment_config) + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + starting_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": + accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") + accelerator.load_state(args.resume_from_checkpoint) + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] + dirs.sort(key=os.path.getctime) + path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last + # Extract `epoch_{i}` or `step_{i}` + training_difference = os.path.splitext(path)[0] + + if "epoch" in training_difference: + starting_epoch = int(training_difference.replace("epoch_", "")) + 1 + resume_step = None + else: + # need to multiply `gradient_accumulation_steps` to reflect real steps + resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + starting_epoch = resume_step // len(train_dataloader) + resume_step -= starting_epoch * len(train_dataloader) + + # update the progress_bar if load from checkpoint + progress_bar.update(starting_epoch * num_update_steps_per_epoch) + completed_steps = starting_epoch * num_update_steps_per_epoch + + for epoch in range(starting_epoch, args.num_train_epochs): + model.train() + if args.with_tracking: + total_loss = 0 + for step, batch in enumerate(train_dataloader): + # We need to skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == starting_epoch: + if resume_step is not None and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + completed_steps += 1 + continue + + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + # We keep track of the loss at each epoch + if args.with_tracking: + total_loss += loss.detach().float() + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + completed_steps += 1 + + if isinstance(checkpointing_steps, int): + if completed_steps % checkpointing_steps == 0: + output_dir = f"step_{completed_steps }" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + if completed_steps >= args.max_train_steps: + break + + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = model(**batch) + + loss = outputs.loss + losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size))) + + losses = torch.cat(losses) + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}") + + if args.with_tracking: + accelerator.log( + { + "perplexity": perplexity, + "eval_loss": eval_loss, + "train_loss": total_loss.item() / len(train_dataloader), + "epoch": epoch, + "step": completed_steps, + }, + step=completed_steps, + ) + + if args.push_to_hub and epoch < args.num_train_epochs - 1: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + repo.push_to_hub( + commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True + ) + + if args.checkpointing_steps == "epoch": + output_dir = f"epoch_{epoch}" + if args.output_dir is not None: + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + + if args.with_tracking: + accelerator.end_training() + + if args.output_dir is not None: + accelerator.wait_for_everyone() + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained( + args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save + ) + if accelerator.is_main_process: + tokenizer.save_pretrained(args.output_dir) + if args.push_to_hub: + repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) + + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump({"perplexity": perplexity}, f) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/split_train_test_dev_hansard.py b/scripts/split_train_test_dev_hansard.py new file mode 100644 index 0000000000..83a745c75c --- /dev/null +++ b/scripts/split_train_test_dev_hansard.py @@ -0,0 +1,43 @@ +import pandas as pd + +if __name__ == '__main__': + train_tr="/netscratch/jalota/datasets/europarl-ppd/de/europarl_no_dup.tok" + dev_tr="/netscratch/jalota/datasets/europarl-ppd/de/europarl_dev.txt" + test_tr="/netscratch/jalota/datasets/europarl-ppd/de/europarl_test.txt" + # train_tr = "/netscratch/jalota/datasets/haifa-hansard/train/translated.txt" + # dev_tr = "/netscratch/jalota/datasets/haifa-hansard/dev/translated_4k_train.txt" + # test_tr = "/netscratch/jalota/datasets/haifa-hansard/test/translated_4k_train.txt" + + # extract 8k sentences from train split and redistribute to dev and test splits + train = open(train_tr) + dev = open(dev_tr, 'w') + test = open(test_tr, 'w') + train_w = open("/netscratch/jalota/datasets/europarl-ppd/de/europarl_train.txt", "w") + # train_w = open("/netscratch/jalota/datasets/haifa-hansard/train/tr_new", "w") + + lines = train.readlines() + df = pd.DataFrame(lines, columns=['text']) + print(df.head()) + dev_df = df.sample(n=5000, random_state=23) + df = df.drop(dev_df.index) + test_df = df.sample(n=5000, random_state=23) + df = df.drop(test_df.index) + + print(len(df), len(test_df), len(dev_df)) + print(not set(df).isdisjoint(test_df)) + print(not set(df).isdisjoint(dev_df)) + print(not set(test_df).isdisjoint(dev_df)) + + for row in test_df['text']: + test.write(row) + for row in dev_df['text']: + dev.write(row) + for row in df['text']: + train_w.write(row) + + train.close() + train_w.close() + dev.close() + test.close() + + diff --git a/traincomp.py b/traincomp.py index 923c13ee05..0c917dee08 100644 --- a/traincomp.py +++ b/traincomp.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from fairseq_cli.traincomp import cli_main - +# from fairseq_cli.train_unsup_comp import cli_main if __name__ == '__main__': cli_main()