Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split TryConcatAlong into different traits #891

Closed
swfsql opened this issue Nov 16, 2023 · 0 comments · Fixed by #892
Closed

Split TryConcatAlong into different traits #891

swfsql opened this issue Nov 16, 2023 · 0 comments · Fixed by #892

Comments

@swfsql
Copy link
Contributor

swfsql commented Nov 16, 2023

This issue is a request to separate the TryConcatAlong trait into two different traits, one for Tensors and another for Shapes.

edit: A small example on disc.


I'm testing creating structs to represent tensor operations to be used as Modules inside model definitions, but when trying to create a Module to represent Tensor concatenation, the rust type/trait system go crazy on the trait bounds. Usually the shape of a tensor is a generic parameter that's actually a "shape", but when doing something like:

impl<A, B, Ax, E: Dtype, D, T: Tape<E, D>, R: Tape<E, D>>
    Module<(Tensor<A, E, D, T>, Tensor<B, E, D, R>)> for ConcatTensorAlong<Ax>
where
    (A, B): TryConcatAlong<Ax>, // <- this line is problematic
    // etc
{
    type Output = /**/;

    fn try_forward(
        &self,
        x: (Tensor<A, E, D, T>, Tensor<B, E, D, R>),
    ) -> Result<Self::Output, Error> {
        // etc
    }
}

In this case, ConcatTensorAlong would be the struct representing a Module.
But when adding what was indicated by as "problematic"*, the rust type system (for some reason) insists in considering A and B to be yet other Tensors, so it tries to verify the trait bounds of Tensor<Tensor<Tensor<...>>> recursively, until it crashes.

*but only if I actually try to use ConcatTensorAlong inside a Model and call forward on it.
This is the error I get:

error[E0275]: overflow evaluating the requirement `((_, _, _, _), (..., ..., ..., ...)): dfdx::prelude::TryConcatAlong<...>`
   --> src/c4/w3/pa_02_image_segmentation.rs:211:35
    |
211 |         let prediction: _ = model.forward(x);
    |                                   ^^^^^^^
    |
    = help: consider increasing the recursion limit by adding a `#![recursion_limit = "256"]` attribute to your crate (`coursera_exercises`)
    = note: required for `(dfdx::prelude::Tensor<(_, _, _, _), _, _, _>, dfdx::prelude::Tensor<(_, _, _, _), _, _, _>)` to implement `dfdx::prelude::TryConcatAlong<dfdx::prelude::Axis<1>>`
    = note: 123 redundant requirements hidden
    = note: required for `(Tensor<Tensor<Tensor<Tensor<..., ..., ..., ...>, ..., ..., ...>, ..., ..., ...>, ..., ..., ...>, ...)` to implement `dfdx::prelude::TryConcatAlong<dfdx::prelude::Axis<1>>`
    = note: the full type name has been written to '/workspaces/coursera-deep-learning-specialization/r/target/debug/deps/coursera_exercises-ec2247d6c3cc6475.long-type-15920087071787293129.txt'
    = note: required for `dfdx::nn::ops::ConcatTensorAlong<dfdx::prelude::Axis<1>>` to implement `dfdx::nn::Module<(dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<dfdx::prelude::Tensor<(_, _, _, _), _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, _, _, _>, 

And well, increasing the recursion limit doesn't help. At first I didn't even read the thing, assumed that my model was too big or something, and increased the recursion to 1024 but noticed that the rust compiler itself crashed instead.

Currently I've noticed that if TryConcatAlong is separated into two different traits, one for Tensor and another for Shape, and adjusting the trait bounds accordingly around the code, rust no longer crashes for that kind of Module definition.

@swfsql swfsql changed the title Separate TryConcatAlong into different traits Split TryConcatAlong into different traits Nov 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant