Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to infer appropriate dtype from uint to int and float to complex? #859

Open
34j opened this issue Nov 25, 2024 · 3 comments
Open

How to infer appropriate dtype from uint to int and float to complex? #859

34j opened this issue Nov 25, 2024 · 3 comments

Comments

@34j
Copy link

34j commented Nov 25, 2024

I would like to compute $f(x) := xi$, $g(y) := y - 1$ where $i$ is an imaginary number, $x$ is float and $y$ is uint, using array-api. However, I am not sure what is the best way to implement it. Following the type promotion rules

def f(x: xp.array) -> xp.array:
	return x * xp.array(1j, dtype=xp.complex64 if x.dtype == xp.float32 else xp.complex128)

def g(x: xp.array) -> xp.array:
	return x - xp.array(1, dtype=xp.int16 if x.dtype == xp.uint8 else xp.int32 if x.dtype == xp.uint16 else  xp.int64)

This seems too redundant. What is the proper way to do this?

@lucascolley
Copy link
Contributor

see gh-841

@asmeurer
Copy link
Member

That issue covers the complex case. Once it is fixed, x*1j should work as you would expect when x has a floating-point dtype.

For the second cast, the standard works like this:

>>> import array_api_strict as xp
>>> xp.asarray([0], dtype=xp.uint8) - 1
Array([255], dtype=array_api_strict.uint8)

The difference is your example would make the resulting type int16. The type promotion rules for scalars says that Python scalars first cast to the type of the array, so <uint8 array> - <python int> will always produces a uint8 array. In this particular case, the result is not actually defined (see the note at https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__neg__.html#neg, together with the note saying x - y == x + (-y) at https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__sub__.html#sub; small nit, __sub__ should probably just state this fact more directly, since it's only undefined when x - y is negative).

The rule that scalars always cast to the same dtype as the array is not something that should change, so you'd want some other way to spell y - 1 so that it upcasts to a signed dtype.

I think the astype suggestion in #841 (comment) would be the cleanest way to do this. If it were implemented, you could write

xp.astype(x, 'signed') - 1

where xp.astype(<uint8>, 'signed') would return an int16 array (and would error for uint64, but that's always a tricky dtype to deal with in the context of type promotion).

The astype improvements idea should be split out into its own issue. I doubt it would be implemented for the 2024 standard release, since it hasn't even been fleshed out yet (though it's not impossible). The complex-scalar-to-float-array issue will definitely be fixed for 2024.

@lucascolley
Copy link
Contributor

I think the astype suggestion in #841 (comment) would be the cleanest way to do this. If it were implemented

I doubt it would be implemented for the 2024 standard release, since it hasn't even been fleshed out yet (though it's not impossible)

gh-848 😉 (are there complexities which I haven't thought about? A review would be appreciated!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants