Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to specify which dimension to reduce along when using sum ? #613

Closed
mspronesti opened this issue Jan 18, 2023 · 2 comments
Closed

How to specify which dimension to reduce along when using sum ? #613

mspronesti opened this issue Jan 18, 2023 · 2 comments

Comments

@mspronesti
Copy link

mspronesti commented Jan 18, 2023

Basically, I want to convert torch.sum(tensor, 1). tch::Tensor has a sum(dtype) method which doesn't allow to specify which dimension to reduce along: it reduces along all of them, like torch.sum does when no dim is passed.

I guess the correct method to call is sum_dim_intlist, however dim must implement Into<Option<&'a [i64]>>, which forces to use a slice. Isn't there anything simpler than this call ?

some_tensor.sum_dim_intlist([1].as_slice(), false, Kind::Int64)
@LaurentMazare
Copy link
Owner

I think you're right and that's the appropriate way to do it, I wouldn't know about a simpler way now but I've just prototype in #621 a way to simplify this and allow one to write the following:

some_tensor.sum_dim_intlist(1, false, Kind::Int64)

However this is a bit of a breaking change so we probably want to think a bit more about it (besides only making it a major revision change which is already the case).

@LaurentMazare
Copy link
Owner

I've merged these changes via #682 so the following should work now.

some_tensor.sum_dim_intlist(1, false, Kind::Int64)

Hopefully this makes it simpler to write code that involve reduction across a dimension. Feel free to re-open if you run into further issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants