Skip to content

Latest commit

 

History

History
79 lines (49 loc) · 4.49 KB

README.md

File metadata and controls

79 lines (49 loc) · 4.49 KB

Our code-base has been significantly influenced by the repository from the paper "Rethinking the Role of Scale for In-Context Learning: An Interpretability-based Case Study at 66 Billion Scale"

We re-implement a Deja-Vu style predictor in this code-base, along with our ShadowLLM predictor. We simplify the DejaVu implementation by having per-layer predictors instead of look-ahead predictors. This variant should ideally have better accuracy characteristics as there are more sparsity predictors in our test.

Setup

Set up and activate an initial conda environment using the provided llm-interpret/environment.yml file.

conda env create -f environment.yml
conda activate opt

Install PyTorch based on your system configuration. We used the following:

conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch

Further, we modify the huggingface transformers implementation, and requires installation as:

cd transformers/
python -m pip install -e .

Getting Started

Our code is heavily built on top of the llm-interpret code.

The model and evaluate code is based off 🤗Hugging Face's transformers and Eleuther AI's lm-evaluation-harness libraries.

To regenerate figures

Figures in our paper can be re-generated by running the python files as instructed in llm-interpret/lm-evaluation-harness/reproduction_log.log

To generate results

All of our results require generation of input-activation traces, as well as predictor models. In our case, our zcp (pruning criteria) traces required over 700GB of storage space. Further, the prediction models when trained for all cases require atleast 50GB of storage.

The med_genjobs.py file can be run to generate commands that should reproduce the entire flow. It is recommended to first try generating and testing the whole flow with a single proxy (pruning criteria), few-shot setting (0/3/5 etc.), as there may be file-paths that need fixing, etc. Please carefully inspect the file to understand how to launch commands.

WikiText2 evaluation may require changing a link on the code, command for which is provided in slurmrunner_medium.slurm, please adapt it for your own paths.

Sample Commands

To generate ZCP (plainact) for PIQA in a 5-shot setting

python main.py --model opt --model_args zcp_calc=plainact,pretrained=facebook/opt-13b,model_cache_dir=opt13b_checkpoints,tokenizer_cache_dir=opt13b_tokenizer --tasks piqa --head_importance_calc --save_importance_path logs/head_importance/opt13b/0shot_piqa_original.pkl --num_fewshot 5 --method predictor

Evaluate AGGREGATE score based pruning for PIQA in a 0-shot setting. (Static Pruning Strategy 50% sparsity)

python main.py --model opt --model_args prune_style=global,ffn_percent_mask=50,fcmaskmethod=fc,aggr_all=False,zcp_calc=plainact,pretrained=facebook/opt-13b,model_cache_dir=opt13b_checkpoints,tokenizer_cache_dir=opt13b_tokenizer,mask_heads=1,head_importance_path=zcps/opt-13b/plainact_all_5.pkl,head_percent_mask=50,maskmethod=predictorL,predictor_=all --tasks piqa --output_path results/13b/piqa/0shot_piqa_predictor.txt --batch_size 1 --num_fewshot 0 --method predictor

Train Head Predictor on ZCP Traces across all tasks (Use combined_generator.py to combine all traces for FFN and Head predictor training.)

python emnlp_activation_predictor.py --fewshot 5 --dataset all --dataset_cname all --zcp_metric plainact --basemodel opt-13b --execmode train --emb_style b1e --rerun

Train FFN Predictor on ZCP Traces across all tasks

python emnlp_activation_ffn_predictor.py --fewshot 5 --dataset all --dataset_cname all --zcp_metric plainact --basemodel opt-13b --execmode train --emb_style b1e --rerun

Evaluate ShadowLLM style per-layer (local) pruning on WikiText2. Uses predictor trained across all downstream tasks. (Dynamic Pruning Strategy 40% sparsity)

python main.py --model opt --model_args prune_style=perlayer,ffn_percent_mask=40,fcmaskmethod=fc,aggr_all=False,zcp_calc=plainact,pretrained=facebook/opt-13b,model_cache_dir=opt13b_checkpoints,tokenizer_cache_dir=opt13b_tokenizer,mask_heads=1,head_importance_path=zcps/opt-13b/plainact_all_5.pkl,head_percent_mask=40,maskmethod=predictorL,predictor_=all --tasks wikitext --output_path results/13b/piqa/0shot_piqa_predictor.txt --batch_size 1 --num_fewshot 0 --method predictor