Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/775.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added automatic batching option in predict() functions.
22 changes: 22 additions & 0 deletions src/tabpfn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,28 @@ def initialize_telemetry() -> None:
capture_session()


def set_multiquery_item_attention(
model: TabPFNClassifier | TabPFNRegressor,
*,
enabled: bool,
) -> None:
"""Set multiquery_item_attention_for_test_set on all model layers.
This controls whether test samples attend to each other during inference.
Disabling it ensures predictions are consistent across different batch sizes.
Args:
model: The fitted TabPFN model.
enabled: If True, test samples can attend to each other.
If False, test samples only attend to training samples.
"""
for model_cache in model.executor_.model_caches:
for m in model_cache._models.values():
for module in m.modules():
if hasattr(module, "multiquery_item_attention_for_test_set"):
module.multiquery_item_attention_for_test_set = enabled


def get_embeddings(
model: TabPFNClassifier | TabPFNRegressor,
X: XType,
Expand Down
188 changes: 177 additions & 11 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_embeddings,
initialize_model_variables_helper,
initialize_telemetry,
set_multiquery_item_attention,
)
from tabpfn.constants import (
PROBABILITY_EPSILON_ROUND_ZERO,
Expand Down Expand Up @@ -378,8 +379,9 @@ class in Fine-Tuning. The fit_from_preprocessed() function sets this
False and True.
!!! warning
This does not batch the original input data. We still recommend to
batch the test set as necessary if you run out of memory.
This does not batch the original input data. If you run out of
memory during prediction, use `batch_size_predict` in the predict
method to automatically batch the test set.
random_state:
Controls the randomness of the model. Pass an int for reproducible
Expand Down Expand Up @@ -1014,6 +1016,8 @@ def _raw_predict(
*,
return_logits: bool,
return_raw_logits: bool = False,
batch_size_predict: int | None = None,
batch_predict_enable_test_interaction: bool = False,
) -> torch.Tensor:
"""Internal method to run prediction.
Expand All @@ -1028,6 +1032,13 @@ def _raw_predict(
post-processing steps.
return_raw_logits: If True, returns the raw logits without
averaging estimators or temperature scaling.
batch_size_predict: If set, predictions are batched into chunks
of this size. If None, no batching is performed.
batch_predict_enable_test_interaction: If False (default), test
samples only attend to training samples during batched prediction,
ensuring predictions match unbatched. If True, test samples can
attend to each other within a batch, so predictions may vary
depending on batch size.
Returns:
The raw torch.Tensor output, either logits or probabilities,
Expand All @@ -1044,6 +1055,16 @@ def _raw_predict(
ord_encoder=getattr(self, "ordinal_encoder_", None),
)

# If batch_size_predict is set, batch the predictions
if batch_size_predict is not None:
return self._batched_raw_predict(
X,
batch_size_predict=batch_size_predict,
batch_predict_enable_test_interaction=batch_predict_enable_test_interaction,
return_logits=return_logits,
return_raw_logits=return_raw_logits,
)

with handle_oom_errors(self.devices_, X, model_type="classifier"):
return self.forward(
X,
Expand All @@ -1052,17 +1073,90 @@ def _raw_predict(
return_raw_logits=return_raw_logits,
)

def _batched_raw_predict(
self,
X: XType,
*,
batch_size_predict: int,
batch_predict_enable_test_interaction: bool,
return_logits: bool,
return_raw_logits: bool = False,
) -> torch.Tensor:
"""Run batched prediction to avoid OOM on large test sets.
Args:
X: The input data for prediction.
batch_size_predict: The batch size for predictions.
batch_predict_enable_test_interaction: If False, test samples only
attend to training samples, ensuring predictions match unbatched.
If True, predictions may vary depending on batch size.
return_logits: If True, the logits are returned.
return_raw_logits: If True, returns the raw logits without
averaging estimators or temperature scaling.
Returns:
The concatenated predictions from all batches.
"""
# Disable multiquery attention for consistent predictions (matching unbatched)
# unless batch_predict_enable_test_interaction is True
if not batch_predict_enable_test_interaction:
set_multiquery_item_attention(self, enabled=False)

try:
results = []
n_samples = X.shape[0] if hasattr(X, "shape") else len(X)

for start in range(0, n_samples, batch_size_predict):
end = min(start + batch_size_predict, n_samples)
X_batch = X[start:end]

with handle_oom_errors(self.devices_, X_batch, model_type="classifier"):
batch_result = self.forward(
X_batch,
use_inference_mode=True,
return_logits=return_logits,
return_raw_logits=return_raw_logits,
)
results.append(batch_result)

# Concatenate along the appropriate dimension
# raw logits: (n_estimators, n_samples, n_classes) -> dim 1
# logits/probas: (n_samples, n_classes) -> dim 0
concat_dim = 1 if return_raw_logits else 0
return torch.cat(results, dim=concat_dim)
finally:
# Restore multiquery attention if we disabled it
if not batch_predict_enable_test_interaction:
set_multiquery_item_attention(self, enabled=True)

@track_model_call(model_method="predict", param_names=["X"])
def predict(self, X: XType) -> np.ndarray:
def predict(
self,
X: XType,
*,
batch_size_predict: int | None = None,
batch_predict_enable_test_interaction: bool = False,
) -> np.ndarray:
"""Predict the class labels for the provided input samples.
Args:
X: The input data for prediction.
batch_size_predict: If set, predictions are batched into chunks
of this size to avoid OOM errors. If None, no batching is performed.
batch_predict_enable_test_interaction: If False (default), test
samples only attend to training samples during batched prediction,
ensuring predictions match unbatched. If True, test samples can
attend to each other within a batch, so predictions may vary
depending on batch size.
Returns:
The predicted class labels as a NumPy array.
"""
probas = self._predict_proba(X=X)
probas = self._predict_proba(
X=X,
batch_size_predict=batch_size_predict,
batch_predict_enable_test_interaction=batch_predict_enable_test_interaction,
)
y_pred = np.argmax(probas, axis=1)
if hasattr(self, "label_encoder_") and self.label_encoder_ is not None:
return self.label_encoder_.inverse_transform(y_pred)
Expand All @@ -1071,24 +1165,48 @@ def predict(self, X: XType) -> np.ndarray:

@config_context(transform_output="default")
@track_model_call(model_method="predict", param_names=["X"])
def predict_logits(self, X: XType) -> np.ndarray:
def predict_logits(
self,
X: XType,
*,
batch_size_predict: int | None = None,
batch_predict_enable_test_interaction: bool = False,
) -> np.ndarray:
"""Predict the raw logits for the provided input samples.
Logits represent the unnormalized log-probabilities of the classes
before the softmax activation function is applied.
Args:
X: The input data for prediction.
batch_size_predict: If set, predictions are batched into chunks
of this size to avoid OOM errors. If None, no batching is performed.
batch_predict_enable_test_interaction: If False (default), test
samples only attend to training samples during batched prediction,
ensuring predictions match unbatched. If True, test samples can
attend to each other within a batch, so predictions may vary
depending on batch size.
Returns:
The predicted logits as a NumPy array. Shape (n_samples, n_classes).
"""
logits_tensor = self._raw_predict(X, return_logits=True)
logits_tensor = self._raw_predict(
X,
return_logits=True,
batch_size_predict=batch_size_predict,
batch_predict_enable_test_interaction=batch_predict_enable_test_interaction,
)
return logits_tensor.float().detach().cpu().numpy()

@config_context(transform_output="default")
@track_model_call(model_method="predict", param_names=["X"])
def predict_raw_logits(self, X: XType) -> np.ndarray:
def predict_raw_logits(
self,
X: XType,
*,
batch_size_predict: int | None = None,
batch_predict_enable_test_interaction: bool = False,
) -> np.ndarray:
"""Predict the raw logits for the provided input samples.
Logits represent the unnormalized log-probabilities of the classes
Expand All @@ -1098,6 +1216,13 @@ def predict_raw_logits(self, X: XType) -> np.ndarray:
Args:
X: The input data for prediction.
batch_size_predict: If set, predictions are batched into chunks
of this size to avoid OOM errors. If None, no batching is performed.
batch_predict_enable_test_interaction: If False (default), test
samples only attend to training samples during batched prediction,
ensuring predictions match unbatched. If True, test samples can
attend to each other within a batch, so predictions may vary
depending on batch size.
Returns:
An array of predicted logits for each estimator,
Expand All @@ -1107,37 +1232,78 @@ def predict_raw_logits(self, X: XType) -> np.ndarray:
X,
return_logits=False,
return_raw_logits=True,
batch_size_predict=batch_size_predict,
batch_predict_enable_test_interaction=batch_predict_enable_test_interaction,
)
return logits_tensor.float().detach().cpu().numpy()

