diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index b926c80398c25a..1c4362ef0cec1a 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -42,7 +42,7 @@ def get_some_linear_layer(model): - if model.config.model_type == "openai-community/gpt2": + if model.config.model_type == "gpt2": return model.transformer.h[0].mlp.c_fc return model.transformer.h[0].mlp.dense_4h_to_h