Skip to content

Commit

Permalink
Merge pull request #835 from LaurentMazare/index_result
Browse files Browse the repository at this point in the history
index result
  • Loading branch information
LaurentMazare authored Jan 6, 2024
2 parents a5f6ea8 + 7f41f93 commit ca92737
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 17 deletions.
61 changes: 44 additions & 17 deletions src/tensor/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
//! shape mismatch error due to advanced indexing rule. Another distinction
//! is that `i` guarantees the input and result tensor shares the same
//! underlying storage, while NumPy may copy the tensor in certain scenarios.
use crate::{TchError, Tensor};
use crate::{Result, TchError, Tensor};
use std::ops::{
Bound, Range, RangeBounds, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive,
};
Expand Down Expand Up @@ -133,14 +133,19 @@ impl_from_range!(RangeToInclusive<i64>);

pub trait IndexOp<T> {
fn i(&self, index: T) -> Tensor;
fn f_i(&self, index: T) -> Result<Tensor>;
}

impl<A> IndexOp<A> for Tensor
where
A: Into<TensorIndexer>,
{
fn i(&self, index: A) -> Tensor {
self.indexer(&[index.into()])
self.f_i(index).unwrap()
}

fn f_i(&self, index: A) -> Result<Tensor> {
self.f_indexer(&[index.into()])
}
}

Expand All @@ -149,8 +154,12 @@ where
A: Into<TensorIndexer>,
{
fn i(&self, index: (A,)) -> Tensor {
self.f_i(index).unwrap()
}

fn f_i(&self, index: (A,)) -> Result<Tensor> {
let idx_a = index.0.into();
self.indexer(&[idx_a])
self.f_indexer(&[idx_a])
}
}

Expand All @@ -160,9 +169,13 @@ where
B: Into<TensorIndexer>,
{
fn i(&self, index: (A, B)) -> Tensor {
self.f_i(index).unwrap()
}

fn f_i(&self, index: (A, B)) -> Result<Tensor> {
let idx_a = index.0.into();
let idx_b = index.1.into();
self.indexer(&[idx_a, idx_b])
self.f_indexer(&[idx_a, idx_b])
}
}

Expand All @@ -173,10 +186,14 @@ where
C: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C)) -> Tensor {
self.f_i(index).unwrap()
}

fn f_i(&self, index: (A, B, C)) -> Result<Tensor> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
self.indexer(&[idx_a, idx_b, idx_c])
self.f_indexer(&[idx_a, idx_b, idx_c])
}
}

Expand All @@ -188,11 +205,15 @@ where
D: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D)) -> Tensor {
self.f_i(index).unwrap()
}

fn f_i(&self, index: (A, B, C, D)) -> Result<Tensor> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
self.indexer(&[idx_a, idx_b, idx_c, idx_d])
self.f_indexer(&[idx_a, idx_b, idx_c, idx_d])
}
}

Expand All @@ -205,12 +226,16 @@ where
E: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D, E)) -> Tensor {
self.f_i(index).unwrap()
}

fn f_i(&self, index: (A, B, C, D, E)) -> Result<Tensor> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
let idx_e = index.4.into();
self.indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e])
self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e])
}
}

Expand All @@ -224,13 +249,17 @@ where
F: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D, E, F)) -> Tensor {
self.f_i(index).unwrap()
}

fn f_i(&self, index: (A, B, C, D, E, F)) -> Result<Tensor> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
let idx_e = index.4.into();
let idx_f = index.5.into();
self.indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f])
self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f])
}
}

Expand All @@ -245,19 +274,23 @@ where
G: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D, E, F, G)) -> Tensor {
self.f_i(index).unwrap()
}

fn f_i(&self, index: (A, B, C, D, E, F, G)) -> Result<Tensor> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
let idx_e = index.4.into();
let idx_f = index.5.into();
let idx_g = index.6.into();
self.indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g])
self.f_indexer(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g])
}
}

impl Tensor {
fn f_indexer(&self, index_spec: &[TensorIndexer]) -> Result<Tensor, TchError> {
fn f_indexer(&self, index_spec: &[TensorIndexer]) -> Result<Tensor> {
use std::ops::Bound::*;
use TensorIndexer::*;

Expand Down Expand Up @@ -321,7 +354,7 @@ impl Tensor {
(Excluded(start), Included(end)) => Some((*start + 1, *end - *start)),
(Excluded(start), Excluded(end)) => Some((*start + 1, *end - *start - 1)),
} {
(curr_tensor.narrow(curr_idx, start, length.max(0)), curr_idx + 1)
(curr_tensor.f_narrow(curr_idx, start, length.max(0))?, curr_idx + 1)
} else {
(curr_tensor, curr_idx + 1)
}
Expand All @@ -331,15 +364,9 @@ impl Tensor {
(curr_tensor.index_select(curr_idx, &index_tensor), curr_idx + 1)
}
};

curr_tensor = next_tensor;
curr_idx = next_idx;
}

Ok(curr_tensor)
}

fn indexer(&self, index_spec: &[TensorIndexer]) -> Tensor {
self.f_indexer(index_spec).unwrap()
}
}
1 change: 1 addition & 0 deletions tests/tensor_tests.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![allow(clippy::unnecessary_fallible_conversions)]
use anyhow::Result;
use half::f16;
use std::convert::{TryFrom, TryInto};
Expand Down

0 comments on commit ca92737

Please sign in to comment.