diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index 44fb36b7c..548f1d1cb 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -676,7 +676,13 @@ def fit(self, X: XType, y: YType) -> Self: return self - def _raw_predict(self, X: XType, *, return_logits: bool) -> torch.Tensor: + def _raw_predict( + self, + X: XType, + *, + return_logits: bool, + return_raw_logits: bool = False, + ) -> torch.Tensor: """Internal method to run prediction. Handles input validation, preprocessing, and the forward pass. @@ -685,13 +691,15 @@ def _raw_predict(self, X: XType, *, return_logits: bool) -> torch.Tensor: Args: X: The input data for prediction. - return_logits: If True, the raw logits are returned. Otherwise, + return_logits: If True, the logits are returned. Otherwise, probabilities are returned after softmax and other post-processing steps. + return_raw_logits: If True, returns the raw logits without + averaging estimators or temperature scaling. Returns: The raw torch.Tensor output, either logits or probabilities, - depending on `return_logits`. + depending on `return_logits` and `return_raw_logits`. """ check_is_fitted(self) @@ -700,7 +708,12 @@ def _raw_predict(self, X: XType, *, return_logits: bool) -> torch.Tensor: X = fix_dtypes(X, cat_indices=self.inferred_categorical_indices_) X = process_text_na_dataframe(X, ord_encoder=self.preprocessor_) - return self.forward(X, use_inference_mode=True, return_logits=return_logits) + return self.forward( + X, + use_inference_mode=True, + return_logits=return_logits, + return_raw_logits=return_raw_logits, + ) @track_model_call(model_method="predict", param_names=["X"]) def predict(self, X: XType) -> np.ndarray: @@ -737,6 +750,30 @@ def predict_logits(self, X: XType) -> np.ndarray: logits_tensor = self._raw_predict(X, return_logits=True) return logits_tensor.float().detach().cpu().numpy() + @config_context(transform_output="default") + @track_model_call(model_method="predict", param_names=["X"]) + def predict_raw_logits(self, X: XType) -> np.ndarray: + """Predict the raw logits for the provided input samples. + + Logits represent the unnormalized log-probabilities of the classes + before the softmax activation function is applied. In contrast to the + `predict_logits` method, this method returns the raw logits for each + estimator, without averaging estimators or temperature scaling. + + Args: + X: The input data for prediction. + + Returns: + An array of predicted logits for each estimator, + Shape (n_estimators, n_samples, n_classes). + """ + logits_tensor = self._raw_predict( + X, + return_logits=False, + return_raw_logits=True, + ) + return logits_tensor.float().detach().cpu().numpy() + @track_model_call(model_method="predict", param_names=["X"]) def predict_proba(self, X: XType) -> np.ndarray: """Predict the probabilities of the classes for the provided input samples. @@ -800,6 +837,7 @@ def forward( # noqa: C901, PLR0912 *, use_inference_mode: bool = False, return_logits: bool = False, + return_raw_logits: bool = False, ) -> torch.Tensor: """Forward pass returning predicted probabilities or logits for TabPFNClassifier Inference Engine. Used in @@ -814,14 +852,23 @@ def forward( # noqa: C901, PLR0912 use_inference_mode: Flag for inference mode., default at False since it is called within predict. During FineTuning forward() is called directly by user, so default should be False here. - return_logits: If True, returns raw logits. Otherwise, probabilities. + return_logits: If True, returns logits averaged across estimators. + Otherwise, probabilities are returned. + return_raw_logits: If True, returns the raw logits, without + averaging estimators or temperature scaling. Returns: The predicted probabilities or logits of the classes as a torch.Tensor. - If `use_inference_mode` is True: Shape (N_samples, N_classes) - If `use_inference_mode` is False (e.g., for training/fine-tuning): Shape (Batch_size, N_classes, N_samples), suitable for NLLLoss. + - If `return_raw_logits` is True: Shape (n_estimators, n_samples, n_classes) """ + if return_logits and return_raw_logits: + raise ValueError( + "Cannot return both logits and raw logits. Please specify only one." + ) + # Scenario 1: Standard inference path is_standard_inference = use_inference_mode and not isinstance( self.executor_, InferenceEngineBatchedNoPreprocessing @@ -905,13 +952,16 @@ def forward( # noqa: C901, PLR0912 stacked_outputs = torch.stack(outputs) # --- Build the processing pipeline by composing the steps in order --- - # The first step is always to apply the temperature scaling. - pipeline = [self._apply_temperature] + pipeline = [] if return_logits: # For logits, we just average the temperature-scaled logits. - pipeline.append(self._average_across_estimators) + pipeline.extend([self._apply_temperature, self._average_across_estimators]) + elif return_raw_logits: + pass # no post-processing else: + pipeline.append(self._apply_temperature) + # For probabilities, the order of averaging and softmax is crucial. if self.average_before_softmax: pipeline.extend([self._average_across_estimators, self._apply_softmax]) @@ -931,7 +981,7 @@ def forward( # noqa: C901, PLR0912 # --- Final output shaping --- if output.ndim > 2 and use_inference_mode: - output = output.squeeze(1) + output = output.squeeze(1) if not return_raw_logits else output.squeeze(2) if not use_inference_mode: # This case is primarily for fine-tuning where NLLLoss expects [B, C, N] diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index b18802c6f..ce13394d8 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -260,6 +260,37 @@ def test_predict_logits_and_consistency( assert log_loss(y, proba_from_predict_proba) < 5.0 +@pytest.mark.parametrize(("n_estimators"), [1, 2]) +def test_predict_raw_logits( + X_y: tuple[np.ndarray, np.ndarray], + n_estimators: int, +): + """Tests the predict_raw_logits method.""" + X, y = X_y + + # Ensure y is int64 for consistency with classification tasks + y = y.astype(np.int64) + + classifier = TabPFNClassifier( + n_estimators=n_estimators, + random_state=42, + ) + classifier.fit(X, y) + + logits = classifier.predict_raw_logits(X) + assert logits.shape[0] == n_estimators + assert isinstance(logits, np.ndarray) + assert logits.shape == (n_estimators, X.shape[0], classifier.n_classes_) + assert logits.dtype == np.float32 + assert not np.isnan(logits).any() + assert not np.isinf(logits).any() + if classifier.n_classes_ > 1: + assert not np.all(logits == logits[:, 0:1]), ( + "Logits are identical across classes for all samples, indicating " + "trivial output." + ) + + def test_softmax_temperature_impact_on_logits_magnitude( X_y: tuple[np.ndarray, np.ndarray], ):