Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 59 additions & 9 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,13 @@ def fit(self, X: XType, y: YType) -> Self:

return self

def _raw_predict(self, X: XType, *, return_logits: bool) -> torch.Tensor:
def _raw_predict(
self,
X: XType,
*,
return_logits: bool,
return_raw_logits: bool = False,
) -> torch.Tensor:
"""Internal method to run prediction.

Handles input validation, preprocessing, and the forward pass.
Expand All @@ -685,13 +691,15 @@ def _raw_predict(self, X: XType, *, return_logits: bool) -> torch.Tensor:

Args:
X: The input data for prediction.
return_logits: If True, the raw logits are returned. Otherwise,
return_logits: If True, the logits are returned. Otherwise,
probabilities are returned after softmax and other
post-processing steps.
return_raw_logits: If True, returns the raw logits without
averaging estimators or temperature scaling.

Returns:
The raw torch.Tensor output, either logits or probabilities,
depending on `return_logits`.
depending on `return_logits` and `return_raw_logits`.
"""
check_is_fitted(self)

Expand All @@ -700,7 +708,12 @@ def _raw_predict(self, X: XType, *, return_logits: bool) -> torch.Tensor:
X = fix_dtypes(X, cat_indices=self.inferred_categorical_indices_)
X = process_text_na_dataframe(X, ord_encoder=self.preprocessor_)

return self.forward(X, use_inference_mode=True, return_logits=return_logits)
return self.forward(
X,
use_inference_mode=True,
return_logits=return_logits,
return_raw_logits=return_raw_logits,
)

@track_model_call(model_method="predict", param_names=["X"])
def predict(self, X: XType) -> np.ndarray:
Expand Down Expand Up @@ -737,6 +750,30 @@ def predict_logits(self, X: XType) -> np.ndarray:
logits_tensor = self._raw_predict(X, return_logits=True)
return logits_tensor.float().detach().cpu().numpy()

@config_context(transform_output="default")
@track_model_call(model_method="predict", param_names=["X"])
def predict_raw_logits(self, X: XType) -> np.ndarray:
"""Predict the raw logits for the provided input samples.

Logits represent the unnormalized log-probabilities of the classes
before the softmax activation function is applied. In contrast to the
`predict_logits` method, this method returns the raw logits for each
estimator, without averaging estimators or temperature scaling.

Args:
X: The input data for prediction.

Returns:
An array of predicted logits for each estimator,
Shape (n_estimators, n_samples, n_classes).
"""
logits_tensor = self._raw_predict(
X,
return_logits=False,
return_raw_logits=True,
)
return logits_tensor.float().detach().cpu().numpy()

@track_model_call(model_method="predict", param_names=["X"])
def predict_proba(self, X: XType) -> np.ndarray:
"""Predict the probabilities of the classes for the provided input samples.
Expand Down Expand Up @@ -800,6 +837,7 @@ def forward( # noqa: C901, PLR0912
*,
use_inference_mode: bool = False,
return_logits: bool = False,
return_raw_logits: bool = False,
) -> torch.Tensor:
"""Forward pass returning predicted probabilities or logits
for TabPFNClassifier Inference Engine. Used in
Expand All @@ -814,14 +852,23 @@ def forward( # noqa: C901, PLR0912
use_inference_mode: Flag for inference mode., default at False since
it is called within predict. During FineTuning forward() is called
directly by user, so default should be False here.
return_logits: If True, returns raw logits. Otherwise, probabilities.
return_logits: If True, returns logits averaged across estimators.
Otherwise, probabilities are returned.
return_raw_logits: If True, returns the raw logits, without
averaging estimators or temperature scaling.

Returns:
The predicted probabilities or logits of the classes as a torch.Tensor.
- If `use_inference_mode` is True: Shape (N_samples, N_classes)
- If `use_inference_mode` is False (e.g., for training/fine-tuning):
Shape (Batch_size, N_classes, N_samples), suitable for NLLLoss.
- If `return_raw_logits` is True: Shape (n_estimators, n_samples, n_classes)
"""
if return_logits and return_raw_logits:
raise ValueError(
"Cannot return both logits and raw logits. Please specify only one."
)

# Scenario 1: Standard inference path
is_standard_inference = use_inference_mode and not isinstance(
self.executor_, InferenceEngineBatchedNoPreprocessing
Expand Down Expand Up @@ -905,13 +952,16 @@ def forward( # noqa: C901, PLR0912
stacked_outputs = torch.stack(outputs)

# --- Build the processing pipeline by composing the steps in order ---
# The first step is always to apply the temperature scaling.
pipeline = [self._apply_temperature]
pipeline = []

if return_logits:
# For logits, we just average the temperature-scaled logits.
pipeline.append(self._average_across_estimators)
pipeline.extend([self._apply_temperature, self._average_across_estimators])
elif return_raw_logits:
pass # no post-processing
else:
pipeline.append(self._apply_temperature)

# For probabilities, the order of averaging and softmax is crucial.
if self.average_before_softmax:
pipeline.extend([self._average_across_estimators, self._apply_softmax])
Expand All @@ -931,7 +981,7 @@ def forward( # noqa: C901, PLR0912

# --- Final output shaping ---
if output.ndim > 2 and use_inference_mode:
output = output.squeeze(1)
output = output.squeeze(1) if not return_raw_logits else output.squeeze(2)

if not use_inference_mode:
# This case is primarily for fine-tuning where NLLLoss expects [B, C, N]
Expand Down
31 changes: 31 additions & 0 deletions tests/test_classifier_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,37 @@ def test_predict_logits_and_consistency(
assert log_loss(y, proba_from_predict_proba) < 5.0


@pytest.mark.parametrize(("n_estimators"), [1, 2])
def test_predict_raw_logits(
X_y: tuple[np.ndarray, np.ndarray],
n_estimators: int,
):
"""Tests the predict_raw_logits method."""
X, y = X_y

# Ensure y is int64 for consistency with classification tasks
y = y.astype(np.int64)

classifier = TabPFNClassifier(
n_estimators=n_estimators,
random_state=42,
)
classifier.fit(X, y)

logits = classifier.predict_raw_logits(X)
assert logits.shape[0] == n_estimators
assert isinstance(logits, np.ndarray)
assert logits.shape == (n_estimators, X.shape[0], classifier.n_classes_)
assert logits.dtype == np.float32
assert not np.isnan(logits).any()
assert not np.isinf(logits).any()
if classifier.n_classes_ > 1:
assert not np.all(logits == logits[:, 0:1]), (
"Logits are identical across classes for all samples, indicating "
"trivial output."
)


def test_softmax_temperature_impact_on_logits_magnitude(
X_y: tuple[np.ndarray, np.ndarray],
):
Expand Down