Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog/745.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Optimize regressor predict method for memory efficiency
- Average ensemble outputs on-the-fly instead of accumulating all outputs
- Reduces memory usage by avoiding storage of all intermediate outputs, especially beneficial for large `n_estimators`
118 changes: 64 additions & 54 deletions src/tabpfn/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import typing
import warnings
from collections.abc import Sequence
from collections.abc import Iterator, Sequence
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
Expand Down Expand Up @@ -849,7 +849,7 @@ def predict(

@config_context(transform_output="default") # type: ignore
@track_model_call(model_method="predict", param_names=["X"])
def predict(
def predict( # noqa: C901, PLR0912
self,
X: XType,
*,
Expand Down Expand Up @@ -913,29 +913,34 @@ def predict(
)

# Runs over iteration engine

n_estimators = 0
averaged_logits: torch.Tensor | None = None
with handle_oom_errors(self.devices_, X, model_type="regressor"):
(
_,
# list of tensors [N_est, N_samples, N_borders] (after forward)
outputs,
# list of numpy arrays containing borders for each estimator
borders,
) = self.forward(X, use_inference_mode=True)

# --- Translate probs, average, get final logits ---
transformed_logits = [
translate_probs_across_borders(
logits,
frm=torch.as_tensor(borders_t, device=logits.device),
to=self.znorm_space_bardist_.borders.to(logits.device),
)
for logits, borders_t in zip(outputs, borders)
]
stacked_logits = torch.stack(transformed_logits, dim=0)
for borders_t, output in self._iter_forward_executor(
X, use_inference_mode=True
):
# Transform probabilities across borders
transformed = translate_probs_across_borders(
output,
frm=torch.as_tensor(borders_t, device=output.device),
to=self.znorm_space_bardist_.borders.to(output.device),
)

if self.average_before_softmax:
transformed = transformed.log()

if averaged_logits is None:
averaged_logits = transformed
else:
averaged_logits = averaged_logits + transformed
n_estimators += 1

# Finalize averaging
if self.average_before_softmax:
logits = stacked_logits.log().mean(dim=0).softmax(dim=-1)
logits = (averaged_logits / n_estimators).softmax(dim=-1) # type: ignore
else:
logits = stacked_logits.mean(dim=0)
logits = averaged_logits / n_estimators # type: ignore
Comment on lines 940 to +943
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If _iter_forward_executor yields no items (e.g., if n_estimators=0), averaged_logits will remain None. This will cause a TypeError when it's used in the division on lines 941 or 943. The existing type: ignore comments suppress this potential runtime error.

To make the method more robust, I suggest adding a check to ensure averaged_logits is not None before proceeding with the calculation. This will provide a clearer error message if no estimators were run and allows for the removal of the type: ignore comments.

Suggested change
if self.average_before_softmax:
logits = stacked_logits.log().mean(dim=0).softmax(dim=-1)
logits = (averaged_logits / n_estimators).softmax(dim=-1) # type: ignore
else:
logits = stacked_logits.mean(dim=0)
logits = averaged_logits / n_estimators # type: ignore
if averaged_logits is None:
raise ValueError("Cannot make predictions, possibly due to `n_estimators=0`.")
elif self.average_before_softmax:
logits = (averaged_logits / n_estimators).softmax(dim=-1)
else:
logits = averaged_logits / n_estimators


# Post-process the logits
logits = logits.log()
Expand Down Expand Up @@ -982,31 +987,12 @@ def predict(

return logit_to_output(output_type=output_type)

def forward(
def _iter_forward_executor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a typo in the return type hint for _iter_forward_executor. np.ndaarray should be np.ndarray.

    ) -> Iterator[tuple[np.ndarray, torch.Tensor]]:

Copy link
Contributor Author

@poonai poonai Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

self,
X: list[torch.Tensor] | XType,
*,
use_inference_mode: bool = False,
) -> tuple[torch.Tensor | None, list[torch.Tensor], list[np.ndarray]]:
"""Forward pass for TabPFNRegressor Inference Engine.
Used in fine-tuning and prediction. Called directly
in FineTuning training loop or by predict() function
with the use_inference_mode flag explicitly set to True.

Iterates over outputs of InferenceEngine.

Args:
X: list[torch.Tensor] in fine-tuning, XType in normal predictions.
use_inference_mode: Flag for inference mode., default at False since
it is called within predict. During FineTuning forward() is called
directly by user, so default should be False here.

Returns:
A tuple containing:
- Averaged logits over the ensemble (for fine-tuning).
- Raw outputs from each estimator in the ensemble.
- Borders used for each estimator.
"""
) -> Iterator[tuple[np.ndarray, torch.Tensor]]:
# Scenario 1: Standard inference path
is_standard_inference = use_inference_mode and not isinstance(
self.executor_, InferenceEngineBatchedNoPreprocessing
Expand Down Expand Up @@ -1036,18 +1022,12 @@ def forward(
"fine-tuning workflow (requires float32 for backpropagation)."
)

check_is_fitted(self)
# Ensure torch.inference_mode is OFF to allow gradients
if self.fit_mode in ["fit_preprocessors", "batched"]:
# only these two modes support this option
self.executor_.use_torch_inference_mode(use_inference=use_inference_mode)

check_is_fitted(self)

std_borders = self.znorm_space_bardist_.borders.cpu().numpy()
outputs: list[torch.Tensor] = []
borders: list[np.ndarray] = []

# Iterate over estimators
for output, config in self.executor_.iter_outputs(
X, autocast=self.use_autocast_
):
Expand Down Expand Up @@ -1091,19 +1071,49 @@ def forward(
if descending_borders:
borders_t = borders_t.flip(-1) # type: ignore

borders.append(borders_t)

if logit_cancel_mask is not None:
output = output.clone() # noqa: PLW2901
output[..., logit_cancel_mask] = float("-inf")

yield borders_t, output
else:
raise ValueError(
"Unexpected config format "
"and Batch prediction is not supported yet!"
)

outputs.append(output) # type: ignore
def forward(
self,
X: list[torch.Tensor] | XType,
*,
use_inference_mode: bool = False,
) -> tuple[torch.Tensor | None, list[torch.Tensor], list[np.ndarray]]:
"""Forward pass for TabPFNRegressor Inference Engine.
Used in fine-tuning and prediction. Called directly
in FineTuning training loop or by predict() function
with the use_inference_mode flag explicitly set to True.

Iterates over outputs of InferenceEngine.

Args:
X: list[torch.Tensor] in fine-tuning, XType in normal predictions.
use_inference_mode: Flag for inference mode., default at False since
it is called within predict. During FineTuning forward() is called
directly by user, so default should be False here.

Returns:
A tuple containing:
- Averaged logits over the ensemble (for fine-tuning).
- Raw outputs from each estimator in the ensemble.
- Borders used for each estimator.
"""
outputs: list[torch.Tensor] = []
borders: list[np.ndarray] = []

for border, output in self._iter_forward_executor(
X, use_inference_mode=use_inference_mode
):
borders.append(border)
outputs.append(output)

averaged_logits = None
all_logits = None
Expand Down