Skip to content

Commit 7cfc8ae

Browse files
committed
Add IndependentSpeculator w/ unit tests
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
1 parent e49fcaa commit 7cfc8ae

File tree

4 files changed

+667
-12
lines changed

4 files changed

+667
-12
lines changed

src/speculators/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from .eagle import EagleSpeculator, EagleSpeculatorConfig
2-
from .independent import IndependentSpeculatorConfig
2+
from .independent import IndependentSpeculator, IndependentSpeculatorConfig
33
from .mlp import MLPSpeculatorConfig
44

55
__all__ = [
66
"EagleSpeculator",
77
"EagleSpeculatorConfig",
8+
"IndependentSpeculator",
89
"IndependentSpeculatorConfig",
910
"MLPSpeculatorConfig",
1011
]
Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1-
from transformers import PretrainedConfig
1+
import os
2+
from typing import ClassVar, Literal, Optional, Union
3+
4+
import torch
5+
from pydantic.fields import Field
6+
from transformers import PretrainedConfig, PreTrainedModel
7+
from transformers.modeling_outputs import CausalLMOutputWithPast
28

39
from speculators import SpeculatorModelConfig, SpeculatorsConfig
10+
from speculators.model import SpeculatorModel
411

512
__all__ = ["IndependentSpeculatorConfig"]
613

