Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine backend and device arguments #1898

Closed
wants to merge 2 commits into from
Closed

Combine backend and device arguments #1898

wants to merge 2 commits into from

Conversation

skye
Copy link
Member

@skye skye commented Dec 19, 2019

No description provided.

@mattjj
Copy link
Collaborator

mattjj commented Jan 7, 2020

This looks promising! What's the status? Want to get it reviewed, or is there more work to be done?

@skye
Copy link
Member Author

skye commented Jan 7, 2020

Thanks for the ping, I forgot about this PR! @hawkinsp @shoyer we talked about this in December, WDYT of this change?

Copy link
Collaborator

@gnecula gnecula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the idea of eliminating backend (gradually), but I think we should discuss more allowing strings to be passed in lieu of devices.

@@ -107,8 +107,8 @@ def jit(fun, static_argnums=(), device=None, backend=None):
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via ``jax.devices()``.) The default is inherited from
XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
Can also be passed ``'cpu'``, ``'gpu'``, or ``'tpu'`` to use the default
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not seem nice to overload the device argument with a string denoting the platform. I assume this would be rarely needed, why not ask the user to use devices=devices("gpu")[0]?

@@ -107,8 +107,8 @@ def jit(fun, static_argnums=(), device=None, backend=None):
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via ``jax.devices()``.) The default is inherited from
XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``.
backend: This is an experimental feature and the API is likely to change.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should get in the habit of deprecating arguments from public APIs before removing them. I would suggest keeping the backend, immediately turning it into device=devices(backend)[0], issuing in awarning, adding the note to Changelog, adding a reminder somewhere to remove it in a future release. This is a lot more expensive than just removing it, but I think that as we have more and more users it is a good thing to do.

Also, it may make sense to do codesearch/ for people who use the backend argument

backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend. 'cpu','gpu', or 'tpu'.
device: This is an experimental feature and the API is likely to change.
Optional, either a Device object or ``'cpu'``, ``'gpu'``, or ``'tpu'``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do remove the string device arguments, we should fix this docstring too


def canonicalize_device_arg(device_arg):
if isinstance(device_arg, str):
if device_arg not in VALID_PLATFORMS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe check instead whether it is a supported platform.

@@ -48,7 +48,7 @@ class MultiBackendTest(jtu.JaxTestCase):
def testMultiBackend(self, backend):
if backend not in ('cpu', jtu.device_under_test(), None):
raise SkipTest()
@partial(api.jit, backend=backend)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing this, I would rewrite (or eliminate) the test. It is not clear that in the absence of the backend argument this test adds value.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Mar 4, 2023

Closing this PR. Whatever its merits were, it is quite stale.

@hawkinsp hawkinsp closed this Mar 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants