Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform
from torchtune.recipe_interfaces import EvalRecipeInterface
from torchtune.training import FullModelTorchTuneCheckpointer

try:
import lm_eval
Expand Down Expand Up @@ -475,28 +476,39 @@ def setup(self, cfg: DictConfig) -> None:

# Load checkpoint
checkpointer = config.instantiate(cfg.checkpointer)
if quantization_mode is None:
ckpt_dict = checkpointer.load_checkpoint()
else:
# weights_only needs to be False when loading a quantized model
# currently loading a quantized model is only supported with the
# FullModelTorchTuneCheckpointer
ckpt_dict = checkpointer.load_checkpoint(weights_only=False)

# Initialize model
with training.set_default_dtype(self.dtype), self.device:
model = config.instantiate(cfg.model)

# Quantize model if requested
if quantization_mode is not None:
if not isinstance(checkpointer, FullModelTorchTuneCheckpointer):
raise ValueError(
"Quantization is only supported for models quantized and saved with the "
"FullModelTorchTuneCheckpointer - please ensure you have quantized your "
"model and are using the quantized weights!"
)
if "qat" in quantization_mode:
raise ValueError(
"You have specified a quantizer with 'QAT' - "
"QAT quantizers should only be used during quantization aware training "
"and when quantizing models. Please use the corresponding post-training "
"quantizer e.g. Int8DynActInt4WeightQuantizer for Int8DynActInt4WeightQATQuantizer."
)
model = quantizer.quantize(model)
model = model.to(device=self.device, dtype=self.dtype)
for k, v in model_state_dict.items():
model_state_dict[k] = v.to(self._device)
model.load_state_dict(model_state_dict, assign=True)
ckpt_dict = checkpointer.load_checkpoint(weights_only=False)[
training.MODEL_KEY
]
for k, v in ckpt_dict.items():
ckpt_dict[k] = v.to(self.device)
model.load_state_dict(ckpt_dict, assign=True)
else:
ckpt_dict = checkpointer.load_checkpoint()[training.MODEL_KEY]
model.load_state_dict(ckpt_dict)

# Load model weights into initialized model
model.load_state_dict(ckpt_dict[training.MODEL_KEY])
self.logger.info(f"Model is initialized with precision {self.dtype}.")

# Put model in eval mode.
Expand Down
76 changes: 75 additions & 1 deletion tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest

from tests.common import TUNE_PATH
from tests.recipes.utils import llama2_test_config
from tests.recipes.utils import llama2_test_config, write_hf_ckpt_config
from tests.test_utils import CKPT_MODEL_PATHS


Expand Down Expand Up @@ -126,6 +126,80 @@ def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir):
in printed_err
)

@pytest.mark.integration_test
def test_eval_recipe_errors_with_quantization_hf_checkpointer(
self, capsys, monkeypatch, tmpdir
):
ckpt = "llama2_hf"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent

# Config file needed for model conversion.
write_hf_ckpt_config(ckpt_dir)

cmd = f"""
tune run eleuther_eval \
--config eleuther_evaluation \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelHFCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
limit=1 \
dtype=fp32 \
device=cpu \
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
quantizer.groupsize=256 \
""".split()

model_config = llama2_test_config()
cmd = cmd + model_config

monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(
ValueError,
match="Quantization is only supported for models quantized and saved with the "
"FullModelTorchTuneCheckpointer",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent

cmd = f"""
tune run eleuther_eval \
--config eleuther_evaluation \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
limit=1 \
dtype=fp32 \
device=cpu \
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \
quantizer.groupsize=32\
""".split()

model_config = llama2_test_config()
cmd = cmd + model_config

monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(
ValueError,
match="QAT quantizers should only be used during quantization aware training",
):
runpy.run_path(TUNE_PATH, run_name="__main__")

@pytest.mark.integration_test
def test_eval_recipe_errors_with_generate_until_and_mc_tasks(
self, caplog, capsys, monkeypatch, tmpdir
Expand Down
Loading