Skip to content

Commit

Permalink
Remove platform canonicalization from xla_bridge.py (#2815)
Browse files Browse the repository at this point in the history
  • Loading branch information
skye authored Apr 24, 2020
1 parent 11d7fb0 commit 343e486
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions jax/lib/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,13 @@ def register_backend(name, factory):

def _get_local_backend(platform=None):
if not platform:
platform = FLAGS.jax_platform_name

# Canonicalize platform names.
cpu = 'cpu'
gpu = 'gpu'
if platform == 'Host':
platform = cpu
elif platform == 'CUDA':
platform = gpu
elif platform == '':
platform = None
platform = FLAGS.jax_platform_name or None

backend = xla_client.get_local_backend(platform)
if backend is None:
raise RuntimeError("No local XLA backends found.")

if backend.platform == cpu and platform != cpu:
if backend.platform == 'cpu' and platform != 'cpu':
warnings.warn('No GPU/TPU found, falling back to CPU.')

return backend
Expand Down Expand Up @@ -406,4 +396,4 @@ def Slice(self, operand, start_indices, limit_indices, strides=None):


def computation_builder_shim(b):
return b if version > (0, 1, 45) else ComputationBuilderShim(b)
return b if version > (0, 1, 45) else ComputationBuilderShim(b)

0 comments on commit 343e486

Please sign in to comment.