Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/tabpfn/architectures/base/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,14 @@ def forward( # noqa: PLR0912, C901
)

# out: s b e
train_encoder_out = encoder_out[:, :single_eval_pos, -1].transpose(0, 1)
thinking_rows_offset = (
self.add_thinking_tokens.num_thinking_rows
if self.add_thinking_tokens is not None
else 0
)
train_encoder_out = encoder_out[
:, thinking_rows_offset:single_eval_pos, -1
].transpose(0, 1)
output_decoded["train_embeddings"] = train_encoder_out
output_decoded["test_embeddings"] = test_encoder_out

Expand Down
5 changes: 2 additions & 3 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from tabpfn.preprocessors.preprocessing_helpers import get_ordinal_encoder
from tabpfn.utils import (
DevicesSpecification,
balance_probas_by_class_counts,
fix_dtypes,
get_embeddings,
infer_categorical_features,
Expand Down Expand Up @@ -1141,9 +1142,7 @@ def _apply_softmax(self, logits: torch.Tensor) -> torch.Tensor:

def _apply_balancing(self, probas: torch.Tensor) -> torch.Tensor:
"""Applies class balancing to a probability tensor."""
class_prob_in_train = self.class_counts_ / self.class_counts_.sum()
balanced_probas = probas / torch.Tensor(class_prob_in_train).to(probas.device)
return balanced_probas / balanced_probas.sum(dim=-1, keepdim=True)
return balance_probas_by_class_counts(probas, self.class_counts_)

def logits_to_probabilities(
self,
Expand Down
20 changes: 20 additions & 0 deletions src/tabpfn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,3 +987,23 @@ def meta_dataset_collator(batch: list, padding_val: float = 0.0) -> tuple:
items_list.append([batch[r][item_idx] for r in range(batch_sz)])

return tuple(items_list)


def balance_probas_by_class_counts(
probas: torch.Tensor,
class_counts: np.ndarray,
) -> torch.Tensor:
"""Balance probabilities by class counts.

Args:
probas: The probabilities to balance.
class_counts: The class counts to use for balancing.

Returns:
The balanced probabilities.
"""
class_prob_in_train = class_counts / class_counts.sum()
balanced_probas = probas / torch.from_numpy(class_prob_in_train).float().to(
probas.device
)
return balanced_probas / balanced_probas.sum(dim=-1, keepdim=True)
52 changes: 28 additions & 24 deletions tests/test_classifier_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,27 +504,39 @@ def test_sklearn_compatible_estimator(
check(estimator)


def test_balanced_probabilities(X_y: tuple[np.ndarray, np.ndarray]) -> None:
def test_balanced_probabilities() -> None:
"""Test that balance_probabilities=True works correctly."""
X, y = X_y
n_classes = 2
n_features = 3

model = TabPFNClassifier(
balance_probabilities=True,
# Create an IMBALANCED dataset
X, y = sklearn.datasets.make_classification(
n_samples=60,
n_classes=n_classes,
n_features=n_features,
n_informative=n_features,
n_redundant=0,
weights=[0.7, 0.3], # Imbalanced classes
random_state=42,
)

model.fit(X, y)
probabilities = model.predict_proba(X)
model_unbalanced = TabPFNClassifier(balance_probabilities=False, random_state=42)
model_unbalanced.fit(X, y)
proba_unbalanced = model_unbalanced.predict_proba(X)

assert np.allclose(probabilities.sum(axis=1), 1.0)
model_balanced = TabPFNClassifier(balance_probabilities=True, random_state=42)
model_balanced.fit(X, y)
proba_balanced = model_balanced.predict_proba(X)

# Check that the mean probability for each class is roughly equal
mean_probs = probabilities.mean(axis=0)
expected_mean = 1.0 / len(np.unique(y))
assert np.allclose(
mean_probs,
expected_mean,
rtol=0.1,
), "Class probabilities are not properly balanced"
mean_proba_unbalanced = proba_unbalanced.mean(axis=0)
mean_proba_balanced = proba_balanced.mean(axis=0)

# Balanced should be MORE uniform than unbalanced
balanced_deviation = np.std(mean_proba_balanced)
unbalanced_deviation = np.std(mean_proba_unbalanced)
assert balanced_deviation < unbalanced_deviation, (
"Balancing did not make probabilities more uniform"
)


def test_classifier_in_pipeline(X_y: tuple[np.ndarray, np.ndarray]) -> None:
Expand All @@ -549,15 +561,7 @@ def test_classifier_in_pipeline(X_y: tuple[np.ndarray, np.ndarray]) -> None:

# Check that probabilities sum to 1 for each prediction
assert np.allclose(probabilities.sum(axis=1), 1.0)

# Check that the mean probability for each class is roughly equal
mean_probs = probabilities.mean(axis=0)
expected_mean = 1.0 / len(np.unique(y))
assert np.allclose(
mean_probs,
expected_mean,
rtol=0.1,
), "Class probabilities are not properly balanced in pipeline"
assert probabilities.shape == (X.shape[0], len(np.unique(y)))


def test_dict_vs_object_preprocessor_config(X_y: tuple[np.ndarray, np.ndarray]) -> None:
Expand Down
8 changes: 2 additions & 6 deletions tests/test_regressor_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,16 +422,12 @@ def test_get_embeddings(X_y: tuple[np.ndarray, np.ndarray], data_source: str) ->

# Need to access the model through the executor
model_instances = typing.cast(typing.Any, model.executor_).models
encoder_shape = next(
m.out_features
for m in model_instances[0].encoder.modules()
if isinstance(m, nn.Linear)
)
hidden_size = model_instances[0].ninp

assert isinstance(embeddings, np.ndarray)
assert embeddings.shape[0] == n_estimators
assert embeddings.shape[1] == X.shape[0]
assert embeddings.shape[2] == encoder_shape
assert embeddings.shape[2] == hidden_size


def test_overflow_bug_does_not_occur():
Expand Down
16 changes: 16 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tabpfn.inference_config import InferenceConfig
from tabpfn.preprocessors.preprocessing_helpers import get_ordinal_encoder
from tabpfn.utils import (
balance_probas_by_class_counts,
fix_dtypes,
get_total_memory_windows,
infer_categorical_features,
Expand Down Expand Up @@ -372,3 +373,18 @@ def test_process_text_na_dataframe(prepared_tabpfn_data):
assert len(np.unique(output_col[~pd.isna(output_col)])) == len(
np.unique(gt_col[~pd.isna(gt_col)])
)


def test_balance_probas_by_class_counts():
"""Test balancing probabilities by class counts."""
probas = torch.tensor([[0.2, 0.8], [0.6, 0.4], [0.5, 0.5]])
class_counts = np.array([1, 2])

balanced = balance_probas_by_class_counts(probas, class_counts)

# Check that each row sums to one
sums = balanced.sum(dim=-1)
assert torch.allclose(sums, torch.ones(3), rtol=1e-5, atol=1e-5)

expected_balanced = torch.tensor([[1 / 3, 2 / 3], [0.75, 0.25], [2 / 3, 1 / 3]])
assert torch.allclose(balanced, expected_balanced, rtol=1e-4, atol=1e-4)
Loading