diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java index e4bdc53c713..7d0f0222bbe 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java @@ -18,8 +18,8 @@ package org.tensorflow.ndarray.impl.dimension; import java.util.Arrays; -import org.tensorflow.ndarray.index.Index; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Index; public class DimensionalSpace { @@ -35,24 +35,42 @@ public static DimensionalSpace create(Shape shape) { } public RelativeDimensionalSpace mapTo(Index[] indices) { - if (dimensions == null || indices.length > dimensions.length) { + if (dimensions == null) { throw new ArrayIndexOutOfBoundsException(); } int dimIdx = 0; + int indexIdx = 0; int newDimIdx = 0; int segmentationIdx = -1; long initialOffset = 0; - Dimension[] newDimensions = new Dimension[dimensions.length]; - while (dimIdx < indices.length) { + int newAxes = 0; + boolean seenEllipsis = false; + for (Index idx : indices) { + if (idx.isNewAxis()) { + newAxes += 1; + } + if (idx.isEllipsis()) { + if (seenEllipsis) { + throw new IllegalArgumentException("Only one ellipsis allowed"); + } else { + seenEllipsis = true; + } + } + } + int newLength = dimensions.length + newAxes; + + Dimension[] newDimensions = new Dimension[newLength]; + while (indexIdx < indices.length) { - if (indices[dimIdx].isPoint()) { + if (indices[indexIdx].isPoint()) { // When an index targets a single point in a given dimension, calculate the offset of this // point and cumulate the offset of any subsequent point as well long offset = 0; do { - offset += indices[dimIdx].mapCoordinate(0, dimensions[dimIdx]); - } while (++dimIdx < indices.length && indices[dimIdx].isPoint()); + offset += indices[indexIdx].mapCoordinate(0, dimensions[dimIdx]); + dimIdx++; + } while (++indexIdx < indices.length && indices[indexIdx].isPoint()); // If this is the first index, then the offset is the position of the whole dimension // space within the original one. If not, then we apply the offset to the last vectorial @@ -65,14 +83,47 @@ public RelativeDimensionalSpace mapTo(Index[] indices) { segmentationIdx = newDimIdx - 1; } + } else if (indices[indexIdx].isNewAxis()) { + long newSize; + if (dimIdx == 0) { + // includes everything. Should really include future reduction (at()) but that doesn't seem to cause issues + // elsewhere + newSize = dimensions[0].numElements() * dimensions[0].elementSize(); + } else { + newSize = dimensions[dimIdx - 1].elementSize(); + } + + newDimensions[newDimIdx] = new Axis(1, newSize); + segmentationIdx = newDimIdx; // is this correct? + ++newDimIdx; + ++indexIdx; + } else if (indices[indexIdx].isEllipsis()) { + int remainingDimensions = dimensions.length - dimIdx; + int requiredDimensions = 0; + for (int i = indexIdx + 1; i < indices.length; i++) { + if (!indices[i].isNewAxis()) { + requiredDimensions++; + } + } + // while the number of dimensions left < the number of indices that consume axes + while (remainingDimensions > requiredDimensions) { + Dimension dim = dimensions[dimIdx++]; + if (dim.isSegmented()) { + segmentationIdx = newDimIdx; + } + newDimensions[newDimIdx++] = dim; + remainingDimensions--; + } + indexIdx++; } else { // Map any other index to the appropriate dimension of this space - Dimension newDimension = indices[dimIdx].apply(dimensions[dimIdx++]); + Dimension newDimension = indices[indexIdx].apply(dimensions[dimIdx++]); newDimensions[newDimIdx] = newDimension; if (newDimension.isSegmented()) { segmentationIdx = newDimIdx; } ++newDimIdx; + ++indexIdx; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java index b38e33d5e22..9d3139f3248 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java @@ -39,4 +39,19 @@ public Dimension apply(Dimension dim) { private All() { } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public String toString() { + return All.class.getSimpleName() + "()"; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java index 5d92ee3286b..31ce021ddc8 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; final class At implements Index { @@ -27,22 +28,47 @@ public long numElements(Dimension dim) { @Override public long mapCoordinate(long coordinate, Dimension dim) { - return dim.positionOf(coord); // TODO validate coordinate is 0? + long coord = this.coord >= 0 ? this.coord : dim.numElements() + this.coord; + return dim.positionOf(coord); } @Override public Dimension apply(Dimension dim) { - throw new IllegalStateException(); // FIXME? + if (!keepDim) { + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); + } + + return dim.withIndex(this); } @Override public boolean isPoint() { - return true; + return !keepDim; } - At(long coord) { + At(long coord, boolean keepDim) { this.coord = coord; + this.keepDim = keepDim; } private final long coord; + private final boolean keepDim; + + @Override + public long begin() { + return coord; + } + + @Override + public long end() { + return coord + 1; + } + + @Override + public String toString() { + return new StringJoiner(", ", At.class.getSimpleName() + "(", ")") + .add("coord=" + coord) + .add("keepDim=" + keepDim) + .toString(); + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java similarity index 61% rename from ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java rename to ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java index 54f53853c32..d4085735df2 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,26 +12,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ======================================================================= + ============================================================================== */ package org.tensorflow.ndarray.index; import org.tensorflow.ndarray.impl.dimension.Dimension; -final class Even implements Index { +final class Ellipsis implements Index { - static final Even INSTANCE = new Even(); + static final Ellipsis INSTANCE = new Ellipsis(); + + private Ellipsis() { + + } @Override public long numElements(Dimension dim) { - return (dim.numElements() >> 1) + (dim.numElements() % 2); + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); } @Override public long mapCoordinate(long coordinate, Dimension dim) { - return coordinate << 1; + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); } - private Even() { + @Override + public boolean isEllipsis() { + return true; + } + + @Override + public String toString() { + return Ellipsis.class.getSimpleName() + "()"; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java deleted file mode 100644 index 7914d8faad5..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Flip implements Index { - - static final Flip INSTANCE = new Flip(); - - @Override - public long numElements(Dimension dim) { - return dim.numElements(); - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return dim.numElements() - coordinate - 1; - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java deleted file mode 100644 index c541e8370b2..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class From implements Index { - - @Override - public long numElements(Dimension dim) { - return dim.numElements() - start; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return start + coordinate; - } - - From(long start) { - this.start = start; - } - - private final long start; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java index 00b411d0167..55c4e510748 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java @@ -15,6 +15,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; /** @@ -71,4 +72,19 @@ public boolean isPoint() { private final long stride; private final long count; private final long block; + + @Override + public String toString() { + return new StringJoiner(", ", Hyperslab.class.getSimpleName() + "Hyperslab(", ")") + .add("start=" + start) + .add("stride=" + stride) + .add("count=" + count) + .add("block=" + block) + .toString(); + } + + @Override + public boolean isStridedSlicingCompliant() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java index da6aa9049f6..617ca4d474b 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java @@ -23,19 +23,16 @@ * An index used for slicing a view out of an N-dimensional array. * *

A slice, i.e. a reduced view, of an N-dimensional array is obtain by calling - * {@link NdArray#slice(Index...)}, given a list of indices - * that select which elements on a given dimension should be included/excluded - * from that view. + * {@link NdArray#slice(Index...)}, given a list of indices that select which elements on a given dimension should be + * included/excluded from that view. */ public interface Index { /** - * Returns the number of elements that can be retrieved using this index on the - * given dimension. + * Returns the number of elements that can be retrieved using this index on the given dimension. * *

An index that maps one-by-one all elements of the dimensions will return a value - * equal to {@code dim.numElements()}, while an index that only maps a subset of these - * will return a smaller value. + * equal to {@code dim.numElements()}, while an index that only maps a subset of these will return a smaller value. * * @param dim the indexed dimension * @return number of elements accessible @@ -43,8 +40,7 @@ public interface Index { long numElements(Dimension dim); /** - * Transforms an element coordinate to a new coordinate by applying this index to the - * given dimension. + * Transforms an element coordinate to a new coordinate by applying this index to the given dimension. * *

For example, if the coordinate is 0 and this index flips the {@code n} elements on this * dimension, then the returned value will be {@code n-1}. @@ -74,4 +70,62 @@ default Dimension apply(Dimension dim) { default boolean isPoint() { return false; } + + /** + * Returns true if this index is a new axis, adding a dimension of size 1 + */ + default boolean isNewAxis() { + return false; + } + + /** + * Returns true if this index is an ellipsis, expanding to take as many dimensions as possible (and applying all() to + * them) + */ + default boolean isEllipsis() { + return false; + } + + /** + * Get whether the Index supports strided slice style indexing (using start, end, stride, and flags, i.e. TensorFlow's). + */ + default boolean isStridedSlicingCompliant() { + return true; + } + + /** + * Get the start of the index, for strided slice style indexing. + */ + default long begin() { + return 0; + } + + /** + * Get the end of the index, strided slice style indexing. + */ + default long end() { + return 0; + } + + /** + * Get the stride of the index, for strided slice style indexing. + */ + default long stride() { + return 1; + } + + /** + * Get whether the Index should start at the beginning of the dimension, for strided slice style indexing. + */ + default boolean beginMask() { + return false; + } + + /** + * Get whether the Index should end at the beginning of the dimension, for strided slice style indexing. + */ + default boolean endMask() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java index abc72195c82..346ab705595 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -34,14 +34,14 @@ public final class Indices { * single element and therefore is excluded from the computation of the rank. * *

For example, given a 3D matrix on the axis [x, y, z], if - * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its - * number of elements is {@code x.numElements()} + * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is + * {@code x.numElements()} * * @param coord coordinate of the element on the indexed axis * @return index */ public static Index at(long coord) { - return new At(coord); + return new At(coord, false); } /** @@ -58,7 +58,46 @@ public static Index at(NdArray coord) { if (coord.rank() > 0) { throw new IllegalRankException("Only scalars are accepted as a value index"); } - return new At(coord.getObject().longValue()); + return new At(coord.getObject().longValue(), false); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *

When this index is applied to a given dimension, the dimension is resolved as a + * single element and therefore, if {@code keepDim} is false, is excluded from the computation of the rank. If {@code} + * keepDim is true, the dimension is collapsed down to one element. + * + *

For example, given a 3D matrix on the axis [x, y, z], if + * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is + * {@code x.numElements()} + * + * @param coord coordinate of the element on the indexed axis + * @param keepDim whether to remove the dimension. + * @return index + */ + public static Index at(long coord, boolean keepDim) { + return new At(coord, keepDim); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *

This is equivalent to call {@link #at(long, boolean)} but where the value of the coordinate is + * provided by an N-dimensional array. + *

+ * If {@code} keepDim is true, the dimension is collapsed down to one element instead of being removed. + * + * @param coord scalar indicating the coordinate of the element on the indexed axis + * @param keepDim whether to remove the dimension. + * @return index + * @throws IllegalRankException if {@code coord} is not a scalar (rank 0) + */ + public static Index at(NdArray coord, boolean keepDim) { + if (coord.rank() > 0) { + throw new IllegalRankException("Only scalars are accepted as a value index"); + } + return new At(coord.getObject().longValue(), keepDim); } /** @@ -110,8 +149,7 @@ public static Index seq(NdArray coords) { } /** - * An index that returns only elements found at an even position in the - * original dimension. + * An index that returns only elements found at an even position in the original dimension. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and n is even, * {@code even()} returns x0, x2, ..., xn-2 @@ -119,12 +157,11 @@ public static Index seq(NdArray coords) { * @return index */ public static Index even() { - return Even.INSTANCE; + return step(2); } /** - * An index that returns only elements found at an odd position in the - * original dimension. + * An index that returns only elements found at an odd position in the original dimension. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and n is even, * {@code odd()} returns x1, x3, ..., xn-1 @@ -132,7 +169,7 @@ public static Index even() { * @return index */ public static Index odd() { - return Odd.INSTANCE; + return sliceFrom(1, 2); } /** @@ -141,30 +178,44 @@ public static Index odd() { *

For example, given a vector with {@code n} elements on the {@code x} axis, * {@code step(k)} returns x0, xk, xk*2, ... * - * @param stepLength the number of elements between each steps + * @param stride the number of elements between each steps + * @return index + */ + public static Index step(long stride) { + return new Step(stride); + } + + /** + * An index that returns only elements on a given dimension starting at a specific coordinate. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, + * {@code from(k)} returns xk, xk+1, ..., xn-1 + * + * @param start coordinate of the first element of the sequence * @return index */ - public static Index step(long stepLength) { - return new Step(stepLength); + public static Index sliceFrom(long start) { + return sliceFrom(start, 1); } /** - * An index that returns only elements on a given dimension starting at a - * specific coordinate. + * An index that returns only elements on a given dimension starting at a specific coordinate, using the given + * stride. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, * {@code from(k)} returns xk, xk+1, ..., xn-1 * * @param start coordinate of the first element of the sequence + * @param stride the stride to use * @return index + * @see #slice(long, long, long) */ - public static Index from(long start) { - return new From(start); + public static Index sliceFrom(long start, long stride) { + return new SliceFrom(start, stride); } /** - * An index that returns only elements on a given dimension up to a - * specific coordinate. + * An index that returns only elements on a given dimension up to a specific coordinate. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, * {@code to(k)} returns x0, x1, ..., xk @@ -172,8 +223,23 @@ public static Index from(long start) { * @param end coordinate of the last element of the sequence (exclusive) * @return index */ - public static Index to(long end) { - return new To(end); + public static Index sliceTo(long end) { + return sliceTo(end, 1); + } + + /** + * An index that returns only elements on a given dimension up to a specific coordinate, using the given stride. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, + * {@code to(k)} returns x0, x1, ..., xk + * + * @param end coordinate of the last element of the sequence (exclusive) + * @param stride the stride to use + * @return index + * @see #slice(long, long, long) + */ + public static Index sliceTo(long end, long stride) { + return new SliceTo(end, stride); } /** @@ -187,7 +253,7 @@ public static Index to(long end) { * @return index */ public static Index range(long start, long end) { - return new Range(start, end); + return slice(start, end); } /** @@ -199,21 +265,99 @@ public static Index range(long start, long end) { * @return index */ public static Index flip() { - return Flip.INSTANCE; + return slice(null, null, -1); } - + /** - * An index that returns elements according to an hyperslab defined by {@code start}, - * {@code stride}, {@code count}, {@code block}. See {@link Hyperslab}. - * + * An index that returns elements according to an hyperslab defined by {@code start}, {@code stride}, {@code count}, + * {@code block}. See {@link Hyperslab}. + * * @param start Starting location for the hyperslab. * @param stride The number of elements to separate each element or block to be selected. * @param count The number of elements or blocks to select along the dimension. * @param block The size of the block selected from the dimension. - * * @return index */ public static Index hyperslab(long start, long stride, long count, long block) { return new Hyperslab(start, stride, count, block); } + + /** + * An index that inserts a new dimension of size 1 into the resulting array. + * + * @return index + */ + public static Index newAxis() { + return NewAxis.INSTANCE; + } + + /** + * An index that expands to fill all available source dimensions. Works the same as Python's {@code ...}. + * + * @return index + */ + public static Index ellipsis() { + return Ellipsis.INSTANCE; + } + + /** + * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code + * null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(long start, long end) { + return slice(start, end, 1); + } + + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or + * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(long start, long end, long stride) { + return new Slice(start, end, stride); + } + + /** + * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code + * null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(Long start, Long end) { + return slice(start, end, 1); + } + + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or + * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(Long start, Long end, long stride) { + if (start == null && end == null) { + if (stride == 1) { + return Indices.all(); + } else { + return Indices.step(stride); + } + } else if (start == null) { + return Indices.sliceTo(end, stride); + } else if (end == null) { + return Indices.sliceFrom(start, stride); + } + + return slice(start.longValue(), end.longValue(), stride); + } + } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/To.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java similarity index 65% rename from ndarray/src/main/java/org/tensorflow/ndarray/index/To.java rename to ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java index 167d1c6865e..a68b1ed9ad1 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/To.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,17 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ======================================================================= + ============================================================================== */ package org.tensorflow.ndarray.index; import org.tensorflow.ndarray.impl.dimension.Dimension; -final class To implements Index { +final class NewAxis implements Index { + + static final NewAxis INSTANCE = new NewAxis(); + + private NewAxis() { + + } @Override public long numElements(Dimension dim) { - return end; + return 1; } @Override @@ -30,9 +36,18 @@ public long mapCoordinate(long coordinate, Dimension dim) { return coordinate; } - To(long end) { - this.end = end; + @Override + public Dimension apply(Dimension dim) { + throw new IllegalStateException(); + } + + @Override + public boolean isNewAxis() { + return true; } - private final long end; + @Override + public String toString() { + return NewAxis.class.getSimpleName() + "()"; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java deleted file mode 100644 index 070331f1ffb..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Odd implements Index { - - static final Odd INSTANCE = new Odd(); - - @Override - public long numElements(Dimension dim) { - return dim.numElements() >> 1; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return (coordinate << 1) + 1; - } - - private Odd() { - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java deleted file mode 100644 index e5d6003d87b..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Range implements Index { - - @Override - public long numElements(Dimension dim) { - return end - start; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return start + coordinate; - } - - Range(long start, long end) { - this.start = start; - this.end = end; - } - - private final long start; - private final long end; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java index 41d37d05806..5b93e434e54 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.impl.dimension.Dimension; @@ -36,4 +37,16 @@ public long mapCoordinate(long coordinate, Dimension dim) { } private final NdArray coords; + + @Override + public String toString() { + return new StringJoiner(", ", Sequence.class.getSimpleName() + "(", ")") + .add("coords=" + coords) + .toString(); + } + + @Override + public boolean isStridedSlicingCompliant() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java new file mode 100644 index 00000000000..1be4368261c --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java @@ -0,0 +1,89 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class Slice implements Index { + + Slice(long start, long end, long stride) { + this.start = start; + this.end = end; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long begin() { + return start; + } + + @Override + public long end() { + return end; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", Slice.class.getSimpleName() + "(", ")") + .add("start=" + start) + .add("end=" + end) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (start < 0) { + return dim.numElements() + start; + } + + return start; + } + + private long end(Dimension dim) { + if (end < 0) { + return dim.numElements() + end; + } else { + return end; + } + } + + private final long start; + private final long end; + private final long stride; +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java new file mode 100644 index 00000000000..c968a325cf7 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java @@ -0,0 +1,86 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class SliceFrom implements Index { + + SliceFrom(long start, long stride) { + this.start = start; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long begin() { + return start; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", SliceFrom.class.getSimpleName() + "(", ")") + .add("start=" + start) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (start < 0) { + return dim.numElements() + start; + } + + return start; + } + + private long end(Dimension dim) { + if (stride > 0) { + return dim.numElements(); + } else { + return -1; // it's exclusive + } + } + + private final long start; + private final long stride; +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java new file mode 100644 index 00000000000..761d1d52a3a --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java @@ -0,0 +1,86 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class SliceTo implements Index { + + SliceTo(long end, long stride) { + this.end = end; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long end() { + return end; + } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", SliceTo.class.getSimpleName() + "(", ")") + .add("end=" + end) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (stride > 0) { + return 0; + } + + return dim.numElements() - 1; // it's inclusive + } + + private long end(Dimension dim) { + if (end < 0) { + return dim.numElements() + end; + } else { + return end; + } + } + + private final long end; + private final long stride; +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java index 725abd8f2e7..c9a21c507b6 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,27 +12,72 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. - ======================================================================= + ============================================================================== */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; final class Step implements Index { + Step(long stride) { + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + @Override public long numElements(Dimension dim) { - return (dim.numElements() / stepLength) + 1; // FIXME always include element 0? + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); } @Override public long mapCoordinate(long coordinate, Dimension dim) { - return coordinate * stepLength; + return start(dim) + stride * coordinate; + } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", Step.class.getSimpleName() + "(", ")") + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (stride > 0) { + return 0; + } + + return dim.numElements() - 1; // it's inclusive } - Step(long stepLength) { - this.stepLength = stepLength; + private long end(Dimension dim) { + if (stride > 0) { + return dim.numElements(); + } else { + return -1; // it's exclusive + } } - private final long stepLength; + private final long stride; } diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java new file mode 100644 index 00000000000..6f92dab9b99 --- /dev/null +++ b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java @@ -0,0 +1,205 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.index.Indices; + +public class IndexTest { + @Test + public void testNullConversions(){ + assertTrue(Indices.slice(null, 0L).beginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue(Indices.slice(null, 0L).beginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue(Indices.slice(null, null).beginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue(Indices.slice(0L, null).endMask(), + "Passed null for slice end but didn't set end mask"); + + assertTrue(Indices.slice(0L, null).endMask(), + "Passed null for slice end but didn't set end mask"); + + assertTrue(Indices.slice(null, null).endMask(), + "Passed null for slice end but didn't set end mask"); + } + + @Test + public void testNewaxis(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.all(), Indices.newAxis()); + + assertEquals(Shape.of(5, 4, 5, 1), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0, 0, 0)); + assertEquals(1, slice1.getInt(0, 0, 1, 0)); + assertEquals(4, slice1.getInt(0, 0, 4, 0)); + assertEquals(2, slice1.getInt(0, 1, 2, 0)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.all()); + + assertEquals(Shape.of(5, 4, 1, 5), slice2.shape()); + assertEquals(0, slice2.getInt(0, 0, 0, 0)); + assertEquals(1, slice2.getInt(0, 0, 0, 1)); + assertEquals(4, slice2.getInt(0, 0, 0, 4)); + assertEquals(2, slice2.getInt(0, 1, 0, 2)); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.newAxis(), Indices.all(), Indices.all()); + + assertEquals(Shape.of(5, 1, 4, 5), slice3.shape()); + assertEquals(0, slice3.getInt(0, 0, 0, 0)); + assertEquals(1, slice3.getInt(0, 0, 0, 1)); + assertEquals(4, slice3.getInt(0, 0, 0, 4)); + assertEquals(2, slice3.getInt(0, 0, 1, 2)); + + IntNdArray slice4 = matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.all()); + + assertEquals(Shape.of(1, 5, 4, 5), slice4.shape()); + assertEquals(0, slice4.getInt(0, 0, 0, 0)); + assertEquals(1, slice4.getInt(0, 0, 0, 1)); + assertEquals(4, slice4.getInt(0, 0, 0, 4)); + assertEquals(2, slice4.getInt(0, 0, 1, 2)); + + } + + @Test + public void testEllipsis(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.ellipsis(), Indices.at(0)) + ); + + assertEquals( + matrix3d.slice(Indices.at(0), Indices.all(), Indices.all()), + matrix3d.slice(Indices.at(0), Indices.ellipsis()) + ); + + assertEquals( + matrix3d.slice(Indices.at(0), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.at(0), Indices.ellipsis(), Indices.at(0)) + ); + + // newaxis interacts specially with ellipsis (since it doesn't consume a dimension), test this + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.at(0)), + matrix3d.slice(Indices.ellipsis(), Indices.newAxis(), Indices.at(0)) + ); + + assertEquals( + matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.newAxis(), Indices.ellipsis(), Indices.at(0)) + ); + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0), Indices.newAxis()), + matrix3d.slice(Indices.ellipsis(), Indices.at(0), Indices.newAxis()) + ); + } + + @Test + public void testSlice(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.sliceTo(3), Indices.all()); + + assertEquals(Shape.of(5, 3, 5), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0, 0)); + assertEquals(1, slice1.getInt(0, 0, 1)); + assertEquals(2, slice1.getInt(0, 1, 2)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4)); + + assertEquals(Shape.of(5, 4, 3), slice2.shape()); + assertEquals(1, slice2.getInt(0, 0, 0)); + assertEquals(3, slice2.getInt(0, 0, 2)); + assertEquals(2, slice2.getInt(0, 1, 1)); + + assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, -1))); + + assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-4, -1))); + + assertEquals(Shape.of(5, 4, 0), matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4, -2)).shape()); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(4, 1, -2)); + + assertEquals(Shape.of(5, 4, 2), slice3.shape()); + assertEquals(4, slice3.getInt(0, 0, 0)); + assertEquals(2, slice3.getInt(0, 1, 1)); + + assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, 1, -2))); + + assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, -4, -2))); + + IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(null, null, -1)); + + assertEquals(Shape.of(5, 4, 5), slice4.shape()); + assertEquals(4, slice4.getInt(0, 0, 0)); + assertEquals(3, slice4.getInt(0, 0, 1)); + assertEquals(2, slice4.getInt(0, 1, 2)); + } + + @Test + public void testAt(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0)); + + assertEquals(Shape.of(5, 4), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(3)); + + assertEquals(Shape.of(5, 4), slice2.shape()); + assertEquals(3, slice2.getInt(0, 0)); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3)); + + assertEquals(Shape.of(5, 4), slice3.shape()); + assertEquals(2, slice3.getInt(0, 0)); + + IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3, true)); + + assertEquals(Shape.of(5, 4, 1), slice4.shape()); + assertEquals(2, slice4.getInt(0, 0, 0)); + } + +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java index 1c1d89680e7..26ac533daa8 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java @@ -24,11 +24,11 @@ import static org.tensorflow.ndarray.index.Indices.at; import static org.tensorflow.ndarray.index.Indices.even; import static org.tensorflow.ndarray.index.Indices.flip; -import static org.tensorflow.ndarray.index.Indices.from; +import static org.tensorflow.ndarray.index.Indices.sliceFrom; import static org.tensorflow.ndarray.index.Indices.odd; import static org.tensorflow.ndarray.index.Indices.range; import static org.tensorflow.ndarray.index.Indices.seq; -import static org.tensorflow.ndarray.index.Indices.to; +import static org.tensorflow.ndarray.index.Indices.sliceTo; import java.nio.BufferOverflowException; import java.nio.BufferUnderflowException; @@ -212,13 +212,13 @@ public void slices() { assertEquals(val101, vector10_flip.getObject(3)); // Vector (1,0,[from 1]) from vector (1,0,*) - NdArray vector10_1toX = vector10X.slice(from(1)); + NdArray vector10_1toX = vector10X.slice(sliceFrom(1)); assertEquals(vector10_1toX.shape(), Shape.of(4)); assertEquals(val101, vector10_1toX.getObject(0)); assertEquals(val102, vector10_1toX.getObject(1)); // Vector (1,0,[to 1]) from vector (1,0,*) - NdArray vector10_Xto1 = vector10X.slice(to(2)); + NdArray vector10_Xto1 = vector10X.slice(sliceTo(2)); assertEquals(vector10_Xto1.shape(), Shape.of(2)); assertEquals(val100, vector10_Xto1.getObject(0)); assertEquals(val101, vector10_Xto1.getObject(1)); diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java b/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java index 8acfdff7721..fb7022bc830 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java @@ -38,7 +38,7 @@ import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.StdArrays; -@Fork(value = 0, jvmArgs = {"-Xms4G", "-Xmx4G"}) +@Fork(value = 1, jvmArgs = {"-Xms4G", "-Xmx4G"}) @BenchmarkMode(Mode.AverageTime) @Warmup(iterations = 3) @Measurement(iterations = 5) diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java index d5b5ca809a4..375f7643875 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java @@ -40,7 +40,7 @@ public void equalsAndHashCodeOnSlices() { {{3, 4}, {6, 7}} }); - assertTrue(vector1.equals(vector2.slice(Indices.from(2)))); + assertTrue(vector1.equals(vector2.slice(Indices.sliceFrom(2)))); assertTrue(vector1.equals(matrix1.get(1))); assertTrue(vector1.equals(matrix2.get(1).slice(Indices.even()))); assertTrue(matrix1.equals(matrix2.slice(Indices.all(), Indices.even()))); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index cf7c5b47030..3cf293f759d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -38,6 +38,7 @@ import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.index.Index; import org.tensorflow.op.core.Abort; import org.tensorflow.op.core.All; import org.tensorflow.op.core.Any; @@ -210,6 +211,7 @@ import org.tensorflow.op.core.StridedSlice; import org.tensorflow.op.core.StridedSliceAssign; import org.tensorflow.op.core.StridedSliceGrad; +import org.tensorflow.op.core.StridedSliceHelper; import org.tensorflow.op.core.Sum; import org.tensorflow.op.core.SwitchCond; import org.tensorflow.op.core.TemporaryVariable; @@ -5900,6 +5902,56 @@ public StopGradient stopGradient(Operand input) { return StopGradient.create(scope, input); } + /** + * Return a strided slice from `input`. + *

+ * The goal of this op is to produce a new tensor with a subset of the elements from the `n` dimensional `input` + * tensor. The subset is chosen using a sequence of `m` sparse range specifications encoded into the arguments of this + * function. Note, in some cases `m` could be equal to `n`, but this need not be the case. Each range specification + * entry can be one of the following: + *

+ * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more dimensions of + * full-dimension selection. For example, {@code stridedSlice(foo, Indices.ellipsis()} is the identity slice. + *

+ * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension. + * For example, `{@code stridedSlice(foo, Indices.newAxis())} where {@code foo} is shape {@code (3, 4)} + * produces a {@code (1, 3, 4)} tensor. + *

+ * - A range {@code begin:end:stride} using {@link Indices#slice(Long, Long, long)} Index.slice()} or {@link Indices#all()}. This is used to specify + * how much to choose from a given dimension. {@code stride} can be any integer but 0. {@code begin} is an integer which + * represents the index of the first value to select while {@code end} represents the index of the last value to select + * (exclusive). Begin and end can be null, in which case the index begins or ends at the beginning or end of the dimension, + * respectively (reversed if stride is negative). When both are null, {@code slice()} is the same as {@code all()}. + * The number of values selected in each dimension is {@code end - begin} if {@code stride > 0} and {@code begin - end} + * if {@code stride < 0}. {@code begin} and {@code end} can be negative where {@code -1} is the last element, {@code -2} + * is the second to last. For example, given a shape {@code (3,)} tensor {@code stridedSlice(foo, Indices.all())}, the + * effective {@code begin} and {@code end} are {@code 0} and {@code 3}. Do not assume this is equivalent to + * {@code stridedSlice(foo, Indices.slice(0, -1))} which has an effective {@code begin} and {@code end} of {@code 0} and + * {@code 2}. Another example is {@code stridedSlice(foo, Indices.slice(-2, null, -1))} which reverses the first dimension + * of a tensor while dropping the last two (in the original order elements). For example {@code foo = [1,2,3,4]; + * stridedSlice(foo, Indices.slice(-2, null, -1)} is {@code [4,3]}. + *

+ * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given index. For + * example ({@code stridedSlice(foo, Indices.at(2))} on a shape {@code (5,6)} tensor produces a shape {@code (6,)} tensor. + * The dimension can be kept with size one using {@link Indices#at(long, boolean)}. + *

+ * These semantics generally follow NumPy's indexing semantics, which can be found here: + * https://numpy.org/doc/stable/reference/arrays.indexing.html + *

+ * + * Requirements: + * `0 != strides[i] for i in [0, m)` Only one ellipsis. + * + * @param scope current scope + * @param data type for {@code output()} output + * @param indices The indices to slice. See {@link Indices}. + * @return a new instance of StridedSlice + * @see Indices + */ + public StridedSlice stridedSlice(Operand input, Index... indices) { + return StridedSliceHelper.stridedSlice(scope, input, indices); + } + /** * Return a strided slice from `input`. *

@@ -6012,6 +6064,28 @@ public StridedSlice stridedSlice(Operand return StridedSlice.create(scope, input, begin, end, strides, options); } + /** + * Assign `value` to the sliced l-value reference of `ref`. + *

+ * The values of `value` are assigned to the positions in the variable `ref` that are selected by the slice + * parameters. The slice parameters `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. + *

+ * NOTE this op currently does not support broadcasting and so `value`'s shape must be exactly the shape produced by + * the slice of `ref`. + * + * @param data type for {@code outputRef()} output + * @param scope current scope + * @param ref the tensor to assign to. + * @param value the value to assign. + * @param indices The indices to slice. See {@link Indices}. + * @return a new instance of StridedSliceAssign + * @see org.tensorflow.op.Ops#stridedSlice(Operand, Index...) + */ + public StridedSliceAssign stridedSliceAssign(Operand ref, + Operand value, Index... indices) { + return StridedSliceHelper.stridedSliceAssign(scope, ref, value, indices); + } + /** * Assign `value` to the sliced l-value reference of `ref`. *

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java new file mode 100644 index 00000000000..e97934ee312 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java @@ -0,0 +1,221 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.ndarray.index.Index; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Helper endpoint methods for Python like indexing. + * + * @see org.tensorflow.ndarray.index.Indices + */ +@Operator +public abstract class StridedSliceHelper { + + static class StridedSliceArgs { + + final int[] begin; + final int[] end; + final int[] strides; + final long beginMask; + final long endMask; + final long ellipsisMask; + final long newAxisMask; + final long shrinkAxisMask; + + private StridedSliceArgs(int[] begin, int[] end, int[] strides, long beginMask, long endMask, long ellipsisMask, + long newAxisMask, long shrinkAxisMask) { + this.begin = begin; + this.end = end; + this.strides = strides; + this.beginMask = beginMask; + this.endMask = endMask; + this.ellipsisMask = ellipsisMask; + this.newAxisMask = newAxisMask; + this.shrinkAxisMask = shrinkAxisMask; + } + } + + static StridedSliceArgs mergeIndexes(Index[] indices) { + int[] begin = new int[indices.length]; + int[] end = new int[indices.length]; + int[] strides = new int[indices.length]; + long beginMask = 0; + long endMask = 0; + long ellipsisMask = 0; + long newAxisMask = 0; + long shrinkAxisMask = 0; + + for (int i = 0; i < indices.length; i++) { + Index idx = indices[i]; + + if (idx == null) { + idx = Indices.all(); + } + + if (!idx.isStridedSlicingCompliant()) { + throw new UnsupportedOperationException("Index " + idx + " is not supported for Tensors"); + } + + begin[i] = (int) idx.begin(); + if (begin[i] != idx.begin()) { + throw new IllegalArgumentException( + "Can't convert long begin value to int for index " + idx + ": Out of bounds"); + } + + end[i] = (int) idx.end(); + if (end[i] != idx.end()) { + throw new IllegalArgumentException("Can't convert long end value to int for index " + idx + ": Out of bounds"); + } + + strides[i] = (int) idx.stride(); + if (strides[i] != idx.stride()) { + throw new IllegalArgumentException( + "Can't convert long stride value to int for index " + idx + ": Out of bounds"); + } + + if (idx.beginMask()) { + beginMask |= 1L << i; + } + + if (idx.endMask()) { + endMask |= 1L << i; + } + + if (idx.isEllipsis()) { + if (ellipsisMask != 0) { + throw new IllegalArgumentException("Can not have two ellipsis in a slice"); + } + ellipsisMask |= 1L << i; + } + + if (idx.isNewAxis()) { + newAxisMask |= 1L << i; + } + + if (idx.isPoint()) { + shrinkAxisMask |= 1L << i; + } + } + + return new StridedSliceArgs(begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); + } + + /** + * Return a strided slice from `input`. + *

+ * The goal of this op is to produce a new tensor with a subset of the elements from the `n` dimensional `input` + * tensor. The subset is chosen using a sequence of `m` sparse range specifications encoded into the arguments of this + * function. Note, in some cases `m` could be equal to `n`, but this need not be the case. Each range specification + * entry can be one of the following: + *

+ * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more dimensions of + * full-dimension selection. For example, {@code stridedSlice(foo, Indices.ellipsis()} is the identity slice. + *

+ * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension. + * For example, `{@code stridedSlice(foo, Indices.newAxis())} where {@code foo} is shape {@code (3, 4)} + * produces a {@code (1, 3, 4)} tensor. + *

+ * - A range {@code begin:end:stride} using {@link Indices#slice(Long, Long, long)} Index.slice()} or {@link Indices#all()}. This is used to specify + * how much to choose from a given dimension. {@code stride} can be any integer but 0. {@code begin} is an integer which + * represents the index of the first value to select while {@code end} represents the index of the last value to select + * (exclusive). Begin and end can be null, in which case the index begins or ends at the beginning or end of the dimension, + * respectively (reversed if stride is negative). When both are null, {@code slice()} is the same as {@code all()}. + * The number of values selected in each dimension is {@code end - begin} if {@code stride > 0} and {@code begin - end} + * if {@code stride < 0}. {@code begin} and {@code end} can be negative where {@code -1} is the last element, {@code -2} + * is the second to last. For example, given a shape {@code (3,)} tensor {@code stridedSlice(foo, Indices.all())}, the + * effective {@code begin} and {@code end} are {@code 0} and {@code 3}. Do not assume this is equivalent to + * {@code stridedSlice(foo, Indices.slice(0, -1))} which has an effective {@code begin} and {@code end} of {@code 0} and + * {@code 2}. Another example is {@code stridedSlice(foo, Indices.slice(-2, null, -1))} which reverses the first dimension + * of a tensor while dropping the last two (in the original order elements). For example {@code foo = [1,2,3,4]; + * stridedSlice(foo, Indices.slice(-2, null, -1)} is {@code [4,3]}. + *

+ * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given index. For + * example ({@code stridedSlice(foo, Indices.at(2))} on a shape {@code (5,6)} tensor produces a shape {@code (6,)} tensor. + * The dimension can be kept with size one using {@link Indices#at(long, boolean)}. + *

+ * These semantics generally follow NumPy's indexing semantics, which can be found here: + * https://numpy.org/doc/stable/reference/arrays.indexing.html + *

+ * + * Requirements: + * `0 != strides[i] for i in [0, m)` Only one ellipsis. + * + * @param scope current scope + * @param data type for {@code output()} output + * @param indices The indices to slice. See {@link Indices}. + * @return a new instance of StridedSlice + * @see Indices + */ + @Endpoint(name = "stridedSlice") + public static StridedSlice stridedSlice(Scope scope, Operand input, Index... indices) { + StridedSliceArgs args = mergeIndexes(indices); + return StridedSlice.create( + scope, + input, + Constant.vectorOf(scope, args.begin), + Constant.vectorOf(scope, args.end), + Constant.vectorOf(scope, args.strides), + StridedSlice.beginMask(args.beginMask), + StridedSlice.endMask(args.endMask), + StridedSlice.ellipsisMask(args.ellipsisMask), + StridedSlice.newAxisMask(args.newAxisMask), + StridedSlice.shrinkAxisMask(args.shrinkAxisMask) + ); + } + + /** + * Assign `value` to the sliced l-value reference of `ref`. + *

+ * The values of `value` are assigned to the positions in the variable `ref` that are selected by the slice + * parameters. The slice parameters `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. + *

+ * NOTE this op currently does not support broadcasting and so `value`'s shape must be exactly the shape produced by + * the slice of `ref`. + * + * @param data type for {@code outputRef()} output + * @param scope current scope + * @param ref the tensor to assign to. + * @param value the value to assign. + * @param indices The indices to slice. See {@link Indices}. + * @return a new instance of StridedSliceAssign + * @see org.tensorflow.op.Ops#stridedSlice(Operand, Index...) + */ + @Endpoint(name = "stridedSliceAssign") + public static StridedSliceAssign stridedSliceAssign(Scope scope, Operand ref, + Operand value, Index... indices) { + StridedSliceArgs args = mergeIndexes(indices); + return StridedSliceAssign.create( + scope, + ref, + Constant.vectorOf(scope, args.begin), + Constant.vectorOf(scope, args.end), + Constant.vectorOf(scope, args.strides), + value, + StridedSliceAssign.beginMask(args.beginMask), + StridedSliceAssign.endMask(args.endMask), + StridedSliceAssign.ellipsisMask(args.ellipsisMask), + StridedSliceAssign.newAxisMask(args.newAxisMask), + StridedSliceAssign.shrinkAxisMask(args.shrinkAxisMask) + ); + } + +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java new file mode 100644 index 00000000000..6e86573b7cf --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.Test; +import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.ndarray.index.Index; +import org.tensorflow.op.Scope; +import org.tensorflow.types.TFloat32; + +public class IndexingTest { + + // [2, 1:2, :, tf.newaxis, ..., :4, 4::2] + private static final Index[] slice = new Index[]{ + Indices.at(2), + Indices.at(1, true), + Indices.all(), + Indices.newAxis(), + Indices.ellipsis(), + Indices.sliceTo( 4), + Indices.sliceFrom(4, 2) + }; + + @Test + public void testIndexMerge() { + StridedSliceHelper.StridedSliceArgs args = StridedSliceHelper.mergeIndexes(slice); + + assertArrayEquals(new int[]{2, 1, 0, 0, 0, 0, 4}, args.begin); + assertArrayEquals(new int[]{3, 2, 0, 0, 0, 4, 0}, args.end); + assertArrayEquals(new int[]{1, 1, 1, 1, 1, 1, 2}, args.strides); + assertEquals(0b0100100, args.beginMask); + assertEquals(0b1000100, args.endMask); + assertEquals(0b0010000, args.ellipsisMask); + assertEquals(0b0001000, args.newAxisMask); + assertEquals(0b0000001, args.shrinkAxisMask); + + } + + @Test + public void testStridedSliceIndex(){ + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {10, 10, 10, 10, 10, 10, 10, 10}; + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.class); + StridedSlice output = StridedSliceHelper.stridedSlice(scope, op, slice); + try (TFloat32 result = (TFloat32) sess.runner().fetch(output.asOutput()).run().get(0)) { + // expected shape from Python tensorflow + assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.shape(), "Slice index didn't match expected (Python)"); + } + } + } + +}