@track_model_call(model_method="predict", param_names=["X"])
def predict_proba(self, X: XType) -> np.ndarray:
def predict_proba(
self,
X: XType,
*,
batch_size_predict: int | None = None,
batch_predict_enable_test_interaction: bool = False,
) -> np.ndarray:
"""Predict the probabilities of the classes for the provided input samples.
This is a wrapper around the `_predict_proba` method.
Args:
X: The input data for prediction.
batch_size_predict: If set, predictions are batched into chunks
of this size to avoid OOM errors. If None, no batching is performed.
batch_predict_enable_test_interaction: If False (default), test
samples only attend to training samples during batched prediction,
ensuring predictions match unbatched. If True, test samples can
attend to each other within a batch, so predictions may vary
depending on batch size.
Returns:
The predicted probabilities of the classes as a NumPy array.
Shape (n_samples, n_classes).
"""
return self._predict_proba(X)
return self._predict_proba(
X,
batch_size_predict=batch_size_predict,
batch_predict_enable_test_interaction=batch_predict_enable_test_interaction,
)

@config_context(transform_output="default") # type: ignore
def _predict_proba(self, X: XType) -> np.ndarray:
def _predict_proba(
self,
X: XType,
*,
batch_size_predict: int | None = None,
batch_predict_enable_test_interaction: bool = False,
) -> np.ndarray:
"""Predict the probabilities of the classes for the provided input samples.
Args:
X: The input data for prediction.
batch_size_predict: If set, predictions are batched into chunks
of this size. If None, no batching is performed.
batch_predict_enable_test_interaction: If False (default), test
samples only attend to training samples during batched prediction,
ensuring predictions match unbatched. If True, test samples can
attend to each other within a batch, so predictions may vary
depending on batch size.
Returns:
The predicted probabilities of the classes as a NumPy array.
Shape (n_samples, n_classes).
"""
probas = (
self._raw_predict(X, return_logits=False).float().detach().cpu().numpy()
self._raw_predict(
X,
return_logits=False,
batch_size_predict=batch_size_predict,
batch_predict_enable_test_interaction=batch_predict_enable_test_interaction,
)
.float()
.detach()
.cpu()
.numpy()
)
probas = self._maybe_reweight_probas(probas=probas)
if self.inference_config_.USE_SKLEARN_16_DECIMAL_PRECISION:
Expand Down
12 changes: 3 additions & 9 deletions src/tabpfn/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,13 @@ def __init__(
n_test_samples: int | None = None,
model_type: str = "classifier",
):
predict_method = "predict_proba" if model_type == "classifier" else "predict"

size_info = f" with {n_test_samples:,} test samples" if n_test_samples else ""

model_class = f"TabPFN{model_type.title()}"
message = (
f"{self.device_name} out of memory{size_info}.\n\n"
f"Solution: Split your test data into smaller batches:\n\n"
f" batch_size = 1000 # depends on hardware\n"
f" predictions = []\n"
f" for i in range(0, len(X_test), batch_size):\n"
f" batch = model.{predict_method}(X_test[i:i + batch_size])\n"
f" predictions.append(batch)\n"
f" predictions = np.vstack(predictions)"
f"Solution: Set batch_size_predict when creating the model:\n\n"
f" model = {model_class}(batch_size_predict=1000)"
)
if original_error is not None:
message += f"\n\nOriginal error: {original_error}"
Expand Down
Loading
Loading