Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- - Ensure `TabPFNValidationError` wraps both custom and sklearn's validate_data() errors ([#732](https://github.com/PriorLabs/TabPFN/pull/732))
- Refactor of model encoder. Move imports from `tabpfn.architectures.base.encoders` to `tabpfn.architectures.encoders` ([#733](https://github.com/PriorLabs/TabPFN/pull/733))
- Renamed the estimator's `preprocessor_` attribute to `ordinal_encoder_` ([#756](https://github.com/PriorLabs/TabPFN/pull/756))
- Pass through kwargs in `FinetunedTabPFNClassifier` and `FinetunedTabPFNRegressor` predict and predict_proba methods to allow additional options like `output_type='full'` ([#772](https://github.com/PriorLabs/TabPFN/pull/772))


## [6.3.1] - 2026-01-14
Expand Down
1 change: 1 addition & 0 deletions changelog/772.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Pass through kwargs in FinetunedTabPFNClassifier and FinetunedTabPFNRegressor predict and predict_proba methods to allow additional options like output_type='full'
12 changes: 8 additions & 4 deletions src/tabpfn/finetuning/finetuned_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,30 +398,34 @@ def fit(
super().fit(X, y, X_val=X_val, y_val=y_val, output_dir=output_dir)
return self

def predict_proba(self, X: XType) -> np.ndarray:
def predict_proba(self, X: XType, **kwargs) -> np.ndarray:
"""Predict class probabilities for X.

Args:
X: The input samples of shape (n_samples, n_features).
**kwargs: Additional keyword arguments to pass to the underlying
inference classifier.

Returns:
The class probabilities of the input samples with shape
(n_samples, n_classes).
"""
check_is_fitted(self)

return self.finetuned_inference_classifier_.predict_proba(X) # type: ignore
return self.finetuned_inference_classifier_.predict_proba(X, **kwargs) # type: ignore

@override
def predict(self, X: XType) -> np.ndarray:
def predict(self, X: XType, **kwargs) -> np.ndarray:
"""Predict the class for X.

Args:
X: The input samples of shape (n_samples, n_features).
**kwargs: Additional keyword arguments to pass to the underlying
inference classifier.

Returns:
The predicted classes with shape (n_samples,).
"""
check_is_fitted(self)

return self.finetuned_inference_classifier_.predict(X) # type: ignore
return self.finetuned_inference_classifier_.predict(X, **kwargs) # type: ignore
7 changes: 5 additions & 2 deletions src/tabpfn/finetuning/finetuned_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
if TYPE_CHECKING:
from tabpfn.constants import XType, YType
from tabpfn.finetuning.data_util import RegressorBatch
from tabpfn.regressor import RegressionResultType


def _compute_regression_loss( # noqa: C901
Expand Down Expand Up @@ -567,15 +568,17 @@ def fit(
return self

@override
def predict(self, X: XType) -> np.ndarray:
def predict(self, X: XType, **kwargs) -> RegressionResultType:
"""Predict target values for X.

Args:
X: The input samples of shape (n_samples, n_features).
**kwargs: Additional keyword arguments to pass to the underlying
inference regressor (e.g., output_type, quantiles).

Returns:
The predicted target values with shape (n_samples,).
"""
check_is_fitted(self)

return self.finetuned_inference_regressor_.predict(X)
return self.finetuned_inference_regressor_.predict(X, **kwargs)
Loading