diff --git a/changelog/775.added.md b/changelog/775.added.md new file mode 100644 index 000000000..ee82ad3ae --- /dev/null +++ b/changelog/775.added.md @@ -0,0 +1 @@ +Added automatic batching option in predict() functions. \ No newline at end of file diff --git a/src/tabpfn/base.py b/src/tabpfn/base.py index 7e1faab49..ef4f77212 100644 --- a/src/tabpfn/base.py +++ b/src/tabpfn/base.py @@ -429,6 +429,28 @@ def initialize_telemetry() -> None: capture_session() +def set_multiquery_item_attention( + model: TabPFNClassifier | TabPFNRegressor, + *, + enabled: bool, +) -> None: + """Set multiquery_item_attention_for_test_set on all model layers. + + This controls whether test samples attend to each other during inference. + Disabling it ensures predictions are consistent across different batch sizes. + + Args: + model: The fitted TabPFN model. + enabled: If True, test samples can attend to each other. + If False, test samples only attend to training samples. + """ + for model_cache in model.executor_.model_caches: + for m in model_cache._models.values(): + for module in m.modules(): + if hasattr(module, "multiquery_item_attention_for_test_set"): + module.multiquery_item_attention_for_test_set = enabled + + def get_embeddings( model: TabPFNClassifier | TabPFNRegressor, X: XType, diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index d067beb3f..73b2b416d 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -41,6 +41,7 @@ get_embeddings, initialize_model_variables_helper, initialize_telemetry, + set_multiquery_item_attention, ) from tabpfn.constants import ( PROBABILITY_EPSILON_ROUND_ZERO, @@ -378,8 +379,9 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this False and True. !!! warning - This does not batch the original input data. We still recommend to - batch the test set as necessary if you run out of memory. + This does not batch the original input data. If you run out of + memory during prediction, use `batch_size_predict` in the predict + method to automatically batch the test set. random_state: Controls the randomness of the model. Pass an int for reproducible @@ -1014,6 +1016,8 @@ def _raw_predict( *, return_logits: bool, return_raw_logits: bool = False, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, ) -> torch.Tensor: """Internal method to run prediction. @@ -1028,6 +1032,13 @@ def _raw_predict( post-processing steps. return_raw_logits: If True, returns the raw logits without averaging estimators or temperature scaling. + batch_size_predict: If set, predictions are batched into chunks + of this size. If None, no batching is performed. + batch_predict_enable_test_interaction: If False (default), test + samples only attend to training samples during batched prediction, + ensuring predictions match unbatched. If True, test samples can + attend to each other within a batch, so predictions may vary + depending on batch size. Returns: The raw torch.Tensor output, either logits or probabilities, @@ -1044,6 +1055,16 @@ def _raw_predict( ord_encoder=getattr(self, "ordinal_encoder_", None), ) + # If batch_size_predict is set, batch the predictions + if batch_size_predict is not None: + return self._batched_raw_predict( + X, + batch_size_predict=batch_size_predict, + batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, + return_logits=return_logits, + return_raw_logits=return_raw_logits, + ) + with handle_oom_errors(self.devices_, X, model_type="classifier"): return self.forward( X, @@ -1052,17 +1073,90 @@ def _raw_predict( return_raw_logits=return_raw_logits, ) + def _batched_raw_predict( + self, + X: XType, + *, + batch_size_predict: int, + batch_predict_enable_test_interaction: bool, + return_logits: bool, + return_raw_logits: bool = False, + ) -> torch.Tensor: + """Run batched prediction to avoid OOM on large test sets. + + Args: + X: The input data for prediction. + batch_size_predict: The batch size for predictions. + batch_predict_enable_test_interaction: If False, test samples only + attend to training samples, ensuring predictions match unbatched. + If True, predictions may vary depending on batch size. + return_logits: If True, the logits are returned. + return_raw_logits: If True, returns the raw logits without + averaging estimators or temperature scaling. + + Returns: + The concatenated predictions from all batches. + """ + # Disable multiquery attention for consistent predictions (matching unbatched) + # unless batch_predict_enable_test_interaction is True + if not batch_predict_enable_test_interaction: + set_multiquery_item_attention(self, enabled=False) + + try: + results = [] + n_samples = X.shape[0] if hasattr(X, "shape") else len(X) + + for start in range(0, n_samples, batch_size_predict): + end = min(start + batch_size_predict, n_samples) + X_batch = X[start:end] + + with handle_oom_errors(self.devices_, X_batch, model_type="classifier"): + batch_result = self.forward( + X_batch, + use_inference_mode=True, + return_logits=return_logits, + return_raw_logits=return_raw_logits, + ) + results.append(batch_result) + + # Concatenate along the appropriate dimension + # raw logits: (n_estimators, n_samples, n_classes) -> dim 1 + # logits/probas: (n_samples, n_classes) -> dim 0 + concat_dim = 1 if return_raw_logits else 0 + return torch.cat(results, dim=concat_dim) + finally: + # Restore multiquery attention if we disabled it + if not batch_predict_enable_test_interaction: + set_multiquery_item_attention(self, enabled=True) + @track_model_call(model_method="predict", param_names=["X"]) - def predict(self, X: XType) -> np.ndarray: + def predict( + self, + X: XType, + *, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, + ) -> np.ndarray: """Predict the class labels for the provided input samples. Args: X: The input data for prediction. + batch_size_predict: If set, predictions are batched into chunks + of this size to avoid OOM errors. If None, no batching is performed. + batch_predict_enable_test_interaction: If False (default), test + samples only attend to training samples during batched prediction, + ensuring predictions match unbatched. If True, test samples can + attend to each other within a batch, so predictions may vary + depending on batch size. Returns: The predicted class labels as a NumPy array. """ - probas = self._predict_proba(X=X) + probas = self._predict_proba( + X=X, + batch_size_predict=batch_size_predict, + batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, + ) y_pred = np.argmax(probas, axis=1) if hasattr(self, "label_encoder_") and self.label_encoder_ is not None: return self.label_encoder_.inverse_transform(y_pred) @@ -1071,7 +1165,13 @@ def predict(self, X: XType) -> np.ndarray: @config_context(transform_output="default") @track_model_call(model_method="predict", param_names=["X"]) - def predict_logits(self, X: XType) -> np.ndarray: + def predict_logits( + self, + X: XType, + *, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, + ) -> np.ndarray: """Predict the raw logits for the provided input samples. Logits represent the unnormalized log-probabilities of the classes @@ -1079,16 +1179,34 @@ def predict_logits(self, X: XType) -> np.ndarray: Args: X: The input data for prediction. + batch_size_predict: If set, predictions are batched into chunks + of this size to avoid OOM errors. If None, no batching is performed. + batch_predict_enable_test_interaction: If False (default), test + samples only attend to training samples during batched prediction, + ensuring predictions match unbatched. If True, test samples can + attend to each other within a batch, so predictions may vary + depending on batch size. Returns: The predicted logits as a NumPy array. Shape (n_samples, n_classes). """ - logits_tensor = self._raw_predict(X, return_logits=True) + logits_tensor = self._raw_predict( + X, + return_logits=True, + batch_size_predict=batch_size_predict, + batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, + ) return logits_tensor.float().detach().cpu().numpy() @config_context(transform_output="default") @track_model_call(model_method="predict", param_names=["X"]) - def predict_raw_logits(self, X: XType) -> np.ndarray: + def predict_raw_logits( + self, + X: XType, + *, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, + ) -> np.ndarray: """Predict the raw logits for the provided input samples. Logits represent the unnormalized log-probabilities of the classes @@ -1098,6 +1216,13 @@ def predict_raw_logits(self, X: XType) -> np.ndarray: Args: X: The input data for prediction. + batch_size_predict: If set, predictions are batched into chunks + of this size to avoid OOM errors. If None, no batching is performed. + batch_predict_enable_test_interaction: If False (default), test + samples only attend to training samples during batched prediction, + ensuring predictions match unbatched. If True, test samples can + attend to each other within a batch, so predictions may vary + depending on batch size. Returns: An array of predicted logits for each estimator, @@ -1107,37 +1232,78 @@ def predict_raw_logits(self, X: XType) -> np.ndarray: X, return_logits=False, return_raw_logits=True, + batch_size_predict=batch_size_predict, + batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, ) return logits_tensor.float().detach().cpu().numpy() @track_model_call(model_method="predict", param_names=["X"]) - def predict_proba(self, X: XType) -> np.ndarray: + def predict_proba( + self, + X: XType, + *, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, + ) -> np.ndarray: """Predict the probabilities of the classes for the provided input samples. This is a wrapper around the `_predict_proba` method. Args: X: The input data for prediction. + batch_size_predict: If set, predictions are batched into chunks + of this size to avoid OOM errors. If None, no batching is performed. + batch_predict_enable_test_interaction: If False (default), test + samples only attend to training samples during batched prediction, + ensuring predictions match unbatched. If True, test samples can + attend to each other within a batch, so predictions may vary + depending on batch size. Returns: The predicted probabilities of the classes as a NumPy array. Shape (n_samples, n_classes). """ - return self._predict_proba(X) + return self._predict_proba( + X, + batch_size_predict=batch_size_predict, + batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, + ) @config_context(transform_output="default") # type: ignore - def _predict_proba(self, X: XType) -> np.ndarray: + def _predict_proba( + self, + X: XType, + *, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, + ) -> np.ndarray: """Predict the probabilities of the classes for the provided input samples. Args: X: The input data for prediction. + batch_size_predict: If set, predictions are batched into chunks + of this size. If None, no batching is performed. + batch_predict_enable_test_interaction: If False (default), test + samples only attend to training samples during batched prediction, + ensuring predictions match unbatched. If True, test samples can + attend to each other within a batch, so predictions may vary + depending on batch size. Returns: The predicted probabilities of the classes as a NumPy array. Shape (n_samples, n_classes). """ probas = ( - self._raw_predict(X, return_logits=False).float().detach().cpu().numpy() + self._raw_predict( + X, + return_logits=False, + batch_size_predict=batch_size_predict, + batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, + ) + .float() + .detach() + .cpu() + .numpy() ) probas = self._maybe_reweight_probas(probas=probas) if self.inference_config_.USE_SKLEARN_16_DECIMAL_PRECISION: diff --git a/src/tabpfn/errors.py b/src/tabpfn/errors.py index 87e99f04c..2b9ff5afd 100644 --- a/src/tabpfn/errors.py +++ b/src/tabpfn/errors.py @@ -63,19 +63,13 @@ def __init__( n_test_samples: int | None = None, model_type: str = "classifier", ): - predict_method = "predict_proba" if model_type == "classifier" else "predict" - size_info = f" with {n_test_samples:,} test samples" if n_test_samples else "" + model_class = f"TabPFN{model_type.title()}" message = ( f"{self.device_name} out of memory{size_info}.\n\n" - f"Solution: Split your test data into smaller batches:\n\n" - f" batch_size = 1000 # depends on hardware\n" - f" predictions = []\n" - f" for i in range(0, len(X_test), batch_size):\n" - f" batch = model.{predict_method}(X_test[i:i + batch_size])\n" - f" predictions.append(batch)\n" - f" predictions = np.vstack(predictions)" + f"Solution: Set batch_size_predict when creating the model:\n\n" + f" model = {model_class}(batch_size_predict=1000)" ) if original_error is not None: message += f"\n\nOriginal error: {original_error}" diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 29ade4071..e63c7c45a 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -46,6 +46,7 @@ get_embeddings, initialize_model_variables_helper, initialize_telemetry, + set_multiquery_item_attention, ) from tabpfn.constants import REGRESSION_CONSTANT_TARGET_BORDER_EPSILON, ModelVersion from tabpfn.errors import TabPFNValidationError, handle_oom_errors @@ -379,8 +380,9 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this False and True. !!! warning - This does not batch the original input data. We still recommend to - batch the test set as necessary if you run out of memory. + This does not batch the original input data. If you run out of + memory during prediction, use `batch_size_predict` in the predict + method to automatically batch the test set. random_state: Controls the randomness of the model. Pass an int for reproducible @@ -818,6 +820,8 @@ def predict( *, output_type: Literal["mean", "median", "mode"] = "mean", quantiles: list[float] | None = None, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, ) -> np.ndarray: ... @overload @@ -827,6 +831,8 @@ def predict( *, output_type: Literal["quantiles"], quantiles: list[float] | None = None, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, ) -> list[np.ndarray]: ... @overload @@ -836,6 +842,8 @@ def predict( *, output_type: Literal["main"], quantiles: list[float] | None = None, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, ) -> MainOutputDict: ... @overload @@ -845,6 +853,8 @@ def predict( *, output_type: Literal["full"], quantiles: list[float] | None = None, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, ) -> FullOutputDict: ... @config_context(transform_output="default") # type: ignore @@ -856,6 +866,8 @@ def predict( # TODO: support "ei", "pi" output_type: OutputType = "mean", quantiles: list[float] | None = None, + batch_size_predict: int | None = None, + batch_predict_enable_test_interaction: bool = False, ) -> RegressionResultType: """Runs the forward() method and then transform the logits from the binning space in order to predict target variable. @@ -882,6 +894,16 @@ def predict( quantiles are returned. The predictions per quantile match the input order. + batch_size_predict: + If set, predictions are batched into chunks of this size to avoid + OOM errors. If None, no batching is performed. + + batch_predict_enable_test_interaction: + If False (default), test samples only attend to training samples + during batched prediction, ensuring predictions match unbatched. + If True, test samples can attend to each other within a batch, + so predictions may vary depending on batch size. + Returns: The prediction, which can be a numpy array, a list of arrays (for quantiles), or a dictionary with detailed outputs. @@ -912,6 +934,16 @@ def predict( ord_encoder=getattr(self, "ordinal_encoder_", None), ) + # If batch_size_predict is set, use batched prediction + if batch_size_predict is not None: + return self._batched_predict( + X, + batch_size_predict=batch_size_predict, + batch_predict_enable_test_interaction=batch_predict_enable_test_interaction, + output_type=output_type, + quantiles=quantiles, + ) + # Runs over iteration engine with handle_oom_errors(self.devices_, X, model_type="regressor"): ( @@ -982,6 +1014,136 @@ def predict( return logit_to_output(output_type=output_type) + def _batched_predict( + self, + X: XType, + *, + batch_size_predict: int, + batch_predict_enable_test_interaction: bool, + output_type: OutputType, + quantiles: list[float], + ) -> RegressionResultType: + """Run batched prediction to avoid OOM on large test sets. + + Args: + X: The input data for prediction (already preprocessed). + batch_size_predict: The batch size for predictions. + batch_predict_enable_test_interaction: If False, test samples only + attend to training samples, ensuring predictions match unbatched. + If True, predictions may vary depending on batch size. + output_type: The type of output to return. + quantiles: The quantiles to compute if output_type is "quantiles". + + Returns: + The concatenated predictions from all batches. + + Raises: + TabPFNValidationError: If output_type is "full" or "main" with batching, + as these return complex structures that can't be easily batched. + """ + if output_type in ["full", "main"]: + raise TabPFNValidationError( + f"output_type='{output_type}' is not supported with " + f"batch_size_predict. Use 'mean', 'median', 'mode', or 'quantiles' " + f"instead, or set batch_size_predict=None." + ) + + # Disable multiquery attention for consistent predictions (matching unbatched) + # unless batch_predict_enable_test_interaction is True + if not batch_predict_enable_test_interaction: + set_multiquery_item_attention(self, enabled=False) + + try: + n_samples = X.shape[0] if hasattr(X, "shape") else len(X) + + if output_type == "quantiles": + # For quantiles, we need to collect results per quantile + all_quantile_results: list[list[np.ndarray]] = [[] for _ in quantiles] + + for start in range(0, n_samples, batch_size_predict): + end = min(start + batch_size_predict, n_samples) + X_batch = X[start:end] + + with handle_oom_errors( + self.devices_, X_batch, model_type="regressor" + ): + (_, outputs, borders) = self.forward( + X_batch, use_inference_mode=True + ) + + logits = self._process_forward_outputs(outputs, borders) + + for i, q in enumerate(quantiles): + q_result = ( + self.raw_space_bardist_.icdf(logits, q) + .cpu() + .detach() + .numpy() + ) + all_quantile_results[i].append(q_result) + + return [np.concatenate(q_results) for q_results in all_quantile_results] + + # For mean, median, mode + results = [] + for start in range(0, n_samples, batch_size_predict): + end = min(start + batch_size_predict, n_samples) + X_batch = X[start:end] + + with handle_oom_errors(self.devices_, X_batch, model_type="regressor"): + (_, outputs, borders) = self.forward( + X_batch, use_inference_mode=True + ) + + logits = self._process_forward_outputs(outputs, borders) + batch_result = _logits_to_output( + output_type=output_type, + logits=logits, + criterion=self.raw_space_bardist_, + quantiles=quantiles, + ) + results.append(batch_result) + + return np.concatenate(results) + finally: + # Restore multiquery attention if we disabled it + if not batch_predict_enable_test_interaction: + set_multiquery_item_attention(self, enabled=True) + + def _process_forward_outputs( + self, + outputs: list[torch.Tensor], + borders: list[np.ndarray], + ) -> torch.Tensor: + """Process forward outputs into final logits. + + Args: + outputs: List of tensors from forward pass. + borders: List of border arrays for each estimator. + + Returns: + Processed logits tensor. + """ + transformed_logits = [ + translate_probs_across_borders( + logits, + frm=torch.as_tensor(borders_t, device=logits.device), + to=self.znorm_space_bardist_.borders.to(logits.device), + ) + for logits, borders_t in zip(outputs, borders) + ] + stacked_logits = torch.stack(transformed_logits, dim=0) + if self.average_before_softmax: + logits = stacked_logits.log().mean(dim=0).softmax(dim=-1) + else: + logits = stacked_logits.mean(dim=0) + + logits = logits.log() + if logits.dtype == torch.float16: + logits = logits.float() + + return logits + def forward( self, X: list[torch.Tensor] | XType, diff --git a/tests/test_classifier_interface.py b/tests/test_classifier_interface.py index caa000f24..530395347 100644 --- a/tests/test_classifier_interface.py +++ b/tests/test_classifier_interface.py @@ -1232,3 +1232,68 @@ def test__create_default_for_version__passes_through_overrides() -> None: assert estimator.n_estimators == 16 assert estimator.softmax_temperature == 0.9 + + +def test__batch_size_predict__matches_unbatched_with_test_interaction( + X_y: tuple, +) -> None: + """Test that batch_size=1 with test interaction matches unbatched predictions. + + When batch_predict_enable_test_interaction=True and batch_size equals the full + test set size, predictions should match unbatched predictions exactly. + """ + X, y = X_y + n_classes = len(np.unique(y)) + clf = TabPFNClassifier(n_estimators=2, random_state=42) + clf.fit(X, y) + + # Get unbatched predictions + proba_unbatched = clf.predict_proba(X) + + # Get batched predictions with batch_size=full test set and test interaction enabled + # This should match unbatched since all samples can interact + proba_batched = clf.predict_proba( + X, + batch_size_predict=len(X), + batch_predict_enable_test_interaction=True, + ) + + # Results should match unbatched + np.testing.assert_allclose(proba_batched, proba_unbatched, rtol=1e-5, atol=1e-5) + + # Also verify valid probability distribution + assert proba_batched.shape == (X.shape[0], n_classes) + np.testing.assert_allclose(proba_batched.sum(axis=1), 1.0, atol=1e-5) + assert np.all(proba_batched >= 0) + assert np.all(proba_batched <= 1) + + +@pytest.mark.parametrize("batch_size", [1, 2, 5]) +def test__batch_size_predict__consistent_across_batch_sizes( + X_y: tuple, batch_size: int +) -> None: + """Test that batched predictions are consistent regardless of batch size. + + By default, batch_predict_enable_test_interaction=False disables multiquery + attention between test samples, ensuring predictions are identical regardless + of how the test set is batched. + """ + X, y = X_y + n_classes = len(np.unique(y)) + clf = TabPFNClassifier(n_estimators=2, random_state=42) + clf.fit(X, y) + + # Get predictions with batch_size=1 as reference + proba_reference = clf.predict_proba(X, batch_size_predict=1) + + # Get predictions with the test batch_size + proba_batched = clf.predict_proba(X, batch_size_predict=batch_size) + + # Results should match regardless of batch size + np.testing.assert_allclose(proba_batched, proba_reference, rtol=1e-5, atol=1e-5) + + # Also verify valid probability distribution + assert proba_batched.shape == (X.shape[0], n_classes) + np.testing.assert_allclose(proba_batched.sum(axis=1), 1.0, atol=1e-5) + assert np.all(proba_batched >= 0) + assert np.all(proba_batched <= 1) diff --git a/tests/test_regressor_interface.py b/tests/test_regressor_interface.py index 545cd71da..2916173cd 100644 --- a/tests/test_regressor_interface.py +++ b/tests/test_regressor_interface.py @@ -21,6 +21,7 @@ from tabpfn import TabPFNRegressor from tabpfn.base import RegressorModelSpecs, initialize_tabpfn_model from tabpfn.constants import ModelVersion +from tabpfn.errors import TabPFNValidationError from tabpfn.model_loading import ModelSource, prepend_cache_path from tabpfn.preprocessing import PreprocessorConfig from tabpfn.settings import settings @@ -836,3 +837,115 @@ def test__create_default_for_version__passes_through_overrides() -> None: assert estimator.n_estimators == 16 assert estimator.softmax_temperature == 0.9 + + +@pytest.mark.parametrize("output_type", ["full", "main"]) +def test__batch_size_predict__complex_output_raises_error( + X_y: tuple, output_type: str +) -> None: + """Test that 'full' and 'main' output types raise error with batching.""" + X, y = X_y + reg = TabPFNRegressor(n_estimators=2, random_state=42) + reg.fit(X, y) + + with pytest.raises(TabPFNValidationError, match=f"output_type='{output_type}'"): + reg.predict(X, output_type=output_type, batch_size_predict=3) + + +@pytest.mark.parametrize("output_type", ["mean", "median", "mode"]) +def test__batch_size_predict__matches_unbatched_with_test_interaction( + X_y: tuple, output_type: str +) -> None: + """Test that full batch with test interaction matches unbatched predictions. + + When batch_predict_enable_test_interaction=True and batch_size equals the full + test set size, predictions should match unbatched predictions exactly. + """ + X, y = X_y + reg = TabPFNRegressor(n_estimators=2, random_state=42) + reg.fit(X, y) + + # Get unbatched predictions + pred_unbatched = reg.predict(X, output_type=output_type) + + # Get batched predictions with batch_size=full test set and test interaction enabled + # This should match unbatched since all samples can interact + pred_batched = reg.predict( + X, + output_type=output_type, + batch_size_predict=len(X), + batch_predict_enable_test_interaction=True, + ) + + # Results should match unbatched + np.testing.assert_allclose(pred_batched, pred_unbatched, rtol=1e-5, atol=1e-5) + + # Also verify valid output + assert pred_batched.shape == (X.shape[0],) + assert np.all(np.isfinite(pred_batched)) + + +@pytest.mark.parametrize("batch_size", [1, 3, 5]) +@pytest.mark.parametrize("output_type", ["mean", "median", "mode"]) +def test__batch_size_predict__consistent_across_batch_sizes( + X_y: tuple, batch_size: int, output_type: str +) -> None: + """Test that batched predictions are consistent regardless of batch size. + + By default, batch_predict_enable_test_interaction=False disables multiquery + attention between test samples, ensuring predictions are identical regardless + of how the test set is batched. + """ + X, y = X_y + reg = TabPFNRegressor(n_estimators=2, random_state=42) + reg.fit(X, y) + + # Get predictions with batch_size=1 as reference + pred_reference = reg.predict(X, output_type=output_type, batch_size_predict=1) + + # Get predictions with the test batch_size + pred_batched = reg.predict( + X, output_type=output_type, batch_size_predict=batch_size + ) + + # Results should match regardless of batch size + np.testing.assert_allclose(pred_batched, pred_reference, rtol=1e-5, atol=1e-5) + + # Also verify valid output + assert pred_batched.shape == (X.shape[0],) + assert np.all(np.isfinite(pred_batched)) + + +@pytest.mark.parametrize("batch_size", [1, 3, 5]) +def test__batch_size_predict__quantiles_consistent_across_batch_sizes( + X_y: tuple, batch_size: int +) -> None: + """Test that batched quantile predictions are consistent regardless of batch size. + + By default, batch_predict_enable_test_interaction=False disables multiquery + attention between test samples, ensuring predictions are identical regardless + of how the test set is batched. + """ + X, y = X_y + quantile_values = [0.1, 0.5, 0.9] + reg = TabPFNRegressor(n_estimators=2, random_state=42) + reg.fit(X, y) + + # Get predictions with batch_size=1 as reference + result_reference = reg.predict( + X, output_type="quantiles", quantiles=quantile_values, batch_size_predict=1 + ) + + # Get predictions with the test batch_size + result_batched = reg.predict( + X, + output_type="quantiles", + quantiles=quantile_values, + batch_size_predict=batch_size, + ) + + # Results should match regardless of batch size + for q_batched, q_reference in zip(result_batched, result_reference): + np.testing.assert_allclose(q_batched, q_reference, rtol=1e-5, atol=1e-5) + assert q_batched.shape == (X.shape[0],) + assert np.all(np.isfinite(q_batched))