Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DBRX Model #29911

Open
milocress opened this issue Mar 27, 2024 · 8 comments
Open

Support DBRX Model #29911

milocress opened this issue Mar 27, 2024 · 8 comments

Comments

@milocress
Copy link

Feature request

Support the DBRX model (only correct pronunciation: DB-Rex) blog post.

Code is from the open source databricks/dbrx repository.

Motivation

Across a range of standard benchmarks, DBRX sets a new state-of-the-art for established open LLMs. Moreover, it provides the open community and enterprises building their own LLMs with capabilities that were previously limited to closed model APIs; according to our measurements, it surpasses GPT-3.5, and it is competitive with Gemini 1.0 Pro. It is an especially capable code model, surpassing specialized models like CodeLLaMA-70B on programming, in addition to its strength as a general-purpose LLM.

Your contribution

#29910

@milocress milocress mentioned this issue Mar 27, 2024
5 tasks
@NielsRogge
Copy link
Contributor

Looks like the authors are already on it: #29921

@jmwoloso
Copy link
Contributor

Using the example code on the model card page results in an error.

model = AutoModelForCausalLM.from_pretrained("databricks/dbrx-instruct", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, token=token, cache_dir=HF_CACHE)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-6-e8e588a00506> in <module>
----> 1 model = AutoModelForCausalLM.from_pretrained("databricks/dbrx-instruct", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, token=token, cache_dir=HF_CACHE)

/databricks/python/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    561         elif type(config) in cls._model_mapping.keys():
    562             model_class = _get_model_class(config, cls._model_mapping)
--> 563             return model_class.from_pretrained(
    564                 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    565             )

/databricks/python/lib/python3.8/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3548         with ContextManagers(init_contexts):
   3549             # Let's make sure we don't run the init function of buffer modules
-> 3550             model = cls(config, *model_args, **model_kwargs)
   3551 
   3552         # make sure we use the model's config since the __init__ call might have copied it

/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config)
   1293     def __init__(self, config: DbrxConfig):
   1294         super().__init__(config)
-> 1295         self.transformer = DbrxModel(config)
   1296         self.vocab_size = config.vocab_size
   1297         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config)
   1078 
   1079         self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
-> 1080         self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
   1081         self.norm_f = nn.LayerNorm(config.d_model, bias=False)
   1082         self.gradient_checkpointing = False

/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in <listcomp>(.0)
   1078 
   1079         self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
-> 1080         self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
   1081         self.norm_f = nn.LayerNorm(config.d_model, bias=False)
   1082         self.gradient_checkpointing = False

/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config, block_idx)
    834         self.resid_pdrop = config.resid_pdrop
    835         self.block_idx = block_idx
--> 836         self.norm_attn_norm = DbrxNormAttentionNorm(
    837             config=config,
    838             block_idx=block_idx,

/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config, block_idx)
    645         self.block_idx = block_idx
    646         self.resid_pdrop = config.resid_pdrop
--> 647         self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
    648         self.attn = DBRX_ATTENTION_CLASSES[config._attn_implementation](
    649             config=config,

TypeError: __init__() got an unexpected keyword argument 'bias'

@amyeroberts
Copy link
Collaborator

Hi @jmwoloso - thanks for flagging! cc @Rocketknight1

@jmwoloso
Copy link
Contributor

Thanks for implementing @amyeroberts 🤗

@Rocketknight1
Copy link
Member

Rocketknight1 commented Apr 22, 2024

Pinging @eitanturok to this one - I'm guessing this probably means the bias kwarg is only supported in some but not all of the attention implementations? Let me know if you're stuck or busy and I can try to work out a fix!

@Rocketknight1
Copy link
Member

Rocketknight1 commented Apr 24, 2024

Hi @jmwoloso, just realized that example is using trust_remote_code=True, but DBRX is now fully supported in transformers! Can you try:

  1. Updating to the latest version of transformers
  2. Running the same code without trust_remote_code=True

and let us know if you get the same error? If not, then all we need to do is fix the example in the docs.

@jmwoloso
Copy link
Contributor

@Rocketknight1 thanks for the prompt (ha!) response, we'll try that out shortly and let you know.

@jmwoloso
Copy link
Contributor

same error @Rocketknight1. I'm using transformers==4.40.1:

model = AutoModelForCausalLM.from_pretrained(HF_CACHE + DBRX, device_map="auto", torch_dtype=torch.bfloat16, token=token)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-8-6dd28f3a319a> in <module>
      2 
      3 # use cached
----> 4 model = AutoModelForCausalLM.from_pretrained(HF_CACHE + DBRX, device_map="auto", torch_dtype=torch.bfloat16, token=token)

/databricks/python/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    561         elif type(config) in cls._model_mapping.keys():
    562             model_class = _get_model_class(config, cls._model_mapping)
--> 563             return model_class.from_pretrained(
    564                 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    565             )

/databricks/python/lib/python3.8/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3548         with ContextManagers(init_contexts):
   3549             # Let's make sure we don't run the init function of buffer modules
-> 3550             model = cls(config, *model_args, **model_kwargs)
   3551 
   3552         # make sure we use the model's config since the __init__ call might have copied it

/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config)
   1293     def __init__(self, config: DbrxConfig):
   1294         super().__init__(config)
-> 1295         self.transformer = DbrxModel(config)
   1296         self.vocab_size = config.vocab_size
   1297         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config)
   1078 
   1079         self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
-> 1080         self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
   1081         self.norm_f = nn.LayerNorm(config.d_model, bias=False)
   1082         self.gradient_checkpointing = False

/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in <listcomp>(.0)
   1078 
   1079         self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
-> 1080         self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
   1081         self.norm_f = nn.LayerNorm(config.d_model, bias=False)
   1082         self.gradient_checkpointing = False

/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config, block_idx)
    834         self.resid_pdrop = config.resid_pdrop
    835         self.block_idx = block_idx
--> 836         self.norm_attn_norm = DbrxNormAttentionNorm(
    837             config=config,
    838             block_idx=block_idx,

/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config, block_idx)
    645         self.block_idx = block_idx
    646         self.resid_pdrop = config.resid_pdrop
--> 647         self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
    648         self.attn = DBRX_ATTENTION_CLASSES[config._attn_implementation](
    649             config=config,

TypeError: __init__() got an unexpected keyword argument 'bias'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants