-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
einops-like packing notation #180
Comments
What you can do todaySo it's a little less elegant, but you can do this today via def unpack(x: Float[Array, "B C H_W"]) -> Float[Array, "B C H (H_W//H)"]:
... # magic unpacking
return y where the The reverse direction is a bit neater, as the logic (which should typically go on the RHS) is easier to read: def pack(x: Float[Array, "B C H W"]) -> Float[Array, "B C H*W"]:
... # magic unpacking
return y Note that this works because in each case the output shape is a function of the input shape -- rather than the other way around! What we could do tomorrowFirst of all, just on the syntax: I think if we were to support something like this, then I'd probably suggest using the syntax And in fact, we can actually already write this: def unpack(x: Float[Array, "B C H*W"]) -> Float[Array, "B C H W"]:
... # magic unpacking
return y but this would raise an error if you were to do runtime type-checking, as it won't have seen So if were to change anything, I think it would probably be to (a) allow such "incomplete" annotations when checking the arguments, and then to (b) go back and check them all again after the function has finished running and we have its output. WDYT? |
Cool! Thanks, that is great that this is already possible! I think this will already be very useful enough for me. The second option is perhaps a bit nicer if it is not too difficult to add. i.e., doing |
I'd be happy to add the second notation, I'd just have to ask for a PR on it as I don't have the time to implement this myself :D If you or anyone else feels strongly about this, then I'd be happy to explain how to tweak the jaxtyping internals to accomplish this. |
For posterity I'm also time deficient at the moment due to teaching. (Anybody reading this thread; feel free to take a stab at this!) |
Hey @patrick-kidger,
I'm wondering how hard it would be to have einops-like notation for packed axes? For example,
would indicate that the last axis is a flattened version of the height and width axis.
This means that if you have the full signature as:
then jaxtyping would check that
y.shape[2] * y.shape[3] == x.shape[2]
.Note that in many cases it would not be able to confirm
H
andW
individually. I think that is okay; it's just free variables. But if it can confirm the individual shapes, then it can do the type check.What do you think? Does this make sense?
The text was updated successfully, but these errors were encountered: