You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I find myself wanting to programmatically find out what the "highest precision float type" is that a particular library supports on a particular device. Concretely pytorch and the MPS device (their name for the GPU in a Apple M1 (and M2?)). On the MPS device they don't support float64 which is how I ended up wanting something to let me find out what the highest precision available float type is.
importtorchimportarray_api_compatx=torch.tensor([1,2,3], device="mps", dtype=torch.float32)
xp=array_api_compat.get_namespace(x)
# side quest: is there a better way to get the torch namespace?x=xp.asarray([1,2,3], device="mps", dtype=torch.float32)
# Maybe `can_cast` is the right tool?xp.can_cast(xp.float32, xp.float64) # -> Truexp.can_cast(x, xp.float64) # -> True
Presumably the two calls to can_cast return True because in general PyTorch supports float64 and the implementation of can_cast does not inspect the device of x? So at least for now/how it is currently implemented I think can_cast is not the right tool for finding out if float64 exists and using float32 if not. Or making my own highest_precision_float().
The text was updated successfully, but these errors were encountered:
betatim
changed the title
Discovering support float types per namespace and device
Discovering supported float types per namespace and device
Aug 29, 2023
I believe with #640 this would get resolved! The proposed xp.__array_namespace_info__() would return a info object with the method dtypes() that returns a list of (supported) dtype objects, so one could info.dtypes(device=...) to get the support dtypes. The kind kwarg could specify this just for floats say, e.g. info.dtypes(device=..., kind="real floating").
You probably don't really want the highest precision floating-point dtype, since numpy's long double dtypes are not great. Rather, I think you want "float64 is available, and the next-lowest-precision dtype otherwise". Correct?
JAX may be a problem here, since float64 may exist but when used may return arrays with dtype float32. I think you'll have to special-case that for the time being. RFC: add a unified inspection API namespace #640 may indeed fix it, but that'll take quite a while to materialize I think.
+1 to what @honno wrote - this use case is extra support for implementing that introspection API.
I find myself wanting to programmatically find out what the "highest precision float type" is that a particular library supports on a particular device. Concretely pytorch and the MPS device (their name for the GPU in a Apple M1 (and M2?)). On the MPS device they don't support
float64
which is how I ended up wanting something to let me find out what the highest precision available float type is.Presumably the two calls to
can_cast
returnTrue
because in general PyTorch supports float64 and the implementation ofcan_cast
does not inspect the device ofx
? So at least for now/how it is currently implemented I thinkcan_cast
is not the right tool for finding out iffloat64
exists and usingfloat32
if not. Or making my ownhighest_precision_float()
.The text was updated successfully, but these errors were encountered: