-
Notifications
You must be signed in to change notification settings - Fork 558
improved regressor memory usage by 60% #745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
poonai
wants to merge
5
commits into
PriorLabs:main
Choose a base branch
from
poonai:poonai/optimize_regressor
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| - Optimize regressor predict method for memory efficiency | ||
| - Average ensemble outputs on-the-fly instead of accumulating all outputs | ||
| - Reduces memory usage by avoiding storage of all intermediate outputs, especially beneficial for large `n_estimators` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ | |
| import logging | ||
| import typing | ||
| import warnings | ||
| from collections.abc import Sequence | ||
| from collections.abc import Iterator, Sequence | ||
| from functools import partial | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Annotated, Any, Literal, Union | ||
|
|
@@ -849,7 +849,7 @@ def predict( | |
|
|
||
| @config_context(transform_output="default") # type: ignore | ||
| @track_model_call(model_method="predict", param_names=["X"]) | ||
| def predict( | ||
| def predict( # noqa: C901, PLR0912 | ||
| self, | ||
| X: XType, | ||
| *, | ||
|
|
@@ -913,29 +913,34 @@ def predict( | |
| ) | ||
|
|
||
| # Runs over iteration engine | ||
|
|
||
| n_estimators = 0 | ||
| averaged_logits: torch.Tensor | None = None | ||
| with handle_oom_errors(self.devices_, X, model_type="regressor"): | ||
| ( | ||
| _, | ||
| # list of tensors [N_est, N_samples, N_borders] (after forward) | ||
| outputs, | ||
| # list of numpy arrays containing borders for each estimator | ||
| borders, | ||
| ) = self.forward(X, use_inference_mode=True) | ||
|
|
||
| # --- Translate probs, average, get final logits --- | ||
| transformed_logits = [ | ||
| translate_probs_across_borders( | ||
| logits, | ||
| frm=torch.as_tensor(borders_t, device=logits.device), | ||
| to=self.znorm_space_bardist_.borders.to(logits.device), | ||
| ) | ||
| for logits, borders_t in zip(outputs, borders) | ||
| ] | ||
| stacked_logits = torch.stack(transformed_logits, dim=0) | ||
| for borders_t, output in self._iter_forward_executor( | ||
| X, use_inference_mode=True | ||
| ): | ||
| # Transform probabilities across borders | ||
| transformed = translate_probs_across_borders( | ||
| output, | ||
| frm=torch.as_tensor(borders_t, device=output.device), | ||
| to=self.znorm_space_bardist_.borders.to(output.device), | ||
| ) | ||
|
|
||
| if self.average_before_softmax: | ||
| transformed = transformed.log() | ||
|
|
||
| if averaged_logits is None: | ||
| averaged_logits = transformed | ||
| else: | ||
| averaged_logits = averaged_logits + transformed | ||
| n_estimators += 1 | ||
|
|
||
| # Finalize averaging | ||
| if self.average_before_softmax: | ||
| logits = stacked_logits.log().mean(dim=0).softmax(dim=-1) | ||
| logits = (averaged_logits / n_estimators).softmax(dim=-1) # type: ignore | ||
| else: | ||
| logits = stacked_logits.mean(dim=0) | ||
| logits = averaged_logits / n_estimators # type: ignore | ||
|
|
||
| # Post-process the logits | ||
| logits = logits.log() | ||
|
|
@@ -982,31 +987,12 @@ def predict( | |
|
|
||
| return logit_to_output(output_type=output_type) | ||
|
|
||
| def forward( | ||
| def _iter_forward_executor( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. resolved |
||
| self, | ||
| X: list[torch.Tensor] | XType, | ||
| *, | ||
| use_inference_mode: bool = False, | ||
| ) -> tuple[torch.Tensor | None, list[torch.Tensor], list[np.ndarray]]: | ||
| """Forward pass for TabPFNRegressor Inference Engine. | ||
| Used in fine-tuning and prediction. Called directly | ||
| in FineTuning training loop or by predict() function | ||
| with the use_inference_mode flag explicitly set to True. | ||
|
|
||
| Iterates over outputs of InferenceEngine. | ||
|
|
||
| Args: | ||
| X: list[torch.Tensor] in fine-tuning, XType in normal predictions. | ||
| use_inference_mode: Flag for inference mode., default at False since | ||
| it is called within predict. During FineTuning forward() is called | ||
| directly by user, so default should be False here. | ||
|
|
||
| Returns: | ||
| A tuple containing: | ||
| - Averaged logits over the ensemble (for fine-tuning). | ||
| - Raw outputs from each estimator in the ensemble. | ||
| - Borders used for each estimator. | ||
| """ | ||
| ) -> Iterator[tuple[np.ndarray, torch.Tensor]]: | ||
| # Scenario 1: Standard inference path | ||
| is_standard_inference = use_inference_mode and not isinstance( | ||
| self.executor_, InferenceEngineBatchedNoPreprocessing | ||
|
|
@@ -1036,18 +1022,12 @@ def forward( | |
| "fine-tuning workflow (requires float32 for backpropagation)." | ||
| ) | ||
|
|
||
| check_is_fitted(self) | ||
| # Ensure torch.inference_mode is OFF to allow gradients | ||
| if self.fit_mode in ["fit_preprocessors", "batched"]: | ||
| # only these two modes support this option | ||
| self.executor_.use_torch_inference_mode(use_inference=use_inference_mode) | ||
|
|
||
| check_is_fitted(self) | ||
|
|
||
| std_borders = self.znorm_space_bardist_.borders.cpu().numpy() | ||
| outputs: list[torch.Tensor] = [] | ||
| borders: list[np.ndarray] = [] | ||
|
|
||
| # Iterate over estimators | ||
| for output, config in self.executor_.iter_outputs( | ||
| X, autocast=self.use_autocast_ | ||
| ): | ||
|
|
@@ -1091,19 +1071,49 @@ def forward( | |
| if descending_borders: | ||
| borders_t = borders_t.flip(-1) # type: ignore | ||
|
|
||
| borders.append(borders_t) | ||
|
|
||
| if logit_cancel_mask is not None: | ||
| output = output.clone() # noqa: PLW2901 | ||
| output[..., logit_cancel_mask] = float("-inf") | ||
|
|
||
| yield borders_t, output | ||
| else: | ||
| raise ValueError( | ||
| "Unexpected config format " | ||
| "and Batch prediction is not supported yet!" | ||
| ) | ||
|
|
||
| outputs.append(output) # type: ignore | ||
| def forward( | ||
| self, | ||
| X: list[torch.Tensor] | XType, | ||
| *, | ||
| use_inference_mode: bool = False, | ||
| ) -> tuple[torch.Tensor | None, list[torch.Tensor], list[np.ndarray]]: | ||
| """Forward pass for TabPFNRegressor Inference Engine. | ||
| Used in fine-tuning and prediction. Called directly | ||
| in FineTuning training loop or by predict() function | ||
| with the use_inference_mode flag explicitly set to True. | ||
|
|
||
| Iterates over outputs of InferenceEngine. | ||
|
|
||
| Args: | ||
| X: list[torch.Tensor] in fine-tuning, XType in normal predictions. | ||
| use_inference_mode: Flag for inference mode., default at False since | ||
| it is called within predict. During FineTuning forward() is called | ||
| directly by user, so default should be False here. | ||
|
|
||
| Returns: | ||
| A tuple containing: | ||
| - Averaged logits over the ensemble (for fine-tuning). | ||
| - Raw outputs from each estimator in the ensemble. | ||
| - Borders used for each estimator. | ||
| """ | ||
| outputs: list[torch.Tensor] = [] | ||
| borders: list[np.ndarray] = [] | ||
|
|
||
| for border, output in self._iter_forward_executor( | ||
| X, use_inference_mode=use_inference_mode | ||
| ): | ||
| borders.append(border) | ||
| outputs.append(output) | ||
|
|
||
| averaged_logits = None | ||
| all_logits = None | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If
_iter_forward_executoryields no items (e.g., ifn_estimators=0),averaged_logitswill remainNone. This will cause aTypeErrorwhen it's used in the division on lines 941 or 943. The existingtype: ignorecomments suppress this potential runtime error.To make the method more robust, I suggest adding a check to ensure
averaged_logitsis notNonebefore proceeding with the calculation. This will provide a clearer error message if no estimators were run and allows for the removal of thetype: ignorecomments.