Skip to content

Automatic Batching option in predict() functions#775

Open
klemens-floege wants to merge 6 commits intomainfrom
batch_size_predict
Open

Automatic Batching option in predict() functions#775
klemens-floege wants to merge 6 commits intomainfrom
batch_size_predict

Conversation

@klemens-floege
Copy link
Contributor

@klemens-floege klemens-floege commented Feb 4, 2026

Summary

  • Add batch_size_predict and batch_predict_enable_test_interaction parameters to predict methods in TabPFNClassifier and TabPFNRegressor
  • Allows automatic batching of large test sets to avoid OOM errors
  • By default (batch_predict_enable_test_interaction=False), predictions are consistent across batch sizes
  • With batch_predict_enable_test_interaction=True, predictions match unbatched when using full test set as batch size

Changes

  • src/tabpfn/base.py: Add set_multiquery_item_attention() helper function
  • src/tabpfn/classifier.py: Add batching parameters to predict(), predict_proba(), predict_logits(), predict_raw_logits()
  • src/tabpfn/regressor.py: Add batching parameters to predict() method
  • src/tabpfn/errors.py: Update OOM error message to mention batch_size_predict
  • Tests: Add tests for batch consistency and matching unbatched predictions

Usage

# Batch predictions to avoid OOM
clf.predict_proba(X_test, batch_size_predict=1000)

# With test interaction (predictions match unbatched when batch_size=len(X))
clf.predict_proba(X_test, batch_size_predict=len(X_test),
                  batch_predict_enable_test_interaction=True)

@klemens-floege klemens-floege requested a review from a team as a code owner February 4, 2026 19:14
@klemens-floege klemens-floege requested review from noahho and removed request for a team February 4, 2026 19:14
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a batch_size_predict parameter to both TabPFNClassifier and TabPFNRegressor, allowing for batched predictions on large test sets to prevent out-of-memory errors. The implementation is solid, with new methods for batched prediction and updated error messages to guide users. The accompanying tests are comprehensive and cover the new functionality well.

I have one suggestion for a minor refactoring in src/tabpfn/regressor.py to reduce code duplication, which will improve maintainability.

I am having trouble creating individual review comments. Click here to see my feedback.

src/tabpfn/regressor.py (999-1070)

medium

There's some code duplication in _batched_predict. The loop for iterating over batches and calling self.forward is repeated for the quantiles case and the mean/median/mode case.

To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, you could refactor this by extracting the batch iteration and logit processing into a local helper generator function. This would remove the duplication while keeping the logic encapsulated within _batched_predict.

    def _batched_predict(
        self,
        X: XType,
        *,
        output_type: OutputType,
        quantiles: list[float],
    ) -> RegressionResultType:
        """Run batched prediction to avoid OOM on large test sets.

        Args:
            X: The input data for prediction (already preprocessed).
            output_type: The type of output to return.
            quantiles: The quantiles to compute if output_type is "quantiles".

        Returns:
            The concatenated predictions from all batches.

        Raises:
            TabPFNValidationError: If output_type is "full" or "main" with batching,
                as these return complex structures that can't be easily batched.
        """
        if output_type in ["full", "main"]:
            raise TabPFNValidationError(
                f"output_type='{output_type}' is not supported with "
                f"batch_size_predict. Use 'mean', 'median', 'mode', or 'quantiles' "
                f"instead, or set batch_size_predict=None."
            )

        n_samples = X.shape[0] if hasattr(X, "shape") else len(X)

        def _iter_logits():
            """Yields logits for each batch of X."""
            for start in range(0, n_samples, self.batch_size_predict):
                end = min(start + self.batch_size_predict, n_samples)
                X_batch = X[start:end]

                with handle_oom_errors(self.devices_, X_batch, model_type="regressor"):
                    (_, outputs, borders) = self.forward(
                        X_batch, use_inference_mode=True
                    )
                yield self._process_forward_outputs(outputs, borders)

        if output_type == "quantiles":
            # For quantiles, we need to collect results per quantile
            all_quantile_results: list[list[np.ndarray]] = [[] for _ in quantiles]
            for logits in _iter_logits():
                for i, q in enumerate(quantiles):
                    q_result = (
                        self.raw_space_bardist_.icdf(logits, q).cpu().detach().numpy()
                    )
                    all_quantile_results[i].append(q_result)

            return [np.concatenate(q_results) for q_results in all_quantile_results]

        # For mean, median, mode
        results = []
        for logits in _iter_logits():
            batch_result = _logits_to_output(
                output_type=output_type,
                logits=logits,
                criterion=self.raw_space_bardist_,
                quantiles=quantiles,
            )
            results.append(batch_result)

        return np.concatenate(results)

@klemens-floege klemens-floege changed the title Batch size predict parameter Automatic Batching option in predict() functions Feb 4, 2026
Copy link
Collaborator

@noahho noahho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for looking into this, however, the code looks a bit vibe-coded?

  • set_multiquery_item_attention seems liek a large complex code piece that doesn't make sense for our transformer (its already independent)
  • there seem to be a few of these large code changes

Can you go through the code and make sure it's actually all correct please? Or in case I'm missing something let me know and I'll take a deeper look

Ohh also just seeing it doesn't pass tests - I assume this one isn't ready then. Please add "WIP" -> "WIP: Automatic Batching option in predict() functions" to the PR so it doesn't get reviewed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants