Skip to content

Indexing API #166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.tensorflow.ndarray.impl.dimension;

import java.util.Arrays;
import org.tensorflow.ndarray.index.Index;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.index.Index;

public class DimensionalSpace {

Expand All @@ -35,24 +35,42 @@ public static DimensionalSpace create(Shape shape) {
}

public RelativeDimensionalSpace mapTo(Index[] indices) {
if (dimensions == null || indices.length > dimensions.length) {
if (dimensions == null) {
throw new ArrayIndexOutOfBoundsException();
}
int dimIdx = 0;
int indexIdx = 0;
int newDimIdx = 0;
int segmentationIdx = -1;
long initialOffset = 0;

Dimension[] newDimensions = new Dimension[dimensions.length];
while (dimIdx < indices.length) {
int newAxes = 0;
boolean seenEllipsis = false;
for (Index idx : indices) {
if (idx.isNewAxis()) {
newAxes += 1;
}
if (idx.isEllipsis()) {
if (seenEllipsis) {
throw new IllegalArgumentException("Only one ellipsis allowed");
} else {
seenEllipsis = true;
}
}
}
int newLength = dimensions.length + newAxes;

Dimension[] newDimensions = new Dimension[newLength];
while (indexIdx < indices.length) {

if (indices[dimIdx].isPoint()) {
if (indices[indexIdx].isPoint()) {
// When an index targets a single point in a given dimension, calculate the offset of this
// point and cumulate the offset of any subsequent point as well
long offset = 0;
do {
offset += indices[dimIdx].mapCoordinate(0, dimensions[dimIdx]);
} while (++dimIdx < indices.length && indices[dimIdx].isPoint());
offset += indices[indexIdx].mapCoordinate(0, dimensions[dimIdx]);
dimIdx++;
} while (++indexIdx < indices.length && indices[indexIdx].isPoint());

// If this is the first index, then the offset is the position of the whole dimension
// space within the original one. If not, then we apply the offset to the last vectorial
Expand All @@ -65,14 +83,47 @@ public RelativeDimensionalSpace mapTo(Index[] indices) {
segmentationIdx = newDimIdx - 1;
}

} else if (indices[indexIdx].isNewAxis()) {
long newSize;
if (dimIdx == 0) {
// includes everything. Should really include future reduction (at()) but that doesn't seem to cause issues
// elsewhere
newSize = dimensions[0].numElements() * dimensions[0].elementSize();
} else {
newSize = dimensions[dimIdx - 1].elementSize();
}

newDimensions[newDimIdx] = new Axis(1, newSize);
segmentationIdx = newDimIdx; // is this correct?
++newDimIdx;
++indexIdx;
} else if (indices[indexIdx].isEllipsis()) {
int remainingDimensions = dimensions.length - dimIdx;
int requiredDimensions = 0;
for (int i = indexIdx + 1; i < indices.length; i++) {
if (!indices[i].isNewAxis()) {
requiredDimensions++;
}
}
// while the number of dimensions left < the number of indices that consume axes
while (remainingDimensions > requiredDimensions) {
Dimension dim = dimensions[dimIdx++];
if (dim.isSegmented()) {
segmentationIdx = newDimIdx;
}
newDimensions[newDimIdx++] = dim;
remainingDimensions--;
}
indexIdx++;
} else {
// Map any other index to the appropriate dimension of this space
Dimension newDimension = indices[dimIdx].apply(dimensions[dimIdx++]);
Dimension newDimension = indices[indexIdx].apply(dimensions[dimIdx++]);
newDimensions[newDimIdx] = newDimension;
if (newDimension.isSegmented()) {
segmentationIdx = newDimIdx;
}
++newDimIdx;
++indexIdx;
}
}

Expand Down
15 changes: 15 additions & 0 deletions ndarray/src/main/java/org/tensorflow/ndarray/index/All.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,19 @@ public Dimension apply(Dimension dim) {

private All() {
}

@Override
public boolean beginMask() {
return true;
}

@Override
public boolean endMask() {
return true;
}

@Override
public String toString() {
return All.class.getSimpleName() + "()";
}
}
34 changes: 30 additions & 4 deletions ndarray/src/main/java/org/tensorflow/ndarray/index/At.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.tensorflow.ndarray.index;

import java.util.StringJoiner;
import org.tensorflow.ndarray.impl.dimension.Dimension;

final class At implements Index {
Expand All @@ -27,22 +28,47 @@ public long numElements(Dimension dim) {

@Override
public long mapCoordinate(long coordinate, Dimension dim) {
return dim.positionOf(coord); // TODO validate coordinate is 0?
long coord = this.coord >= 0 ? this.coord : dim.numElements() + this.coord;
return dim.positionOf(coord);
}

@Override
public Dimension apply(Dimension dim) {
throw new IllegalStateException(); // FIXME?
if (!keepDim) {
throw new UnsupportedOperationException("Should be handled in DimensionalSpace.");
}

return dim.withIndex(this);
}

@Override
public boolean isPoint() {
return true;
return !keepDim;
}

At(long coord) {
At(long coord, boolean keepDim) {
this.coord = coord;
this.keepDim = keepDim;
}

private final long coord;
private final boolean keepDim;

@Override
public long begin() {
return coord;
}

@Override
public long end() {
return coord + 1;
}

@Override
public String toString() {
return new StringJoiner(", ", At.class.getSimpleName() + "(", ")")
.add("coord=" + coord)
.add("keepDim=" + keepDim)
.toString();
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -12,26 +12,37 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
==============================================================================
*/
package org.tensorflow.ndarray.index;

import org.tensorflow.ndarray.impl.dimension.Dimension;

final class Even implements Index {
final class Ellipsis implements Index {

static final Even INSTANCE = new Even();
static final Ellipsis INSTANCE = new Ellipsis();

private Ellipsis() {

}

@Override
public long numElements(Dimension dim) {
return (dim.numElements() >> 1) + (dim.numElements() % 2);
throw new UnsupportedOperationException("Should be handled in DimensionalSpace.");
}

@Override
public long mapCoordinate(long coordinate, Dimension dim) {
return coordinate << 1;
throw new UnsupportedOperationException("Should be handled in DimensionalSpace.");
}

private Even() {
@Override
public boolean isEllipsis() {
return true;
}

@Override
public String toString() {
return Ellipsis.class.getSimpleName() + "()";
}
}
34 changes: 0 additions & 34 deletions ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java

This file was deleted.

38 changes: 0 additions & 38 deletions ndarray/src/main/java/org/tensorflow/ndarray/index/From.java

This file was deleted.

16 changes: 16 additions & 0 deletions ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.tensorflow.ndarray.index;

import java.util.StringJoiner;
import org.tensorflow.ndarray.impl.dimension.Dimension;

/**
Expand Down Expand Up @@ -71,4 +72,19 @@ public boolean isPoint() {
private final long stride;
private final long count;
private final long block;

@Override
public String toString() {
return new StringJoiner(", ", Hyperslab.class.getSimpleName() + "Hyperslab(", ")")
.add("start=" + start)
.add("stride=" + stride)
.add("count=" + count)
.add("block=" + block)
.toString();
}

@Override
public boolean isStridedSlicingCompliant() {
return false;
}
}
Loading