diff --git a/changelog/745.changed.md b/changelog/745.changed.md new file mode 100644 index 000000000..f48b0f7c0 --- /dev/null +++ b/changelog/745.changed.md @@ -0,0 +1,3 @@ +- Optimize regressor predict method for memory efficiency + - Average ensemble outputs on-the-fly instead of accumulating all outputs + - Reduces memory usage by avoiding storage of all intermediate outputs, especially beneficial for large `n_estimators` diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 29ade4071..779e1fd34 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -20,7 +20,7 @@ import logging import typing import warnings -from collections.abc import Sequence +from collections.abc import Iterator, Sequence from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Annotated, Any, Literal, Union @@ -849,7 +849,7 @@ def predict( @config_context(transform_output="default") # type: ignore @track_model_call(model_method="predict", param_names=["X"]) - def predict( + def predict( # noqa: C901, PLR0912 self, X: XType, *, @@ -913,29 +913,34 @@ def predict( ) # Runs over iteration engine + + n_estimators = 0 + averaged_logits: torch.Tensor | None = None with handle_oom_errors(self.devices_, X, model_type="regressor"): - ( - _, - # list of tensors [N_est, N_samples, N_borders] (after forward) - outputs, - # list of numpy arrays containing borders for each estimator - borders, - ) = self.forward(X, use_inference_mode=True) - - # --- Translate probs, average, get final logits --- - transformed_logits = [ - translate_probs_across_borders( - logits, - frm=torch.as_tensor(borders_t, device=logits.device), - to=self.znorm_space_bardist_.borders.to(logits.device), - ) - for logits, borders_t in zip(outputs, borders) - ] - stacked_logits = torch.stack(transformed_logits, dim=0) + for borders_t, output in self._iter_forward_executor( + X, use_inference_mode=True + ): + # Transform probabilities across borders + transformed = translate_probs_across_borders( + output, + frm=torch.as_tensor(borders_t, device=output.device), + to=self.znorm_space_bardist_.borders.to(output.device), + ) + + if self.average_before_softmax: + transformed = transformed.log() + + if averaged_logits is None: + averaged_logits = transformed + else: + averaged_logits = averaged_logits + transformed + n_estimators += 1 + + # Finalize averaging if self.average_before_softmax: - logits = stacked_logits.log().mean(dim=0).softmax(dim=-1) + logits = (averaged_logits / n_estimators).softmax(dim=-1) # type: ignore else: - logits = stacked_logits.mean(dim=0) + logits = averaged_logits / n_estimators # type: ignore # Post-process the logits logits = logits.log() @@ -982,31 +987,12 @@ def predict( return logit_to_output(output_type=output_type) - def forward( + def _iter_forward_executor( self, X: list[torch.Tensor] | XType, *, use_inference_mode: bool = False, - ) -> tuple[torch.Tensor | None, list[torch.Tensor], list[np.ndarray]]: - """Forward pass for TabPFNRegressor Inference Engine. - Used in fine-tuning and prediction. Called directly - in FineTuning training loop or by predict() function - with the use_inference_mode flag explicitly set to True. - - Iterates over outputs of InferenceEngine. - - Args: - X: list[torch.Tensor] in fine-tuning, XType in normal predictions. - use_inference_mode: Flag for inference mode., default at False since - it is called within predict. During FineTuning forward() is called - directly by user, so default should be False here. - - Returns: - A tuple containing: - - Averaged logits over the ensemble (for fine-tuning). - - Raw outputs from each estimator in the ensemble. - - Borders used for each estimator. - """ + ) -> Iterator[tuple[np.ndarray, torch.Tensor]]: # Scenario 1: Standard inference path is_standard_inference = use_inference_mode and not isinstance( self.executor_, InferenceEngineBatchedNoPreprocessing @@ -1036,18 +1022,12 @@ def forward( "fine-tuning workflow (requires float32 for backpropagation)." ) + check_is_fitted(self) # Ensure torch.inference_mode is OFF to allow gradients if self.fit_mode in ["fit_preprocessors", "batched"]: # only these two modes support this option self.executor_.use_torch_inference_mode(use_inference=use_inference_mode) - - check_is_fitted(self) - std_borders = self.znorm_space_bardist_.borders.cpu().numpy() - outputs: list[torch.Tensor] = [] - borders: list[np.ndarray] = [] - - # Iterate over estimators for output, config in self.executor_.iter_outputs( X, autocast=self.use_autocast_ ): @@ -1091,19 +1071,49 @@ def forward( if descending_borders: borders_t = borders_t.flip(-1) # type: ignore - borders.append(borders_t) - if logit_cancel_mask is not None: output = output.clone() # noqa: PLW2901 output[..., logit_cancel_mask] = float("-inf") - + yield borders_t, output else: raise ValueError( "Unexpected config format " "and Batch prediction is not supported yet!" ) - outputs.append(output) # type: ignore + def forward( + self, + X: list[torch.Tensor] | XType, + *, + use_inference_mode: bool = False, + ) -> tuple[torch.Tensor | None, list[torch.Tensor], list[np.ndarray]]: + """Forward pass for TabPFNRegressor Inference Engine. + Used in fine-tuning and prediction. Called directly + in FineTuning training loop or by predict() function + with the use_inference_mode flag explicitly set to True. + + Iterates over outputs of InferenceEngine. + + Args: + X: list[torch.Tensor] in fine-tuning, XType in normal predictions. + use_inference_mode: Flag for inference mode., default at False since + it is called within predict. During FineTuning forward() is called + directly by user, so default should be False here. + + Returns: + A tuple containing: + - Averaged logits over the ensemble (for fine-tuning). + - Raw outputs from each estimator in the ensemble. + - Borders used for each estimator. + """ + outputs: list[torch.Tensor] = [] + borders: list[np.ndarray] = [] + + for border, output in self._iter_forward_executor( + X, use_inference_mode=use_inference_mode + ): + borders.append(border) + outputs.append(output) averaged_logits = None all_logits = None