You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When training the SingleTaskVariationalGP using a DataLoader, there is memory leak that causes the device to run out of memory as training progresses. This is not happening immediately after starting the optimization, but after couple of training iterations. I believe this is because fval in torch_minimize is not detached before passing it to stopping_criterion.
---------------------------------------------------------------------------
OutOfMemoryError Traceback (most recent call last)
Cell In[8], line 1
----> 1 _ = fit_gpytorch_mll(mll, data_loader=train_loader)
File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/fit.py:104, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
101 if optimizer is not None: # defer to per-method defaults
102 kwargs["optimizer"] = optimizer
--> 104 return FitGPyTorchMLL(
105 mll,
106 type(mll.likelihood),
107 type(mll.model),
108 closure=closure,
109 closure_kwargs=closure_kwargs,
110 optimizer_kwargs=optimizer_kwargs,
111 **kwargs,
112 )
File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
91 func = self.__getitem__(types=types)
92 try:
---> 93 return func(*args, **kwargs)
94 except MDNotImplementedError:
95 # Traverses registered methods in order, yields whenever a match is found
96 funcs = self.dispatch_iter(*types)
File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/fit.py:331, in _fit_fallback_approximate(mll, _, __, closure, data_loader, optimizer, full_batch_limit, **kwargs)
324 if optimizer is None:
325 optimizer = (
326 fit_gpytorch_mll_scipy
327 if closure is None and len(mll.model.train_targets) <= full_batch_limit
328 else fit_gpytorch_mll_torch
329 )
--> 331 return _fit_fallback(mll, _, __, closure=closure, optimizer=optimizer, **kwargs)
File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/fit.py:204, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, pick_best_of_all_attempts, warning_handler, caught_exception_types, **ignore)
202 with catch_warnings(record=True) as warning_list, debug(True):
203 simplefilter("always", category=OptimizationWarning)
--> 204 result = optimizer(mll, closure=closure, **optimizer_kwargs)
206 # Resolve warnings and determine whether or not to retry
207 success = True
File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/optim/fit.py:164, in fit_gpytorch_mll_torch(mll, parameters, bounds, closure, closure_kwargs, step_limit, stopping_criterion, optimizer, scheduler, callback, timeout_sec)
161 if closure_kwargs is not None:
162 closure = partial(closure, **closure_kwargs)
--> 164 return torch_minimize(
165 closure=closure,
166 parameters=parameters,
167 bounds=bounds_dict if bounds is None else {**bounds_dict, **bounds},
168 optimizer=optimizer,
169 scheduler=scheduler,
170 step_limit=step_limit,
171 stopping_criterion=stopping_criterion,
172 callback=callback,
173 timeout_sec=timeout_sec,
174 )
File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/optim/core.py:195, in torch_minimize(closure, parameters, bounds, callback, optimizer, scheduler, step_limit, timeout_sec, stopping_criterion)
189 _bounds = (
190 {}
191 if bounds is None
192 else {name: limits for name, limits in bounds.items() if name in parameters}
193 )
194 for step in tqdm.tqdm(range(1, step_limit + 1)):
--> 195 fval, _ = closure()
196 runtime = monotonic() - start_time
197 result = OptimizationResult(
198 step=step,
199 fval=fval.detach().cpu().item(),
200 status=OptimizationStatus.RUNNING,
201 runtime=runtime,
202 )
File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
64 values = self.forward(**kwargs)
65 value = values if self.reducer is None else self.reducer(values)
---> 66 self.backward(value)
68 grads = tuple(param.grad for param in self.parameters.values())
69 if self.callback:
File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/torch/_tensor.py:521, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
511 if has_torch_function_unary(self):
512 return handle_torch_function(
513 Tensor.backward,
514 (self,),
(...)
519 inputs=inputs,
520 )
--> 521 torch.autograd.backward(
522 self, gradient, retain_graph, create_graph, inputs=inputs
523 )
File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py:289, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
284 retain_graph = create_graph
286 # The reason we repeat the same comment below is that
287 # some Python versions print out the first line of a multi-line function
288 # calls in the traceback and some print out the last line
--> 289 _engine_run_backward(
290 tensors,
291 grad_tensors_,
292 retain_graph,
293 create_graph,
294 inputs,
295 allow_unreachable=True,
296 accumulate_grad=True,
297 )
File ~/.cache/pypoetry/virtualenvs/new-project-mAQ5cvJS-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py:769, in _engine_run_backward(t_outputs, *args, **kwargs)
767 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
768 try:
--> 769 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
770 t_outputs, *args, **kwargs
771 ) # Calls into the C++ engine to run the backward pass
772 finally:
773 if attach_logging_hooks:
OutOfMemoryError: CUDA out of memory. Tried to allocate 66.00 MiB. GPU 0 has a total capacity of 5.76 GiB of which 67.56 MiB is free. Process 1641283 has 2.36 GiB memory in use. Including non-PyTorch memory, this process has 2.74 GiB memory in use. Of the allocated memory 2.20 GiB is allocated by PyTorch, and 404.29 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Expected Behavior
That the training would complete without running out of memory.
System information
Please complete the following information:
BoTorch Version: 0.11.3
GPyTorch Version: 1.12
PyTorch Version: 2.4.2+cu121
Ubuntu 22.04.4
The text was updated successfully, but these errors were encountered:
🐛 Bug
When training the
SingleTaskVariationalGP
using aDataLoader
, there is memory leak that causes the device to run out of memory as training progresses. This is not happening immediately after starting the optimization, but after couple of training iterations. I believe this is becausefval
intorch_minimize
is not detached before passing it tostopping_criterion
.To reproduce
Expected Behavior
That the training would complete without running out of memory.
System information
Please complete the following information:
The text was updated successfully, but these errors were encountered: