diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java index e4bdc53c713..7d0f0222bbe 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java @@ -18,8 +18,8 @@ package org.tensorflow.ndarray.impl.dimension; import java.util.Arrays; -import org.tensorflow.ndarray.index.Index; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Index; public class DimensionalSpace { @@ -35,24 +35,42 @@ public static DimensionalSpace create(Shape shape) { } public RelativeDimensionalSpace mapTo(Index[] indices) { - if (dimensions == null || indices.length > dimensions.length) { + if (dimensions == null) { throw new ArrayIndexOutOfBoundsException(); } int dimIdx = 0; + int indexIdx = 0; int newDimIdx = 0; int segmentationIdx = -1; long initialOffset = 0; - Dimension[] newDimensions = new Dimension[dimensions.length]; - while (dimIdx < indices.length) { + int newAxes = 0; + boolean seenEllipsis = false; + for (Index idx : indices) { + if (idx.isNewAxis()) { + newAxes += 1; + } + if (idx.isEllipsis()) { + if (seenEllipsis) { + throw new IllegalArgumentException("Only one ellipsis allowed"); + } else { + seenEllipsis = true; + } + } + } + int newLength = dimensions.length + newAxes; + + Dimension[] newDimensions = new Dimension[newLength]; + while (indexIdx < indices.length) { - if (indices[dimIdx].isPoint()) { + if (indices[indexIdx].isPoint()) { // When an index targets a single point in a given dimension, calculate the offset of this // point and cumulate the offset of any subsequent point as well long offset = 0; do { - offset += indices[dimIdx].mapCoordinate(0, dimensions[dimIdx]); - } while (++dimIdx < indices.length && indices[dimIdx].isPoint()); + offset += indices[indexIdx].mapCoordinate(0, dimensions[dimIdx]); + dimIdx++; + } while (++indexIdx < indices.length && indices[indexIdx].isPoint()); // If this is the first index, then the offset is the position of the whole dimension // space within the original one. If not, then we apply the offset to the last vectorial @@ -65,14 +83,47 @@ public RelativeDimensionalSpace mapTo(Index[] indices) { segmentationIdx = newDimIdx - 1; } + } else if (indices[indexIdx].isNewAxis()) { + long newSize; + if (dimIdx == 0) { + // includes everything. Should really include future reduction (at()) but that doesn't seem to cause issues + // elsewhere + newSize = dimensions[0].numElements() * dimensions[0].elementSize(); + } else { + newSize = dimensions[dimIdx - 1].elementSize(); + } + + newDimensions[newDimIdx] = new Axis(1, newSize); + segmentationIdx = newDimIdx; // is this correct? + ++newDimIdx; + ++indexIdx; + } else if (indices[indexIdx].isEllipsis()) { + int remainingDimensions = dimensions.length - dimIdx; + int requiredDimensions = 0; + for (int i = indexIdx + 1; i < indices.length; i++) { + if (!indices[i].isNewAxis()) { + requiredDimensions++; + } + } + // while the number of dimensions left < the number of indices that consume axes + while (remainingDimensions > requiredDimensions) { + Dimension dim = dimensions[dimIdx++]; + if (dim.isSegmented()) { + segmentationIdx = newDimIdx; + } + newDimensions[newDimIdx++] = dim; + remainingDimensions--; + } + indexIdx++; } else { // Map any other index to the appropriate dimension of this space - Dimension newDimension = indices[dimIdx].apply(dimensions[dimIdx++]); + Dimension newDimension = indices[indexIdx].apply(dimensions[dimIdx++]); newDimensions[newDimIdx] = newDimension; if (newDimension.isSegmented()) { segmentationIdx = newDimIdx; } ++newDimIdx; + ++indexIdx; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java index b38e33d5e22..9d3139f3248 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java @@ -39,4 +39,19 @@ public Dimension apply(Dimension dim) { private All() { } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public String toString() { + return All.class.getSimpleName() + "()"; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java index 5d92ee3286b..31ce021ddc8 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; final class At implements Index { @@ -27,22 +28,47 @@ public long numElements(Dimension dim) { @Override public long mapCoordinate(long coordinate, Dimension dim) { - return dim.positionOf(coord); // TODO validate coordinate is 0? + long coord = this.coord >= 0 ? this.coord : dim.numElements() + this.coord; + return dim.positionOf(coord); } @Override public Dimension apply(Dimension dim) { - throw new IllegalStateException(); // FIXME? + if (!keepDim) { + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); + } + + return dim.withIndex(this); } @Override public boolean isPoint() { - return true; + return !keepDim; } - At(long coord) { + At(long coord, boolean keepDim) { this.coord = coord; + this.keepDim = keepDim; } private final long coord; + private final boolean keepDim; + + @Override + public long begin() { + return coord; + } + + @Override + public long end() { + return coord + 1; + } + + @Override + public String toString() { + return new StringJoiner(", ", At.class.getSimpleName() + "(", ")") + .add("coord=" + coord) + .add("keepDim=" + keepDim) + .toString(); + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java similarity index 61% rename from ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java rename to ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java index 54f53853c32..d4085735df2 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,26 +12,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ======================================================================= + ============================================================================== */ package org.tensorflow.ndarray.index; import org.tensorflow.ndarray.impl.dimension.Dimension; -final class Even implements Index { +final class Ellipsis implements Index { - static final Even INSTANCE = new Even(); + static final Ellipsis INSTANCE = new Ellipsis(); + + private Ellipsis() { + + } @Override public long numElements(Dimension dim) { - return (dim.numElements() >> 1) + (dim.numElements() % 2); + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); } @Override public long mapCoordinate(long coordinate, Dimension dim) { - return coordinate << 1; + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); } - private Even() { + @Override + public boolean isEllipsis() { + return true; + } + + @Override + public String toString() { + return Ellipsis.class.getSimpleName() + "()"; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java deleted file mode 100644 index 7914d8faad5..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Flip implements Index { - - static final Flip INSTANCE = new Flip(); - - @Override - public long numElements(Dimension dim) { - return dim.numElements(); - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return dim.numElements() - coordinate - 1; - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java deleted file mode 100644 index c541e8370b2..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class From implements Index { - - @Override - public long numElements(Dimension dim) { - return dim.numElements() - start; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return start + coordinate; - } - - From(long start) { - this.start = start; - } - - private final long start; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java index 00b411d0167..55c4e510748 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java @@ -15,6 +15,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; /** @@ -71,4 +72,19 @@ public boolean isPoint() { private final long stride; private final long count; private final long block; + + @Override + public String toString() { + return new StringJoiner(", ", Hyperslab.class.getSimpleName() + "Hyperslab(", ")") + .add("start=" + start) + .add("stride=" + stride) + .add("count=" + count) + .add("block=" + block) + .toString(); + } + + @Override + public boolean isStridedSlicingCompliant() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java index da6aa9049f6..617ca4d474b 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java @@ -23,19 +23,16 @@ * An index used for slicing a view out of an N-dimensional array. * *
A slice, i.e. a reduced view, of an N-dimensional array is obtain by calling - * {@link NdArray#slice(Index...)}, given a list of indices - * that select which elements on a given dimension should be included/excluded - * from that view. + * {@link NdArray#slice(Index...)}, given a list of indices that select which elements on a given dimension should be + * included/excluded from that view. */ public interface Index { /** - * Returns the number of elements that can be retrieved using this index on the - * given dimension. + * Returns the number of elements that can be retrieved using this index on the given dimension. * *
An index that maps one-by-one all elements of the dimensions will return a value - * equal to {@code dim.numElements()}, while an index that only maps a subset of these - * will return a smaller value. + * equal to {@code dim.numElements()}, while an index that only maps a subset of these will return a smaller value. * * @param dim the indexed dimension * @return number of elements accessible @@ -43,8 +40,7 @@ public interface Index { long numElements(Dimension dim); /** - * Transforms an element coordinate to a new coordinate by applying this index to the - * given dimension. + * Transforms an element coordinate to a new coordinate by applying this index to the given dimension. * *
For example, if the coordinate is 0 and this index flips the {@code n} elements on this * dimension, then the returned value will be {@code n-1}. @@ -74,4 +70,62 @@ default Dimension apply(Dimension dim) { default boolean isPoint() { return false; } + + /** + * Returns true if this index is a new axis, adding a dimension of size 1 + */ + default boolean isNewAxis() { + return false; + } + + /** + * Returns true if this index is an ellipsis, expanding to take as many dimensions as possible (and applying all() to + * them) + */ + default boolean isEllipsis() { + return false; + } + + /** + * Get whether the Index supports strided slice style indexing (using start, end, stride, and flags, i.e. TensorFlow's). + */ + default boolean isStridedSlicingCompliant() { + return true; + } + + /** + * Get the start of the index, for strided slice style indexing. + */ + default long begin() { + return 0; + } + + /** + * Get the end of the index, strided slice style indexing. + */ + default long end() { + return 0; + } + + /** + * Get the stride of the index, for strided slice style indexing. + */ + default long stride() { + return 1; + } + + /** + * Get whether the Index should start at the beginning of the dimension, for strided slice style indexing. + */ + default boolean beginMask() { + return false; + } + + /** + * Get whether the Index should end at the beginning of the dimension, for strided slice style indexing. + */ + default boolean endMask() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java index abc72195c82..346ab705595 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -34,14 +34,14 @@ public final class Indices { * single element and therefore is excluded from the computation of the rank. * *
For example, given a 3D matrix on the axis [x, y, z], if - * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its - * number of elements is {@code x.numElements()} + * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is + * {@code x.numElements()} * * @param coord coordinate of the element on the indexed axis * @return index */ public static Index at(long coord) { - return new At(coord); + return new At(coord, false); } /** @@ -58,7 +58,46 @@ public static Index at(NdArray extends Number> coord) { if (coord.rank() > 0) { throw new IllegalRankException("Only scalars are accepted as a value index"); } - return new At(coord.getObject().longValue()); + return new At(coord.getObject().longValue(), false); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *
When this index is applied to a given dimension, the dimension is resolved as a + * single element and therefore, if {@code keepDim} is false, is excluded from the computation of the rank. If {@code} + * keepDim is true, the dimension is collapsed down to one element. + * + *
For example, given a 3D matrix on the axis [x, y, z], if + * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is + * {@code x.numElements()} + * + * @param coord coordinate of the element on the indexed axis + * @param keepDim whether to remove the dimension. + * @return index + */ + public static Index at(long coord, boolean keepDim) { + return new At(coord, keepDim); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *
This is equivalent to call {@link #at(long, boolean)} but where the value of the coordinate is + * provided by an N-dimensional array. + *
+ * If {@code} keepDim is true, the dimension is collapsed down to one element instead of being removed. + * + * @param coord scalar indicating the coordinate of the element on the indexed axis + * @param keepDim whether to remove the dimension. + * @return index + * @throws IllegalRankException if {@code coord} is not a scalar (rank 0) + */ + public static Index at(NdArray extends Number> coord, boolean keepDim) { + if (coord.rank() > 0) { + throw new IllegalRankException("Only scalars are accepted as a value index"); + } + return new At(coord.getObject().longValue(), keepDim); } /** @@ -110,8 +149,7 @@ public static Index seq(NdArray extends Number> coords) { } /** - * An index that returns only elements found at an even position in the - * original dimension. + * An index that returns only elements found at an even position in the original dimension. * *
For example, given a vector with {@code n} elements on the {@code x} axis, and n is even, * {@code even()} returns x0, x2, ..., xn-2 @@ -119,12 +157,11 @@ public static Index seq(NdArray extends Number> coords) { * @return index */ public static Index even() { - return Even.INSTANCE; + return step(2); } /** - * An index that returns only elements found at an odd position in the - * original dimension. + * An index that returns only elements found at an odd position in the original dimension. * *
For example, given a vector with {@code n} elements on the {@code x} axis, and n is even, * {@code odd()} returns x1, x3, ..., xn-1 @@ -132,7 +169,7 @@ public static Index even() { * @return index */ public static Index odd() { - return Odd.INSTANCE; + return sliceFrom(1, 2); } /** @@ -141,30 +178,44 @@ public static Index odd() { *
For example, given a vector with {@code n} elements on the {@code x} axis, * {@code step(k)} returns x0, xk, xk*2, ... * - * @param stepLength the number of elements between each steps + * @param stride the number of elements between each steps + * @return index + */ + public static Index step(long stride) { + return new Step(stride); + } + + /** + * An index that returns only elements on a given dimension starting at a specific coordinate. + * + *
For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, + * {@code from(k)} returns xk, xk+1, ..., xn-1 + * + * @param start coordinate of the first element of the sequence * @return index */ - public static Index step(long stepLength) { - return new Step(stepLength); + public static Index sliceFrom(long start) { + return sliceFrom(start, 1); } /** - * An index that returns only elements on a given dimension starting at a - * specific coordinate. + * An index that returns only elements on a given dimension starting at a specific coordinate, using the given + * stride. * *
For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, * {@code from(k)} returns xk, xk+1, ..., xn-1 * * @param start coordinate of the first element of the sequence + * @param stride the stride to use * @return index + * @see #slice(long, long, long) */ - public static Index from(long start) { - return new From(start); + public static Index sliceFrom(long start, long stride) { + return new SliceFrom(start, stride); } /** - * An index that returns only elements on a given dimension up to a - * specific coordinate. + * An index that returns only elements on a given dimension up to a specific coordinate. * *
For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, * {@code to(k)} returns x0, x1, ..., xk @@ -172,8 +223,23 @@ public static Index from(long start) { * @param end coordinate of the last element of the sequence (exclusive) * @return index */ - public static Index to(long end) { - return new To(end); + public static Index sliceTo(long end) { + return sliceTo(end, 1); + } + + /** + * An index that returns only elements on a given dimension up to a specific coordinate, using the given stride. + * + *
For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, + * {@code to(k)} returns x0, x1, ..., xk + * + * @param end coordinate of the last element of the sequence (exclusive) + * @param stride the stride to use + * @return index + * @see #slice(long, long, long) + */ + public static Index sliceTo(long end, long stride) { + return new SliceTo(end, stride); } /** @@ -187,7 +253,7 @@ public static Index to(long end) { * @return index */ public static Index range(long start, long end) { - return new Range(start, end); + return slice(start, end); } /** @@ -199,21 +265,99 @@ public static Index range(long start, long end) { * @return index */ public static Index flip() { - return Flip.INSTANCE; + return slice(null, null, -1); } - + /** - * An index that returns elements according to an hyperslab defined by {@code start}, - * {@code stride}, {@code count}, {@code block}. See {@link Hyperslab}. - * + * An index that returns elements according to an hyperslab defined by {@code start}, {@code stride}, {@code count}, + * {@code block}. See {@link Hyperslab}. + * * @param start Starting location for the hyperslab. * @param stride The number of elements to separate each element or block to be selected. * @param count The number of elements or blocks to select along the dimension. * @param block The size of the block selected from the dimension. - * * @return index */ public static Index hyperslab(long start, long stride, long count, long block) { return new Hyperslab(start, stride, count, block); } + + /** + * An index that inserts a new dimension of size 1 into the resulting array. + * + * @return index + */ + public static Index newAxis() { + return NewAxis.INSTANCE; + } + + /** + * An index that expands to fill all available source dimensions. Works the same as Python's {@code ...}. + * + * @return index + */ + public static Index ellipsis() { + return Ellipsis.INSTANCE; + } + + /** + * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code + * null}, starts or ends at the beginning or the end, respectively. + *
+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(long start, long end) { + return slice(start, end, 1); + } + + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or + * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *
+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(long start, long end, long stride) { + return new Slice(start, end, stride); + } + + /** + * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code + * null}, starts or ends at the beginning or the end, respectively. + *
+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(Long start, Long end) { + return slice(start, end, 1); + } + + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or + * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *
+ * Analogous to Python's {@code :} slice syntax.
+ *
+ * @return index
+ */
+ public static Index slice(Long start, Long end, long stride) {
+ if (start == null && end == null) {
+ if (stride == 1) {
+ return Indices.all();
+ } else {
+ return Indices.step(stride);
+ }
+ } else if (start == null) {
+ return Indices.sliceTo(end, stride);
+ } else if (end == null) {
+ return Indices.sliceFrom(start, stride);
+ }
+
+ return slice(start.longValue(), end.longValue(), stride);
+ }
+
}
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/To.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java
similarity index 65%
rename from ndarray/src/main/java/org/tensorflow/ndarray/index/To.java
rename to ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java
index 167d1c6865e..a68b1ed9ad1 100644
--- a/ndarray/src/main/java/org/tensorflow/ndarray/index/To.java
+++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java
@@ -1,5 +1,5 @@
/*
- Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,17 +12,23 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
- =======================================================================
+ ==============================================================================
*/
package org.tensorflow.ndarray.index;
import org.tensorflow.ndarray.impl.dimension.Dimension;
-final class To implements Index {
+final class NewAxis implements Index {
+
+ static final NewAxis INSTANCE = new NewAxis();
+
+ private NewAxis() {
+
+ }
@Override
public long numElements(Dimension dim) {
- return end;
+ return 1;
}
@Override
@@ -30,9 +36,18 @@ public long mapCoordinate(long coordinate, Dimension dim) {
return coordinate;
}
- To(long end) {
- this.end = end;
+ @Override
+ public Dimension apply(Dimension dim) {
+ throw new IllegalStateException();
+ }
+
+ @Override
+ public boolean isNewAxis() {
+ return true;
}
- private final long end;
+ @Override
+ public String toString() {
+ return NewAxis.class.getSimpleName() + "()";
+ }
}
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java
deleted file mode 100644
index 070331f1ffb..00000000000
--- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- =======================================================================
- */
-package org.tensorflow.ndarray.index;
-
-import org.tensorflow.ndarray.impl.dimension.Dimension;
-
-final class Odd implements Index {
-
- static final Odd INSTANCE = new Odd();
-
- @Override
- public long numElements(Dimension dim) {
- return dim.numElements() >> 1;
- }
-
- @Override
- public long mapCoordinate(long coordinate, Dimension dim) {
- return (coordinate << 1) + 1;
- }
-
- private Odd() {
- }
-}
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java
deleted file mode 100644
index e5d6003d87b..00000000000
--- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- =======================================================================
- */
-package org.tensorflow.ndarray.index;
-
-import org.tensorflow.ndarray.impl.dimension.Dimension;
-
-final class Range implements Index {
-
- @Override
- public long numElements(Dimension dim) {
- return end - start;
- }
-
- @Override
- public long mapCoordinate(long coordinate, Dimension dim) {
- return start + coordinate;
- }
-
- Range(long start, long end) {
- this.start = start;
- this.end = end;
- }
-
- private final long start;
- private final long end;
-}
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java
index 41d37d05806..5b93e434e54 100644
--- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java
+++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java
@@ -16,6 +16,7 @@
*/
package org.tensorflow.ndarray.index;
+import java.util.StringJoiner;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.impl.dimension.Dimension;
@@ -36,4 +37,16 @@ public long mapCoordinate(long coordinate, Dimension dim) {
}
private final NdArray extends Number> coords;
+
+ @Override
+ public String toString() {
+ return new StringJoiner(", ", Sequence.class.getSimpleName() + "(", ")")
+ .add("coords=" + coords)
+ .toString();
+ }
+
+ @Override
+ public boolean isStridedSlicingCompliant() {
+ return false;
+ }
}
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java
new file mode 100644
index 00000000000..1be4368261c
--- /dev/null
+++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java
@@ -0,0 +1,89 @@
+/*
+ Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow.ndarray.index;
+
+import java.util.StringJoiner;
+import org.tensorflow.ndarray.impl.dimension.Dimension;
+
+final class Slice implements Index {
+
+ Slice(long start, long end, long stride) {
+ this.start = start;
+ this.end = end;
+ this.stride = stride;
+
+ if (stride == 0) {
+ throw new IllegalArgumentException("Can not have a stride of 0");
+ }
+ }
+
+ @Override
+ public long numElements(Dimension dim) {
+ long length = end(dim) - start(dim);
+
+ return (length / stride) + (length % stride != 0 ? 1 : 0);
+ }
+
+ @Override
+ public long mapCoordinate(long coordinate, Dimension dim) {
+ return start(dim) + stride * coordinate;
+ }
+
+ @Override
+ public long begin() {
+ return start;
+ }
+
+ @Override
+ public long end() {
+ return end;
+ }
+
+ @Override
+ public long stride() {
+ return stride;
+ }
+
+ @Override
+ public String toString() {
+ return new StringJoiner(", ", Slice.class.getSimpleName() + "(", ")")
+ .add("start=" + start)
+ .add("end=" + end)
+ .add("stride=" + stride)
+ .toString();
+ }
+
+ private long start(Dimension dim) {
+ if (start < 0) {
+ return dim.numElements() + start;
+ }
+
+ return start;
+ }
+
+ private long end(Dimension dim) {
+ if (end < 0) {
+ return dim.numElements() + end;
+ } else {
+ return end;
+ }
+ }
+
+ private final long start;
+ private final long end;
+ private final long stride;
+}
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java
new file mode 100644
index 00000000000..c968a325cf7
--- /dev/null
+++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java
@@ -0,0 +1,86 @@
+/*
+ Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow.ndarray.index;
+
+import java.util.StringJoiner;
+import org.tensorflow.ndarray.impl.dimension.Dimension;
+
+final class SliceFrom implements Index {
+
+ SliceFrom(long start, long stride) {
+ this.start = start;
+ this.stride = stride;
+
+ if (stride == 0) {
+ throw new IllegalArgumentException("Can not have a stride of 0");
+ }
+ }
+
+ @Override
+ public long numElements(Dimension dim) {
+ long length = end(dim) - start(dim);
+
+ return (length / stride) + (length % stride != 0 ? 1 : 0);
+ }
+
+ @Override
+ public long mapCoordinate(long coordinate, Dimension dim) {
+ return start(dim) + stride * coordinate;
+ }
+
+ @Override
+ public long begin() {
+ return start;
+ }
+
+ @Override
+ public boolean endMask() {
+ return true;
+ }
+
+ @Override
+ public long stride() {
+ return stride;
+ }
+
+ @Override
+ public String toString() {
+ return new StringJoiner(", ", SliceFrom.class.getSimpleName() + "(", ")")
+ .add("start=" + start)
+ .add("stride=" + stride)
+ .toString();
+ }
+
+ private long start(Dimension dim) {
+ if (start < 0) {
+ return dim.numElements() + start;
+ }
+
+ return start;
+ }
+
+ private long end(Dimension dim) {
+ if (stride > 0) {
+ return dim.numElements();
+ } else {
+ return -1; // it's exclusive
+ }
+ }
+
+ private final long start;
+ private final long stride;
+}
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java
new file mode 100644
index 00000000000..761d1d52a3a
--- /dev/null
+++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java
@@ -0,0 +1,86 @@
+/*
+ Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow.ndarray.index;
+
+import java.util.StringJoiner;
+import org.tensorflow.ndarray.impl.dimension.Dimension;
+
+final class SliceTo implements Index {
+
+ SliceTo(long end, long stride) {
+ this.end = end;
+ this.stride = stride;
+
+ if (stride == 0) {
+ throw new IllegalArgumentException("Can not have a stride of 0");
+ }
+ }
+
+ @Override
+ public long numElements(Dimension dim) {
+ long length = end(dim) - start(dim);
+
+ return (length / stride) + (length % stride != 0 ? 1 : 0);
+ }
+
+ @Override
+ public long mapCoordinate(long coordinate, Dimension dim) {
+ return start(dim) + stride * coordinate;
+ }
+
+ @Override
+ public long end() {
+ return end;
+ }
+
+ @Override
+ public boolean beginMask() {
+ return true;
+ }
+
+ @Override
+ public long stride() {
+ return stride;
+ }
+
+ @Override
+ public String toString() {
+ return new StringJoiner(", ", SliceTo.class.getSimpleName() + "(", ")")
+ .add("end=" + end)
+ .add("stride=" + stride)
+ .toString();
+ }
+
+ private long start(Dimension dim) {
+ if (stride > 0) {
+ return 0;
+ }
+
+ return dim.numElements() - 1; // it's inclusive
+ }
+
+ private long end(Dimension dim) {
+ if (end < 0) {
+ return dim.numElements() + end;
+ } else {
+ return end;
+ }
+ }
+
+ private final long end;
+ private final long stride;
+}
diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java
index 725abd8f2e7..c9a21c507b6 100644
--- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java
+++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java
@@ -1,5 +1,5 @@
/*
- Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -12,27 +12,72 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
- =======================================================================
+ ==============================================================================
*/
package org.tensorflow.ndarray.index;
+import java.util.StringJoiner;
import org.tensorflow.ndarray.impl.dimension.Dimension;
final class Step implements Index {
+ Step(long stride) {
+ this.stride = stride;
+
+ if (stride == 0) {
+ throw new IllegalArgumentException("Can not have a stride of 0");
+ }
+ }
+
@Override
public long numElements(Dimension dim) {
- return (dim.numElements() / stepLength) + 1; // FIXME always include element 0?
+ long length = end(dim) - start(dim);
+
+ return (length / stride) + (length % stride != 0 ? 1 : 0);
}
@Override
public long mapCoordinate(long coordinate, Dimension dim) {
- return coordinate * stepLength;
+ return start(dim) + stride * coordinate;
+ }
+
+ @Override
+ public boolean beginMask() {
+ return true;
+ }
+
+ @Override
+ public boolean endMask() {
+ return true;
+ }
+
+ @Override
+ public long stride() {
+ return stride;
+ }
+
+ @Override
+ public String toString() {
+ return new StringJoiner(", ", Step.class.getSimpleName() + "(", ")")
+ .add("stride=" + stride)
+ .toString();
+ }
+
+ private long start(Dimension dim) {
+ if (stride > 0) {
+ return 0;
+ }
+
+ return dim.numElements() - 1; // it's inclusive
}
- Step(long stepLength) {
- this.stepLength = stepLength;
+ private long end(Dimension dim) {
+ if (stride > 0) {
+ return dim.numElements();
+ } else {
+ return -1; // it's exclusive
+ }
}
- private final long stepLength;
+ private final long stride;
}
diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java
new file mode 100644
index 00000000000..6f92dab9b99
--- /dev/null
+++ b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java
@@ -0,0 +1,205 @@
+/*
+ Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================
+ */
+package org.tensorflow.ndarray;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import org.junit.jupiter.api.Test;
+import org.tensorflow.ndarray.index.Indices;
+
+public class IndexTest {
+ @Test
+ public void testNullConversions(){
+ assertTrue(Indices.slice(null, 0L).beginMask(),
+ "Passed null for slice start but didn't set begin mask");
+
+ assertTrue(Indices.slice(null, 0L).beginMask(),
+ "Passed null for slice start but didn't set begin mask");
+
+ assertTrue(Indices.slice(null, null).beginMask(),
+ "Passed null for slice start but didn't set begin mask");
+
+ assertTrue(Indices.slice(0L, null).endMask(),
+ "Passed null for slice end but didn't set end mask");
+
+ assertTrue(Indices.slice(0L, null).endMask(),
+ "Passed null for slice end but didn't set end mask");
+
+ assertTrue(Indices.slice(null, null).endMask(),
+ "Passed null for slice end but didn't set end mask");
+ }
+
+ @Test
+ public void testNewaxis(){
+ IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5));
+
+ matrix3d.scalars().forEachIndexed((coords, scalar) ->
+ scalar.setInt((int)coords[2])
+ );
+
+ IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.all(), Indices.newAxis());
+
+ assertEquals(Shape.of(5, 4, 5, 1), slice1.shape());
+ assertEquals(0, slice1.getInt(0, 0, 0, 0));
+ assertEquals(1, slice1.getInt(0, 0, 1, 0));
+ assertEquals(4, slice1.getInt(0, 0, 4, 0));
+ assertEquals(2, slice1.getInt(0, 1, 2, 0));
+
+ IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.all());
+
+ assertEquals(Shape.of(5, 4, 1, 5), slice2.shape());
+ assertEquals(0, slice2.getInt(0, 0, 0, 0));
+ assertEquals(1, slice2.getInt(0, 0, 0, 1));
+ assertEquals(4, slice2.getInt(0, 0, 0, 4));
+ assertEquals(2, slice2.getInt(0, 1, 0, 2));
+
+ IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.newAxis(), Indices.all(), Indices.all());
+
+ assertEquals(Shape.of(5, 1, 4, 5), slice3.shape());
+ assertEquals(0, slice3.getInt(0, 0, 0, 0));
+ assertEquals(1, slice3.getInt(0, 0, 0, 1));
+ assertEquals(4, slice3.getInt(0, 0, 0, 4));
+ assertEquals(2, slice3.getInt(0, 0, 1, 2));
+
+ IntNdArray slice4 = matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.all());
+
+ assertEquals(Shape.of(1, 5, 4, 5), slice4.shape());
+ assertEquals(0, slice4.getInt(0, 0, 0, 0));
+ assertEquals(1, slice4.getInt(0, 0, 0, 1));
+ assertEquals(4, slice4.getInt(0, 0, 0, 4));
+ assertEquals(2, slice4.getInt(0, 0, 1, 2));
+
+ }
+
+ @Test
+ public void testEllipsis(){
+ IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5));
+
+ matrix3d.scalars().forEachIndexed((coords, scalar) ->
+ scalar.setInt((int)coords[2])
+ );
+
+ assertEquals(
+ matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0)),
+ matrix3d.slice(Indices.ellipsis(), Indices.at(0))
+ );
+
+ assertEquals(
+ matrix3d.slice(Indices.at(0), Indices.all(), Indices.all()),
+ matrix3d.slice(Indices.at(0), Indices.ellipsis())
+ );
+
+ assertEquals(
+ matrix3d.slice(Indices.at(0), Indices.all(), Indices.at(0)),
+ matrix3d.slice(Indices.at(0), Indices.ellipsis(), Indices.at(0))
+ );
+
+ // newaxis interacts specially with ellipsis (since it doesn't consume a dimension), test this
+
+ assertEquals(
+ matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.at(0)),
+ matrix3d.slice(Indices.ellipsis(), Indices.newAxis(), Indices.at(0))
+ );
+
+ assertEquals(
+ matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.at(0)),
+ matrix3d.slice(Indices.newAxis(), Indices.ellipsis(), Indices.at(0))
+ );
+
+ assertEquals(
+ matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0), Indices.newAxis()),
+ matrix3d.slice(Indices.ellipsis(), Indices.at(0), Indices.newAxis())
+ );
+ }
+
+ @Test
+ public void testSlice(){
+ IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5));
+
+ matrix3d.scalars().forEachIndexed((coords, scalar) ->
+ scalar.setInt((int)coords[2])
+ );
+
+ IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.sliceTo(3), Indices.all());
+
+ assertEquals(Shape.of(5, 3, 5), slice1.shape());
+ assertEquals(0, slice1.getInt(0, 0, 0));
+ assertEquals(1, slice1.getInt(0, 0, 1));
+ assertEquals(2, slice1.getInt(0, 1, 2));
+
+ IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4));
+
+ assertEquals(Shape.of(5, 4, 3), slice2.shape());
+ assertEquals(1, slice2.getInt(0, 0, 0));
+ assertEquals(3, slice2.getInt(0, 0, 2));
+ assertEquals(2, slice2.getInt(0, 1, 1));
+
+ assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, -1)));
+
+ assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-4, -1)));
+
+ assertEquals(Shape.of(5, 4, 0), matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4, -2)).shape());
+
+ IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(4, 1, -2));
+
+ assertEquals(Shape.of(5, 4, 2), slice3.shape());
+ assertEquals(4, slice3.getInt(0, 0, 0));
+ assertEquals(2, slice3.getInt(0, 1, 1));
+
+ assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, 1, -2)));
+
+ assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, -4, -2)));
+
+ IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(null, null, -1));
+
+ assertEquals(Shape.of(5, 4, 5), slice4.shape());
+ assertEquals(4, slice4.getInt(0, 0, 0));
+ assertEquals(3, slice4.getInt(0, 0, 1));
+ assertEquals(2, slice4.getInt(0, 1, 2));
+ }
+
+ @Test
+ public void testAt(){
+ IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5));
+
+ matrix3d.scalars().forEachIndexed((coords, scalar) ->
+ scalar.setInt((int)coords[2])
+ );
+
+ IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0));
+
+ assertEquals(Shape.of(5, 4), slice1.shape());
+ assertEquals(0, slice1.getInt(0, 0));
+
+ IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(3));
+
+ assertEquals(Shape.of(5, 4), slice2.shape());
+ assertEquals(3, slice2.getInt(0, 0));
+
+ IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3));
+
+ assertEquals(Shape.of(5, 4), slice3.shape());
+ assertEquals(2, slice3.getInt(0, 0));
+
+ IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3, true));
+
+ assertEquals(Shape.of(5, 4, 1), slice4.shape());
+ assertEquals(2, slice4.getInt(0, 0, 0));
+ }
+
+}
diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java
index 1c1d89680e7..26ac533daa8 100644
--- a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java
+++ b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java
@@ -24,11 +24,11 @@
import static org.tensorflow.ndarray.index.Indices.at;
import static org.tensorflow.ndarray.index.Indices.even;
import static org.tensorflow.ndarray.index.Indices.flip;
-import static org.tensorflow.ndarray.index.Indices.from;
+import static org.tensorflow.ndarray.index.Indices.sliceFrom;
import static org.tensorflow.ndarray.index.Indices.odd;
import static org.tensorflow.ndarray.index.Indices.range;
import static org.tensorflow.ndarray.index.Indices.seq;
-import static org.tensorflow.ndarray.index.Indices.to;
+import static org.tensorflow.ndarray.index.Indices.sliceTo;
import java.nio.BufferOverflowException;
import java.nio.BufferUnderflowException;
@@ -212,13 +212,13 @@ public void slices() {
assertEquals(val101, vector10_flip.getObject(3));
// Vector (1,0,[from 1]) from vector (1,0,*)
- NdArray
+ * The goal of this op is to produce a new tensor with a subset of the elements from the `n` dimensional `input`
+ * tensor. The subset is chosen using a sequence of `m` sparse range specifications encoded into the arguments of this
+ * function. Note, in some cases `m` could be equal to `n`, but this need not be the case. Each range specification
+ * entry can be one of the following:
+ *
+ * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more dimensions of
+ * full-dimension selection. For example, {@code stridedSlice(foo, Indices.ellipsis()} is the identity slice.
+ *
+ * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension.
+ * For example, `{@code stridedSlice(foo, Indices.newAxis())} where {@code foo} is shape {@code (3, 4)}
+ * produces a {@code (1, 3, 4)} tensor.
+ *
+ * - A range {@code begin:end:stride} using {@link Indices#slice(Long, Long, long)} Index.slice()} or {@link Indices#all()}. This is used to specify
+ * how much to choose from a given dimension. {@code stride} can be any integer but 0. {@code begin} is an integer which
+ * represents the index of the first value to select while {@code end} represents the index of the last value to select
+ * (exclusive). Begin and end can be null, in which case the index begins or ends at the beginning or end of the dimension,
+ * respectively (reversed if stride is negative). When both are null, {@code slice()} is the same as {@code all()}.
+ * The number of values selected in each dimension is {@code end - begin} if {@code stride > 0} and {@code begin - end}
+ * if {@code stride < 0}. {@code begin} and {@code end} can be negative where {@code -1} is the last element, {@code -2}
+ * is the second to last. For example, given a shape {@code (3,)} tensor {@code stridedSlice(foo, Indices.all())}, the
+ * effective {@code begin} and {@code end} are {@code 0} and {@code 3}. Do not assume this is equivalent to
+ * {@code stridedSlice(foo, Indices.slice(0, -1))} which has an effective {@code begin} and {@code end} of {@code 0} and
+ * {@code 2}. Another example is {@code stridedSlice(foo, Indices.slice(-2, null, -1))} which reverses the first dimension
+ * of a tensor while dropping the last two (in the original order elements). For example {@code foo = [1,2,3,4];
+ * stridedSlice(foo, Indices.slice(-2, null, -1)} is {@code [4,3]}.
+ *
+ * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given index. For
+ * example ({@code stridedSlice(foo, Indices.at(2))} on a shape {@code (5,6)} tensor produces a shape {@code (6,)} tensor.
+ * The dimension can be kept with size one using {@link Indices#at(long, boolean)}.
+ *
+ * These semantics generally follow NumPy's indexing semantics, which can be found here:
+ * https://numpy.org/doc/stable/reference/arrays.indexing.html
+ *
+ *
+ * Requirements:
+ * `0 != strides[i] for i in [0, m)` Only one ellipsis.
+ *
+ * @param scope current scope
+ * @param
@@ -6012,6 +6064,28 @@ public
+ * The values of `value` are assigned to the positions in the variable `ref` that are selected by the slice
+ * parameters. The slice parameters `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`.
+ *
+ * NOTE this op currently does not support broadcasting and so `value`'s shape must be exactly the shape produced by
+ * the slice of `ref`.
+ *
+ * @param
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java
new file mode 100644
index 00000000000..e97934ee312
--- /dev/null
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java
@@ -0,0 +1,221 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.op.core;
+
+import org.tensorflow.Operand;
+import org.tensorflow.ndarray.index.Indices;
+import org.tensorflow.ndarray.index.Index;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Endpoint;
+import org.tensorflow.op.annotation.Operator;
+import org.tensorflow.types.family.TType;
+
+/**
+ * Helper endpoint methods for Python like indexing.
+ *
+ * @see org.tensorflow.ndarray.index.Indices
+ */
+@Operator
+public abstract class StridedSliceHelper {
+
+ static class StridedSliceArgs {
+
+ final int[] begin;
+ final int[] end;
+ final int[] strides;
+ final long beginMask;
+ final long endMask;
+ final long ellipsisMask;
+ final long newAxisMask;
+ final long shrinkAxisMask;
+
+ private StridedSliceArgs(int[] begin, int[] end, int[] strides, long beginMask, long endMask, long ellipsisMask,
+ long newAxisMask, long shrinkAxisMask) {
+ this.begin = begin;
+ this.end = end;
+ this.strides = strides;
+ this.beginMask = beginMask;
+ this.endMask = endMask;
+ this.ellipsisMask = ellipsisMask;
+ this.newAxisMask = newAxisMask;
+ this.shrinkAxisMask = shrinkAxisMask;
+ }
+ }
+
+ static StridedSliceArgs mergeIndexes(Index[] indices) {
+ int[] begin = new int[indices.length];
+ int[] end = new int[indices.length];
+ int[] strides = new int[indices.length];
+ long beginMask = 0;
+ long endMask = 0;
+ long ellipsisMask = 0;
+ long newAxisMask = 0;
+ long shrinkAxisMask = 0;
+
+ for (int i = 0; i < indices.length; i++) {
+ Index idx = indices[i];
+
+ if (idx == null) {
+ idx = Indices.all();
+ }
+
+ if (!idx.isStridedSlicingCompliant()) {
+ throw new UnsupportedOperationException("Index " + idx + " is not supported for Tensors");
+ }
+
+ begin[i] = (int) idx.begin();
+ if (begin[i] != idx.begin()) {
+ throw new IllegalArgumentException(
+ "Can't convert long begin value to int for index " + idx + ": Out of bounds");
+ }
+
+ end[i] = (int) idx.end();
+ if (end[i] != idx.end()) {
+ throw new IllegalArgumentException("Can't convert long end value to int for index " + idx + ": Out of bounds");
+ }
+
+ strides[i] = (int) idx.stride();
+ if (strides[i] != idx.stride()) {
+ throw new IllegalArgumentException(
+ "Can't convert long stride value to int for index " + idx + ": Out of bounds");
+ }
+
+ if (idx.beginMask()) {
+ beginMask |= 1L << i;
+ }
+
+ if (idx.endMask()) {
+ endMask |= 1L << i;
+ }
+
+ if (idx.isEllipsis()) {
+ if (ellipsisMask != 0) {
+ throw new IllegalArgumentException("Can not have two ellipsis in a slice");
+ }
+ ellipsisMask |= 1L << i;
+ }
+
+ if (idx.isNewAxis()) {
+ newAxisMask |= 1L << i;
+ }
+
+ if (idx.isPoint()) {
+ shrinkAxisMask |= 1L << i;
+ }
+ }
+
+ return new StridedSliceArgs(begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask);
+ }
+
+ /**
+ * Return a strided slice from `input`.
+ *
+ * The goal of this op is to produce a new tensor with a subset of the elements from the `n` dimensional `input`
+ * tensor. The subset is chosen using a sequence of `m` sparse range specifications encoded into the arguments of this
+ * function. Note, in some cases `m` could be equal to `n`, but this need not be the case. Each range specification
+ * entry can be one of the following:
+ *
+ * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more dimensions of
+ * full-dimension selection. For example, {@code stridedSlice(foo, Indices.ellipsis()} is the identity slice.
+ *
+ * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension.
+ * For example, `{@code stridedSlice(foo, Indices.newAxis())} where {@code foo} is shape {@code (3, 4)}
+ * produces a {@code (1, 3, 4)} tensor.
+ *
+ * - A range {@code begin:end:stride} using {@link Indices#slice(Long, Long, long)} Index.slice()} or {@link Indices#all()}. This is used to specify
+ * how much to choose from a given dimension. {@code stride} can be any integer but 0. {@code begin} is an integer which
+ * represents the index of the first value to select while {@code end} represents the index of the last value to select
+ * (exclusive). Begin and end can be null, in which case the index begins or ends at the beginning or end of the dimension,
+ * respectively (reversed if stride is negative). When both are null, {@code slice()} is the same as {@code all()}.
+ * The number of values selected in each dimension is {@code end - begin} if {@code stride > 0} and {@code begin - end}
+ * if {@code stride < 0}. {@code begin} and {@code end} can be negative where {@code -1} is the last element, {@code -2}
+ * is the second to last. For example, given a shape {@code (3,)} tensor {@code stridedSlice(foo, Indices.all())}, the
+ * effective {@code begin} and {@code end} are {@code 0} and {@code 3}. Do not assume this is equivalent to
+ * {@code stridedSlice(foo, Indices.slice(0, -1))} which has an effective {@code begin} and {@code end} of {@code 0} and
+ * {@code 2}. Another example is {@code stridedSlice(foo, Indices.slice(-2, null, -1))} which reverses the first dimension
+ * of a tensor while dropping the last two (in the original order elements). For example {@code foo = [1,2,3,4];
+ * stridedSlice(foo, Indices.slice(-2, null, -1)} is {@code [4,3]}.
+ *
+ * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given index. For
+ * example ({@code stridedSlice(foo, Indices.at(2))} on a shape {@code (5,6)} tensor produces a shape {@code (6,)} tensor.
+ * The dimension can be kept with size one using {@link Indices#at(long, boolean)}.
+ *
+ * These semantics generally follow NumPy's indexing semantics, which can be found here:
+ * https://numpy.org/doc/stable/reference/arrays.indexing.html
+ *
+ *
+ * Requirements:
+ * `0 != strides[i] for i in [0, m)` Only one ellipsis.
+ *
+ * @param scope current scope
+ * @param
+ * The values of `value` are assigned to the positions in the variable `ref` that are selected by the slice
+ * parameters. The slice parameters `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`.
+ *
+ * NOTE this op currently does not support broadcasting and so `value`'s shape must be exactly the shape produced by
+ * the slice of `ref`.
+ *
+ * @param