@@ -16,16 +23,76 @@ def from_pretrained_config(
1623

1724
return cls(**pretrained_dict, speculators_config=speculators_config)
1825

19-
speculators_model_type: str = "independent"
26+
speculators_model_type: Literal["independent"] = "independent"
27+
architectures: list[str] = Field(
28+
default_factory=lambda: ["LlamaForCausalLM"],
29+
description=("List of model architectures that can be used with the model "),
30+
)
31+
draft_model: str = Field(
32+
default="",
33+
description=(
34+
"The name or path to the draft model to use for the speculator. "
35+
"Must be a model that is compatible with the speculator."
36+
),
37+
)
38+
39+
40+
@SpeculatorModel.register("independent")
41+
class IndependentSpeculator(SpeculatorModel):
42+
config_class: ClassVar[type[IndependentSpeculatorConfig]] = ( # type: ignore[misc]
43+
IndependentSpeculatorConfig
44+
)
45+
_keys_to_ignore_on_load_missing: ClassVar[list[str]] = [ # type: ignore[misc]
46+
"verifier*",
47+
]
48+
_keys_to_ignore_on_save: ClassVar[list[str]] = [ # type: ignore[assignment,misc]
49+
"verifier*",
50+
]
51+
52+
def __init__(
53+
self,
54+
config: IndependentSpeculatorConfig,
55+
verifier: Optional[Union[str, os.PathLike, PreTrainedModel]] = None,
56+
verifier_attachment_mode: Optional[
57+
Literal["detached", "full", "train_only"]
58+
] = None,
59+
):
60+
if not isinstance(config, IndependentSpeculatorConfig):
61+
raise ValueError(
62+
"config must be an instance of IndependentSpeculatorConfig, "
63+
f"got {type(config)} instead."
64+
)
2065

21-
def __init__(self, **kwargs):
22-
super().__init__(**kwargs)
66+
super().__init__(
67+
config=config,
68+
verifier=verifier,
69+
verifier_attachment_mode=verifier_attachment_mode,
70+
)
2371

24-
# ensure we set the model_type to the one from the original config
25-
self._model_type = kwargs.get("model_type")
72+
self.draft_model: PreTrainedModel = self.resolve_verifier(config.draft_model)
73+
self.post_init()
2674

27-
def to_dict(self):
28-
config_dict = super().to_dict()
29-
config_dict["model_type"] = self._model_type
30-
del config_dict["_model_type"]
31-
return config_dict
75+
def forward(
76+
self,
77+
input_ids: torch.LongTensor,
78+
attention_mask: Optional[torch.Tensor] = None,
79+
position_ids: Optional[torch.LongTensor] = None,
80+
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
81+
use_cache: Optional[bool] = None,
82+
output_attentions: Optional[bool] = None,
83+
output_hidden_states: Optional[bool] = None, # noqa: ARG002
84+
return_dict: Optional[bool] = None,
85+
) -> Union[torch.FloatTensor, CausalLMOutputWithPast]:
86+
return_dict = (
87+
return_dict if return_dict is not None else self.config.use_return_dict
88+
)
89+
return self.draft_model(
90+
input_ids,
91+
attention_mask,
92+
position_ids,
93+
past_key_values,
94+
use_cache,
95+
output_attentions,
96+
output_hidden_states,
97+
return_dict,
98+
)
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
"""
2+
Unit tests for the eagle model module in the Speculators library.
3+
"""
4+
5+
import pytest
6+
from pydantic import BaseModel, ValidationError
7+
8+
from speculators import (
9+
SpeculatorModelConfig,
10+
SpeculatorsConfig,
11+
VerifierConfig,
12+
)
13+
from speculators.models import IndependentSpeculatorConfig
14+
from speculators.proposals import GreedyTokenProposalConfig
15+
16+
# ===== Fixtures =====
17+
18+
19+
@pytest.fixture
20+
def sample_verifier_config():
21+
return VerifierConfig(
22+
name_or_path="test/verifier",
23+
architectures=["LlamaForCausalLM"],
24+
)
25+
26+
27+
@pytest.fixture
28+
def sample_token_proposal_config():
29+
return GreedyTokenProposalConfig(
30+
speculative_tokens=5,
31+
verifier_accept_k=1,
32+
accept_tolerance=0.0,
33+
)
34+
35+
36+
@pytest.fixture
37+
def sample_speculators_config(sample_token_proposal_config, sample_verifier_config):
38+
return SpeculatorsConfig(
39+
algorithm="independent",
40+
proposal_methods=[sample_token_proposal_config],
41+
default_proposal_method="greedy",
42+
verifier=sample_verifier_config,
43+
)
44+
45+
46+
@pytest.fixture
47+
def independent_config_dict():
48+
return {
49+
"speculators_model_type": "independent",
50+
"architectures": ["LlamaForCausalLM"],
51+
"draft_model": "test/draft",
52+
"speculators_config": {
53+
"algorithm": "independent",
54+
"proposal_methods": [
55+
{
56+
"proposal_type": "greedy",
57+
"speculative_tokens": 5,
58+
"verifier_accept_k": 1,
59+
"accept_tolerance": 0.0,
60+
}
61+
],
62+
"default_proposal_method": "greedy",
63+
"verifier": {
64+
"name_or_path": "test/verifier",
65+
"architectures": ["LlamaForCausalLM"],
66+
"hidden_size": 768,
67+
"intermediate_size": 3072,
68+
"vocab_size": 32000,
69+
"max_position_embeddings": 2048,
70+
"bos_token_id": 1,
71+
"eos_token_id": 2,
72+
},
73+
},
74+
}
75+
76+
77+
# ===== EagleSpeculatorConfig Tests =====
78+
79+
80+
@pytest.mark.smoke
81+
def test_independent_speculator_config_initialization():
82+
"""Test default initialization of IndependentSpeculatorConfig."""
83+
config = IndependentSpeculatorConfig()
84+
85+
# Verify Independent-specific defaults
86+
assert config.speculators_model_type == "independent"
87+
assert config.architectures == ["LlamaForCausalLM"]
88+
assert config.draft_model == ""
89+
90+
# Verify base class defaults
91+
assert config.model_type == "speculator_model"
92+
assert config.speculators_config is None
93+
94+
95+
@pytest.mark.smoke
96+
def test_independent_speculator_config_custom_initialization(sample_speculators_config):
97+
"""Test custom initialization of IndependentSpeculatorConfig."""
98+
config = IndependentSpeculatorConfig(
99+
architectures=["CustomIndependentSpeculator"],
100+
draft_model="test/draft",
101+
speculators_config=sample_speculators_config,
102+
)
103+
104+
# Verify custom values
105+
assert config.speculators_model_type == "independent"
106+
assert "CustomIndependentSpeculator" in config.architectures
107+
assert config.draft_model == "test/draft"
108+
assert config.speculators_config == sample_speculators_config
109+
110+
111+
@pytest.mark.smoke
112+
def test_independent_speculator_config_base_initialization(sample_speculators_config):
113+
# Create IndependentSpeculatorConfig with custom values
114+
original_config = IndependentSpeculatorConfig(
115+
architectures=["CustomIndependentSpeculator"],
116+
draft_model="test/draft",
117+
speculators_config=sample_speculators_config,
118+
)
119+
120+
# Convert to dict and validate through base class
121+
config_dict = original_config.model_dump()
122+
recreated_config = SpeculatorModelConfig.model_validate(config_dict)
123+
124+
# Verify type and values preservation
125+
assert isinstance(recreated_config, IndependentSpeculatorConfig)
126+
assert recreated_config.speculators_model_type == "independent"
127+
assert "CustomIndependentSpeculator" in recreated_config.architectures
128+
assert recreated_config.draft_model == "test/draft"
129+
assert recreated_config.speculators_config == sample_speculators_config
130+
131+
132+
@pytest.mark.regression
133+
def test_eagle_speculator_config_nested_initialization():
134+
class ParentModel(BaseModel):
135+
single_config: IndependentSpeculatorConfig
136+
config_list: list[IndependentSpeculatorConfig]
137+
config_dict: dict[str, IndependentSpeculatorConfig]
138+
139+
parent = ParentModel(
140+
single_config=IndependentSpeculatorConfig(draft_model="test/draft"),
141+
config_list=[
142+
IndependentSpeculatorConfig(draft_model="test/draft1"),
143+
IndependentSpeculatorConfig(draft_model="test/draft2"),
144+
],
145+
config_dict={
146+
"draft1": IndependentSpeculatorConfig(draft_model="test/draft1"),
147+
"draft2": IndependentSpeculatorConfig(draft_model="test/draft2"),
148+
},
149+
)
150+
151+
# Verify single config
152+
assert isinstance(parent.single_config, IndependentSpeculatorConfig)
153+
assert parent.single_config.draft_model == "test/draft"
154+
155+
# Verify config list
156+
assert len(parent.config_list) == 2
157+
assert all(isinstance(c, IndependentSpeculatorConfig) for c in parent.config_list)
158+
assert parent.config_list[0].draft_model == "test/draft1"
159+
assert parent.config_list[1].draft_model == "test/draft2"
160+
161+
# Verify config dict
162+
assert len(parent.config_dict) == 2
163+
assert all(
164+
isinstance(c, IndependentSpeculatorConfig) for c in parent.config_dict.values()
165+
)
166+
assert parent.config_dict["draft1"].draft_model == "test/draft1"
167+
assert parent.config_dict["draft2"].draft_model == "test/draft2"
168+
169+
170+
@pytest.mark.smoke
171+
def test_independent_speculator_config_invalid_initialization():
172+
# Test invalid speculators_model_type
173+
with pytest.raises(ValidationError) as exc_info:
174+
IndependentSpeculatorConfig(speculators_model_type="invalid") # type: ignore[arg-type]
175+
assert "speculators_model_type" in str(exc_info.value)
176+
177+
# Test invalid architectures type
178+
with pytest.raises(ValidationError) as exc_info:
179+
IndependentSpeculatorConfig(architectures="not_a_list") # type: ignore[arg-type]
180+
assert "architectures" in str(exc_info.value)
181+
182+
# Test invalid draft_model type
183+
with pytest.raises(ValidationError) as exc_info:
184+
IndependentSpeculatorConfig(draft_model=123) # type: ignore[arg-type]
185+
assert "draft_model" in str(exc_info.value)
186+
187+
188+
@pytest.mark.smoke
189+
def test_independent_speculator_config_auto_registry():
190+
registered_classes = SpeculatorModelConfig.registered_classes()
191+
class_names = [cls.__name__ for cls in registered_classes]
192+
193+
# Verify IndependentSpeculatorConfig is registered
194+
assert "IndependentSpeculatorConfig" in class_names
195+
196+
# Verify registry key mapping
197+
assert SpeculatorModelConfig.registry is not None
198+
assert "independent" in SpeculatorModelConfig.registry
199+
assert SpeculatorModelConfig.registry["independent"] == IndependentSpeculatorConfig
200+
201+
202+
@pytest.mark.smoke
203+
def test_independent_speculator_config_marshalling(sample_speculators_config):
204+
original_config = IndependentSpeculatorConfig(
205+
draft_model="test/draft",
206+
speculators_config=sample_speculators_config,
207+
)
208+
209+
# Test model_dump()
210+
config_dict = original_config.model_dump()
211+
assert isinstance(config_dict, dict)
212+
assert config_dict["speculators_model_type"] == "independent"
213+
assert config_dict["draft_model"] == "test/draft"
214+
assert config_dict["speculators_config"] == sample_speculators_config.model_dump()
215+
216+
# Test model_validate() on base class
217+
recreated_base = SpeculatorModelConfig.model_validate(config_dict)
218+
assert isinstance(recreated_base, IndependentSpeculatorConfig)
219+
assert recreated_base.draft_model == "test/draft"
220+
assert recreated_base.speculators_config == sample_speculators_config
221+
222+
# Test model_validate() on derived class
223+
recreated_derived = IndependentSpeculatorConfig.model_validate(config_dict)
224+
assert isinstance(recreated_derived, IndependentSpeculatorConfig)
225+
assert recreated_derived.draft_model == "test/draft"
226+
assert recreated_derived.speculators_config == sample_speculators_config

0 commit comments

Comments
 (0)