diff --git a/examples/example_llama3.py b/examples/example_llama3.py index eeff9ffb..df57751f 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -42,18 +42,36 @@ use_vocab_parallel = not use_1d_mesh device = torch.device("cuda") +model_type = "8b" + def model_fn(): - model_args = TransformerModelArgs( - dim=4096, - n_heads=32, - n_layers=32, - vocab_size=vocab_size, - max_seq_len=seqlen, - multiple_of=1024, - ffn_dim_multiplier=1.3, - n_kv_heads=8, - ) + if model_type == "8b": + model_args = TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + vocab_size=vocab_size, + max_seq_len=seqlen, + ) + elif model_type == "70b": + model_args = TransformerModelArgs( + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + rope_theta=500000, + vocab_size=vocab_size, + max_seq_len=seqlen, + ) + else: + raise ValueError(f"{model_type} not available") m = Transformer(model_args) return m