Skip to content

Commit

Permalink
Map experimental C (actually C++) API for gradient tape
Browse files Browse the repository at this point in the history
  • Loading branch information
saudet committed Apr 9, 2021
1 parent 0d73a9b commit 93a827e
Show file tree
Hide file tree
Showing 88 changed files with 2,504 additions and 69 deletions.
2 changes: 1 addition & 1 deletion tensorflow-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
<javacpp.platform.macosx-x86_64.extension>macosx-x86_64${javacpp.platform.extension}</javacpp.platform.macosx-x86_64.extension>
<javacpp.platform.windows-x86.extension>windows-x86${javacpp.platform.extension}</javacpp.platform.windows-x86.extension>
<javacpp.platform.windows-x86_64.extension>windows-x86_64${javacpp.platform.extension}</javacpp.platform.windows-x86_64.extension>
<javacpp.version>1.5.4</javacpp.version>
<javacpp.version>1.5.5</javacpp.version>
</properties>

<profiles>
Expand Down
25 changes: 25 additions & 0 deletions tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,19 @@
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.1.0</version>
<executions>
<execution>
<id>javacpp-parser</id>
<phase>generate-sources</phase>
<goals>
<goal>resources</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.0</version>
Expand Down Expand Up @@ -209,7 +222,15 @@
<classPath>${project.build.outputDirectory}</classPath>
<includePaths>
<includePath>${project.basedir}/</includePath>
<includePath>${project.basedir}/bazel-bin/external/llvm-project/llvm/include/</includePath>
<includePath>${project.basedir}/bazel-bin/external/org_tensorflow/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/eigen_archive/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_absl/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_protobuf/src/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/farmhash_archive/src/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/llvm-project/llvm/include/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/</includePath>
<includePath>${project.basedir}/target/classes/org/tensorflow/internal/c_api/include/</includePath>
</includePaths>
<linkPaths>
<linkPath>${project.basedir}/bazel-bin/external/llvm_openmp/</linkPath>
Expand Down Expand Up @@ -315,6 +336,10 @@
<outputDirectory>${project.build.directory}/native/org/tensorflow/internal/c_api/${native.classifier}/</outputDirectory>
<skip>${javacpp.compiler.skip}</skip>
<classOrPackageName>org.tensorflow.internal.c_api.**</classOrPackageName>
<compilerOptions>
<!-- TODO: Remove files from here as they get integrated into the Bazel build -->
<compilerOption>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/tensorflow/c/eager/gradients.cc</compilerOption>
</compilerOptions>
<copyLibs>true</copyLibs>
<copyResources>true</copyResources>
</configuration>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;


// Abstract interface to a context.
//
// This serves as a factory for creating `AbstractOperation`s and for
// registering traced functions.
// Operations creation within a context can only be executed in that context
// (for now at least).
// Implementations of the context may contain some state e.g. an execution
// environment, a traced representation etc.
@Namespace("tensorflow") @NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class AbstractContext extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public AbstractContext(Pointer p) { super(p); }

public native int getKind();

// Release any underlying resources, including the interface object.
//
// WARNING: The destructor of this class is marked as protected to disallow
// clients from directly destroying this object since it may manage it's own
// lifetime through ref counting. Thus clients MUST call Release() in order to
// destroy an instance of this class.
public native void Release();

// Creates an operation builder and ties it to this context.
// The returned object can be used for setting operation's attributes,
// adding inputs and finally executing (immediately or lazily as in tracing)
// it in this context.
public native AbstractOperation CreateOperation();

// Registers a function with this context, after this the function is
// available to be called/referenced by its name in this context.
public native @ByVal Status RegisterFunction(AbstractFunction arg0);
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
public native @ByVal Status RemoveFunction(@StdString BytePointer func);
public native @ByVal Status RemoveFunction(@StdString String func);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;

@Namespace("tensorflow::internal") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class AbstractContextDeleter extends Pointer {
static { Loader.load(); }
/** Default native constructor. */
public AbstractContextDeleter() { super((Pointer)null); allocate(); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public AbstractContextDeleter(long size) { super((Pointer)null); allocateArray(size); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public AbstractContextDeleter(Pointer p) { super(p); }
private native void allocate();
private native void allocateArray(long size);
@Override public AbstractContextDeleter position(long position) {
return (AbstractContextDeleter)super.position(position);
}
@Override public AbstractContextDeleter getPointer(long i) {
return new AbstractContextDeleter((Pointer)this).position(position + i);
}

public native @Name("operator ()") void apply(AbstractContext p);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Targeted by JavaCPP version 1.5.5: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;


// A traced function: this hides the complexity of converting the serialized
// representation between various supported formats e.g. FunctionDef and Mlir
// function.
@Namespace("tensorflow") @NoOffset @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class AbstractFunction extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public AbstractFunction(Pointer p) { super(p); }

// Returns which subclass is this instance of.
public native int getKind();

// Returns the AbstractFunction as a FunctionDef.
public native @ByVal Status GetFunctionDef(@Cast("tensorflow::FunctionDef**") PointerPointer arg0);
public native @ByVal Status GetFunctionDef(@Cast("tensorflow::FunctionDef**") @ByPtrPtr Pointer arg0);
}
Loading

0 comments on commit 93a827e

Please sign in to comment.