Automatic Batching option in predict() functions#775
Automatic Batching option in predict() functions#775klemens-floege wants to merge 6 commits intomainfrom
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
This pull request introduces a batch_size_predict parameter to both TabPFNClassifier and TabPFNRegressor, allowing for batched predictions on large test sets to prevent out-of-memory errors. The implementation is solid, with new methods for batched prediction and updated error messages to guide users. The accompanying tests are comprehensive and cover the new functionality well.
I have one suggestion for a minor refactoring in src/tabpfn/regressor.py to reduce code duplication, which will improve maintainability.
I am having trouble creating individual review comments. Click here to see my feedback.
src/tabpfn/regressor.py (999-1070)
There's some code duplication in _batched_predict. The loop for iterating over batches and calling self.forward is repeated for the quantiles case and the mean/median/mode case.
To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, you could refactor this by extracting the batch iteration and logit processing into a local helper generator function. This would remove the duplication while keeping the logic encapsulated within _batched_predict.
def _batched_predict(
self,
X: XType,
*,
output_type: OutputType,
quantiles: list[float],
) -> RegressionResultType:
"""Run batched prediction to avoid OOM on large test sets.
Args:
X: The input data for prediction (already preprocessed).
output_type: The type of output to return.
quantiles: The quantiles to compute if output_type is "quantiles".
Returns:
The concatenated predictions from all batches.
Raises:
TabPFNValidationError: If output_type is "full" or "main" with batching,
as these return complex structures that can't be easily batched.
"""
if output_type in ["full", "main"]:
raise TabPFNValidationError(
f"output_type='{output_type}' is not supported with "
f"batch_size_predict. Use 'mean', 'median', 'mode', or 'quantiles' "
f"instead, or set batch_size_predict=None."
)
n_samples = X.shape[0] if hasattr(X, "shape") else len(X)
def _iter_logits():
"""Yields logits for each batch of X."""
for start in range(0, n_samples, self.batch_size_predict):
end = min(start + self.batch_size_predict, n_samples)
X_batch = X[start:end]
with handle_oom_errors(self.devices_, X_batch, model_type="regressor"):
(_, outputs, borders) = self.forward(
X_batch, use_inference_mode=True
)
yield self._process_forward_outputs(outputs, borders)
if output_type == "quantiles":
# For quantiles, we need to collect results per quantile
all_quantile_results: list[list[np.ndarray]] = [[] for _ in quantiles]
for logits in _iter_logits():
for i, q in enumerate(quantiles):
q_result = (
self.raw_space_bardist_.icdf(logits, q).cpu().detach().numpy()
)
all_quantile_results[i].append(q_result)
return [np.concatenate(q_results) for q_results in all_quantile_results]
# For mean, median, mode
results = []
for logits in _iter_logits():
batch_result = _logits_to_output(
output_type=output_type,
logits=logits,
criterion=self.raw_space_bardist_,
quantiles=quantiles,
)
results.append(batch_result)
return np.concatenate(results)There was a problem hiding this comment.
Thanks a lot for looking into this, however, the code looks a bit vibe-coded?
- set_multiquery_item_attention seems liek a large complex code piece that doesn't make sense for our transformer (its already independent)
- there seem to be a few of these large code changes
Can you go through the code and make sure it's actually all correct please? Or in case I'm missing something let me know and I'll take a deeper look
Ohh also just seeing it doesn't pass tests - I assume this one isn't ready then. Please add "WIP" -> "WIP: Automatic Batching option in predict() functions" to the PR so it doesn't get reviewed.
Summary
Changes
Usage