diff --git a/CHANGELOG.md b/CHANGELOG.md index 6707ada1b..aa3c421ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - - Ensure `TabPFNValidationError` wraps both custom and sklearn's validate_data() errors ([#732](https://github.com/PriorLabs/TabPFN/pull/732)) - Refactor of model encoder. Move imports from `tabpfn.architectures.base.encoders` to `tabpfn.architectures.encoders` ([#733](https://github.com/PriorLabs/TabPFN/pull/733)) - Renamed the estimator's `preprocessor_` attribute to `ordinal_encoder_` ([#756](https://github.com/PriorLabs/TabPFN/pull/756)) +- Pass through kwargs in `FinetunedTabPFNClassifier` and `FinetunedTabPFNRegressor` predict and predict_proba methods to allow additional options like `output_type='full'` ([#772](https://github.com/PriorLabs/TabPFN/pull/772)) ## [6.3.1] - 2026-01-14 diff --git a/changelog/772.added.md b/changelog/772.added.md new file mode 100644 index 000000000..f5fb26a1b --- /dev/null +++ b/changelog/772.added.md @@ -0,0 +1 @@ +Pass through kwargs in FinetunedTabPFNClassifier and FinetunedTabPFNRegressor predict and predict_proba methods to allow additional options like output_type='full' \ No newline at end of file diff --git a/src/tabpfn/finetuning/finetuned_classifier.py b/src/tabpfn/finetuning/finetuned_classifier.py index 7eda0a6fe..ee1432f0a 100644 --- a/src/tabpfn/finetuning/finetuned_classifier.py +++ b/src/tabpfn/finetuning/finetuned_classifier.py @@ -398,11 +398,13 @@ def fit( super().fit(X, y, X_val=X_val, y_val=y_val, output_dir=output_dir) return self - def predict_proba(self, X: XType) -> np.ndarray: + def predict_proba(self, X: XType, **kwargs) -> np.ndarray: """Predict class probabilities for X. Args: X: The input samples of shape (n_samples, n_features). + **kwargs: Additional keyword arguments to pass to the underlying + inference classifier. Returns: The class probabilities of the input samples with shape @@ -410,18 +412,20 @@ def predict_proba(self, X: XType) -> np.ndarray: """ check_is_fitted(self) - return self.finetuned_inference_classifier_.predict_proba(X) # type: ignore + return self.finetuned_inference_classifier_.predict_proba(X, **kwargs) # type: ignore @override - def predict(self, X: XType) -> np.ndarray: + def predict(self, X: XType, **kwargs) -> np.ndarray: """Predict the class for X. Args: X: The input samples of shape (n_samples, n_features). + **kwargs: Additional keyword arguments to pass to the underlying + inference classifier. Returns: The predicted classes with shape (n_samples,). """ check_is_fitted(self) - return self.finetuned_inference_classifier_.predict(X) # type: ignore + return self.finetuned_inference_classifier_.predict(X, **kwargs) # type: ignore diff --git a/src/tabpfn/finetuning/finetuned_regressor.py b/src/tabpfn/finetuning/finetuned_regressor.py index 95832fede..f27cecdd0 100644 --- a/src/tabpfn/finetuning/finetuned_regressor.py +++ b/src/tabpfn/finetuning/finetuned_regressor.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from tabpfn.constants import XType, YType from tabpfn.finetuning.data_util import RegressorBatch + from tabpfn.regressor import RegressionResultType def _compute_regression_loss( # noqa: C901 @@ -567,15 +568,17 @@ def fit( return self @override - def predict(self, X: XType) -> np.ndarray: + def predict(self, X: XType, **kwargs) -> RegressionResultType: """Predict target values for X. Args: X: The input samples of shape (n_samples, n_features). + **kwargs: Additional keyword arguments to pass to the underlying + inference regressor (e.g., output_type, quantiles). Returns: The predicted target values with shape (n_samples,). """ check_is_fitted(self) - return self.finetuned_inference_regressor_.predict(X) + return self.finetuned_inference_regressor_.predict(X, **kwargs)