Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add Nested quantization check #3916

Merged
merged 5 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
MAX_CPU_BATCH_SIZE,
MINIMIZE,
MODEL_ECD,
MODEL_LLM,
TEST,
TRAINING,
USED_TOKENS,
Expand Down Expand Up @@ -68,6 +69,7 @@
from ludwig.utils import time_utils
from ludwig.utils.batch_size_tuner import BatchSizeEvaluator
from ludwig.utils.checkpoint_utils import Checkpoint, CheckpointManager
from ludwig.utils.config_utils import get_quantization
from ludwig.utils.data_utils import load_json
from ludwig.utils.defaults import default_random_seed
from ludwig.utils.fs_utils import path_exists
Expand Down Expand Up @@ -1133,19 +1135,19 @@ def train(

# For a full explanation of this 8-bit workaround, see https://github.com/ludwig-ai/ludwig/pull/3606
# TODO (jeffkinnison): Determine why `SCB` and `CB` are deleted from parameter state
if (
hasattr(self.model.config_obj, "quantization")
and self.model.config_obj.quantization
and self.model.config_obj.quantization.bits == 8
):
quantization = get_quantization(self.model.config_obj)
uses_quantization = bool(quantization) if not isinstance(quantization, list) else any(quantization)
if uses_quantization and 8 in quantization:
# If the model was previously placed on GPU, 8-bit parameter state will be updated with several
# matrices containing quantization information. These are recorded matrices are recorded in the
# training checkpoint state dicts, but do not necessarily exist in the parameter object, leading
# to a RuntimeError in `load_state_dict`. Explicitly call `model.cuda()` to make sure the
# matrices are part of model state. This workaround is necessary because the matrices are
# deleted during the model's forward pass.
if self.model.model.device.type == "cuda":
if self.model.config_obj.model_type == MODEL_LLM and self.model.model.device.type == "cuda":
self.model.model.cuda()
elif self.model.config_obj.model_type == MODEL_ECD and self.model.device.type == "cuda":
self.model.cuda()
_, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
only_weights_format_keys = ["weights_format" in k for k in unexpected_keys]

Expand Down
42 changes: 41 additions & 1 deletion ludwig/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Set, Union
from typing import Any, Dict, List, Set, Union

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import (
Expand Down Expand Up @@ -142,3 +142,43 @@ def config_uses_llm(config: Union[Dict[str, Any], ModelConfig]) -> bool:
raise ValueError(f"Invalid config cannot be checked for LLM usage. Config: {config}")

return uses_llm


def get_quantization(config: Union[Dict[str, Any], ModelConfig]) -> Union[int, List[int], None]:
"""Get the quantization specified in a config at any level.

Args:
config: Ludwig config object or dictionary

Returns:
For LLM models, the value of quantization.bits or None if it is not specified.
For ECD and GBM models, the list of values of quantization.bits for each encoder. If the encoder does not
support quantization or no quantization config is specified, the list entry is None.
"""
if isinstance(config, ModelConfig):
if config.model_type == MODEL_LLM:
return config.quantization.bits if config.quantization else None
else:
quantization_bits = []
for feature in config.input_features:
try:
quantization = feature.encoder.quantization.bits
except AttributeError:
quantization = None
quantization_bits.append(quantization)
return quantization_bits
elif isinstance(config, dict) and config:
if config.get(MODEL_TYPE, MODEL_ECD) == MODEL_LLM:
return config.get("quantization", {}).get("bits")
elif INPUT_FEATURES in config:
quantization_bits = []
for feature in config.get(INPUT_FEATURES, []):
quantization_bits.append(feature.get(ENCODER, {}).get("quantization", {}).get("bits"))
return quantization_bits
else:
raise ValueError(
"Invalid config cannot be checked for quantization because it has no input features."
f"Config: {config}"
)
else:
raise ValueError(f"Invalid config cannot be checked for quantization. Config: {config}")
215 changes: 168 additions & 47 deletions tests/ludwig/utils/test_config_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional
import copy
from typing import Any, Dict, List, Optional, Union

import pytest

Expand All @@ -20,7 +21,7 @@
from ludwig.schema.encoders.utils import get_encoder_cls
from ludwig.schema.features.preprocessing.text import TextPreprocessingConfig
from ludwig.schema.model_config import ModelConfig
from ludwig.utils.config_utils import config_uses_llm
from ludwig.utils.config_utils import config_uses_llm, get_quantization


@pytest.mark.parametrize(
Expand Down Expand Up @@ -84,11 +85,6 @@ def llm_config_dict() -> Dict[str, Any]:
}


@pytest.fixture(scope="module")
def llm_config_object(llm_config_dict: Dict[str, Any]) -> ModelConfig:
return ModelConfig.from_dict(llm_config_dict)


@pytest.fixture(scope="module")
def ecd_config_dict_llm_encoder() -> Dict[str, Any]:
return {
Expand All @@ -104,11 +100,6 @@ def ecd_config_dict_llm_encoder() -> Dict[str, Any]:
}


@pytest.fixture(scope="module")
def ecd_config_object_llm_encoder(ecd_config_dict_llm_encoder: Dict[str, Any]) -> ModelConfig:
return ModelConfig.from_dict(ecd_config_dict_llm_encoder)


@pytest.fixture(scope="module")
def ecd_config_dict_llm_encoder_multiple_features() -> Dict[str, Any]:
return {
Expand All @@ -125,13 +116,6 @@ def ecd_config_dict_llm_encoder_multiple_features() -> Dict[str, Any]:
}


@pytest.fixture(scope="module")
def ecd_config_object_llm_encoder_multiple_features(
ecd_config_dict_llm_encoder_multiple_features: Dict[str, Any]
) -> ModelConfig:
return ModelConfig.from_dict(ecd_config_dict_llm_encoder_multiple_features)


@pytest.fixture(scope="module")
def ecd_config_dict_no_llm_encoder() -> Dict[str, Any]:
return {
Expand All @@ -141,11 +125,6 @@ def ecd_config_dict_no_llm_encoder() -> Dict[str, Any]:
}


@pytest.fixture(scope="module")
def ecd_config_object_no_llm_encoder(ecd_config_dict_no_llm_encoder: Dict[str, Any]) -> ModelConfig:
return ModelConfig.from_dict(ecd_config_dict_no_llm_encoder)


@pytest.fixture(scope="module")
def ecd_config_dict_no_text_features() -> Dict[str, Any]:
return {
Expand All @@ -155,11 +134,6 @@ def ecd_config_dict_no_text_features() -> Dict[str, Any]:
}


@pytest.fixture(scope="module")
def ecd_config_object_no_text_features(ecd_config_dict_no_text_features: Dict[str, Any]) -> ModelConfig:
return ModelConfig.from_dict(ecd_config_dict_no_text_features)


@pytest.fixture(scope="module")
def gbm_config_dict() -> Dict[str, Any]:
return {
Expand All @@ -169,11 +143,6 @@ def gbm_config_dict() -> Dict[str, Any]:
}


@pytest.fixture(scope="module")
def gbm_config_object(gbm_config_dict: Dict[str, Any]) -> ModelConfig:
return ModelConfig.from_dict(gbm_config_dict)


@pytest.fixture(scope="module")
def gbm_config_dict_no_text_features() -> Dict[str, Any]:
return {
Expand All @@ -183,38 +152,27 @@ def gbm_config_dict_no_text_features() -> Dict[str, Any]:
}


@pytest.fixture(scope="module")
def gbm_config_object_no_text_features(gbm_config_dict_no_text_features: Dict[str, Any]) -> ModelConfig:
return ModelConfig.from_dict(gbm_config_dict_no_text_features)


@pytest.mark.parametrize(
"config,expectation",
[
# LLM configurations
("llm_config_dict", True),
("llm_config_object", True),
# LLM encoder configurations
("ecd_config_dict_llm_encoder", True),
("ecd_config_object_llm_encoder", True),
# LLM encoder configurations, multiple features
("ecd_config_dict_llm_encoder_multiple_features", True),
("ecd_config_object_llm_encoder_multiple_features", True),
# ECD configuration with text feature and non-LLM encoder
("ecd_config_dict_no_llm_encoder", False),
("ecd_config_object_no_llm_encoder", False),
# ECD configuration with no text features
("ecd_config_dict_no_text_features", False),
("ecd_config_object_no_text_features", False),
# GBM configuration with text feature. "tf_idf" is the only valid text encoder
("gbm_config_dict", False),
("gbm_config_object", False),
# GBM configuration with no text features
("gbm_config_dict_no_text_features", False),
("gbm_config_object_no_text_features", False),
],
)
def test_is_or_uses_llm(config, expectation, request):
@pytest.mark.parametrize("config_type", ["dict", "object"])
def test_is_or_uses_llm(config: Dict[str, Any], expectation: bool, config_type, request):
"""Test LLM detection on a variety of configs. Configs that use an LLM anywhere should return True, otherwise
False.

Expand All @@ -224,6 +182,8 @@ def test_is_or_uses_llm(config, expectation, request):
request: pytest `request` fixture
"""
config = request.getfixturevalue(config)
if config_type == "object":
config = ModelConfig.from_dict(config)
assert config_uses_llm(config) == expectation


Expand All @@ -238,3 +198,164 @@ def test_is_or_uses_llm_invalid_input(invalid_config):
"""
with pytest.raises(ValueError):
config_uses_llm(invalid_config)


@pytest.fixture(scope="module")
def quantization_4bit_config() -> Dict[str, Any]:
return {"quantization": {"bits": 4}}


@pytest.fixture(scope="module")
def quantization_8bit_config() -> Dict[str, Any]:
return {"quantization": {"bits": 8}}


@pytest.fixture(scope="module")
def llm_config_dict_4bit(llm_config_dict: Dict[str, Any], quantization_4bit_config: Dict[str, Any]) -> Dict[str, Any]:
config = copy.deepcopy(llm_config_dict)
config.update(quantization_4bit_config)
return config


@pytest.fixture(scope="module")
def llm_config_dict_8bit(llm_config_dict: Dict[str, Any], quantization_8bit_config: Dict[str, Any]) -> Dict[str, Any]:
config = copy.deepcopy(llm_config_dict)
config.update(quantization_8bit_config)
return config


@pytest.fixture(scope="module")
def ecd_config_dict_llm_encoder_4bit(
ecd_config_dict_llm_encoder: Dict[str, Any], quantization_4bit_config: Dict[str, Any]
) -> Dict[str, Any]:
config = copy.deepcopy(ecd_config_dict_llm_encoder)
config[INPUT_FEATURES][0][ENCODER].update(quantization_4bit_config)
return config


@pytest.fixture(scope="module")
def ecd_config_dict_llm_encoder_8bit(
ecd_config_dict_llm_encoder: Dict[str, Any], quantization_8bit_config: Dict[str, Any]
) -> Dict[str, Any]:
config = copy.deepcopy(ecd_config_dict_llm_encoder)
config[INPUT_FEATURES][0][ENCODER].update(quantization_8bit_config)
return config


@pytest.mark.parametrize(
"config,expectation",
[
# LLM configurations
("llm_config_dict", None),
("llm_config_dict_4bit", 4),
("llm_config_dict_8bit", 8),
# LLM encoder configurations with one feature
("ecd_config_dict_llm_encoder", [None]),
("ecd_config_dict_llm_encoder_4bit", [4]),
("ecd_config_dict_llm_encoder_8bit", [8]),
# GBM configuration with text feature. "tf_idf" is the only valid text encoder
("gbm_config_dict", [None]),
# GBM configuration with no text features
("gbm_config_dict_no_text_features", [None]),
],
)
@pytest.mark.parametrize("config_type", ["dict", "object"])
def test_get_quantization(
config: Dict[str, Any], expectation: Union[int, List[int], None, List[None]], config_type: str, request
):
"""Test get_quantization with LLM and single-feature ECD/GBM configs.

Args:
config: The configuration to test
expectation: The expected quantization
config_type: Whether to test the config as a dict or object
request: pytest builtin fixture
"""
config = request.getfixturevalue(config)
if config_type == "object":
config = ModelConfig.from_dict(config)
assert get_quantization(config) == expectation


TEST_FEATURE_CONFIGS = [
(
{
TYPE: BINARY,
},
None,
),
(
{
TYPE: TEXT,
},
None,
),
({TYPE: TEXT, ENCODER: {TYPE: MODEL_LLM, BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM"}}, None),
(
{
TYPE: TEXT,
ENCODER: {
TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
"quantization": {"bits": 4},
},
},
4,
),
(
{
TYPE: TEXT,
ENCODER: {
TYPE: MODEL_LLM,
BASE_MODEL: "HuggingFaceH4/tiny-random-LlamaForCausalLM",
"quantization": {"bits": 8},
},
},
8,
),
]

TEST_FEATURE_CONFIGS_IDS = [BINARY, TEXT, MODEL_LLM, f"{MODEL_LLM}-4bit", f"{MODEL_LLM}-8bit"]


@pytest.mark.parametrize("feature1,quantization1", TEST_FEATURE_CONFIGS, ids=TEST_FEATURE_CONFIGS_IDS)
@pytest.mark.parametrize("feature2,quantization2", TEST_FEATURE_CONFIGS, ids=TEST_FEATURE_CONFIGS_IDS)
@pytest.mark.parametrize("config_type", ["dict", "object"])
def test_get_quantization_multiple_features(
ecd_config_dict_llm_encoder_multiple_features: Dict[str, Any],
feature1: Dict[str, Any],
quantization1: int,
feature2: Dict[str, Any],
quantization2: int,
config_type: str,
):
"""Test get_quantization with multiple features.

Args:
ecd_config_dict_llm_encoder_multiple_features: Baseline config to add features to.
feature1: First input feature config dict
quantization1: First input feature expected quantization
feature2: Second input feature config dict
quantization2: Second input feature expected quantization
config_type: Whether to test the config as a dict or object
"""
config = copy.deepcopy(ecd_config_dict_llm_encoder_multiple_features)
feature1 = dict(name="in1", **feature1)
feature2 = dict(name="in2", **feature2)
config[INPUT_FEATURES] = [feature1, feature2]

if config_type == "object":
config = ModelConfig.from_dict(config)

assert get_quantization(config) == [quantization1, quantization2]


@pytest.mark.parametrize("invalid_config", [1, 1.0, "foo", True, False, None, [], {}, {"foo": "bar"}])
def test_get_quantization_invalid_input(invalid_config):
"""Test get_quantization with invalid configs. These should always raise a ValueError.

Args:
invalid_config: The invalid config to test
"""
with pytest.raises(ValueError):
get_quantization(invalid_config)
Loading