diff --git a/src/tabpfn/inference.py b/src/tabpfn/inference.py index 2b3e7a5be..04f5c7308 100644 --- a/src/tabpfn/inference.py +++ b/src/tabpfn/inference.py @@ -4,11 +4,11 @@ from __future__ import annotations -import contextlib from abc import ABC, abstractmethod from collections.abc import Iterator, Sequence from copy import deepcopy from dataclasses import dataclass +from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Literal from typing_extensions import override @@ -18,6 +18,7 @@ import torch from tabpfn.architectures.base.memory import MemoryUsageEstimator +from tabpfn.parallel_execute import parallel_execute from tabpfn.preprocessing import fit_preprocessing from tabpfn.utils import get_autocast_context @@ -198,9 +199,6 @@ def iter_outputs( autocast: bool, only_return_standard_out: bool = True, ) -> Iterator[tuple[torch.Tensor | dict, EnsembleConfig]]: - # This engine currently only supports one device, so just take the first. - device = devices[0] - rng = np.random.default_rng(self.static_seed) itr = fit_preprocessing( configs=self.ensemble_configs, @@ -212,50 +210,72 @@ def iter_outputs( parallel_mode="as-ready", ) - self.model = self.model.to(device) if self.force_inference_dtype is not None: - self.model = self.model.type(self.force_inference_dtype) - - for config, preprocessor, X_train, y_train, cat_ix in itr: - X_train = torch.as_tensor(X_train, dtype=torch.float32, device=device) # noqa: PLW2901 - - X_test = preprocessor.transform(X).X - X_test = torch.as_tensor(X_test, dtype=torch.float32, device=device) + self.model.type(self.force_inference_dtype) + + model_forward_functions = ( + partial( + self._call_model, + X_train=X_train, + X_test=preprocessor.transform(X).X, + y_train=y_train, + cat_ix=cat_ix, + only_return_standard_out=only_return_standard_out, + autocast=autocast, + ) + for _, preprocessor, X_train, y_train, cat_ix in itr + ) + outputs = parallel_execute(devices, model_forward_functions) - X_full = torch.cat([X_train, X_test], dim=0).unsqueeze(1) - batched_cat_ix = [cat_ix] - y_train = torch.as_tensor(y_train, dtype=torch.float32, device=device) # type: ignore # noqa: PLW2901 + for config, output in zip(self.ensemble_configs, outputs): + yield _move_and_squeeze_output(output, devices[0]), config - MemoryUsageEstimator.reset_peak_memory_if_required( - save_peak_mem=self.save_peak_mem, - model=self.model, - X=X_full, - cache_kv=False, - dtype_byte_size=self.dtype_byte_size, - device=device, - safety_factor=1.2, # TODO(Arjun): make customizable - ) + self.model.cpu() - if self.force_inference_dtype is not None: - X_full = X_full.type(self.force_inference_dtype) - y_train = y_train.type(self.force_inference_dtype) # type: ignore # noqa: PLW2901 + def _call_model( + self, + *, + device: torch.device, + is_parallel: bool, + X_train: torch.Tensor | np.ndarray, + X_test: torch.Tensor | np.ndarray, + y_train: torch.Tensor | np.ndarray, + cat_ix: list[int], + autocast: bool, + only_return_standard_out: bool, + ) -> torch.Tensor | dict[str, torch.Tensor]: + """Execute a model forward pass on the provided device. - with ( - get_autocast_context(device, enabled=autocast), - torch.inference_mode(), - ): - output = self.model( - X_full, - y_train, - only_return_standard_out=only_return_standard_out, - categorical_inds=batched_cat_ix, - ) + Note that several instances of this function may be executed in parallel in + different threads, one for each device in the system. + """ + # If several estimators are being run in parallel, then each thread needs its + # own copy of the model so it can move it to its device. + model = deepcopy(self.model) if is_parallel else self.model + model.to(device) - output = output if isinstance(output, dict) else output.squeeze(1) + X_full, y_train = _prepare_model_inputs( + device, self.force_inference_dtype, X_train, X_test, y_train + ) + batched_cat_ix = [cat_ix] - yield output, config + MemoryUsageEstimator.reset_peak_memory_if_required( + save_peak_mem=self.save_peak_mem, + model=model, + X=X_full, + cache_kv=False, + dtype_byte_size=self.dtype_byte_size, + device=device, + safety_factor=1.2, + ) - self.model = self.model.cpu() + with get_autocast_context(device, enabled=autocast), torch.inference_mode(): + return model( + X_full, + y_train, + only_return_standard_out=only_return_standard_out, + categorical_inds=batched_cat_ix, + ) @dataclass @@ -458,69 +478,86 @@ def iter_outputs( autocast: bool, only_return_standard_out: bool = True, ) -> Iterator[tuple[torch.Tensor | dict, EnsembleConfig]]: - # This engine currently only supports one device, so just take the first. - device = devices[0] - - self.model = self.model.to(device) if self.force_inference_dtype is not None: - self.model = self.model.type(self.force_inference_dtype) - for preprocessor, X_train, y_train, config, cat_ix in zip( - self.preprocessors, - self.X_trains, - self.y_trains, - self.ensemble_configs, - self.cat_ixs, - ): - if not isinstance(X_train, torch.Tensor): - X_train = torch.as_tensor(X_train, dtype=torch.float32) # noqa: PLW2901 - X_train = X_train.to(device) # noqa: PLW2901 - X_test = preprocessor.transform(X).X if not self.no_preprocessing else X - if not isinstance(X_test, torch.Tensor): - X_test = torch.as_tensor(X_test, dtype=torch.float32) - X_test = X_test.to(device) - X_full = torch.cat([X_train, X_test], dim=0).unsqueeze(1) - if not isinstance(y_train, torch.Tensor): - y_train = torch.as_tensor(y_train, dtype=torch.float32) # noqa: PLW2901 - y_train = y_train.to(device) # noqa: PLW2901 + self.model.type(self.force_inference_dtype) - batched_cat_ix = [cat_ix] + if self.no_preprocessing: + X_tests = (X for _ in range(len(self.ensemble_configs))) + else: + X_tests = ( + preprocessor.transform(X).X for preprocessor in self.preprocessors + ) - # Handle type casting - with contextlib.suppress(Exception): # Avoid overflow error - X_full = X_full.float() - if self.force_inference_dtype is not None: - X_full = X_full.type(self.force_inference_dtype) - y_train = y_train.type(self.force_inference_dtype) # type: ignore # noqa: PLW2901 - - if self.inference_mode: - MemoryUsageEstimator.reset_peak_memory_if_required( - save_peak_mem=self.save_peak_mem, - model=self.model, - X=X_full, - cache_kv=False, - device=device, - dtype_byte_size=self.dtype_byte_size, - safety_factor=1.2, # TODO(Arjun): make customizable - ) - else: - pass + model_forward_functions = ( + partial( + self._call_model, + X_train=X_train, + X_test=X_test, + y_train=y_train, + cat_ix=cat_ix, + autocast=autocast, + only_return_standard_out=only_return_standard_out, + ) + for X_train, X_test, y_train, cat_ix in zip( + self.X_trains, X_tests, self.y_trains, self.cat_ixs + ) + ) + outputs = parallel_execute(devices, model_forward_functions) - with ( - get_autocast_context(device, enabled=autocast), - torch.inference_mode(self.inference_mode), - ): - output = self.model( - X_full, - y_train, - only_return_standard_out=only_return_standard_out, - categorical_inds=batched_cat_ix, - ) + for output, config in zip(outputs, self.ensemble_configs): + yield _move_and_squeeze_output(output, devices[0]), config - output = output if isinstance(output, dict) else output.squeeze(1) + if self.inference_mode: + self.model.cpu() - yield output, config - if self.inference_mode: ## if inference - self.model = self.model.cpu() + def _call_model( + self, + *, + device: torch.device, + is_parallel: bool, + X_train: torch.Tensor | np.ndarray, + X_test: torch.Tensor | np.ndarray, + y_train: torch.Tensor | np.ndarray, + cat_ix: list[int], + autocast: bool, + only_return_standard_out: bool, + ) -> torch.Tensor | dict[str, torch.Tensor]: + """Execute a model forward pass on the provided device. + + Note that several instances of this function may be executed in parallel in + different threads, one for each device in the system. + """ + # If several estimators are being run in parallel, then each thread needs its + # own copy of the model so it can move it to its device. + model = deepcopy(self.model) if is_parallel else self.model + model.to(device) + + X_full, y_train = _prepare_model_inputs( + device, self.force_inference_dtype, X_train, X_test, y_train + ) + batched_cat_ix = [cat_ix] + + if self.inference_mode: + MemoryUsageEstimator.reset_peak_memory_if_required( + save_peak_mem=self.save_peak_mem, + model=model, + X=X_full, + cache_kv=False, + device=device, + dtype_byte_size=self.dtype_byte_size, + safety_factor=1.2, + ) + + with ( + get_autocast_context(device, enabled=autocast), + torch.inference_mode(self.inference_mode), + ): + return model( + X_full, + y_train, + only_return_standard_out=only_return_standard_out, + categorical_inds=batched_cat_ix, + ) @override def use_torch_inference_mode(self, *, use_inference: bool) -> None: @@ -701,3 +738,26 @@ def iter_outputs( output = output if isinstance(output, dict) else output.squeeze(1) yield output, config + + +def _prepare_model_inputs( + device: torch.device, + force_inference_dtype: torch.dtype | None, + X_train: torch.Tensor | np.ndarray, + X_test: torch.Tensor | np.ndarray, + y_train: torch.Tensor | np.ndarray, +) -> tuple[torch.Tensor, torch.Tensor]: + dtype = force_inference_dtype if force_inference_dtype else torch.float32 + X_train = torch.as_tensor(X_train, dtype=dtype, device=device) + X_test = torch.as_tensor(X_test, dtype=dtype, device=device) + X_full = torch.cat([X_train, X_test], dim=0).unsqueeze(1) + y_train = torch.as_tensor(y_train, dtype=dtype, device=device) + return X_full, y_train + + +def _move_and_squeeze_output( + output: dict | torch.Tensor, device: torch.device +) -> dict[str, torch.Tensor] | torch.Tensor: + if isinstance(output, dict): + return {k: v.to(device) for k, v in output.items()} + return output.squeeze(1).to(device) diff --git a/src/tabpfn/parallel_execute.py b/src/tabpfn/parallel_execute.py new file mode 100644 index 000000000..b01050b5b --- /dev/null +++ b/src/tabpfn/parallel_execute.py @@ -0,0 +1,107 @@ +"""Parallel evaluation of a set of functions across multiple PyTorch devices.""" + +from __future__ import annotations + +import queue +from collections.abc import Generator, Iterable, Sequence +from multiprocessing.pool import ThreadPool +from typing import Generic, Protocol, TypeVar + +import torch + +R_co = TypeVar("R_co", covariant=True) + + +class ParallelFunction(Protocol, Generic[R_co]): + """Interface that functions submitted to `parallel_execute()` should implement.""" + + def __call__(self, *, device: torch.device, is_parallel: bool) -> R_co: + """Execute the function. + + Args: + device: PyTorch device that all computation should be performed on. + is_parallel: Indicates whether this function is being executed in parallel + with other functions. If True, then the function should take care to + copy any state shared with other functions before mutating it. For + example, any nn.Modules should be deep copied before moving them to + `device`. If False, then copying can be avoided to reduce overhead. + + Returns: + Any desired value. Any Tensors in the returned value can be on any device. + """ + ... + + +def parallel_execute( + devices: Sequence[torch.device], + functions: Iterable[ParallelFunction[R_co]], +) -> Generator[R_co]: + """Evaluate the given functions in parallel across `devices`. + + The function evaluations are parallelised using Python threads, so this will only + result in a speed-up if the functions do not hold the global interpreter lock. It + works well for functions that spend most of their time executing GPU kernels. + + If only one device is provided, then the functions are executed in the current + thread to reduce overhead. + + Args: + devices: The devices to use for evaluation. + functions: The functions to evaluate following the `ParallelFunction` protocol. + + Returns: + A generator consisting of the return values of the functions, in the same order + as `functions`. + """ + if len(devices) == 1: + # If we only have one device then just use the current thread to avoid overhead. + yield from _execute_in_current_thread(devices[0], functions) + else: + yield from _execute_with_multithreading(devices, functions) + + +def _execute_in_current_thread( + device: torch.device, functions: Iterable[ParallelFunction[R_co]] +) -> Generator[R_co]: + for function in functions: + yield function(device=device, is_parallel=False) + + +def _execute_with_multithreading( + devices: Sequence[torch.device], + functions: Iterable[ParallelFunction[R_co]], +) -> Generator[R_co]: + free_devices: queue.Queue[int] = queue.Queue(maxsize=len(devices)) + for device_index, _ in enumerate(devices): + free_devices.put(device_index, block=False) + + with ThreadPool(processes=len(devices)) as pool: + async_results = [ + pool.apply_async(_execute_function_in_thread, (devices, free_devices, func)) + for func in functions + ] + for async_result in async_results: + yield async_result.get() + + +def _execute_function_in_thread( + all_devices: Sequence[torch.device], + free_devices: queue.Queue[int], + function: ParallelFunction[R_co], +) -> R_co: + device_index = free_devices.get(block=True) + try: + device = all_devices[device_index] + if device.type == "cuda": + # We use a separate stream per thread so that threads can execute kernels in + # parallel. + with ( + torch.cuda.stream(torch.cuda.Stream(device)), + torch.cuda.device(device), + ): + return function(device=device, is_parallel=True) + # Theoretically it is possible to parallelise over classes of device other than + # GPUs, but mainly this is useful for unit testing with multiple CPU devices. + return function(device=device, is_parallel=True) + finally: + free_devices.put(device_index) diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 000000000..00af57a30 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,157 @@ +"""Test the inference engines.""" + +from __future__ import annotations + +from typing import Literal, overload +from typing_extensions import override + +import torch +from numpy.random import default_rng +from torch import Tensor + +from tabpfn.architectures.interface import Architecture +from tabpfn.inference import InferenceEngineCachePreprocessing, InferenceEngineOnDemand +from tabpfn.preprocessing import ( + ClassifierEnsembleConfig, + PreprocessorConfig, +) + + +class TestModel(Architecture): + @overload + def forward( + self, + x: Tensor | dict[str, Tensor], + y: Tensor | dict[str, Tensor] | None, + *, + only_return_standard_out: Literal[True] = True, + categorical_inds: list[list[int]] | None = None, + ) -> Tensor: ... + + @overload + def forward( + self, + x: Tensor | dict[str, Tensor], + y: Tensor | dict[str, Tensor] | None, + *, + only_return_standard_out: Literal[False], + categorical_inds: list[list[int]] | None = None, + ) -> dict[str, Tensor]: ... + + @override + def forward( + self, + x: Tensor | dict[str, Tensor], + y: Tensor | dict[str, Tensor] | None, + *, + only_return_standard_out: bool = True, + categorical_inds: list[list[int]] | None = None, + ) -> Tensor | dict[str, Tensor]: + return torch.zeros(size=(10, 1, 10)) + + @property + def ninp(self) -> int: + return 2 + + @property + def features_per_group(self) -> int: + return 2 + + def reset_save_peak_mem_factor(self, factor: int | None = None) -> None: + pass + + +def test__cache_preprocessing__result_equal_in_serial_and_in_parallel() -> None: + rng = default_rng(seed=0) + n_train = 10 + X_train = rng.standard_normal(size=(n_train, 1, 2)) + y_train = rng.standard_normal(size=(n_train, 1)) + X_test = rng.standard_normal(size=(2, 1, 2)) + + ensemble_config = ClassifierEnsembleConfig( + preprocess_config=PreprocessorConfig(name="power", categorical_name="none"), + add_fingerprint_feature=False, + polynomial_features="no", + feature_shift_count=0, + feature_shift_decoder="shuffle", + subsample_ix=None, + class_permutation=None, + ) + engine = InferenceEngineCachePreprocessing.prepare( + X_train, + y_train, + cat_ix=[0] * n_train, + model=TestModel(), + ensemble_configs=[ensemble_config] * 2, + n_workers=0, + rng=rng, + dtype_byte_size=4, + force_inference_dtype=None, + save_peak_mem=True, + inference_mode=True, + ) + + outputs_sequential = list( + engine.iter_outputs(X_test, devices=[torch.device("cpu")], autocast=False) + ) + outputs_parallel = list( + engine.iter_outputs( + X_test, devices=[torch.device("cpu"), torch.device("cpu")], autocast=False + ) + ) + assert len(outputs_sequential) == len(outputs_parallel) + for (seq_output, seq_config), (par_output, par_config) in zip( + outputs_sequential, outputs_parallel + ): + assert isinstance(seq_output, Tensor) + assert isinstance(par_output, Tensor) + assert torch.allclose(seq_output, par_output) + assert seq_config == par_config + + +def test__on_demand__result_equal_in_serial_and_in_parallel() -> None: + rng = default_rng(seed=0) + n_train = 10 + n_estimators = 5 + X_train = rng.standard_normal(size=(n_train, 1, 2)) + y_train = rng.standard_normal(size=(n_train, 1)) + X_test = rng.standard_normal(size=(2, 1, 2)) + + ensemble_config = ClassifierEnsembleConfig( + preprocess_config=PreprocessorConfig(name="power", categorical_name="none"), + add_fingerprint_feature=False, + polynomial_features="no", + feature_shift_count=0, + feature_shift_decoder="shuffle", + subsample_ix=None, + class_permutation=None, + ) + engine = InferenceEngineOnDemand.prepare( + X_train, + y_train, + cat_ix=[0] * n_train, + model=TestModel(), + ensemble_configs=[ensemble_config] * n_estimators, + n_workers=0, + rng=rng, + dtype_byte_size=4, + force_inference_dtype=None, + save_peak_mem=True, + ) + + outputs_sequential = list( + engine.iter_outputs(X_test, devices=[torch.device("cpu")], autocast=False) + ) + outputs_parallel = list( + engine.iter_outputs( + X_test, devices=[torch.device("cpu"), torch.device("cpu")], autocast=False + ) + ) + assert len(outputs_sequential) == len(outputs_parallel) + for (seq_output, seq_config), (par_output, par_config) in zip( + outputs_sequential, outputs_parallel + ): + assert isinstance(seq_output, Tensor) + assert isinstance(par_output, Tensor) + assert torch.allclose(seq_output, par_output) + assert seq_config == par_config diff --git a/tests/test_parallel_execute.py b/tests/test_parallel_execute.py new file mode 100644 index 000000000..d3abbf581 --- /dev/null +++ b/tests/test_parallel_execute.py @@ -0,0 +1,92 @@ +"""Tests for tabpfn.parallel_execute.""" + +from __future__ import annotations + +import threading + +import torch + +from tabpfn.parallel_execute import parallel_execute + + +def test__parallel_execute__single_device__executes_in_current_thread() -> None: + def test_function(device: torch.device, is_parallel: bool) -> int: # noqa: ARG001 + return threading.get_ident() + + thread_ids = parallel_execute( + devices=[torch.device("cpu")], functions=[test_function, test_function] + ) + + current_thread_id = threading.get_ident() + assert list(thread_ids) == [current_thread_id, current_thread_id] + + +def test__parallel_execute__single_device__sets_is_parallel_to_False() -> None: + def test_function(device: torch.device, is_parallel: bool) -> bool: # noqa: ARG001 + return is_parallel + + is_parallels = parallel_execute( + devices=[torch.device("cpu")], functions=[test_function, test_function] + ) + + assert list(is_parallels) == [False, False] + + +def test__parallel_execute__single_device__results_in_same_order_as_functions() -> None: + def a(device: torch.device, is_parallel: bool) -> str: # noqa: ARG001 + return "a" + + def b(device: torch.device, is_parallel: bool) -> str: # noqa: ARG001 + return "b" + + def c(device: torch.device, is_parallel: bool) -> str: # noqa: ARG001 + return "c" + + results = parallel_execute(devices=[torch.device("cpu")], functions=[a, b, c]) + + assert list(results) == ["a", "b", "c"] + + +def test__parallel_execute__multiple_devices__executes_in_worker_threads() -> None: + def test_function(device: torch.device, is_parallel: bool) -> int: # noqa: ARG001 + return threading.get_ident() + + thread_ids = parallel_execute( + devices=[torch.device("cpu"), torch.device("meta")], + functions=[test_function, test_function], + ) + + current_thread_id = threading.get_ident() + for thread_id in thread_ids: + assert thread_id != current_thread_id + + +def test__parallel_execute__multiple_devices__sets_is_parallel_to_True() -> None: + def test_function(device: torch.device, is_parallel: bool) -> bool: # noqa: ARG001 + return is_parallel + + is_parallels = parallel_execute( + devices=[torch.device("cpu"), torch.device("meta")], + functions=[test_function, test_function], + ) + + assert list(is_parallels) == [True, True] + + +def test__parallel_execute__multiple_devices__results_in_same_order_as_functions() -> ( + None +): + def a(device: torch.device, is_parallel: bool) -> str: # noqa: ARG001 + return "a" + + def b(device: torch.device, is_parallel: bool) -> str: # noqa: ARG001 + return "b" + + def c(device: torch.device, is_parallel: bool) -> str: # noqa: ARG001 + return "c" + + results = parallel_execute( + devices=[torch.device("meta"), torch.device("meta")], functions=[a, b, c] + ) + + assert list(results) == ["a", "b", "c"]