-
-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split TryConcatAlong
into different traits
#892
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some tiny changes, but otherwise looks good and I was able to verify that the current main doesn't compile the test and the fix does 👍
@swfsql looks good - do you mind merging with the webgpu changes I just merged? Should be pretty straightforward 🤞 No need to test with webgpu features yet |
- Deprecated `TryConcatAlong` in favor of `TryConcatTensorAlong` or `TryConcatShapeAlong`. - Created `concat_tensor_along/` and `concat_shape_along/`. - Copied relevant sections and files from `concat_along`, adjusting where necessary. - Moved `concat_along/` kernels to `concat_tensor_along/`. - Adjusted the issue's integration test to the new trait, which runs successfully.
Co-authored-by: Corey Lowman <clowman1993@gmail.com>
@coreylowman sure np, I've rebased and basically just moved the webgpu kernel to where the others are. I've made one change, remaking this item as public so it's the same behavior as from before.
|
Woohoo, thanks for this change! 🎉 |
Closes #891.
dfdx_core
. But there's no Module indfdx::nn
representing tensor concatenation, so I've added as an integration test.TryConcatAlong
in favor ofTryConcatTensorAlong
orTryConcatShapeAlong
.concat_tensor_along/
andconcat_shape_along/
.concat_along
, adjusting where necessary.concat_along/
kernels toconcat_tensor_along/
.The added test, that initially failed to compile:
dfdx/dfdx/tests/issue_tests.rs
Lines 4 to 44 in b292254