-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Consider updating copy
semantics in astype
to False/None/True
#788
Comments
The default specification behavior follows NumPy and its derivatives. From the NumPy docs,
and the signature
The default is |
If NumPy is the motivation for this, then I think there may be a problem: NumPy has explicitly decided to differ from the Array API in this respect; see https://numpy.org/neps/nep-0056-array-api-main-namespace.html#copy-keyword-semantics |
My reading from that NEP is not that NumPy is choosing to differ. NumPy is following the spec where However, the situation is actually reversed. We based the current spec on what NumPy and its kin did at the time, which was to, by default, always copy. |
I see, thanks. |
For the default, yes. NumPy does support >>> x = np.ones((3, 2), dtype=np.float32)
>>> y = x.astype(dtype=np.float32, copy=None)
>>> y is x
True
>>> y = x.astype(dtype=np.float64, copy=None)
>>> y is x
False I'm not sure if this was left out of the standard's version of |
In NumPy's implementation, do |
Yes. But that's only because there was resistance to changing |
Maybe it'll be changed on a longer timescale, I don't know. |
Although the historical creation of the To summarize:
|
I think you can get "copy never" behavior in the current standard using |
I am trying to figure out why this matters for JAX, since it doesn't actually have to do a physical memory copy - by design it guarantees that there is no difference (immutability). My understanding is that JAX had a This has now come up several times, so we really should make this more clear. The first time was at #495 (comment) I believe. And more recently, we had a more extensive discussion on this in dmlc/dlpack#136 for DLPack. For DLPack it's about exchange between two libraries rather than semantics with a single library, so we put more effort into the "what does copy actually mean". In summary, the |
True copies are sometimes important in JAX; for example functions can be called with donated buffers (in which the compiler is free to reuse the donated memory), and donated buffers cannot be used in subsequent function calls. If a user wants a copy of an array for this purpose, we currently recommend In recent work to make JAX compatible with the Array API, we've been trying to understand the recommended semantics of the |
Currently the specification for
astype
, as added in #290, specifies thatcopy=False/True
wherein there is no room for a "copy never" option, and the defaultcopy=True
means that calls toastype
that do not specify thecopy
kwarg never have a chance to be a no-op -- this results in unnecessary copies as default behavior.Other functions which use the
copy=False/None/True
semantics include:asarray
__dlpack__
from_dlpack
reshape
Hence,
copy=False
implies a copy will still occur while elsewhere it means "copy never"copy=True
is imo a bad default that leads to more memory movement than may be necessaryThe text was updated successfully, but these errors were encountered: