From 0be9eedfce30b28c111101d7fd3c694f5a8e1e30 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 22 Mar 2024 12:15:21 +0000 Subject: [PATCH 1/4] add hard rope scaling test --- tests/models/llama/test_modeling_llama.py | 319 ++++++++++++++++++++++ 1 file changed, 319 insertions(+) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 36dc8d6bcdf4e8..edfe48fb90b202 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -654,6 +654,325 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + @slow + @require_torch_gpu + @require_bitsandbytes + def test_llama_rope_scaling(self): + # Tests that RoPE scaling works as expected on Llama. + # Note: although this test doesn't take long to run, it requires ~13GB of GPU memory as of 2024-03 + + # The first sections of the Llama 2 paper. Input with >6k tokens, larger than the 4k model context window + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_fast=True) + VERY_LONG_INPUT = ''' +You are given a partial and unparsed scientific article, please read it carefully and answer the follow up question. + +== BEGIN ARTICLE == + +Llama 2 : Open Foundation and Fine-Tuned Chat Models +Hugo Touvron∗Louis Martin†Kevin Stone† +Peter Albert Amjad Almahairi Yasmine Babaei Nikolay Bashlykov Soumya Batra +Prajjwal Bhargava Shruti Bhosale Dan Bikel Lukas Blecher Cristian Canton Ferrer Moya Chen +Guillem Cucurull David Esiobu Jude Fernandes Jeremy Fu Wenyin Fu Brian Fuller +Cynthia Gao Vedanuj Goswami Naman Goyal Anthony Hartshorn Saghar Hosseini Rui Hou +Hakan Inan Marcin Kardas Viktor Kerkez Madian Khabsa Isabel Kloumann Artem Korenev +Punit Singh Koura Marie-Anne Lachaux Thibaut Lavril Jenya Lee Diana Liskovich +Yinghai Lu Yuning Mao Xavier Martinet Todor Mihaylov Pushkar Mishra +Igor Molybog Yixin Nie Andrew Poulton Jeremy Reizenstein Rashi Rungta Kalyan Saladi +Alan Schelten Ruan Silva Eric Michael Smith Ranjan Subramanian Xiaoqing Ellen Tan Binh Tang +Ross Taylor Adina Williams Jian Xiang Kuan Puxin Xu Zheng Yan Iliyan Zarov Yuchen Zhang +Angela Fan Melanie Kambadur Sharan Narang Aurelien Rodriguez Robert Stojnic +Sergey Edunov Thomas Scialom∗ +GenAI, Meta +Abstract +In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned +large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. +Our fine-tuned LLMs, called Llama 2-Chat , are optimized for dialogue use cases. Our +models outperform open-source chat models on most benchmarks we tested, and based on +ourhumanevaluationsforhelpfulnessandsafety,maybeasuitablesubstituteforclosed- +source models. We provide a detailed description of our approach to fine-tuning and safety +improvements of Llama 2-Chat in order to enable the community to build on our work and +contribute to the responsible development of LLMs. +∗Equal contribution, corresponding authors: {tscialom, htouvron}@meta.com +†Second author +2 +Figure 1: Helpfulness human evaluation results for Llama +2-Chatcomparedtootheropen-sourceandclosed-source +models. Human raters compared model generations on ~4k +promptsconsistingofbothsingleandmulti-turnprompts. +The95%confidenceintervalsforthisevaluationarebetween +1%and2%. MoredetailsinSection3.4.2. Whilereviewing +these results, it is important to note that human evaluations +canbenoisyduetolimitationsofthepromptset,subjectivity +of the review guidelines, subjectivity of individual raters, +and the inherent difficulty of comparing generations. +Figure 2: Win-rate % for helpfulness and +safety between commercial-licensed base- +lines and Llama 2-Chat , according to GPT- +4. Tocomplementthehumanevaluation,we +used a more capable model, not subject to +ourownguidance. Greenareaindicatesour +modelisbetteraccordingtoGPT-4. Toremove +ties, we used win/ (win+loss). The orders in +whichthemodelresponsesarepresentedto +GPT-4arerandomlyswappedtoalleviatebias. +1 Introduction +Large Language Models (LLMs) have shown great promise as highly capable AI assistants that excel in +complex reasoning tasks requiring expert knowledge across a wide range of fields, including in specialized +domains such as programming and creative writing. They enable interaction with humans through intuitive +chat interfaces, which has led to rapid and widespread adoption among the general public. +ThecapabilitiesofLLMsareremarkableconsideringtheseeminglystraightforwardnatureofthetraining +methodology. Auto-regressivetransformersarepretrainedonanextensivecorpusofself-superviseddata, +followed by alignment with human preferences via techniques such as Reinforcement Learning with Human +Feedback(RLHF).Althoughthetrainingmethodologyissimple,highcomputationalrequirementshave +limited the development of LLMs to a few players. There have been public releases of pretrained LLMs +(such as BLOOM (Scao et al., 2022), LLaMa-1 (Touvron et al., 2023), and Falcon (Penedo et al., 2023)) that +match the performance of closed pretrained competitors like GPT-3 (Brown et al., 2020) and Chinchilla +(Hoffmann et al., 2022), but none of these models are suitable substitutes for closed “product” LLMs, such +asChatGPT,BARD,andClaude. TheseclosedproductLLMsareheavilyfine-tunedtoalignwithhuman +preferences, which greatly enhances their usability and safety. This step can require significant costs in +computeandhumanannotation,andisoftennottransparentoreasilyreproducible,limitingprogresswithin +the community to advance AI alignment research. +In this work, we develop and release Llama 2, a family of pretrained and fine-tuned LLMs, Llama 2 and +Llama 2-Chat , at scales up to 70B parameters. On the series of helpfulness and safety benchmarks we tested, +Llama 2-Chat models generally perform better than existing open-source models. They also appear to +be on par with some of the closed-source models, at least on the human evaluations we performed (see +Figures1and3). Wehavetakenmeasurestoincreasethesafetyofthesemodels,usingsafety-specificdata +annotation and tuning, as well as conducting red-teaming and employing iterative evaluations. Additionally, +thispapercontributesathoroughdescriptionofourfine-tuningmethodologyandapproachtoimproving +LLM safety. We hope that this openness will enable the community to reproduce fine-tuned LLMs and +continue to improve the safety of those models, paving the way for more responsible development of LLMs. +Wealsosharenovelobservationswemadeduringthedevelopmentof Llama 2 andLlama 2-Chat ,suchas +the emergence of tool usage and temporal organization of knowledge. +3 +Figure 3: Safety human evaluation results for Llama 2-Chat compared to other open-source and closed- +source models. Human raters judged model generations for safety violations across ~2,000 adversarial +prompts consisting of both single and multi-turn prompts. More details can be found in Section 4.4. It is +importanttocaveatthesesafetyresultswiththeinherentbiasofLLMevaluationsduetolimitationsofthe +promptset,subjectivityofthereviewguidelines,andsubjectivityofindividualraters. Additionally,these +safety evaluations are performed using content standards that are likely to be biased towards the Llama +2-Chatmodels. +We are releasing the following models to the general public for research and commercial use‡: +1.Llama 2 ,anupdatedversionof Llama 1,trainedonanewmixofpubliclyavailabledata. Wealso +increasedthesizeofthepretrainingcorpusby40%,doubledthecontextlengthofthemodel,and +adoptedgrouped-queryattention(Ainslieetal.,2023). Wearereleasingvariantsof Llama 2 with +7B,13B,and70Bparameters. Wehavealsotrained34Bvariants,whichwereportoninthispaper +but are not releasing.§ +2.Llama 2-Chat , a fine-tuned version of Llama 2 that is optimized for dialogue use cases. We release +variants of this model with 7B, 13B, and 70B parameters as well. +WebelievethattheopenreleaseofLLMs,whendonesafely,willbeanetbenefittosociety. LikeallLLMs, +Llama 2 is a new technology that carries potential risks with use (Bender et al., 2021b; Weidinger et al., 2021; +Solaimanet al.,2023). Testingconductedtodate hasbeeninEnglish andhasnot— andcouldnot— cover +all scenarios. Therefore, before deploying any applications of Llama 2-Chat , developers should perform +safetytestingand tuningtailoredtotheirspecificapplicationsofthemodel. Weprovidearesponsibleuse +guide¶and code examples‖to facilitate the safe deployment of Llama 2 andLlama 2-Chat . More details of +our responsible release strategy can be found in Section 5.3. +Theremainderofthispaperdescribesourpretrainingmethodology(Section2),fine-tuningmethodology +(Section 3), approach to model safety (Section 4), key observations and insights (Section 5), relevant related +work (Section 6), and conclusions (Section 7). +‡https://ai.meta.com/resources/models-and-libraries/llama/ +§We are delaying the release of the 34B model due to a lack of time to sufficiently red team. +¶https://ai.meta.com/llama +‖https://github.com/facebookresearch/llama +4 +Figure4: Trainingof Llama 2-Chat : Thisprocessbeginswiththe pretraining ofLlama 2 usingpublicly +availableonlinesources. Followingthis,wecreateaninitialversionof Llama 2-Chat throughtheapplication +ofsupervised fine-tuning . Subsequently, the model is iteratively refined using Reinforcement Learning +with Human Feedback (RLHF) methodologies, specifically through rejection sampling and Proximal Policy +Optimization(PPO).ThroughouttheRLHFstage,theaccumulationof iterativerewardmodelingdata in +parallel with model enhancements is crucial to ensure the reward models remain within distribution. +2 Pretraining +Tocreatethenewfamilyof Llama 2models,webeganwiththepretrainingapproachdescribedinTouvronetal. +(2023), using an optimized auto-regressive transformer, but made several changes to improve performance. +Specifically,weperformedmorerobustdatacleaning,updatedourdatamixes,trainedon40%moretotal +tokens,doubledthecontextlength,andusedgrouped-queryattention(GQA)toimproveinferencescalability +for our larger models. Table 1 compares the attributes of the new Llama 2 models with the Llama 1 models. +2.1 Pretraining Data +Our training corpus includes a new mix of data from publicly available sources, which does not include data +fromMeta’sproductsorservices. Wemadeanefforttoremovedatafromcertainsitesknowntocontaina +highvolumeofpersonalinformationaboutprivateindividuals. Wetrainedon2trilliontokensofdataasthis +providesagoodperformance–costtrade-off,up-samplingthemostfactualsourcesinanefforttoincrease +knowledge and dampen hallucinations. +Weperformedavarietyofpretrainingdatainvestigationssothatuserscanbetterunderstandthepotential +capabilities and limitations of our models; results can be found in Section 4.1. +2.2 Training Details +We adopt most of the pretraining setting and model architecture from Llama 1 . We use the standard +transformer architecture (Vaswani et al., 2017), apply pre-normalization using RMSNorm (Zhang and +Sennrich, 2019), use the SwiGLU activation function (Shazeer, 2020), and rotary positional embeddings +(RoPE, Su et al. 2022). The primary architectural differences from Llama 1 include increased context length +andgrouped-queryattention(GQA).WedetailinAppendixSectionA.2.1eachofthesedifferenceswith +ablation experiments to demonstrate their importance. +Hyperparameters. We trained using the AdamW optimizer (Loshchilov and Hutter, 2017), with β1= +0.9, β2= 0.95,eps= 10−5. We use a cosine learning rate schedule, with warmup of 2000 steps, and decay +finallearningratedownto10%ofthepeaklearningrate. Weuseaweightdecayof 0.1andgradientclipping +of1.0. Figure 5 (a) shows the training loss for Llama 2 with these hyperparameters. +5 +Training Data Params Context +LengthGQA Tokens LR +Llama 1See Touvron et al. +(2023)7B 2k ✗ 1.0T 3.0×10−4 +13B 2k ✗ 1.0T 3.0×10−4 +33B 2k ✗ 1.4T 1.5×10−4 +65B 2k ✗ 1.4T 1.5×10−4 +Llama 2A new mix of publicly +available online data7B 4k ✗ 2.0T 3.0×10−4 +13B 4k ✗ 2.0T 3.0×10−4 +34B 4k ✓ 2.0T 1.5×10−4 +70B 4k ✓ 2.0T 1.5×10−4 +Table 1: Llama 2 family of models. Token counts refer to pretraining data only. All models are trained with +a global batch-size of 4M tokens. Bigger models — 34B and 70B — use Grouped-Query Attention (GQA) for +improved inference scalability. +0 250 500 750 1000 1250 1500 1750 2000 +Processed Tokens (Billions)1.41.51.61.71.81.92.02.12.2Train PPLLlama-2 +7B +13B +34B +70B +Figure 5: Training Loss for Llama 2 models. We compare the training loss of the Llama 2 family of models. +We observe that after pretraining on 2T Tokens, the models still did not show any sign of saturation. +Tokenizer. Weusethesametokenizeras Llama 1;itemploysabytepairencoding(BPE)algorithm(Sennrich +etal.,2016)usingtheimplementationfromSentencePiece(KudoandRichardson,2018). Aswith Llama 1, +we split all numbers into individual digits and use bytes to decompose unknown UTF-8 characters. The total +vocabulary size is 32k tokens. +2.2.1 Training Hardware & Carbon Footprint +TrainingHardware. WepretrainedourmodelsonMeta’sResearchSuperCluster(RSC)(LeeandSengupta, +2022)aswellasinternalproductionclusters. BothclustersuseNVIDIAA100s. Therearetwokeydifferences +between the two clusters, with the first being the type of interconnect available: RSC uses NVIDIA Quantum +InfiniBandwhileourproductionclusterisequippedwithaRoCE(RDMAoverconvergedEthernet)solution +based on commodity ethernet Switches. Both of these solutions interconnect 200 Gbps end-points. The +seconddifferenceistheper-GPUpowerconsumptioncap—RSCuses400Wwhileourproductioncluster +uses350W.Withthistwo-clustersetup,wewereabletocomparethesuitabilityofthesedifferenttypesof +interconnectforlargescaletraining. RoCE(whichisamoreaffordable,commercialinterconnectnetwork) +6 +Time +(GPU hours)Power +Consumption (W)Carbon Emitted +(tCO 2eq) +Llama 27B 184320 400 31.22 +13B 368640 400 62.44 +34B 1038336 350 153.90 +70B 1720320 400 291.42 +Total 3311616 539.00 +Table 2: CO2emissions during pretraining. Time: total GPU time required for training each model. Power +Consumption: peak power capacity per GPU device for the GPUs used adjusted for power usage efficiency. +100%oftheemissionsaredirectlyoffsetbyMeta’ssustainabilityprogram,andbecauseweareopenlyreleasing +these models, the pretraining costs do not need to be incurred by others. +can scale almost as well as expensive Infiniband up to 2000 GPUs, which makes pretraining even more +democratizable. On A100s with RoCE and GPU power capped at 350W, our optimized codebase reached up +to 90% of the performance of RSC using IB interconnect and 400W GPU power. +Carbon Footprint of Pretraining. Following preceding research (Bender et al., 2021a; Patterson et al., 2021; +Wu et al., 2022; Dodge et al., 2022) and using power consumption estimates of GPU devices and carbon +efficiency, we aim tocalculate thecarbon emissions resultingfrom the pretrainingof Llama 2 models. The +actualpowerusageofaGPUisdependentonitsutilizationandislikelytovaryfromtheThermalDesign +Power(TDP)thatweemployasanestimationforGPUpower. Itisimportanttonotethatourcalculations +do not account for further power demands, such as those from interconnect or non-GPU server power +consumption,norfromdatacentercoolingsystems. Additionally,thecarbonoutputrelatedtotheproduction +of AI hardware, like GPUs, could add to the overall carbon footprint as suggested by Gupta et al. (2022b,a). +Table 2 summarizes the carbon emission for pretraining the Llama 2 family of models. A cumulative of +3.3M GPUhours ofcomputation wasperformed onhardware oftype A100-80GB (TDPof 400Wor 350W). +We estimate the total emissions for training to be 539 tCO 2eq, of which 100% were directly offset by Meta’s +sustainability program.∗∗Our open release strategy also means that these pretraining costs will not need to +be incurred by other companies, saving more global resources. +2.3 Llama 2 Pretrained Model Evaluation +In this section, we report the results for the Llama 1 andLlama 2 base models, MosaicML Pretrained +Transformer(MPT)††models,andFalcon(Almazroueietal.,2023)modelsonstandardacademicbenchmarks. +For all the evaluations, we use our internal evaluations library. We reproduce results for the MPT and Falcon +modelsinternally. Forthesemodels,wealwayspickthebestscorebetweenourevaluationframeworkand +any publicly reported results. +InTable3,wesummarizetheoverallperformanceacrossasuiteofpopularbenchmarks. Notethatsafety +benchmarks are shared in Section 4.1. The benchmarks are grouped into the categories listed below. The +results for all the individual benchmarks are available in Section A.2.2. +•Code.Wereporttheaveragepass@1scoresofourmodelsonHumanEval(Chenetal.,2021)and +MBPP (Austin et al., 2021). +•CommonsenseReasoning. WereporttheaverageofPIQA(Bisketal.,2020),SIQA(Sapetal.,2019), +HellaSwag (Zellers et al., 2019a), WinoGrande (Sakaguchi et al., 2021), ARC easy and challenge +(Clark et al., 2018), OpenBookQA (Mihaylov et al., 2018), and CommonsenseQA (Talmor et al., +2018). We report 7-shot results for CommonSenseQA and 0-shot results for all other benchmarks. +•World Knowledge. We evaluate the 5-shot performance on NaturalQuestions (Kwiatkowski et al., +2019) and TriviaQA (Joshi et al., 2017) and report the average. +•Reading Comprehension. For reading comprehension, we report the 0-shot average on SQuAD +(Rajpurkar et al., 2018), QuAC (Choi et al., 2018), and BoolQ (Clark et al., 2019). +∗∗https://sustainability.fb.com/2021-sustainability-report/ +††https://www.mosaicml.com/blog/mpt-7b +7 +Model Size CodeCommonsense +ReasoningWorld +KnowledgeReading +ComprehensionMath MMLU BBH AGI Eval +MPT7B 20.5 57.4 41.0 57.5 4.9 26.8 31.0 23.5 +30B 28.9 64.9 50.0 64.7 9.1 46.9 38.0 33.8 +Falcon7B 5.6 56.1 42.8 36.0 4.6 26.2 28.0 21.2 +40B 15.2 69.2 56.7 65.7 12.6 55.4 37.1 37.0 +Llama 17B 14.1 60.8 46.2 58.5 6.95 35.1 30.3 23.9 +13B 18.9 66.1 52.6 62.3 10.9 46.9 37.0 33.9 +33B 26.0 70.0 58.4 67.6 21.4 57.8 39.8 41.7 +65B 30.7 70.7 60.5 68.6 30.8 63.4 43.5 47.6 +Llama 27B 16.8 63.9 48.9 61.3 14.6 45.3 32.6 29.3 +13B 24.5 66.9 55.4 65.8 28.7 54.8 39.4 39.1 +34B 27.8 69.9 58.7 68.0 24.2 62.6 44.1 43.4 +70B37.5 71.9 63.6 69.4 35.2 68.9 51.2 54.2 +Table3: Overallperformanceongroupedacademicbenchmarkscomparedtoopen-sourcebasemodels. +•MATH. We report the average of the GSM8K (8 shot) (Cobbe et al., 2021) and MATH (4 shot) +(Hendrycks et al., 2021) benchmarks at top 1. +•Popular Aggregated Benchmarks . We report the overall results for MMLU (5 shot) (Hendrycks +et al., 2020), Big Bench Hard (BBH) (3 shot) (Suzgun et al., 2022), and AGI Eval (3–5 shot) (Zhong +et al., 2023). For AGI Eval, we only evaluate on the English tasks and report the average. +As shown in Table 3, Llama 2 models outperform Llama 1 models. In particular, Llama 2 70B improves the +resultsonMMLUandBBHby ≈5and≈8points,respectively,comparedto Llama 1 65B.Llama 2 7Band30B +modelsoutperformMPTmodelsofthecorrespondingsizeonallcategoriesbesidescodebenchmarks. Forthe +Falcon models, Llama 2 7B and 34B outperform Falcon 7B and 40B models on all categories of benchmarks. +Additionally, Llama 2 70B model outperforms all open-source models. +In addition to open-source models, we also compare Llama 2 70B results to closed-source models. As shown +in Table 4, Llama 2 70B is close to GPT-3.5 (OpenAI, 2023) on MMLU and GSM8K, but there is a significant +gaponcodingbenchmarks. Llama 2 70BresultsareonparorbetterthanPaLM(540B)(Chowdheryetal., +2022)onalmostallbenchmarks. Thereisstillalargegapinperformancebetween Llama 2 70BandGPT-4 +and PaLM-2-L. +We also analysed the potential data contamination and share the details in Section A.6. +Benchmark (shots) GPT-3.5 GPT-4 PaLM PaLM-2-L Llama 2 +MMLU (5-shot) 70.0 86.4 69.3 78.3 68.9 +TriviaQA (1-shot) – – 81.4 86.1 85.0 +Natural Questions (1-shot) – – 29.3 37.5 33.0 +GSM8K (8-shot) 57.1 92.0 56.5 80.7 56.8 +HumanEval (0-shot) 48.1 67.0 26.2 – 29.9 +BIG-Bench Hard (3-shot) – – 52.3 65.7 51.2 +Table 4: Comparison to closed-source models on academic benchmarks. Results for GPT-3.5 and GPT-4 +are from OpenAI (2023). Results for the PaLM model are from Chowdhery et al. (2022). Results for the +PaLM-2-L are from Anil et al. (2023). + +== END ARTICLE == + +''' + question = "What is the paper about?" + model_inputs = tokenizer(VERY_LONG_INPUT + question, return_tensors="pt").to(torch_device) + + # No RoPE scaling -> garbage output + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", device_map="auto", load_in_4bit=True, + ) + self.assertTrue(model_inputs["input_ids"].shape[1] > model.config.max_position_embeddings) + generate_kwargs = {"max_new_tokens": 40, "do_sample": False} + gen_out = model.generate(**model_inputs, **generate_kwargs) + decoded_text = tokenizer.decode(gen_out[0], skip_special_tokens=True) + self.assertTrue(decoded_text.endswith("Ћ\nЋ\nЋЋЋЋЋЋЋЋЋ\nЋ\nЋ\n")) + + # Dynamic NTK RoPE scaling -> good output (doesn't need fine-tuning) + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", device_map="auto", load_in_4bit=True, rope_scaling={"type": "dynamic", "factor": 2.0}, + ) + generate_kwargs = {"max_new_tokens": 40, "do_sample": False} + gen_out = model.generate(**model_inputs, **generate_kwargs) + decoded_text = tokenizer.decode(gen_out[0], skip_special_tokens=True) + self.assertTrue(decoded_text.endswith("The paper is about the release of Llama 2, a family of pretrained and fine-tuned large language models.\nWhat is Llama 2?\n")) + # Note: the output above matches our initial release of RoPE scaling + + # Linear RoPE scaling -> usualy okay output (should be used with fine-tuning) + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", device_map="auto", load_in_4bit=True, rope_scaling={"type": "linear", "factor": 2.0}, + ) + generate_kwargs = {"max_new_tokens": 40, "do_sample": False} + gen_out = model.generate(**model_inputs, **generate_kwargs) + decoded_text = tokenizer.decode(gen_out[0], skip_special_tokens=True) + self.assertTrue(decoded_text.endswith("The paper is about the development of Llama 2, a large language model (LLM) family, and the release of Llama 2-Chat, a fine-")) + @require_torch class CodeLlamaIntegrationTest(unittest.TestCase): From 5299f57e2d06aff6d02898b8d622e28d35f26b17 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 22 Mar 2024 19:42:40 +0000 Subject: [PATCH 2/4] make fixup --- tests/models/llama/test_modeling_llama.py | 30 +++++++++++++++++------ 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index edfe48fb90b202..6eb429f874ace8 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -663,7 +663,7 @@ def test_llama_rope_scaling(self): # The first sections of the Llama 2 paper. Input with >6k tokens, larger than the 4k model context window tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_fast=True) - VERY_LONG_INPUT = ''' + VERY_LONG_INPUT = """ You are given a partial and unparsed scientific article, please read it carefully and answer the follow up question. == BEGIN ARTICLE == @@ -940,13 +940,15 @@ def test_llama_rope_scaling(self): == END ARTICLE == -''' +""" question = "What is the paper about?" model_inputs = tokenizer(VERY_LONG_INPUT + question, return_tensors="pt").to(torch_device) # No RoPE scaling -> garbage output model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", device_map="auto", load_in_4bit=True, + "meta-llama/Llama-2-7b-hf", + device_map="auto", + load_in_4bit=True, ) self.assertTrue(model_inputs["input_ids"].shape[1] > model.config.max_position_embeddings) generate_kwargs = {"max_new_tokens": 40, "do_sample": False} @@ -956,22 +958,36 @@ def test_llama_rope_scaling(self): # Dynamic NTK RoPE scaling -> good output (doesn't need fine-tuning) model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", device_map="auto", load_in_4bit=True, rope_scaling={"type": "dynamic", "factor": 2.0}, + "meta-llama/Llama-2-7b-hf", + device_map="auto", + load_in_4bit=True, + rope_scaling={"type": "dynamic", "factor": 2.0}, ) generate_kwargs = {"max_new_tokens": 40, "do_sample": False} gen_out = model.generate(**model_inputs, **generate_kwargs) decoded_text = tokenizer.decode(gen_out[0], skip_special_tokens=True) - self.assertTrue(decoded_text.endswith("The paper is about the release of Llama 2, a family of pretrained and fine-tuned large language models.\nWhat is Llama 2?\n")) + self.assertTrue( + decoded_text.endswith( + "The paper is about the release of Llama 2, a family of pretrained and fine-tuned large language models.\nWhat is Llama 2?\n" + ) + ) # Note: the output above matches our initial release of RoPE scaling # Linear RoPE scaling -> usualy okay output (should be used with fine-tuning) model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", device_map="auto", load_in_4bit=True, rope_scaling={"type": "linear", "factor": 2.0}, + "meta-llama/Llama-2-7b-hf", + device_map="auto", + load_in_4bit=True, + rope_scaling={"type": "linear", "factor": 2.0}, ) generate_kwargs = {"max_new_tokens": 40, "do_sample": False} gen_out = model.generate(**model_inputs, **generate_kwargs) decoded_text = tokenizer.decode(gen_out[0], skip_special_tokens=True) - self.assertTrue(decoded_text.endswith("The paper is about the development of Llama 2, a large language model (LLM) family, and the release of Llama 2-Chat, a fine-")) + self.assertTrue( + decoded_text.endswith( + "The paper is about the development of Llama 2, a large language model (LLM) family, and the release of Llama 2-Chat, a fine-" + ) + ) @require_torch From 0d39641cabca50b12e0e34143d600f5f9e1395df Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 27 Mar 2024 17:54:15 +0000 Subject: [PATCH 3/4] quick rope scaling tests --- tests/models/falcon/test_modeling_falcon.py | 66 ++- .../models/gpt_neox/test_modeling_gpt_neox.py | 66 ++- tests/models/llama/test_modeling_llama.py | 405 +++--------------- .../persimmon/test_modeling_persimmon.py | 68 ++- tests/models/phi/test_modeling_phi.py | 98 ++++- .../models/stablelm/test_modeling_stablelm.py | 66 ++- 6 files changed, 427 insertions(+), 342 deletions(-) diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index fa7ea2af816cb0..346c62cd66b5cd 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -45,6 +45,11 @@ FalconForTokenClassification, FalconModel, ) + from transformers.models.falcon.modeling_falcon import ( + FalconDynamicNTKScalingRotaryEmbedding, + FalconLinearScalingRotaryEmbedding, + FalconRotaryEmbedding, + ) class FalconModelTester: @@ -408,7 +413,7 @@ def test_past_key_values_format(self): ) @parameterized.expand([("linear",), ("dynamic",)]) - def test_model_rope_scaling(self, scaling_type): + def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) @@ -438,6 +443,65 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + + # Sanity check original RoPE + original_rope = FalconRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, short_input_length) + original_cos_long, original_sin_long = original_rope(x, long_input_length) + torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = FalconLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) + torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = FalconDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_torch_sdpa @slow def test_eager_matches_sdpa_generate(self): diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 19e3db2a61fb91..00065a7006ffff 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -38,6 +38,11 @@ GPTNeoXForTokenClassification, GPTNeoXModel, ) + from transformers.models.gpt_neox.modeling_gpt_neox import ( + GPTNeoXDynamicNTKScalingRotaryEmbedding, + GPTNeoXLinearScalingRotaryEmbedding, + GPTNeoXRotaryEmbedding, + ) class GPTNeoXModelTester: @@ -301,7 +306,7 @@ def test_feed_forward_chunking(self): pass @parameterized.expand([("linear",), ("dynamic",)]) - def test_model_rope_scaling(self, scaling_type): + def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) @@ -331,6 +336,65 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + + # Sanity check original RoPE + original_rope = GPTNeoXRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rotary_emb_base, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, short_input_length) + original_cos_long, original_sin_long = original_rope(x, long_input_length) + torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = GPTNeoXLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rotary_emb_base, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) + torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = GPTNeoXDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rotary_emb_base, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_torch class GPTNeoXLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 6eb429f874ace8..e0a3990bd8de30 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -51,6 +51,11 @@ LlamaModel, LlamaTokenizer, ) + from transformers.models.llama.modeling_llama import ( + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaRotaryEmbedding, + ) class LlamaModelTester: @@ -370,7 +375,7 @@ def test_save_load_fast_init_from_base(self): pass @parameterized.expand([("linear",), ("dynamic",)]) - def test_model_rope_scaling(self, scaling_type): + def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) @@ -400,6 +405,69 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) + position_ids_short = position_ids_short.unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) + position_ids_long = position_ids_long.unsqueeze(0) + + # Sanity check original RoPE + original_rope = LlamaRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, position_ids_short) + original_cos_long, original_sin_long = original_rope(x, position_ids_long) + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = LlamaLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) + torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = LlamaDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_flash_attn @require_torch_gpu @require_bitsandbytes @@ -654,341 +722,6 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) - @slow - @require_torch_gpu - @require_bitsandbytes - def test_llama_rope_scaling(self): - # Tests that RoPE scaling works as expected on Llama. - # Note: although this test doesn't take long to run, it requires ~13GB of GPU memory as of 2024-03 - - # The first sections of the Llama 2 paper. Input with >6k tokens, larger than the 4k model context window - tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_fast=True) - VERY_LONG_INPUT = """ -You are given a partial and unparsed scientific article, please read it carefully and answer the follow up question. - -== BEGIN ARTICLE == - -Llama 2 : Open Foundation and Fine-Tuned Chat Models -Hugo Touvron∗Louis Martin†Kevin Stone† -Peter Albert Amjad Almahairi Yasmine Babaei Nikolay Bashlykov Soumya Batra -Prajjwal Bhargava Shruti Bhosale Dan Bikel Lukas Blecher Cristian Canton Ferrer Moya Chen -Guillem Cucurull David Esiobu Jude Fernandes Jeremy Fu Wenyin Fu Brian Fuller -Cynthia Gao Vedanuj Goswami Naman Goyal Anthony Hartshorn Saghar Hosseini Rui Hou -Hakan Inan Marcin Kardas Viktor Kerkez Madian Khabsa Isabel Kloumann Artem Korenev -Punit Singh Koura Marie-Anne Lachaux Thibaut Lavril Jenya Lee Diana Liskovich -Yinghai Lu Yuning Mao Xavier Martinet Todor Mihaylov Pushkar Mishra -Igor Molybog Yixin Nie Andrew Poulton Jeremy Reizenstein Rashi Rungta Kalyan Saladi -Alan Schelten Ruan Silva Eric Michael Smith Ranjan Subramanian Xiaoqing Ellen Tan Binh Tang -Ross Taylor Adina Williams Jian Xiang Kuan Puxin Xu Zheng Yan Iliyan Zarov Yuchen Zhang -Angela Fan Melanie Kambadur Sharan Narang Aurelien Rodriguez Robert Stojnic -Sergey Edunov Thomas Scialom∗ -GenAI, Meta -Abstract -In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned -large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. -Our fine-tuned LLMs, called Llama 2-Chat , are optimized for dialogue use cases. Our -models outperform open-source chat models on most benchmarks we tested, and based on -ourhumanevaluationsforhelpfulnessandsafety,maybeasuitablesubstituteforclosed- -source models. We provide a detailed description of our approach to fine-tuning and safety -improvements of Llama 2-Chat in order to enable the community to build on our work and -contribute to the responsible development of LLMs. -∗Equal contribution, corresponding authors: {tscialom, htouvron}@meta.com -†Second author -2 -Figure 1: Helpfulness human evaluation results for Llama -2-Chatcomparedtootheropen-sourceandclosed-source -models. Human raters compared model generations on ~4k -promptsconsistingofbothsingleandmulti-turnprompts. -The95%confidenceintervalsforthisevaluationarebetween -1%and2%. MoredetailsinSection3.4.2. Whilereviewing -these results, it is important to note that human evaluations -canbenoisyduetolimitationsofthepromptset,subjectivity -of the review guidelines, subjectivity of individual raters, -and the inherent difficulty of comparing generations. -Figure 2: Win-rate % for helpfulness and -safety between commercial-licensed base- -lines and Llama 2-Chat , according to GPT- -4. Tocomplementthehumanevaluation,we -used a more capable model, not subject to -ourownguidance. Greenareaindicatesour -modelisbetteraccordingtoGPT-4. Toremove -ties, we used win/ (win+loss). The orders in -whichthemodelresponsesarepresentedto -GPT-4arerandomlyswappedtoalleviatebias. -1 Introduction -Large Language Models (LLMs) have shown great promise as highly capable AI assistants that excel in -complex reasoning tasks requiring expert knowledge across a wide range of fields, including in specialized -domains such as programming and creative writing. They enable interaction with humans through intuitive -chat interfaces, which has led to rapid and widespread adoption among the general public. -ThecapabilitiesofLLMsareremarkableconsideringtheseeminglystraightforwardnatureofthetraining -methodology. Auto-regressivetransformersarepretrainedonanextensivecorpusofself-superviseddata, -followed by alignment with human preferences via techniques such as Reinforcement Learning with Human -Feedback(RLHF).Althoughthetrainingmethodologyissimple,highcomputationalrequirementshave -limited the development of LLMs to a few players. There have been public releases of pretrained LLMs -(such as BLOOM (Scao et al., 2022), LLaMa-1 (Touvron et al., 2023), and Falcon (Penedo et al., 2023)) that -match the performance of closed pretrained competitors like GPT-3 (Brown et al., 2020) and Chinchilla -(Hoffmann et al., 2022), but none of these models are suitable substitutes for closed “product” LLMs, such -asChatGPT,BARD,andClaude. TheseclosedproductLLMsareheavilyfine-tunedtoalignwithhuman -preferences, which greatly enhances their usability and safety. This step can require significant costs in -computeandhumanannotation,andisoftennottransparentoreasilyreproducible,limitingprogresswithin -the community to advance AI alignment research. -In this work, we develop and release Llama 2, a family of pretrained and fine-tuned LLMs, Llama 2 and -Llama 2-Chat , at scales up to 70B parameters. On the series of helpfulness and safety benchmarks we tested, -Llama 2-Chat models generally perform better than existing open-source models. They also appear to -be on par with some of the closed-source models, at least on the human evaluations we performed (see -Figures1and3). Wehavetakenmeasurestoincreasethesafetyofthesemodels,usingsafety-specificdata -annotation and tuning, as well as conducting red-teaming and employing iterative evaluations. Additionally, -thispapercontributesathoroughdescriptionofourfine-tuningmethodologyandapproachtoimproving -LLM safety. We hope that this openness will enable the community to reproduce fine-tuned LLMs and -continue to improve the safety of those models, paving the way for more responsible development of LLMs. -Wealsosharenovelobservationswemadeduringthedevelopmentof Llama 2 andLlama 2-Chat ,suchas -the emergence of tool usage and temporal organization of knowledge. -3 -Figure 3: Safety human evaluation results for Llama 2-Chat compared to other open-source and closed- -source models. Human raters judged model generations for safety violations across ~2,000 adversarial -prompts consisting of both single and multi-turn prompts. More details can be found in Section 4.4. It is -importanttocaveatthesesafetyresultswiththeinherentbiasofLLMevaluationsduetolimitationsofthe -promptset,subjectivityofthereviewguidelines,andsubjectivityofindividualraters. Additionally,these -safety evaluations are performed using content standards that are likely to be biased towards the Llama -2-Chatmodels. -We are releasing the following models to the general public for research and commercial use‡: -1.Llama 2 ,anupdatedversionof Llama 1,trainedonanewmixofpubliclyavailabledata. Wealso -increasedthesizeofthepretrainingcorpusby40%,doubledthecontextlengthofthemodel,and -adoptedgrouped-queryattention(Ainslieetal.,2023). Wearereleasingvariantsof Llama 2 with -7B,13B,and70Bparameters. Wehavealsotrained34Bvariants,whichwereportoninthispaper -but are not releasing.§ -2.Llama 2-Chat , a fine-tuned version of Llama 2 that is optimized for dialogue use cases. We release -variants of this model with 7B, 13B, and 70B parameters as well. -WebelievethattheopenreleaseofLLMs,whendonesafely,willbeanetbenefittosociety. LikeallLLMs, -Llama 2 is a new technology that carries potential risks with use (Bender et al., 2021b; Weidinger et al., 2021; -Solaimanet al.,2023). Testingconductedtodate hasbeeninEnglish andhasnot— andcouldnot— cover -all scenarios. Therefore, before deploying any applications of Llama 2-Chat , developers should perform -safetytestingand tuningtailoredtotheirspecificapplicationsofthemodel. Weprovidearesponsibleuse -guide¶and code examples‖to facilitate the safe deployment of Llama 2 andLlama 2-Chat . More details of -our responsible release strategy can be found in Section 5.3. -Theremainderofthispaperdescribesourpretrainingmethodology(Section2),fine-tuningmethodology -(Section 3), approach to model safety (Section 4), key observations and insights (Section 5), relevant related -work (Section 6), and conclusions (Section 7). -‡https://ai.meta.com/resources/models-and-libraries/llama/ -§We are delaying the release of the 34B model due to a lack of time to sufficiently red team. -¶https://ai.meta.com/llama -‖https://github.com/facebookresearch/llama -4 -Figure4: Trainingof Llama 2-Chat : Thisprocessbeginswiththe pretraining ofLlama 2 usingpublicly -availableonlinesources. Followingthis,wecreateaninitialversionof Llama 2-Chat throughtheapplication -ofsupervised fine-tuning . Subsequently, the model is iteratively refined using Reinforcement Learning -with Human Feedback (RLHF) methodologies, specifically through rejection sampling and Proximal Policy -Optimization(PPO).ThroughouttheRLHFstage,theaccumulationof iterativerewardmodelingdata in -parallel with model enhancements is crucial to ensure the reward models remain within distribution. -2 Pretraining -Tocreatethenewfamilyof Llama 2models,webeganwiththepretrainingapproachdescribedinTouvronetal. -(2023), using an optimized auto-regressive transformer, but made several changes to improve performance. -Specifically,weperformedmorerobustdatacleaning,updatedourdatamixes,trainedon40%moretotal -tokens,doubledthecontextlength,andusedgrouped-queryattention(GQA)toimproveinferencescalability -for our larger models. Table 1 compares the attributes of the new Llama 2 models with the Llama 1 models. -2.1 Pretraining Data -Our training corpus includes a new mix of data from publicly available sources, which does not include data -fromMeta’sproductsorservices. Wemadeanefforttoremovedatafromcertainsitesknowntocontaina -highvolumeofpersonalinformationaboutprivateindividuals. Wetrainedon2trilliontokensofdataasthis -providesagoodperformance–costtrade-off,up-samplingthemostfactualsourcesinanefforttoincrease -knowledge and dampen hallucinations. -Weperformedavarietyofpretrainingdatainvestigationssothatuserscanbetterunderstandthepotential -capabilities and limitations of our models; results can be found in Section 4.1. -2.2 Training Details -We adopt most of the pretraining setting and model architecture from Llama 1 . We use the standard -transformer architecture (Vaswani et al., 2017), apply pre-normalization using RMSNorm (Zhang and -Sennrich, 2019), use the SwiGLU activation function (Shazeer, 2020), and rotary positional embeddings -(RoPE, Su et al. 2022). The primary architectural differences from Llama 1 include increased context length -andgrouped-queryattention(GQA).WedetailinAppendixSectionA.2.1eachofthesedifferenceswith -ablation experiments to demonstrate their importance. -Hyperparameters. We trained using the AdamW optimizer (Loshchilov and Hutter, 2017), with β1= -0.9, β2= 0.95,eps= 10−5. We use a cosine learning rate schedule, with warmup of 2000 steps, and decay -finallearningratedownto10%ofthepeaklearningrate. Weuseaweightdecayof 0.1andgradientclipping -of1.0. Figure 5 (a) shows the training loss for Llama 2 with these hyperparameters. -5 -Training Data Params Context -LengthGQA Tokens LR -Llama 1See Touvron et al. -(2023)7B 2k ✗ 1.0T 3.0×10−4 -13B 2k ✗ 1.0T 3.0×10−4 -33B 2k ✗ 1.4T 1.5×10−4 -65B 2k ✗ 1.4T 1.5×10−4 -Llama 2A new mix of publicly -available online data7B 4k ✗ 2.0T 3.0×10−4 -13B 4k ✗ 2.0T 3.0×10−4 -34B 4k ✓ 2.0T 1.5×10−4 -70B 4k ✓ 2.0T 1.5×10−4 -Table 1: Llama 2 family of models. Token counts refer to pretraining data only. All models are trained with -a global batch-size of 4M tokens. Bigger models — 34B and 70B — use Grouped-Query Attention (GQA) for -improved inference scalability. -0 250 500 750 1000 1250 1500 1750 2000 -Processed Tokens (Billions)1.41.51.61.71.81.92.02.12.2Train PPLLlama-2 -7B -13B -34B -70B -Figure 5: Training Loss for Llama 2 models. We compare the training loss of the Llama 2 family of models. -We observe that after pretraining on 2T Tokens, the models still did not show any sign of saturation. -Tokenizer. Weusethesametokenizeras Llama 1;itemploysabytepairencoding(BPE)algorithm(Sennrich -etal.,2016)usingtheimplementationfromSentencePiece(KudoandRichardson,2018). Aswith Llama 1, -we split all numbers into individual digits and use bytes to decompose unknown UTF-8 characters. The total -vocabulary size is 32k tokens. -2.2.1 Training Hardware & Carbon Footprint -TrainingHardware. WepretrainedourmodelsonMeta’sResearchSuperCluster(RSC)(LeeandSengupta, -2022)aswellasinternalproductionclusters. BothclustersuseNVIDIAA100s. Therearetwokeydifferences -between the two clusters, with the first being the type of interconnect available: RSC uses NVIDIA Quantum -InfiniBandwhileourproductionclusterisequippedwithaRoCE(RDMAoverconvergedEthernet)solution -based on commodity ethernet Switches. Both of these solutions interconnect 200 Gbps end-points. The -seconddifferenceistheper-GPUpowerconsumptioncap—RSCuses400Wwhileourproductioncluster -uses350W.Withthistwo-clustersetup,wewereabletocomparethesuitabilityofthesedifferenttypesof -interconnectforlargescaletraining. RoCE(whichisamoreaffordable,commercialinterconnectnetwork) -6 -Time -(GPU hours)Power -Consumption (W)Carbon Emitted -(tCO 2eq) -Llama 27B 184320 400 31.22 -13B 368640 400 62.44 -34B 1038336 350 153.90 -70B 1720320 400 291.42 -Total 3311616 539.00 -Table 2: CO2emissions during pretraining. Time: total GPU time required for training each model. Power -Consumption: peak power capacity per GPU device for the GPUs used adjusted for power usage efficiency. -100%oftheemissionsaredirectlyoffsetbyMeta’ssustainabilityprogram,andbecauseweareopenlyreleasing -these models, the pretraining costs do not need to be incurred by others. -can scale almost as well as expensive Infiniband up to 2000 GPUs, which makes pretraining even more -democratizable. On A100s with RoCE and GPU power capped at 350W, our optimized codebase reached up -to 90% of the performance of RSC using IB interconnect and 400W GPU power. -Carbon Footprint of Pretraining. Following preceding research (Bender et al., 2021a; Patterson et al., 2021; -Wu et al., 2022; Dodge et al., 2022) and using power consumption estimates of GPU devices and carbon -efficiency, we aim tocalculate thecarbon emissions resultingfrom the pretrainingof Llama 2 models. The -actualpowerusageofaGPUisdependentonitsutilizationandislikelytovaryfromtheThermalDesign -Power(TDP)thatweemployasanestimationforGPUpower. Itisimportanttonotethatourcalculations -do not account for further power demands, such as those from interconnect or non-GPU server power -consumption,norfromdatacentercoolingsystems. Additionally,thecarbonoutputrelatedtotheproduction -of AI hardware, like GPUs, could add to the overall carbon footprint as suggested by Gupta et al. (2022b,a). -Table 2 summarizes the carbon emission for pretraining the Llama 2 family of models. A cumulative of -3.3M GPUhours ofcomputation wasperformed onhardware oftype A100-80GB (TDPof 400Wor 350W). -We estimate the total emissions for training to be 539 tCO 2eq, of which 100% were directly offset by Meta’s -sustainability program.∗∗Our open release strategy also means that these pretraining costs will not need to -be incurred by other companies, saving more global resources. -2.3 Llama 2 Pretrained Model Evaluation -In this section, we report the results for the Llama 1 andLlama 2 base models, MosaicML Pretrained -Transformer(MPT)††models,andFalcon(Almazroueietal.,2023)modelsonstandardacademicbenchmarks. -For all the evaluations, we use our internal evaluations library. We reproduce results for the MPT and Falcon -modelsinternally. Forthesemodels,wealwayspickthebestscorebetweenourevaluationframeworkand -any publicly reported results. -InTable3,wesummarizetheoverallperformanceacrossasuiteofpopularbenchmarks. Notethatsafety -benchmarks are shared in Section 4.1. The benchmarks are grouped into the categories listed below. The -results for all the individual benchmarks are available in Section A.2.2. -•Code.Wereporttheaveragepass@1scoresofourmodelsonHumanEval(Chenetal.,2021)and -MBPP (Austin et al., 2021). -•CommonsenseReasoning. WereporttheaverageofPIQA(Bisketal.,2020),SIQA(Sapetal.,2019), -HellaSwag (Zellers et al., 2019a), WinoGrande (Sakaguchi et al., 2021), ARC easy and challenge -(Clark et al., 2018), OpenBookQA (Mihaylov et al., 2018), and CommonsenseQA (Talmor et al., -2018). We report 7-shot results for CommonSenseQA and 0-shot results for all other benchmarks. -•World Knowledge. We evaluate the 5-shot performance on NaturalQuestions (Kwiatkowski et al., -2019) and TriviaQA (Joshi et al., 2017) and report the average. -•Reading Comprehension. For reading comprehension, we report the 0-shot average on SQuAD -(Rajpurkar et al., 2018), QuAC (Choi et al., 2018), and BoolQ (Clark et al., 2019). -∗∗https://sustainability.fb.com/2021-sustainability-report/ -††https://www.mosaicml.com/blog/mpt-7b -7 -Model Size CodeCommonsense -ReasoningWorld -KnowledgeReading -ComprehensionMath MMLU BBH AGI Eval -MPT7B 20.5 57.4 41.0 57.5 4.9 26.8 31.0 23.5 -30B 28.9 64.9 50.0 64.7 9.1 46.9 38.0 33.8 -Falcon7B 5.6 56.1 42.8 36.0 4.6 26.2 28.0 21.2 -40B 15.2 69.2 56.7 65.7 12.6 55.4 37.1 37.0 -Llama 17B 14.1 60.8 46.2 58.5 6.95 35.1 30.3 23.9 -13B 18.9 66.1 52.6 62.3 10.9 46.9 37.0 33.9 -33B 26.0 70.0 58.4 67.6 21.4 57.8 39.8 41.7 -65B 30.7 70.7 60.5 68.6 30.8 63.4 43.5 47.6 -Llama 27B 16.8 63.9 48.9 61.3 14.6 45.3 32.6 29.3 -13B 24.5 66.9 55.4 65.8 28.7 54.8 39.4 39.1 -34B 27.8 69.9 58.7 68.0 24.2 62.6 44.1 43.4 -70B37.5 71.9 63.6 69.4 35.2 68.9 51.2 54.2 -Table3: Overallperformanceongroupedacademicbenchmarkscomparedtoopen-sourcebasemodels. -•MATH. We report the average of the GSM8K (8 shot) (Cobbe et al., 2021) and MATH (4 shot) -(Hendrycks et al., 2021) benchmarks at top 1. -•Popular Aggregated Benchmarks . We report the overall results for MMLU (5 shot) (Hendrycks -et al., 2020), Big Bench Hard (BBH) (3 shot) (Suzgun et al., 2022), and AGI Eval (3–5 shot) (Zhong -et al., 2023). For AGI Eval, we only evaluate on the English tasks and report the average. -As shown in Table 3, Llama 2 models outperform Llama 1 models. In particular, Llama 2 70B improves the -resultsonMMLUandBBHby ≈5and≈8points,respectively,comparedto Llama 1 65B.Llama 2 7Band30B -modelsoutperformMPTmodelsofthecorrespondingsizeonallcategoriesbesidescodebenchmarks. Forthe -Falcon models, Llama 2 7B and 34B outperform Falcon 7B and 40B models on all categories of benchmarks. -Additionally, Llama 2 70B model outperforms all open-source models. -In addition to open-source models, we also compare Llama 2 70B results to closed-source models. As shown -in Table 4, Llama 2 70B is close to GPT-3.5 (OpenAI, 2023) on MMLU and GSM8K, but there is a significant -gaponcodingbenchmarks. Llama 2 70BresultsareonparorbetterthanPaLM(540B)(Chowdheryetal., -2022)onalmostallbenchmarks. Thereisstillalargegapinperformancebetween Llama 2 70BandGPT-4 -and PaLM-2-L. -We also analysed the potential data contamination and share the details in Section A.6. -Benchmark (shots) GPT-3.5 GPT-4 PaLM PaLM-2-L Llama 2 -MMLU (5-shot) 70.0 86.4 69.3 78.3 68.9 -TriviaQA (1-shot) – – 81.4 86.1 85.0 -Natural Questions (1-shot) – – 29.3 37.5 33.0 -GSM8K (8-shot) 57.1 92.0 56.5 80.7 56.8 -HumanEval (0-shot) 48.1 67.0 26.2 – 29.9 -BIG-Bench Hard (3-shot) – – 52.3 65.7 51.2 -Table 4: Comparison to closed-source models on academic benchmarks. Results for GPT-3.5 and GPT-4 -are from OpenAI (2023). Results for the PaLM model are from Chowdhery et al. (2022). Results for the -PaLM-2-L are from Anil et al. (2023). - -== END ARTICLE == - -""" - question = "What is the paper about?" - model_inputs = tokenizer(VERY_LONG_INPUT + question, return_tensors="pt").to(torch_device) - - # No RoPE scaling -> garbage output - model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", - device_map="auto", - load_in_4bit=True, - ) - self.assertTrue(model_inputs["input_ids"].shape[1] > model.config.max_position_embeddings) - generate_kwargs = {"max_new_tokens": 40, "do_sample": False} - gen_out = model.generate(**model_inputs, **generate_kwargs) - decoded_text = tokenizer.decode(gen_out[0], skip_special_tokens=True) - self.assertTrue(decoded_text.endswith("Ћ\nЋ\nЋЋЋЋЋЋЋЋЋ\nЋ\nЋ\n")) - - # Dynamic NTK RoPE scaling -> good output (doesn't need fine-tuning) - model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", - device_map="auto", - load_in_4bit=True, - rope_scaling={"type": "dynamic", "factor": 2.0}, - ) - generate_kwargs = {"max_new_tokens": 40, "do_sample": False} - gen_out = model.generate(**model_inputs, **generate_kwargs) - decoded_text = tokenizer.decode(gen_out[0], skip_special_tokens=True) - self.assertTrue( - decoded_text.endswith( - "The paper is about the release of Llama 2, a family of pretrained and fine-tuned large language models.\nWhat is Llama 2?\n" - ) - ) - # Note: the output above matches our initial release of RoPE scaling - - # Linear RoPE scaling -> usualy okay output (should be used with fine-tuning) - model = LlamaForCausalLM.from_pretrained( - "meta-llama/Llama-2-7b-hf", - device_map="auto", - load_in_4bit=True, - rope_scaling={"type": "linear", "factor": 2.0}, - ) - generate_kwargs = {"max_new_tokens": 40, "do_sample": False} - gen_out = model.generate(**model_inputs, **generate_kwargs) - decoded_text = tokenizer.decode(gen_out[0], skip_special_tokens=True) - self.assertTrue( - decoded_text.endswith( - "The paper is about the development of Llama 2, a large language model (LLM) family, and the release of Llama 2-Chat, a fine-" - ) - ) - @require_torch class CodeLlamaIntegrationTest(unittest.TestCase): diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 864db992772772..776a9d562aebe9 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -45,6 +45,11 @@ PersimmonForSequenceClassification, PersimmonModel, ) + from transformers.models.persimmon.modeling_persimmon import ( + PersimmonDynamicNTKScalingRotaryEmbedding, + PersimmonLinearScalingRotaryEmbedding, + PersimmonRotaryEmbedding, + ) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon @@ -365,8 +370,8 @@ def test_save_load_fast_init_from_base(self): pass @parameterized.expand([("linear",), ("dynamic",)]) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling with Llama->Persimmon - def test_model_rope_scaling(self, scaling_type): + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Persimmon + def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) @@ -396,6 +401,65 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + + # Sanity check original RoPE + original_rope = PersimmonRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, short_input_length) + original_cos_long, original_sin_long = original_rope(x, long_input_length) + torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = PersimmonLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) + torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = PersimmonDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_torch class PersimmonIntegrationTest(unittest.TestCase): diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index d69bbb32c1a682..e7702ad0a18a76 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -19,8 +19,9 @@ import unittest import pytest +from parameterized import parameterized -from transformers import PhiConfig, is_torch_available +from transformers import PhiConfig, is_torch_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, @@ -46,6 +47,11 @@ PhiForTokenClassification, PhiModel, ) + from transformers.models.phi.modeling_phi import ( + PhiDynamicNTKScalingRotaryEmbedding, + PhiLinearScalingRotaryEmbedding, + PhiRotaryEmbedding, + ) class PhiModelTester: @@ -360,6 +366,96 @@ def test_phi_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + @parameterized.expand([("linear",), ("dynamic",)]) + def test_model_rope_scaling_from_config(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = PhiModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = PhiModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + + # Sanity check original RoPE + original_rope = PhiRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, short_input_length) + original_cos_long, original_sin_long = original_rope(x, long_input_length) + torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = PhiLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) + torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = PhiDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_flash_attn @require_torch_gpu @require_bitsandbytes diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index 2497dfc3eee6c4..e020076e4ff037 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -44,6 +44,11 @@ StableLmForSequenceClassification, StableLmModel, ) + from transformers.models.stablelm.modeling_stablelm import ( + StableLmDynamicNTKScalingRotaryEmbedding, + StableLmLinearScalingRotaryEmbedding, + StableLmRotaryEmbedding, + ) # Copied from transformers.tests.models.persimmon.test_modeling_persimmon.PersimmonModelTester with Persimmon -> StableLm @@ -351,7 +356,7 @@ def test_stablelm_sequence_classification_model_for_multi_label(self): self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) @parameterized.expand([("linear",), ("dynamic",)]) - def test_model_rope_scaling(self, scaling_type): + def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) @@ -381,6 +386,65 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + def test_model_rope_scaling(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + head_dim = hidden_size // num_heads + scaling_factor = 10 + short_input_length = 10 + long_input_length = int(config.max_position_embeddings * 1.5) + + # Inputs + x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + + # Sanity check original RoPE + original_rope = StableLmRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ).to(torch_device) + original_cos_short, original_sin_short = original_rope(x, short_input_length) + original_cos_long, original_sin_long = original_rope(x, long_input_length) + torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + + # Sanity check linear RoPE scaling + # New position "x" should match original position with index "x/scaling_factor" + linear_scaling_rope = StableLmLinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + for new_position in range(0, long_input_length, scaling_factor): + original_position = int(new_position // scaling_factor) + torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) + torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + + # Sanity check Dynamic NTK RoPE scaling + # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase + # with scaling_factor (or that `inv_freq` decreases) + ntk_scaling_rope = StableLmDynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + scaling_factor=scaling_factor, + ).to(torch_device) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + torch.testing.assert_close(ntk_cos_short, original_cos_short) + torch.testing.assert_close(ntk_sin_short, original_sin_short) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_cos_long, original_cos_long) + with self.assertRaises(AssertionError): + torch.testing.assert_close(ntk_sin_long, original_sin_long) + self.assertTrue((ntk_scaling_rope.inv_freq <= original_rope.inv_freq).all()) + @require_torch class StableLmModelIntegrationTest(unittest.TestCase): From 2154b5572dc3c82ca548a86f05895ab05529d4db Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 28 Mar 2024 11:08:42 +0000 Subject: [PATCH 4/4] add copy statements --- tests/models/falcon/test_modeling_falcon.py | 1 + tests/models/gpt_neox/test_modeling_gpt_neox.py | 2 ++ tests/models/persimmon/test_modeling_persimmon.py | 1 + tests/models/phi/test_modeling_phi.py | 2 ++ tests/models/stablelm/test_modeling_stablelm.py | 2 ++ 5 files changed, 8 insertions(+) diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index 346c62cd66b5cd..17b3dc42cf5684 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -413,6 +413,7 @@ def test_past_key_values_format(self): ) @parameterized.expand([("linear",), ("dynamic",)]) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 00065a7006ffff..92d130b35101bb 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -306,6 +306,7 @@ def test_feed_forward_chunking(self): pass @parameterized.expand([("linear",), ("dynamic",)]) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->GPTNeoX def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) @@ -336,6 +337,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->GPTNeoX, rope_theta->rotary_emb_base def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() hidden_size = config.hidden_size diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 776a9d562aebe9..79cee8a64863cb 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -401,6 +401,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Persimmon def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() hidden_size = config.hidden_size diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index e7702ad0a18a76..e3c145bfa268ca 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -367,6 +367,7 @@ def test_phi_sequence_classification_model_for_multi_label(self): self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) @parameterized.expand([("linear",), ("dynamic",)]) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Phi def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) @@ -397,6 +398,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Phi def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() hidden_size = config.hidden_size diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index e020076e4ff037..64f828825c44fa 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -356,6 +356,7 @@ def test_stablelm_sequence_classification_model_for_multi_label(self): self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) @parameterized.expand([("linear",), ("dynamic",)]) + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->StableLm def test_model_rope_scaling_from_config(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) @@ -386,6 +387,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->StableLm def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() hidden_size = config.hidden_size