-
Notifications
You must be signed in to change notification settings - Fork 530
[Fix] Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations #520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
e43735a
e69fcfb
b394350
095ab3a
c33416d
5241f8b
b5bc5a8
1a694bb
929b78a
d5219f2
6f4be00
98e1aa9
9571538
207a693
db29971
67a6129
26cb118
a016406
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,19 +4,45 @@ | |
|
||
# License: MIT License | ||
|
||
import pytest | ||
from ot.backend import jax, tf | ||
from ot.backend import get_backend_list | ||
import functools | ||
import os | ||
import pytest | ||
|
||
from ot.backend import ( | ||
get_available_backend_implementations, | ||
get_backend_list, | ||
_get_backend_instance, | ||
jax, | ||
tf | ||
) | ||
|
||
|
||
if jax: | ||
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' | ||
from jax.config import config | ||
config.update("jax_enable_x64", True) | ||
|
||
if tf: | ||
# make sure TF doesn't allocate entire GPU | ||
import tensorflow as tf | ||
physical_devices = tf.config.list_physical_devices('GPU') | ||
for device in physical_devices: | ||
try: | ||
tf.config.experimental.set_memory_growth(device, True) | ||
except Exception: | ||
pass | ||
|
||
# allow numpy API for TF | ||
from tensorflow.python.ops.numpy_ops import np_config | ||
np_config.enable_numpy_behavior() | ||
|
||
|
||
# before taking list of backends, we need to make sure all | ||
# available implementations are instantiated. looks somewhat hacky, | ||
# but hopefully it won't be needed for a common library use | ||
for backend_impl in get_available_backend_implementations(): | ||
|
||
_get_backend_instance(backend_impl) | ||
|
||
backend_list = get_backend_list() | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.