Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More fixes for 2023.12 support #166

Merged
merged 29 commits into from
Sep 30, 2024
Merged

More fixes for 2023.12 support #166

merged 29 commits into from
Sep 30, 2024

Conversation

asmeurer
Copy link
Member

@asmeurer asmeurer commented Jul 29, 2024

#127

Fixes #152

@asmeurer asmeurer mentioned this pull request Jul 29, 2024
18 tasks
- Ensure the arrays that are created are created on the same device as x.
  (fixes data-apis#177)

- Make clip() work with dask.array. The workaround avoid uint64 -> float64
  promotion does not work here. (fixes data-apis#176)

- Fix loss of precision when clipping a float64 tensor with torch due to the
  scalar being converted to a float32 tensor.
I'm not sure if all the details here are correct. See
data-apis#127 (comment).
Some of these things have to be inspected manually, and I'm not completely
certain everything here is correct.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hoping someone with more knowledge of pytorch can review this. I've made various assumptions here, which I'll try to outline in comments below. If anyone can confirm whether those assumptions are valid, that would be helpful.

"""
return {
"boolean indexing": True,
"data-dependent shapes": True,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming boolean indexing and data-dependent shapes (i.e., functions like unique) always work in pytorch.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify what you mean by always work? unique and nonzero will generally work in PyTorch the same way they do for NumPy, considering normal eager execution.

With the compiler, maybe depending on the opinions chosen, is a graph break considered working?

Complex dtypes will not work, but not due to issues with data dependent shapes, it is because our implementation will unconditionally sort the input.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hadn't even considered the torch compiler. But I think as long as it functions without an error, that should be considered working (if that's not the case, then it might actually be worth flipping this flag in a compiler context, assuming that's easy).

These flags exist for libraries like dask or JAX that straight up don't allow these types of operations.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I would say you are probably correct the way you have it. Those are intended to be supported and lack of support/conditional support in the compile context is considered a deficiency.

'cpu'

"""
return torch.device("cpu")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming the default device in pytorch is always "cpu" (note there is currently an unfortunate distinction between the "default device" and "current device" in the standard. See data-apis/array-api#835). By "default device", I mean the device that is used by default when pytorch is first started.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct.

'indexing': torch.int64}

"""
default_floating = torch.get_default_dtype()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding the docs correctly, this function will always give the default floating-point dtype (i.e., what is used by default for something like torch.ones() with no dtype argument).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is true. However that "default" in torch is not static. It can be changed by the user, so that would be the "current default" and the "default default" would be torch.float32

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note to myself: I checked and torch does correctly fail if you set the default dtype to float64 and try to create a tensor on a device that doesn't support float64:

>>> torch.set_default_dtype(torch.float64)
>>> torch.asarray(0., device='mps')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

(unlike certain other libraries that silently map float64 back to float32)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so that would be the "current default" and the "default default" would be torch.float32

I'm not sure which this should use then. I think it should be the "current default", but the meaning of "default" is ambiguous. I mentioned this at data-apis/array-api#835 (comment)


"""
default_floating = torch.get_default_dtype()
default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs explicitly state that the default complex dtype always matches the floating dtype. https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html

"""
default_floating = torch.get_default_dtype()
default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128
default_integral = torch.asarray(0, device=device).dtype
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to access this that doesn't require creating a tensor?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also just hard code it, we don't have a default for integers internally. If this were to change it would break bc and would need to be updated everywhere an integer tensor is produced, manually.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean it's always int64? I wasn't sure if this would be different for certain devices.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have a concept of the default integer type. It would always need to be provided or the type is deduced from the argument that is what happens in this case, python ints can be larger than int32 so we use int64.

I mean that I do not know of any mechanism for changing the behavior you see here. Our deduction rules for creation from a non array type can be seen here. If the device does not support int64 I would expect you to see an error like you do with MPS and float64 (which I did not know would happen, I expect they had to do some work to make sure that error gets raised.)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that that layer of code is going to run "above" the dispatch to device specific logic (as in closer to python), so there is no way for the device to influence how that type is deduced, the device will be involved when the underlying tensor object is constructed where it will only see that a dtype argument has been provided.

"real floating": default_floating,
"complex floating": default_complex,
"integral": default_integral,
"indexing": default_integral,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming the default indexing type is always the same as the default integer dtype (the default indexing type is a concept in the array API for functions that return indexing arrays. For example, nonzero should return an array with the default indexing type https://data-apis.org/array-api/latest/API_specification/generated/array_api.nonzero.html)

uint8 = getattr(torch, "uint8", None)
uint16 = getattr(torch, "uint16", None)
uint32 = getattr(torch, "uint32", None)
uint64 = getattr(torch, "uint64", None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the only dtypes that can be undefined (I know newer torch versions have them, but I want to make sure older versions work here too).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I know the support for some of these is limited. Would it be more correct to always omit them, even when they are technically defined here? They aren't really fully supported from the point of view of the array API standard.

Copy link

@amjames amjames Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that depends on the purpose of the declaration. If it is to list the definitions for data types provided by the library leave them. If the point is to declare the data types supported by the array api then drop them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I should probably remove them then. Too many array API things don't actually work with them, and this API is supposed to be a better way to check that than hasattr(xp, 'uint64').

del res[k]
continue
try:
torch.empty((0,), dtype=v, device=device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the best way to test if a dtype is supported on a given device?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What defines supported?

If dtype and device are valid this should always work, storage is untyped so the dtype is only used to compute the number of bytes needed.

Any given operator may or may not correctly dispatch for the underlying dtype. If the type is not explicitly handed for a given operator it will throw.

return res
raise ValueError(f"unsupported kind: {kind!r}")

@cache
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming it's safe to cache the output of this function (it's potentially expensive since it constructs tensors on the given device to test if a dtype is supported).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is currently no way to define a new dtype than adding an entry to a enum in the source code.

del res[k]
return res

@cache
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming the set of legal devices never changes at runtime and can be cached.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use hooks to register a custom device backend dynamically at runtime. I am unsure if this will simply add a new accepted device type string, or if the privateuseone string is overwritten by it.

# currently supported devices. To do this, we first parse the error
# message of torch.device to get the list of all possible types of
# device:
try:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the big one. This is code is based on some discussions with @pearu. To get the list of possible torch devices, we first parse an error message, then check which of those devices actually work. If there is any better way of doing this, please let me know. I would definitely prefer if this functionality were built-in to pytorch (ditto for the other functions here too).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unfortunate. I think there might be a less expensive, but more ugly way to do this. Would adding something to surface this info even help, or would you still need to have something like this to support older versions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding something would definitely be helpful. Note that ideally, pytorch will implement this exact API here, since it's part of the array API.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that pytorch sort of has APIs to do this better, e.g., https://pytorch.org/docs/stable/generated/torch.mps.device_count.html#torch.mps.device_count, but they are not consistent across all device types, and I didn't want to hard-code all the possible device types here since torch seems to support a lot of them.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I think that API should exist for optional devices. And you can find the module programmatically with torch.get_device_module(<name>). The always available devices are of course special, but I can look into handling this on the pytorch side.

'int64': cupy.int64}

"""
# TODO: Does this depend on device?
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leofang is it possible for cupy to not support some of the array API dtypes depending on a given device?

@asmeurer
Copy link
Member Author

So a couple of small things are still missing here #127. Most notable is repeat which is missing for PyTorch and has some minor issues with NumPy. There's also some concerns with default_device in the inspection API which might end up changing (see data-apis/array-api#835). But there are quite a few changes here so I want to get them merged and released. I will hold off on bumping the __array_api_version__ for now, though.

@asmeurer asmeurer merged commit b30a59e into data-apis:main Sep 30, 2024
43 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Result of numpy.sum with float32 input is float64
2 participants