Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discovering supported float types per namespace and device #678

Closed
betatim opened this issue Aug 29, 2023 · 4 comments
Closed

Discovering supported float types per namespace and device #678

betatim opened this issue Aug 29, 2023 · 4 comments
Labels
Duplicate This issue or pull request already exists.
Milestone

Comments

@betatim
Copy link
Member

betatim commented Aug 29, 2023

I find myself wanting to programmatically find out what the "highest precision float type" is that a particular library supports on a particular device. Concretely pytorch and the MPS device (their name for the GPU in a Apple M1 (and M2?)). On the MPS device they don't support float64 which is how I ended up wanting something to let me find out what the highest precision available float type is.

import torch
import array_api_compat

x = torch.tensor([1,2,3], device="mps", dtype=torch.float32)
xp = array_api_compat.get_namespace(x)
# side quest: is there a better way to get the torch namespace?

x = xp.asarray([1,2,3], device="mps", dtype=torch.float32)

# Maybe `can_cast` is the right tool?
xp.can_cast(xp.float32, xp.float64)  # -> True
xp.can_cast(x, xp.float64)  # -> True

Presumably the two calls to can_cast return True because in general PyTorch supports float64 and the implementation of can_cast does not inspect the device of x? So at least for now/how it is currently implemented I think can_cast is not the right tool for finding out if float64 exists and using float32 if not. Or making my own highest_precision_float().

@betatim betatim changed the title Discovering support float types per namespace and device Discovering supported float types per namespace and device Aug 29, 2023
@honno
Copy link
Member

honno commented Aug 29, 2023

I believe with #640 this would get resolved! The proposed xp.__array_namespace_info__() would return a info object with the method dtypes() that returns a list of (supported) dtype objects, so one could info.dtypes(device=...) to get the support dtypes. The kind kwarg could specify this just for floats say, e.g. info.dtypes(device=..., kind="real floating").

@rgommers
Copy link
Member

Thanks @betatim! I have two thoughts here:

  • You probably don't really want the highest precision floating-point dtype, since numpy's long double dtypes are not great. Rather, I think you want "float64 is available, and the next-lowest-precision dtype otherwise". Correct?
  • JAX may be a problem here, since float64 may exist but when used may return arrays with dtype float32. I think you'll have to special-case that for the time being. RFC: add a unified inspection API namespace #640 may indeed fix it, but that'll take quite a while to materialize I think.

+1 to what @honno wrote - this use case is extra support for implementing that introspection API.

@kgryte kgryte added this to the v2023 milestone Aug 29, 2023
@betatim
Copy link
Member Author

betatim commented Aug 30, 2023

Indeed, what I really want is "float64 if available and the next lowest otherwise".

Should we close this issue or somehow mark it as "please don't discuss here, go to #640 instead"? It seems like it is kind of a duplicate.

@rgommers rgommers added the Duplicate This issue or pull request already exists. label Aug 30, 2023
@rgommers
Copy link
Member

Sounds good - done!

@rgommers rgommers closed this as not planned Won't fix, can't repro, duplicate, stale Aug 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Duplicate This issue or pull request already exists.
Projects
None yet
Development

No branches or pull requests

4 participants