-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support DBRX Model #29911
Comments
Looks like the authors are already on it: #29921 |
Using the example code on the model card page results in an error. model = AutoModelForCausalLM.from_pretrained("databricks/dbrx-instruct", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, token=token, cache_dir=HF_CACHE) ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-6-e8e588a00506> in <module>
----> 1 model = AutoModelForCausalLM.from_pretrained("databricks/dbrx-instruct", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, token=token, cache_dir=HF_CACHE)
/databricks/python/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
561 elif type(config) in cls._model_mapping.keys():
562 model_class = _get_model_class(config, cls._model_mapping)
--> 563 return model_class.from_pretrained(
564 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
565 )
/databricks/python/lib/python3.8/site-packages/transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
3548 with ContextManagers(init_contexts):
3549 # Let's make sure we don't run the init function of buffer modules
-> 3550 model = cls(config, *model_args, **model_kwargs)
3551
3552 # make sure we use the model's config since the __init__ call might have copied it
/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config)
1293 def __init__(self, config: DbrxConfig):
1294 super().__init__(config)
-> 1295 self.transformer = DbrxModel(config)
1296 self.vocab_size = config.vocab_size
1297 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config)
1078
1079 self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
-> 1080 self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
1081 self.norm_f = nn.LayerNorm(config.d_model, bias=False)
1082 self.gradient_checkpointing = False
/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in <listcomp>(.0)
1078
1079 self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
-> 1080 self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
1081 self.norm_f = nn.LayerNorm(config.d_model, bias=False)
1082 self.gradient_checkpointing = False
/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config, block_idx)
834 self.resid_pdrop = config.resid_pdrop
835 self.block_idx = block_idx
--> 836 self.norm_attn_norm = DbrxNormAttentionNorm(
837 config=config,
838 block_idx=block_idx,
/databricks/python/lib/python3.8/site-packages/transformers/models/dbrx/modeling_dbrx.py in __init__(self, config, block_idx)
645 self.block_idx = block_idx
646 self.resid_pdrop = config.resid_pdrop
--> 647 self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
648 self.attn = DBRX_ATTENTION_CLASSES[config._attn_implementation](
649 config=config,
TypeError: __init__() got an unexpected keyword argument 'bias' |
Hi @jmwoloso - thanks for flagging! cc @Rocketknight1 |
Thanks for implementing @amyeroberts 🤗 |
Pinging @eitanturok to this one - I'm guessing this probably means the |
Hi @jmwoloso, just realized that example is using
and let us know if you get the same error? If not, then all we need to do is fix the example in the docs. |
@Rocketknight1 thanks for the prompt (ha!) response, we'll try that out shortly and let you know. |
same error @Rocketknight1. I'm using model = AutoModelForCausalLM.from_pretrained(HF_CACHE + DBRX, device_map="auto", torch_dtype=torch.bfloat16, token=token)
|
Feature request
Support the DBRX model (only correct pronunciation: DB-Rex) blog post.
Code is from the open source databricks/dbrx repository.
Motivation
Your contribution
#29910
The text was updated successfully, but these errors were encountered: