From 9f46413219c1b33049c9a93fac43ae807c223171 Mon Sep 17 00:00:00 2001 From: Dan Stahlke Date: Sun, 22 Sep 2024 16:24:50 -0700 Subject: [PATCH 1/6] added nanoGPT --- text/nanogpt/Manifest.toml | 1052 +++++++++++++++++++ text/nanogpt/Project.toml | 10 + text/nanogpt/README.md | 28 + text/nanogpt/docs/Full_GPT_architecture.svg | 156 +++ text/nanogpt/gpt.jl | 263 +++++ 5 files changed, 1509 insertions(+) create mode 100644 text/nanogpt/Manifest.toml create mode 100644 text/nanogpt/Project.toml create mode 100644 text/nanogpt/README.md create mode 100644 text/nanogpt/docs/Full_GPT_architecture.svg create mode 100644 text/nanogpt/gpt.jl diff --git a/text/nanogpt/Manifest.toml b/text/nanogpt/Manifest.toml new file mode 100644 index 00000000..8ab01068 --- /dev/null +++ b/text/nanogpt/Manifest.toml @@ -0,0 +1,1052 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.0" +manifest_format = "2.0" +project_hash = "e88c592f8b016de0c2ec0bdcabed8d5bce446df1" + +[[deps.AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + +[[deps.Accessors]] +deps = ["CompositionsBase", "ConstructionBase", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown"] +git-tree-sha1 = "b392ede862e506d451fc1616e79aa6f4c673dab8" +uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +version = "0.1.38" + + [deps.Accessors.extensions] + AccessorsAxisKeysExt = "AxisKeys" + AccessorsDatesExt = "Dates" + AccessorsIntervalSetsExt = "IntervalSets" + AccessorsStaticArraysExt = "StaticArrays" + AccessorsStructArraysExt = "StructArrays" + AccessorsTestExt = "Test" + AccessorsUnitfulExt = "Unitful" + + [deps.Accessors.weakdeps] + AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + Requires = "ae029012-a4dd-5104-9daa-d747884805df" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.0.4" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.5.0" + +[[deps.BangBang]] +deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] +git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.4.3" + + [deps.BangBang.extensions] + BangBangChainRulesCoreExt = "ChainRulesCore" + BangBangDataFramesExt = "DataFrames" + BangBangStaticArraysExt = "StaticArrays" + BangBangStructArraysExt = "StructArrays" + BangBangTablesExt = "Tables" + BangBangTypedTablesExt = "TypedTables" + + [deps.BangBang.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" + TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] +git-tree-sha1 = "fdd9dfb67dfefd548f51000cc400bb51003de247" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "5.4.3" + + [deps.CUDA.extensions] + ChainRulesCoreExt = "ChainRulesCore" + EnzymeCoreExt = "EnzymeCore" + SpecialFunctionsExt = "SpecialFunctions" + + [deps.CUDA.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "325058b426c2b421e3d2df3d5fa646d72d2e3e7e" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.9.2+0" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "33576c7c1b2500f8e7e6baa082e04563203b3a45" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.3.5" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "afea94249b821dc754a8ca6695d3daed851e1f5a" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.14.1+0" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "9.0.0+1" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] +git-tree-sha1 = "be227d253d132a6d57f9ccf5f67c0fb6488afd87" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.71.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.25.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.5" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.11" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools"] +git-tree-sha1 = "cda2cfaebb4be89c9084adaca7dd7333369715c5" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.1" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.16.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.0.5+1" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" +weakdeps = ["InverseFunctions"] + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + +[[deps.ConstructionBase]] +git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.8" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseLinearAlgebraExt = "LinearAlgebra" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.2" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.FileIO]] +deps = ["Pkg", "Requires", "UUIDs"] +git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322" +uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +version = "1.16.3" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.13.0" + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + + [deps.FillArrays.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.5" + +[[deps.Flux]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "d7d0a182089d9d3ff0cd0b761d21020fea2b1035" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.14.20" + + [deps.Flux.extensions] + FluxAMDGPUExt = "AMDGPU" + FluxCUDAExt = "CUDA" + FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + FluxEnzymeExt = "Enzyme" + FluxMPIExt = "MPI" + FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] + FluxMetalExt = "Metal" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "64d8e93700c7a3f28f717d265382d52fac9fa1c1" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.4.12" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "62ee71528cca49be797076a76bdc654a170a523e" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "10.3.1" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Preferences", "Scratch", "Serialization", "TOML", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "ab29216184312f99ff957b32cd63c2fe9c928b91" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.26.7" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.14" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InlineStrings]] +git-tree-sha1 = "45521d31238e87ee9f9732561bfee12d4eebd52d" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.2" + + [deps.InlineStrings.extensions] + ArrowTypesExt = "ArrowTypes" + ParsersExt = "Parsers" + + [deps.InlineStrings.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InverseFunctions]] +git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.17" +weakdeps = ["Dates", "Test"] + + [deps.InverseFunctions.extensions] + InverseFunctionsDatesExt = "Dates" + InverseFunctionsTestExt = "Test" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLD2]] +deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] +git-tree-sha1 = "07f9dec43deef049c2f0daa96f67bfc0baa20a17" +uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +version = "0.5.3" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "f389674c99bfcde17dc57454011aa44d5a260a40" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.6.0" + +[[deps.JuliaNVTXCallbacks_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" +uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e" +version = "0.2.1+0" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "5126765c5847f74758c411c994312052eb7117ef" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.27" + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + LinearAlgebraExt = "LinearAlgebra" + SparseArraysExt = "SparseArrays" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "2470e69781ddd70b8878491233cd09bc1bd7fc96" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "8.1.0" +weakdeps = ["BFloat16s"] + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "597d1c758c9ae5d985ba4202386a607c675ee700" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.31+0" + +[[deps.LLVMLoopInfo]] +git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" +uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586" +version = "1.0.0" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.28" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.MLStyle]] +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.17" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.4.4" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.MicroCollections]] +deps = ["Accessors", "BangBang", "InitialValues"] +git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.2.0" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "da09a1e112fd75f9af2a5229323f01b56ec96a4c" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.9.24" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" + NNlibFFTWExt = "FFTW" + NNlibForwardDiffExt = "ForwardDiff" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.NVTX]] +deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] +git-tree-sha1 = "53046f0483375e3ed78e49190f1154fa0a4083a1" +uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" +version = "0.3.4" + +[[deps.NVTX_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b" +uuid = "e98f9f5b-d649-5603-91fd-7774390e6439" +version = "3.1.0+2" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.5" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+2" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.3.3" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.2" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.ProgressMeter]] +deps = ["Distributed", "Printf"] +git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4" +uuid = "92933f4c-e287-5a05-a399-4b506db050ca" +version = "1.10.2" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "4743b43e5a9c4a2ede372de7061eed81795b12e7" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.7.0" + +[[deps.RandomNumbers]] +deps = ["Random"] +git-tree-sha1 = "c6ec94d2aaba1ab2ff983052cf6a606ca5985902" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.6.0" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.1" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "ff11acffdb082493657550959d4feb4b6149e73a" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.5" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SparseInverseSubset]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" +uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" +version = "0.1.2" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.7" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.3" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + +[[deps.StructArrays]] +deps = ["ConstructionBase", "DataAPI", "Tables"] +git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.18" +weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] + + [deps.StructArrays.extensions] + StructArraysAdaptExt = "Adapt" + StructArraysGPUArraysCoreExt = "GPUArraysCore" + StructArraysSparseArraysExt = "SparseArrays" + StructArraysStaticArraysExt = "StaticArrays" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.12.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.24" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "e84b3a11b9bece70d14cce63406bbc79ed3464d2" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.11.2" + +[[deps.Transducers]] +deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.82" + + [deps.Transducers.extensions] + TransducersBlockArraysExt = "BlockArrays" + TransducersDataFramesExt = "DataFrames" + TransducersLazyArraysExt = "LazyArrays" + TransducersOnlineStatsBaseExt = "OnlineStatsBase" + TransducersReferenceablesExt = "Referenceables" + + [deps.Transducers.weakdeps] + BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" + Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "2d17fabcd17e67d7625ce9c531fb9f40b7c42ce4" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.2.1" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.70" + + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.5" + +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] +git-tree-sha1 = "4909e87d6d62c29a897d54d9001c63932e41cb0e" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.3.2" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" diff --git a/text/nanogpt/Project.toml b/text/nanogpt/Project.toml new file mode 100644 index 00000000..ce6b93b8 --- /dev/null +++ b/text/nanogpt/Project.toml @@ -0,0 +1,10 @@ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" diff --git a/text/nanogpt/README.md b/text/nanogpt/README.md new file mode 100644 index 00000000..07451575 --- /dev/null +++ b/text/nanogpt/README.md @@ -0,0 +1,28 @@ +# Generative pre-trained transformer + +![GPT architecture](docs/Full_GPT_architecture.svg) + +[Source](https://en.wikipedia.org/wiki/Generative_pre-trained_transformer) + +## Model Information + +GPT is built of a multi-head attention architecture. We offer here a very small instance based on Andrej Karpathy's [nanoGPT](https://github.com/karpathy/nanoGPT). The default parameters give a model much smaller than nanoGPT, tuned for fastest convergence on a very small data set (Shakespeare). + +This model takes as input a sequence of existing text (context) and produces as output the predicted next character. Actually, it produces the predicted next character for each initial sub-sequence of the input, in effect giving an extra degree of parallelism for the purposes of training. + +For the attention mechanism, we use [Flux.MultiHeadAttention](https://fluxml.ai/Flux.jl/stable/reference/models/layers/#MultiHeadAttention). + + +## Training + +```shell +cd text/gpt +julia --project gpt.jl +``` + +## References + +* [Attention is all you need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) +* [Youtube (3blue1brown): Attention in transformers, visually explained](https://www.youtube.com/watch?v=eMlx5fFNoYc) +* [Youtube (Karpathy): Let's build GPT: from scratch, in code, spelled out](https://www.youtube.com/watch?v=kCc8FmEb1nY) +* [nanoGPT](https://github.com/karpathy/nanoGPT) diff --git a/text/nanogpt/docs/Full_GPT_architecture.svg b/text/nanogpt/docs/Full_GPT_architecture.svg new file mode 100644 index 00000000..32e9cdb6 --- /dev/null +++ b/text/nanogpt/docs/Full_GPT_architecture.svg @@ -0,0 +1,156 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Transformer Block + Layer L + Transformer Block + Layer ... + Transformer Block + Layer 1 + Positional + Encoding + Input + Embedding + Head 1 + Head ... + Head H + + + Softmax + LayerNorm + Dropout + Linear + Output + Input + Softmax + Mask + Matmul + Dropout + Matmul + Softmax + Mask + LayerNorm + LayerNorm + Dropout + Dropout + Dropout + Matmul + Softmax + Mask + Matmul + Dropout + Matmul + Linear + Linear + Linear + Linear + Gelu + Matmul + Transformer Block Input + Transformer Block Output + + + diff --git a/text/nanogpt/gpt.jl b/text/nanogpt/gpt.jl new file mode 100644 index 00000000..1fc48122 --- /dev/null +++ b/text/nanogpt/gpt.jl @@ -0,0 +1,263 @@ +## Multi-head attention (GPT) + +# GPT is built of a multi-head attention architecture. We offer here a very small instance based on +# Andrej Karpathy's [nanoGPT](https://github.com/karpathy/nanoGPT). The default parameters give a +# model much smaller than nanoGPT, tuned for fastest convergence on a very small data set +# (Shakespeare). + +# This model takes as input a sequence of existing text (context) and produces as output the +# predicted next character. Actually, it produces the predicted next character for each initial +# sub-sequence of the input, in effect giving an extra degree of parallelism for the purposes of +# training. + +# To run this example, we need the following packages: + +For the attention mechanism, we use [Flux.MultiHeadAttention](https://fluxml.ai/Flux.jl/stable/reference/models/layers/#MultiHeadAttention). + +using JLD2 +using CUDA, cuDNN +using Flux +using MLUtils +using Random +using Statistics +using StatsBase +using ProgressMeter + +device = Flux.get_device() + +# With these options, each epoch takes 22 seconds on an RTX 4090. +# Loss is 1.81 after 1 epoch, and generates recognizable text. +# Loss is 1.58 after 5 epochs. +# Loss is 1.52 after 20 epochs. +Base.@kwdef mutable struct Args + n_embed::Int = 64 # Length of latent vector + n_hidden::Int = 256 # Hidden dim for MLP layer + n_heads::Int = 4 # Number of attention heads + qk_dim::Int = 16 # Attn query/key size, typically n_embed / n_heads + v_dim::Int = 16 # Attn value size, typically n_embed / n_heads + n_layers::Int = 6 # Number of attention/MLP layers + seqlen::Int = 64 # Context length + batchsz::Int = 128 # Number of sequences in each batch + dropout::Float32 = 0.0 # Dropout fraction during training + testpercent::Float64 = 0.1 # Percent of corpus examples to use for testing + lr::Float64 = 1e-2 # Learning rate + epochs::Int = 20 # Number of epochs +end + + + +# One layer of the GPT model. We will have args.n_layers of these. +struct GPTBlock + layernorm1::LayerNorm + mha::MultiHeadAttention + mlp::Chain +end + +Flux.@layer GPTBlock + +function GPTBlock(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout) + GPTBlock( + LayerNorm(n_embed), + MultiHeadAttention(n_embed => (qk_dim, v_dim) => n_embed; nheads=n_heads, dropout_prob=dropout), + Chain( + LayerNorm(n_embed), + Dense(n_embed => n_hidden, gelu), + Dense(n_hidden => n_embed), + Dropout(dropout) + ), + ) +end + +function (m::GPTBlock)(x) + y, α = m.mha(m.layernorm1(x); mask=NNlib.make_causal_mask(x)) + x += y + x += m.mlp(x) +end + + + +struct GPT + alphabet::Vector{Char} + tok_embed::Embedding + pos_embed::Embedding + dropout::Dropout + blocks::Vector{GPTBlock} + layernorm1::LayerNorm + output_layer::Dense +end + +Flux.@layer GPT + +function GPT(args::Args, alphabet::AbstractVector{Char}) + n_vocab = length(alphabet) + GPT( + alphabet, + Embedding(n_vocab => args.n_embed), + Embedding(args.seqlen => args.n_embed), + Dropout(args.dropout), + map(_ -> GPTBlock( + n_embed = args.n_embed, + n_hidden = args.n_hidden, + qk_dim = args.qk_dim, + v_dim = args.v_dim, + n_heads = args.n_heads, + dropout = args.dropout), 1:args.n_layers), + LayerNorm(args.n_embed), + Dense(args.n_embed => n_vocab), + ) +end + +function (m::GPT)(tokens) + T, B = size(tokens) + te = m.tok_embed(tokens) + pe = m.pos_embed(1:T) + x = m.dropout(te .+ pe) + for blk in m.blocks + x = blk(x) + end + x = m.layernorm1(x) + x = m.output_layer(x) + return x +end + +# Infer args.seqlen from the given model. +context_length(m::GPT) = size(m.pos_embed.weight, 2) + + + +# Use the model to generate some text. +function generate(model, seed, outlen) + seqlen = context_length(model) + if isempty(seed) + seed = "_" + end + x = map(c -> findfirst(==(c), model.alphabet)::Int64, collect(seed)) + while length(x) < outlen + tail = x[max(1, end-seqlen+1):end] + tail = reshape(tail, length(tail), 1) + y = model(tail |> device) |> cpu + p = softmax(y[:,end,1]) + j = sample(1:length(model.alphabet), Weights(p)) + #j = argmax(p) + #x = vcat(x, [j]) + push!(x, j) + end + String(map(j -> model.alphabet[j], x)) +end + + + +# Load data from input file, and partition into training and testing subsets. +function getdata(args::Args) + isfile("input.txt") || download( + "https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", + "input.txt", + ) + + text = String(read("input.txt")) + + ## an array of all unique characters + alphabet = [unique(text)..., '_'] + stop = alphabet[end] + + B = (length(text)-1) ÷ args.seqlen + Xs = reshape(collect(text[1:B*args.seqlen]), args.seqlen, B) + Ys = reshape(collect(text[2:B*args.seqlen+1]), args.seqlen, B) + + Xs[1,:] .= stop + + # Xs (input) should consist of indices into `alphabet` because this is what Embedding expects. + # Ys (output) should be one-hot because this is what logitcrossentropy expects. + Xs = map(c -> Int32(findfirst(==(c), alphabet)), Xs) + Ys = Flux.onehotbatch(Ys, alphabet) + #@show Xs |> typeof # = Matrix{Int32} + #@show Xs |> size # = (64, 71458) + #@show Ys |> typeof # = OneHotArrays.OneHotArray{UInt32, 2, 3, Matrix{UInt32}} + #@show Ys |> size # = (68, 64, 71458) + + numbatch = size(Xs, 2) + split = floor(Int, (1-args.testpercent) * numbatch) + + trainX, trainY = Xs[:,1:split], Ys[:,:,1:split] + testX, testY = Xs[:,(split+1):end], Ys[:,:,(split+1):end] + + return (alphabet, trainX, trainY, testX, testY) +end + + + +function train(; kws...) + args = Args(; kws...) + + @info "Training on $device" + + # Load data from input file, and partition into training and testing subsets. + alphabet, trainX, trainY, testX, testY = getdata(args) + + # Move data to the device (CPU or GPU). + trainX, trainY, testX, testY = device.((trainX, trainY, testX, testY)) + + @info "Training size: $(size(trainX, 2)) sequences." + @info "Testing size: $(size(testX, 2)) sequences." + + # This will iterate over the training data, giving us batches of size args.batchsz. + loader = MLUtils.DataLoader((trainX, trainY), batchsize=args.batchsz, shuffle=true) + + # Construct the model. + model = GPT(args, alphabet) |> device + @info "Number of params: $(sum(length, Flux.params(model)))" + + function loss(m, xs, ys) + return sum(Flux.logitcrossentropy(m(xs), ys)) + end + + opt_state = Flux.setup(Adam(args.lr), model) + + for epoch = 1:args.epochs + @info "Training, epoch $(epoch) / $(args.epochs)" + trainmode!(model) # Enable dropout, for training + @showprogress for (x,y) in loader + grad = Flux.gradient(model) do m + loss(m, x, y) + end + Flux.update!(opt_state, model, grad[1]) + end + + testmode!(model) # Disable dropout, for testing/inference + + # Save model checkpoint. + jldsave("model-checkpoint.jld2", + model_state=Flux.state(model |> cpu), + alphabet=alphabet, + args=args) + + # Show loss per character for the testing dataset. + @show loss(model, testX, testY) + + # Generate some text. The character "_" is the stop character, and we're using it here to + # represent that we are starting with zero context. + for _ in 1:5 + @show generate(model, "_", 50) + end + end + + return args, model +end + +# Load a model from a checkpoint (see `jldsave` above). +function load_model(filename) + args = JLD2.load(filename, "args") + alphabet = JLD2.load(filename, "alphabet") + model = GPT(args, alphabet) + model_state = JLD2.load(filename, "model_state") + model = Flux.loadmodel!(model, model_state); + return args, model +end + +if true + args, model = train() +else + args, model = load_model("model-checkpoint.jld2") |> device +end + +generate(model, "The", 50) From 5768b7210e75bba2df544875b8251e9f0354deb5 Mon Sep 17 00:00:00 2001 From: Dan Stahlke Date: Sun, 22 Sep 2024 16:39:32 -0700 Subject: [PATCH 2/6] fixed comment --- text/nanogpt/gpt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/text/nanogpt/gpt.jl b/text/nanogpt/gpt.jl index 1fc48122..c671da71 100644 --- a/text/nanogpt/gpt.jl +++ b/text/nanogpt/gpt.jl @@ -10,9 +10,9 @@ # sub-sequence of the input, in effect giving an extra degree of parallelism for the purposes of # training. -# To run this example, we need the following packages: +# For the attention mechanism, we use [Flux.MultiHeadAttention](https://fluxml.ai/Flux.jl/stable/reference/models/layers/#MultiHeadAttention). -For the attention mechanism, we use [Flux.MultiHeadAttention](https://fluxml.ai/Flux.jl/stable/reference/models/layers/#MultiHeadAttention). +# To run this example, we need the following packages: using JLD2 using CUDA, cuDNN From f14649efc4ac11d455f6c2f0b081b8d8cc747133 Mon Sep 17 00:00:00 2001 From: Dan Stahlke Date: Sun, 22 Sep 2024 19:43:30 -0700 Subject: [PATCH 3/6] fixes for War+Peace --- text/nanogpt/gpt.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/text/nanogpt/gpt.jl b/text/nanogpt/gpt.jl index c671da71..37a8eca1 100644 --- a/text/nanogpt/gpt.jl +++ b/text/nanogpt/gpt.jl @@ -151,19 +151,27 @@ end function getdata(args::Args) isfile("input.txt") || download( "https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt", + #"https://cs.stanford.edu/people/karpathy/char-rnn/warpeace_input.txt", "input.txt", ) text = String(read("input.txt")) + # For aesthetic reasons, replace newlines with strings. This is not necessary, but makes + # strings print nicer. + text = replace(text, r"\r?\n" => " ") + ## an array of all unique characters alphabet = [unique(text)..., '_'] stop = alphabet[end] B = (length(text)-1) ÷ args.seqlen - Xs = reshape(collect(text[1:B*args.seqlen]), args.seqlen, B) - Ys = reshape(collect(text[2:B*args.seqlen+1]), args.seqlen, B) + # We must collect() before indexing, because String indexing does strange things with multi-byte + # characters and we could end up with the wrong length. + Xs = reshape(collect(text)[1:B*args.seqlen], args.seqlen, B) + Ys = reshape(collect(text)[2:B*args.seqlen+1], args.seqlen, B) + # Input string starts with stop character '_', representing zero context. Xs[1,:] .= stop # Xs (input) should consist of indices into `alphabet` because this is what Embedding expects. @@ -197,6 +205,7 @@ function train(; kws...) # Move data to the device (CPU or GPU). trainX, trainY, testX, testY = device.((trainX, trainY, testX, testY)) + @info "Alphabet size: $(length(alphabet))" @info "Training size: $(size(trainX, 2)) sequences." @info "Testing size: $(size(testX, 2)) sequences." From f898bcac7720d8d175186187fb670951b0ca196f Mon Sep 17 00:00:00 2001 From: Dan Stahlke Date: Mon, 23 Sep 2024 19:35:45 -0700 Subject: [PATCH 4/6] Update text/nanogpt/gpt.jl Co-authored-by: Carlo Lucibello --- text/nanogpt/gpt.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/text/nanogpt/gpt.jl b/text/nanogpt/gpt.jl index 37a8eca1..ea5269b5 100644 --- a/text/nanogpt/gpt.jl +++ b/text/nanogpt/gpt.jl @@ -72,6 +72,7 @@ function (m::GPTBlock)(x) y, α = m.mha(m.layernorm1(x); mask=NNlib.make_causal_mask(x)) x += y x += m.mlp(x) + return x end From 3dbed4c29359bd5dc8a8abd72b9fc408f2f207f1 Mon Sep 17 00:00:00 2001 From: Dan Stahlke Date: Mon, 23 Sep 2024 19:40:38 -0700 Subject: [PATCH 5/6] gpt.jl: save optimizer checkpoints --- text/nanogpt/gpt.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/text/nanogpt/gpt.jl b/text/nanogpt/gpt.jl index ea5269b5..0f1cd563 100644 --- a/text/nanogpt/gpt.jl +++ b/text/nanogpt/gpt.jl @@ -222,6 +222,7 @@ function train(; kws...) end opt_state = Flux.setup(Adam(args.lr), model) + #opt_state = JLD2.load("model-checkpoint.jld2", "opt_state") for epoch = 1:args.epochs @info "Training, epoch $(epoch) / $(args.epochs)" @@ -238,6 +239,7 @@ function train(; kws...) # Save model checkpoint. jldsave("model-checkpoint.jld2", model_state=Flux.state(model |> cpu), + opt_state=opt_state, alphabet=alphabet, args=args) From 9465c80e33879874320f29cbabb7d417e16fa820 Mon Sep 17 00:00:00 2001 From: Dan Stahlke Date: Mon, 23 Sep 2024 20:27:39 -0700 Subject: [PATCH 6/6] gpt.jl: added example output --- text/nanogpt/README.md | 16 ++++++++++++++++ text/nanogpt/gpt.jl | 19 ++++++++++++++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/text/nanogpt/README.md b/text/nanogpt/README.md index 07451575..a0f7c18a 100644 --- a/text/nanogpt/README.md +++ b/text/nanogpt/README.md @@ -20,6 +20,22 @@ cd text/gpt julia --project gpt.jl ``` +## Example output + +After one epoch: + + generate(model, "_", 50) = "_me, but plept fairs, And heards, verchean my word" + generate(model, "_", 50) = "_ows know yought, This alce! totether him. weliest" + generate(model, "The", 50) = "These prurd passtion? CINCESSIT: He eloucy I must" + generate(model, "The", 50) = "The bitherse dresic in to so shall with a his the " + +After 20 epochs: + + generate(model, "_", 50) = "_ething a calling do me diseases Of, on he's to th" + generate(model, "_", 50) = "_ ragg Thou flatters all in wators the selfsarut o" + generate(model, "The", 50) = "The Mirtouggake Go: For my mischance lords his sea" + generate(model, "The", 50) = "The oll-gakemoremo his dead: All this man make gen" + ## References * [Attention is all you need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) diff --git a/text/nanogpt/gpt.jl b/text/nanogpt/gpt.jl index 0f1cd563..dfe5e512 100644 --- a/text/nanogpt/gpt.jl +++ b/text/nanogpt/gpt.jl @@ -12,6 +12,18 @@ # For the attention mechanism, we use [Flux.MultiHeadAttention](https://fluxml.ai/Flux.jl/stable/reference/models/layers/#MultiHeadAttention). +# Example output after one epoch: +# generate(model, "_", 50) = "_me, but plept fairs, And heards, verchean my word" +# generate(model, "_", 50) = "_ows know yought, This alce! totether him. weliest" +# generate(model, "The", 50) = "These prurd passtion? CINCESSIT: He eloucy I must" +# generate(model, "The", 50) = "The bitherse dresic in to so shall with a his the " + +# Example output after 20 epochs: +# generate(model, "_", 50) = "_ething a calling do me diseases Of, on he's to th" +# generate(model, "_", 50) = "_ ragg Thou flatters all in wators the selfsarut o" +# generate(model, "The", 50) = "The Mirtouggake Go: For my mischance lords his sea" +# generate(model, "The", 50) = "The oll-gakemoremo his dead: All this man make gen" + # To run this example, we need the following packages: using JLD2 @@ -248,9 +260,10 @@ function train(; kws...) # Generate some text. The character "_" is the stop character, and we're using it here to # represent that we are starting with zero context. - for _ in 1:5 - @show generate(model, "_", 50) - end + @show generate(model, "_", 50) + @show generate(model, "_", 50) + @show generate(model, "The", 50) + @show generate(model, "The", 50) end return args, model