Skip to content

Add TensorScope #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
998677d
Start of TensorScope
rnett Jan 9, 2021
591d44d
Finish TensorScope, add test
rnett Jan 14, 2021
34e1429
Javadoc updates
rnett Jan 14, 2021
700f0f8
Add a TensorScope to EagerSession to replicate pointer attachment
rnett Jan 14, 2021
8231407
Make auto-attach optional, use TensorScope in eager session to close …
rnett Jan 14, 2021
a9982a4
Test for non-auto-attach scope
rnett Jan 14, 2021
d30d7b1
cleanup scopes
rnett Jan 14, 2021
8a7e8be
HasTensors abstraction for resource management of multiple tensors
rnett Jan 14, 2021
38b98f0
Iterable attach methods
rnett Jan 15, 2021
bfe6aa5
refactor hierarchy, add release to parent methods
rnett Jan 16, 2021
393b5da
fix NPE
rnett Jan 16, 2021
7a3f365
Javadoc updates
rnett Jan 16, 2021
cd6f3d2
New tests, remove eager session tensor closing test
rnett Jan 16, 2021
6bc6fce
remove incorrect test
rnett Jan 16, 2021
e2cd366
clarify threading docs
rnett Jan 16, 2021
f3ff90a
grammar
rnett Jan 16, 2021
37e0374
formatting
rnett Jan 16, 2021
0837634
Add note about different scopes
rnett Jan 16, 2021
555b13b
Add option to not require parent to Tensor and HasTensors
rnett Jan 16, 2021
4319a65
Adjust API to be more explicit, add release
rnett Jan 19, 2021
c13c8db
Make constructor package private, use static methods.
rnett Jan 23, 2021
7ebb447
format
rnett Jan 23, 2021
49ac26e
fixes
rnett Jan 24, 2021
b4b3ed4
remove extra closed tracking
rnett Jan 25, 2021
775d8bd
New tests, make static methods build on eachother
rnett Jan 25, 2021
8032310
Convert to scope-passing style
rnett Jan 29, 2021
dacd0c3
Add no-output run method
rnett Jan 29, 2021
9ceeefd
Doc updates
rnett Jan 29, 2021
344fe2b
Fix framework
rnett Jan 29, 2021
9a846dc
Set TF_Tensor size properly, update docs
rnett Jan 29, 2021
a10043e
remove unneeded synchronizeds
rnett Jan 29, 2021
eabf063
Don't register memory for view tensors (i.e. from eager operations)
rnett Jan 30, 2021
cf8b24b
Initial Rebase fixes
rnett Feb 16, 2021
a50d06c
More framework fixes
rnett Mar 4, 2021
46645ee
Rebase fixes
rnett Mar 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ public String toString() {
*
* <p>This is only supported in an eager execution environment.
*
* @param scope the {@link TensorScope} to create the tensor in
* @param outputIdx index of the output of this operation
* @return output tensor
*/
abstract Tensor tensor(int outputIdx);
abstract Tensor tensor(TensorScope scope, int outputIdx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the scope is going to end up by convention as the last argument. That allows it to be made optional more naturally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we ever want to make it optional though? I'd agree if so, but I never got that impression.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's totally possible, for example:

double reduceSomething(Tensor input);
double reduceSomething(Tensor input, Tensor optionalOutput, Scope optionalScopeForOptionalOutput);

Do you see an advantage in having it as the first argument though?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karllessard @Craigacp Any preferences?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we need to enforce the explicit scoping of tensors resulting from an eager operation like this in all cases. Default behavior would tie these tensors to their eager session, which will release them once closed. Do we want to offer TensorScope as a simple utility or as something mandatory? I was under the impression of the former, meaning that we should leave an endpoint accepting no TensorScope.

Related to the order of the arguments in the methods, I don't have a very firm opinion on this but if I need to pick one, I also prefer having optional arguments at the end of the signature rather than the beginning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make most sense to me for an eager session to be a TensorScope, but we'd probably want to have it without attach() and detach(). Maybe we could have a special MutableTensorScope if we really want to have something with attach() and detach()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but I'd just use composition instead of inheritance, there's no need to expose it. Eventually I think it would make sense to tie the tensor lifetime to the lifetime of the eager operand, but since we don't have operand lifetimes yet this is fine. I just don't like using EagerSession for scope management very much since it is usually very long lived, but if we have an optional tensor scope version it's fine for now.

For the first vs last argument, I was under the impression that the scope would be mandatory, like I thought Panama's was, but looking at it again I'm not actually sure. If we keep detach (see other reply), I'm not sure what benefit optional scopes would really provide, other than removing some boilerplate. If we do decide to use optional scopes, I agree about the last argument, I made it first since it is (was?) a required scope, and those live on the left, at least in Kotlin. It really should be a receiver.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK let's try to see how it goes by enforcing a scope for the tensors, which is a better safety net than the GC for sure since these tensors can be quite large, compared to the operands native objects.

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
package org.tensorflow;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.ListIterator;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import org.tensorflow.op.Ops;
Expand All @@ -43,12 +43,12 @@ public class ConcreteFunction implements AutoCloseable {
* Creates a function by building a new graph.
*
* <p>The {@code functionBuilder} must initialize the function graph from the provided
* {@link Ops} instance and return a valid signature that will be used to feed the input tensors
* and fetch the output tensors on execution.
* {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output
* tensors on execution.
*
* <p>The function will be the owner of the new graph and its resulting session. Therefore,
* the function must be enclosed properly with a try-with-resources block to guarantee that
* all native resources will be freed once the function is discarded. For example:
* the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will
* be freed once the function is discarded. For example:
*
* <pre>{@code
* public class MyModel {
Expand Down Expand Up @@ -87,8 +87,8 @@ public static ConcreteFunction create(Function<Ops, Signature> functionBuilder)
* Create a function from a signature and an existing graph.
*
* <p>The function will keep the ownership of the session used to run the graph but not
* the graph itself, meaning that the lifetime of the latter can extend beyond the scope
* of the function. For example:
* the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For
* example:
*
* <pre>{@code
* try (Graph g = new Graph()) {
Expand Down Expand Up @@ -116,8 +116,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
* Create a function from a signature and a valid graph session.
*
* <p>The function will not own the session nor its graph, meaning that their lifetime
* can extend beyond the scope of the function. Therefore the function does not need to be
* closed after its usage. For example:
* can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For
* example:
*
* <pre>{@code
* try (Graph g = new Graph()) {
Expand Down Expand Up @@ -156,14 +156,11 @@ public Signature signature() {
/**
* Invokes a function.
*
* <p>Caller is responsible for closing all Tensors.
*
* @param arguments list of tensors to pass in input to the function,
* mapped by their signature name
* @return output tensors resulting from the execution of the function,
* mapped by their signature name
* @param scope the {@link TensorScope} to create the outputs in
* @param arguments list of tensors to pass in input to the function, mapped by their signature name
* @return output tensors resulting from the execution of the function, mapped by their signature name
*/
public Map<String, Tensor> call(Map<String, Tensor> arguments)
public Map<String, Tensor> call(TensorScope scope, Map<String, Tensor> arguments)
throws IllegalArgumentException {

final SignatureDef signatureDef = signature.asSignatureDef();
Expand All @@ -180,13 +177,13 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments)
Map<String, TensorInfo> outputToNode = signatureDef.getOutputsMap();
outputToNode.values().forEach(t -> runner.fetch(t.getName()));

List<Tensor> resultTensors = runner.run();
List<Tensor> resultTensors = runner.run(scope);
try {
ListIterator<Tensor> resultTensorIter = resultTensors.listIterator();
Map<String, Tensor> returnMap = new HashMap<String, Tensor>();

// Use the output names as present in the signature definition
for (String nodeName: outputToNode.keySet()) {
for (String nodeName : outputToNode.keySet()) {
returnMap.put(nodeName, resultTensorIter.next());
}
return returnMap;
Expand All @@ -203,29 +200,27 @@ public Map<String, Tensor> call(Map<String, Tensor> arguments)
/**
* Invokes a function with a single input and output.
*
* <p>Caller is responsible for closing all Tensors.
*
* @param scope the {@link TensorScope} to create the output in
* @param tensor input tensor
* @return output tensor
* @throws IllegalArgumentException if there are multiple input or output parameters defined
* in the function
* @throws IllegalArgumentException if there are multiple input or output parameters defined in the function
*/
public Tensor call(Tensor tensor) throws IllegalArgumentException {
public Tensor call(TensorScope scope, Tensor tensor) throws IllegalArgumentException {
final SignatureDef signatureDef = signature.asSignatureDef();

if (signatureDef.getInputsCount() != 1) {
throw new IllegalArgumentException(
String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName()));
String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName()));
}
String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName();

if (signatureDef.getOutputsCount() != 1) {
throw new IllegalArgumentException(
String.format("Function [%s] has multiple outputs", signatureDef.getMethodName()));
String.format("Function [%s] has multiple outputs", signatureDef.getMethodName()));
}
String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName();

return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0);
return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run(scope).get(0);
}

/**
Expand All @@ -245,8 +240,8 @@ public void save(String exportDir) throws IOException {
* Returns the session used to execute the graph when calling this function
*
* <p>In general, a user does not need to handle directly the session of a function and rely
* on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to
* the session might be necessary, as it allows more running options.
* on {@link #call(TensorScope, Map)} to execute the graph instead. But in some cases, direct access to the session
* might be necessary, as it allows more running options.
*
* @return the function session
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
* Implementation of an {@link Operation} executed eagerly.
*
* <p>EagerOperation instances are valid only as long as the {@link EagerSession} they are a part of
* is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the
* EagerOperation instance may fail with an {@code IllegalStateException}.
* is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the EagerOperation instance may
* fail with an {@code IllegalStateException}.
*
* <p>EagerOperation instances are thread-safe.
*/
Expand Down Expand Up @@ -120,10 +120,10 @@ DataType dtype(int outputIndex) {
}

@Override
Tensor tensor(int outputIndex) {
Tensor tensor(TensorScope scope, int outputIndex) {
Tensor tensor = outputTensors.get(outputIndex);
if (tensor == null) {
tensor = resolveTensor(outputIndex);
tensor = resolveTensor(scope, outputIndex);
}
return tensor;
}
Expand All @@ -133,11 +133,11 @@ Tensor tensor(int outputIndex) {
private final String name;
private final AtomicReferenceArray<Tensor> outputTensors;

private Tensor resolveTensor(int outputIndex) {
private Tensor resolveTensor(TensorScope scope, int outputIndex) {
// Take an optimistic approach, where we attempt to resolve the output tensor without locking.
// If another thread has resolved it meanwhile, release our copy and reuse the existing one
// instead.
Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session);
Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), scope);
if (!outputTensors.compareAndSet(outputIndex, null, tensor)) {
session.detach(tensor.asRawTensor().nativeHandle());
tensor = outputTensors.get(outputIndex);
Expand All @@ -160,13 +160,13 @@ private static void requireTensorHandle(TFE_TensorHandle handle) {
}
}

private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) {
private static Tensor resolveTensorHandle(TFE_TensorHandle handle, TensorScope tensorScope) {
requireTensorHandle(handle);
try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator();
TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(true);
status.throwExceptionIfNotOK();
return RawTensor.fromHandle(tensor, session).asTypedTensor();
return RawTensor.fromHandle(tensorScope, tensor).asTypedTensor();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
* Implementation for an {@link Operation} added as a node to a {@link Graph}.
*
* <p>GraphOperation instances are valid only as long as the {@link Graph} they are a part of is
* valid. Thus, if {@link Graph#close()} has been invoked, then methods on the GraphOperation
* instance may fail with an {@code IllegalStateException}.
* valid. Thus, if {@link Graph#close()} has been invoked, then methods on the GraphOperation instance may fail with an
* {@code IllegalStateException}.
*
* <p>GraphOperation instances are immutable and thread-safe.
*/
Expand Down Expand Up @@ -166,7 +166,7 @@ DataType dtype(int outputIdx) {
}

@Override
Tensor tensor(int outputIdx) {
Tensor tensor(TensorScope scope, int outputIdx) {
throw new IllegalStateException("Graph tensors must be fetched by running a session");
}

Expand Down Expand Up @@ -236,7 +236,9 @@ private static long[] shape(TF_Graph graphHandle, TF_Operation opHandle, int out
TF_Status status = TF_Status.newStatus();
int numDims = TF_GraphGetTensorNumDims(graphHandle, output, status);
status.throwExceptionIfNotOK();
if (numDims < 0) return null;
if (numDims < 0) {
return null;
}
long[] dims = new long[numDims];
TF_GraphGetTensorShape(graphHandle, output, dims, numDims, status);
status.throwExceptionIfNotOK();
Expand All @@ -250,8 +252,8 @@ private static int dtype(TF_Graph graphHandle, TF_Operation opHandle, int output

int numOutputs = TF_OperationNumOutputs(opHandle);
if (outputIndex < 0 || outputIndex >= numOutputs) {
throw new IndexOutOfBoundsException("invalid output index (" + outputIndex
+ ") for an operation that has " + numOutputs + " outputs");
throw new IndexOutOfBoundsException("invalid output index (" + outputIndex
+ ") for an operation that has " + numOutputs + " outputs");
}

try (PointerScope scope = new PointerScope()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ public interface Operand<T extends TType> extends Op, Shaped {
*
* <i>Only works when running in an eager execution</i>
*
* @param scope the {@link TensorScope} to create the tensor in
* @return the tensor
* @throws IllegalStateException if this is an operand of a graph
*/
default T asTensor() {
return asOutput().asTensor();
default T asTensor(TensorScope scope) {
return asOutput().asTensor(scope);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,36 @@
*/
public final class Output<T extends TType> implements Operand<T> {

/** Returns the index into the outputs of the Operation. */
/**
* Returns the index into the outputs of the Operation.
*/
public int index() {
return index;
}

/** Returns the DataType of the tensor referred to by this Output. */
/**
* Returns the DataType of the tensor referred to by this Output.
*/
@SuppressWarnings("unchecked")
public DataType dataType() {
return operation.dtype(index);
}

/** Returns the type of the tensor referred to by this Output. */
/**
* Returns the type of the tensor referred to by this Output.
*/
@SuppressWarnings("unchecked")
@Override
public Class<T> type() {
return (Class<T>)TensorTypeRegistry.find(dataType()).type();
return (Class<T>) TensorTypeRegistry.find(dataType()).type();
}

/**
* Returns this Output object with the type {@code Output<U>}. This method is useful when given a
* value of type {@code Output<?>}.
* Returns this Output object with the type {@code Output<U>}. This method is useful when given a value of type {@code
* Output<?>}.
*
* @param type any supported tensor type
* @throws IllegalArgumentException if the actual data type of this object does not match the type
* {@code U}.
* @throws IllegalArgumentException if the actual data type of this object does not match the type {@code U}.
*/
@SuppressWarnings("unchecked")
public <U extends TType> Output<U> expect(Class<U> type) {
Expand All @@ -72,8 +77,7 @@ public <U extends TType> Output<U> expect(Class<U> type) {
* Returns the tensor at this output.
*
* <p>This operation is only supported on the outputs of an operation executed eagerly. For graph
* environments, output tensors must be fetched by running a session, using {@link
* Session.Runner#fetch(Output)}.
* environments, output tensors must be fetched by running a session, using {@link Session.Runner#fetch(Output)}.
*
* <p>It is recommended to close explicitly the returned tensor as soon as possible, since the
* garbage collector is not aware of the amount of memory it consumes, which can be significant.
Expand All @@ -84,8 +88,8 @@ public <U extends TType> Output<U> expect(Class<U> type) {
* @see EagerSession
*/
@SuppressWarnings("unchecked")
public T asTensor() {
return (T)operation.tensor(index);
public T asTensor(TensorScope scope) {
return (T) operation.tensor(scope, index);
}

/**
Expand Down Expand Up @@ -130,7 +134,9 @@ public String toString() {
operation.type(), operation.name(), index, shape().toString(), dataType());
}

/** Handle to the idx-th output of the Operation {@code op}. */
/**
* Handle to the idx-th output of the Operation {@code op}.
*/
Output(AbstractOperation op, int idx) {
operation = op;
index = idx;
Expand Down
Loading