Skip to content

[Bug] Memory leak when using a DataLoader for SingleTaskVariationalGP #2526

Closed
@mikkelbue

Description

@mikkelbue

🐛 Bug

When training the SingleTaskVariationalGP using a DataLoader, there is memory leak that causes the device to run out of memory as training progresses. This is not happening immediately after starting the optimization, but after couple of training iterations. I believe this is because fval in torch_minimize is not detached before passing it to stopping_criterion.

To reproduce

import math
import torch
from torch.utils.data import TensorDataset, DataLoader
from botorch.models import SingleTaskVariationalGP
from gpytorch.mlls import VariationalELBO
from botorch import fit_gpytorch_mll

n_train = 100000

train_x = torch.linspace(0, 1, n_train).unsqueeze(-1)
train_y = torch.sin(train_x * (4 * math.pi) + torch.randn(train_x.size()) * 0.2)

n_inducing = 4096
idx_inducing = torch.randint(0, train_x.shape[0], size=(n_inducing,))

if torch.cuda.is_available():
    train_x, train_y = train_x.cuda(), train_y.cuda()

train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=2048, shuffle=True)

model = SingleTaskVariationalGP(train_x[:10,:],
                                train_y[:10,:],
                                inducing_points=train_x[idx_inducing,:])
mll = VariationalELBO(model.likelihood, model.model, train_x.shape[0])

if torch.cuda.is_available():
    mll = mll.cuda()

_ = fit_gpytorch_mll(mll, data_loader=train_loader)
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[8], line 1
----> 1 _ = fit_gpytorch_mll(mll, data_loader=train_loader)

File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/fit.py:104, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    101 if optimizer is not None:  # defer to per-method defaults
    102     kwargs["optimizer"] = optimizer
--> 104 return FitGPyTorchMLL(
    105     mll,
    106     type(mll.likelihood),
    107     type(mll.model),
    108     closure=closure,
    109     closure_kwargs=closure_kwargs,
    110     optimizer_kwargs=optimizer_kwargs,
    111     **kwargs,
    112 )

File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/fit.py:331, in _fit_fallback_approximate(mll, _, __, closure, data_loader, optimizer, full_batch_limit, **kwargs)
    324 if optimizer is None:
    325     optimizer = (
    326         fit_gpytorch_mll_scipy
    327         if closure is None and len(mll.model.train_targets) <= full_batch_limit
    328         else fit_gpytorch_mll_torch
    329     )
--> 331 return _fit_fallback(mll, _, __, closure=closure, optimizer=optimizer, **kwargs)

File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/fit.py:204, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, pick_best_of_all_attempts, warning_handler, caught_exception_types, **ignore)
    202 with catch_warnings(record=True) as warning_list, debug(True):
    203     simplefilter("always", category=OptimizationWarning)
--> 204     result = optimizer(mll, closure=closure, **optimizer_kwargs)
    206 # Resolve warnings and determine whether or not to retry
    207 success = True

File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/optim/fit.py:164, in fit_gpytorch_mll_torch(mll, parameters, bounds, closure, closure_kwargs, step_limit, stopping_criterion, optimizer, scheduler, callback, timeout_sec)
    161 if closure_kwargs is not None:
    162     closure = partial(closure, **closure_kwargs)
--> 164 return torch_minimize(
    165     closure=closure,
    166     parameters=parameters,
    167     bounds=bounds_dict if bounds is None else {**bounds_dict, **bounds},
    168     optimizer=optimizer,
    169     scheduler=scheduler,
    170     step_limit=step_limit,
    171     stopping_criterion=stopping_criterion,
    172     callback=callback,
    173     timeout_sec=timeout_sec,
    174 )

File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/optim/core.py:195, in torch_minimize(closure, parameters, bounds, callback, optimizer, scheduler, step_limit, timeout_sec, stopping_criterion)
    189 _bounds = (
    190     {}
    191     if bounds is None
    192     else {name: limits for name, limits in bounds.items() if name in parameters}
    193 )
    194 for step in tqdm.tqdm(range(1, step_limit + 1)):
--> 195     fval, _ = closure()
    196     runtime = monotonic() - start_time
    197     result = OptimizationResult(
    198         step=step,
    199         fval=fval.detach().cpu().item(),
    200         status=OptimizationStatus.RUNNING,
    201         runtime=runtime,
    202     )

File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
     64 values = self.forward(**kwargs)
     65 value = values if self.reducer is None else self.reducer(values)
---> 66 self.backward(value)
     68 grads = tuple(param.grad for param in self.parameters.values())
     69 if self.callback:

File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/torch/_tensor.py:521, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    511 if has_torch_function_unary(self):
    512     return handle_torch_function(
    513         Tensor.backward,
    514         (self,),
   (...)
    519         inputs=inputs,
    520     )
--> 521 torch.autograd.backward(
    522     self, gradient, retain_graph, create_graph, inputs=inputs
    523 )

File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py:289, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    284     retain_graph = create_graph
    286 # The reason we repeat the same comment below is that
    287 # some Python versions print out the first line of a multi-line function
    288 # calls in the traceback and some print out the last line
--> 289 _engine_run_backward(
    290     tensors,
    291     grad_tensors_,
    292     retain_graph,
    293     create_graph,
    294     inputs,
    295     allow_unreachable=True,
    296     accumulate_grad=True,
    297 )

File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py:769, in _engine_run_backward(t_outputs, *args, **kwargs)
    767     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    768 try:
--> 769     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    770         t_outputs, *args, **kwargs
    771     )  # Calls into the C++ engine to run the backward pass
    772 finally:
    773     if attach_logging_hooks:

OutOfMemoryError: CUDA out of memory. Tried to allocate 66.00 MiB. GPU 0 has a total capacity of 5.76 GiB of which 67.56 MiB is free. Process 1641283 has 2.36 GiB memory in use. Including non-PyTorch memory, this process has 2.74 GiB memory in use. Of the allocated memory 2.20 GiB is allocated by PyTorch, and 404.29 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Expected Behavior

That the training would complete without running out of memory.

System information

Please complete the following information:

  • BoTorch Version: 0.11.3
  • GPyTorch Version: 1.12
  • PyTorch Version: 2.4.2+cu121
  • Ubuntu 22.04.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions