Skip to content

Add Session Result class #167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,10 @@ public final class Ops {

public final SignalOps signal;

public final QuantizationOps quantization;

public final TrainOps train;

public final QuantizationOps quantization;

private final Scope scope;

private Ops(Scope scope) {
Expand All @@ -372,8 +372,8 @@ private Ops(Scope scope) {
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
quantization = new QuantizationOps(this);
train = new TrainOps(this);
quantization = new QuantizationOps(this);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments)
Map<String, TensorInfo> outputToNode = signatureDef.getOutputsMap();
outputToNode.values().forEach(t -> runner.fetch(t.getName()));

List<Tensor> resultTensors = runner.run();
List<Tensor> resultTensors = runner.run().getResults();
try {
ListIterator<Tensor> resultTensorIter = resultTensors.listIterator();
Map<String, Tensor> returnMap = new HashMap<String, Tensor>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,66 @@ public GraphOperation operation(String name) {
}
}


/**
* Returns the operation (node in the Graph) with the provided name.
* <p>
* Or throws an {@code IllegalArgumentException} if no such operation exists in the Graph.
*
* @param name name of the operation to look for
* @return operation in the graph with this name
* @see #operation(String)
*/
public GraphOperation operationOrError(String name) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why wouldn't this return null or Optional<GraphOperation> rather than throwing? Is it a Kotlin thing?

Copy link
Contributor Author

@rnett rnett Dec 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of, it definitely works better from it, eventually I'd want to mark it @NonNull. I moved the method here from session, where it did throw like that rather than returning null. Since we already had the operation in Graph (which returns null if not found) I added this as operationOrError rather than moving the check to any uses.

GraphOperation op = operation(name);
if (op == null) {
throw new IllegalArgumentException("No Operation named [" + name + "] in the Graph");
}
return op;
}

/**
* Returns the {@code index}-th output of {@code operation}.
* Throws {@code IllegalArgumentException} if the operation is not found, or does not have an output at {@code index}.
*
* @param operation The operation to get the output of.
* @param index The index of the output to get.
* @return The {@code index}-th output of {@code operation}.
*/
public Output<?> getOutput(String operation, int index){
GraphOperation graphOp = operationOrError(operation);
if(index < 0 || index >= graphOp.numOutputs()){
throw new IllegalArgumentException("Index out of bounds for operation " + operation +
". Operation has " + graphOp.numOutputs() + " outputs");
}

return graphOp.output(index);
}

/**
* Returns the output specified by {@code output}.
* Will try to parse the output index from {@code output}.
* I.e. {@code "scope/op:2"} will get the 2nd (0-indexed) output of {@code scope/op}.
* Otherwise, will return the 0th output.
*
* @param output The operation to get the output of, with the index optionally specified by colon.
* @return The output specified by {@code output}.
*/
@SuppressWarnings("rawtypes")
public Output<?> getOutput(String output) {
int colon = output.lastIndexOf(':');
if (colon == -1 || colon == output.length() - 1) {
return new Output(operationOrError(output), 0);
}
try {
String op = output.substring(0, colon);
int index = Integer.parseInt(output.substring(colon + 1));
return new Output(operationOrError(op), index);
} catch (NumberFormatException e) {
return new Output(operationOrError(output), 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fallback case seems odd, as it doesn't log anything when it's likely to be programmer error if it's triggered right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved that from Session, I don't think it's too odd? It's less about the actual number and more about if you have an op name like scope:myOp you want to get scope:myOp, not try to interpret myOp as a number.

Logging it would be a good idea, but it doesn't look like there's any logging currently set up. Is there one I can use?

}
}

/**
* Iterator over all the {@link Operation}s in the graph.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,18 @@ public RawTensor asRawTensor() {

@Override
public void close() {
if(closed) {
throw new IllegalStateException("Tensor has already been closed");
}
tensorScope.close();
closed = true;
}

/**
* @return {@code true} if this tensor has been closed;
*/
public boolean isClosed() {
return closed;
}

/**
Expand Down Expand Up @@ -222,6 +233,7 @@ private static long[] shape(TF_Tensor handle) {
}

private PointerScope tensorScope;
private boolean closed = false;
private TF_Tensor tensorHandle;
private final TensorTypeInfo<? extends TType> typeInfo;
private final Shape shape;
Expand Down
Loading