Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 157 additions & 97 deletions src/tabpfn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from __future__ import annotations

import contextlib
from abc import ABC, abstractmethod
from collections.abc import Iterator, Sequence
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Literal
from typing_extensions import override
Expand All @@ -18,6 +18,7 @@
import torch

from tabpfn.architectures.base.memory import MemoryUsageEstimator
from tabpfn.parallel_execute import parallel_execute
from tabpfn.preprocessing import fit_preprocessing
from tabpfn.utils import get_autocast_context

Expand Down Expand Up @@ -198,9 +199,6 @@ def iter_outputs(
autocast: bool,
only_return_standard_out: bool = True,
) -> Iterator[tuple[torch.Tensor | dict, EnsembleConfig]]:
# This engine currently only supports one device, so just take the first.
device = devices[0]

rng = np.random.default_rng(self.static_seed)
itr = fit_preprocessing(
configs=self.ensemble_configs,
Expand All @@ -212,50 +210,72 @@ def iter_outputs(
parallel_mode="as-ready",
)

self.model = self.model.to(device)
if self.force_inference_dtype is not None:
self.model = self.model.type(self.force_inference_dtype)

for config, preprocessor, X_train, y_train, cat_ix in itr:
X_train = torch.as_tensor(X_train, dtype=torch.float32, device=device) # noqa: PLW2901

X_test = preprocessor.transform(X).X
X_test = torch.as_tensor(X_test, dtype=torch.float32, device=device)
self.model.type(self.force_inference_dtype)

model_forward_functions = (
partial(
self._call_model,
X_train=X_train,
X_test=preprocessor.transform(X).X,
y_train=y_train,
cat_ix=cat_ix,
only_return_standard_out=only_return_standard_out,
autocast=autocast,
)
for _, preprocessor, X_train, y_train, cat_ix in itr
)
outputs = parallel_execute(devices, model_forward_functions)

X_full = torch.cat([X_train, X_test], dim=0).unsqueeze(1)
batched_cat_ix = [cat_ix]
y_train = torch.as_tensor(y_train, dtype=torch.float32, device=device) # type: ignore # noqa: PLW2901
for config, output in zip(self.ensemble_configs, outputs):
yield _move_and_squeeze_output(output, devices[0]), config

MemoryUsageEstimator.reset_peak_memory_if_required(
save_peak_mem=self.save_peak_mem,
model=self.model,
X=X_full,
cache_kv=False,
dtype_byte_size=self.dtype_byte_size,
device=device,
safety_factor=1.2, # TODO(Arjun): make customizable
)
self.model.cpu()

if self.force_inference_dtype is not None:
X_full = X_full.type(self.force_inference_dtype)
y_train = y_train.type(self.force_inference_dtype) # type: ignore # noqa: PLW2901
def _call_model(
self,
*,
device: torch.device,
is_parallel: bool,
X_train: torch.Tensor | np.ndarray,
X_test: torch.Tensor | np.ndarray,
y_train: torch.Tensor | np.ndarray,
cat_ix: list[int],
autocast: bool,
only_return_standard_out: bool,
) -> torch.Tensor | dict[str, torch.Tensor]:
"""Execute a model forward pass on the provided device.

with (
get_autocast_context(device, enabled=autocast),
torch.inference_mode(),
):
output = self.model(
X_full,
y_train,
only_return_standard_out=only_return_standard_out,
categorical_inds=batched_cat_ix,
)
Note that several instances of this function may be executed in parallel in
different threads, one for each device in the system.
"""
# If several estimators are being run in parallel, then each thread needs its
# own copy of the model so it can move it to its device.
model = deepcopy(self.model) if is_parallel else self.model
model.to(device)

output = output if isinstance(output, dict) else output.squeeze(1)
X_full, y_train = _prepare_model_inputs(
device, self.force_inference_dtype, X_train, X_test, y_train
)
batched_cat_ix = [cat_ix]

yield output, config
MemoryUsageEstimator.reset_peak_memory_if_required(
save_peak_mem=self.save_peak_mem,
model=model,
X=X_full,
cache_kv=False,
dtype_byte_size=self.dtype_byte_size,
device=device,
safety_factor=1.2,
)

self.model = self.model.cpu()
with get_autocast_context(device, enabled=autocast), torch.inference_mode():
return model(
X_full,
y_train,
only_return_standard_out=only_return_standard_out,
categorical_inds=batched_cat_ix,
)


@dataclass
Expand Down Expand Up @@ -458,69 +478,86 @@ def iter_outputs(
autocast: bool,
only_return_standard_out: bool = True,
) -> Iterator[tuple[torch.Tensor | dict, EnsembleConfig]]:
# This engine currently only supports one device, so just take the first.
device = devices[0]

self.model = self.model.to(device)
if self.force_inference_dtype is not None:
self.model = self.model.type(self.force_inference_dtype)
for preprocessor, X_train, y_train, config, cat_ix in zip(
self.preprocessors,
self.X_trains,
self.y_trains,
self.ensemble_configs,
self.cat_ixs,
):
if not isinstance(X_train, torch.Tensor):
X_train = torch.as_tensor(X_train, dtype=torch.float32) # noqa: PLW2901
X_train = X_train.to(device) # noqa: PLW2901
X_test = preprocessor.transform(X).X if not self.no_preprocessing else X
if not isinstance(X_test, torch.Tensor):
X_test = torch.as_tensor(X_test, dtype=torch.float32)
X_test = X_test.to(device)
X_full = torch.cat([X_train, X_test], dim=0).unsqueeze(1)
if not isinstance(y_train, torch.Tensor):
y_train = torch.as_tensor(y_train, dtype=torch.float32) # noqa: PLW2901
y_train = y_train.to(device) # noqa: PLW2901
self.model.type(self.force_inference_dtype)

batched_cat_ix = [cat_ix]
if self.no_preprocessing:
X_tests = (X for _ in range(len(self.ensemble_configs)))
else:
X_tests = (
preprocessor.transform(X).X for preprocessor in self.preprocessors
)

# Handle type casting
with contextlib.suppress(Exception): # Avoid overflow error
X_full = X_full.float()
if self.force_inference_dtype is not None:
X_full = X_full.type(self.force_inference_dtype)
y_train = y_train.type(self.force_inference_dtype) # type: ignore # noqa: PLW2901

if self.inference_mode:
MemoryUsageEstimator.reset_peak_memory_if_required(
save_peak_mem=self.save_peak_mem,
model=self.model,
X=X_full,
cache_kv=False,
device=device,
dtype_byte_size=self.dtype_byte_size,
safety_factor=1.2, # TODO(Arjun): make customizable
)
else:
pass
model_forward_functions = (
partial(
self._call_model,
X_train=X_train,
X_test=X_test,
y_train=y_train,
cat_ix=cat_ix,
autocast=autocast,
only_return_standard_out=only_return_standard_out,
)
for X_train, X_test, y_train, cat_ix in zip(
self.X_trains, X_tests, self.y_trains, self.cat_ixs
)
)
outputs = parallel_execute(devices, model_forward_functions)

with (
get_autocast_context(device, enabled=autocast),
torch.inference_mode(self.inference_mode),
):
output = self.model(
X_full,
y_train,
only_return_standard_out=only_return_standard_out,
categorical_inds=batched_cat_ix,
)
for output, config in zip(outputs, self.ensemble_configs):
yield _move_and_squeeze_output(output, devices[0]), config

output = output if isinstance(output, dict) else output.squeeze(1)
if self.inference_mode:
self.model.cpu()

yield output, config
if self.inference_mode: ## if inference
self.model = self.model.cpu()
def _call_model(
self,
*,
device: torch.device,
is_parallel: bool,
X_train: torch.Tensor | np.ndarray,
X_test: torch.Tensor | np.ndarray,
y_train: torch.Tensor | np.ndarray,
cat_ix: list[int],
autocast: bool,
only_return_standard_out: bool,
) -> torch.Tensor | dict[str, torch.Tensor]:
"""Execute a model forward pass on the provided device.

Note that several instances of this function may be executed in parallel in
different threads, one for each device in the system.
"""
# If several estimators are being run in parallel, then each thread needs its
# own copy of the model so it can move it to its device.
model = deepcopy(self.model) if is_parallel else self.model
model.to(device)

X_full, y_train = _prepare_model_inputs(
device, self.force_inference_dtype, X_train, X_test, y_train
)
batched_cat_ix = [cat_ix]

if self.inference_mode:
MemoryUsageEstimator.reset_peak_memory_if_required(
save_peak_mem=self.save_peak_mem,
model=model,
X=X_full,
cache_kv=False,
device=device,
dtype_byte_size=self.dtype_byte_size,
safety_factor=1.2,
)

with (
get_autocast_context(device, enabled=autocast),
torch.inference_mode(self.inference_mode),
):
return model(
X_full,
y_train,
only_return_standard_out=only_return_standard_out,
categorical_inds=batched_cat_ix,
)

@override
def use_torch_inference_mode(self, *, use_inference: bool) -> None:
Expand Down Expand Up @@ -701,3 +738,26 @@ def iter_outputs(
output = output if isinstance(output, dict) else output.squeeze(1)

yield output, config


def _prepare_model_inputs(
device: torch.device,
force_inference_dtype: torch.dtype | None,
X_train: torch.Tensor | np.ndarray,
X_test: torch.Tensor | np.ndarray,
y_train: torch.Tensor | np.ndarray,
) -> tuple[torch.Tensor, torch.Tensor]:
dtype = force_inference_dtype if force_inference_dtype else torch.float32
X_train = torch.as_tensor(X_train, dtype=dtype, device=device)
X_test = torch.as_tensor(X_test, dtype=dtype, device=device)
X_full = torch.cat([X_train, X_test], dim=0).unsqueeze(1)
y_train = torch.as_tensor(y_train, dtype=dtype, device=device)
return X_full, y_train


def _move_and_squeeze_output(
output: dict | torch.Tensor, device: torch.device
) -> dict[str, torch.Tensor] | torch.Tensor:
if isinstance(output, dict):
return {k: v.to(device) for k, v in output.items()}
return output.squeeze(1).to(device)
Loading
Loading