diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_CancellationManager.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_CancellationManager.java new file mode 100644 index 00000000000..f0ea1d914b2 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_CancellationManager.java @@ -0,0 +1,21 @@ +// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE + +package org.tensorflow.internal.c_api; + +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + + +// ----------------------------------------------------------------------------- +// Cancellation APIs. + +@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public class TFE_CancellationManager extends Pointer { + /** Empty constructor. Calls {@code super((Pointer)null)}. */ + public TFE_CancellationManager() { super((Pointer)null); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TFE_CancellationManager(Pointer p) { super(p); } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_Executor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_Executor.java new file mode 100644 index 00000000000..a27312ae11b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_Executor.java @@ -0,0 +1,20 @@ +// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE + +package org.tensorflow.internal.c_api; + +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + + +// ----------------------------------------------------------------------------- +// Eager Executor APIs. +@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public class TFE_Executor extends Pointer { + /** Empty constructor. Calls {@code super((Pointer)null)}. */ + public TFE_Executor() { super((Pointer)null); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TFE_Executor(Pointer p) { super(p); } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_MonitoringBuckets.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_MonitoringBuckets.java new file mode 100644 index 00000000000..cdfe2e5c329 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_MonitoringBuckets.java @@ -0,0 +1,19 @@ +// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE + +package org.tensorflow.internal.c_api; + +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + + +// APIs for sampler buckets +@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public class TFE_MonitoringBuckets extends Pointer { + /** Empty constructor. Calls {@code super((Pointer)null)}. */ + public TFE_MonitoringBuckets() { super((Pointer)null); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TFE_MonitoringBuckets(Pointer p) { super(p); } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_MonitoringStringGauge3.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_MonitoringStringGauge3.java new file mode 100644 index 00000000000..86348a4232f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_MonitoringStringGauge3.java @@ -0,0 +1,19 @@ +// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE + +package org.tensorflow.internal.c_api; + +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + + +// APIs for String Gauge with 3 labels. +@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public class TFE_MonitoringStringGauge3 extends Pointer { + /** Empty constructor. Calls {@code super((Pointer)null)}. */ + public TFE_MonitoringStringGauge3() { super((Pointer)null); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TFE_MonitoringStringGauge3(Pointer p) { super(p); } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_MonitoringStringGauge4.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_MonitoringStringGauge4.java new file mode 100644 index 00000000000..afcbcc52089 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_MonitoringStringGauge4.java @@ -0,0 +1,19 @@ +// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE + +package org.tensorflow.internal.c_api; + +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + + +// APIs for String Gauge with 4 labels. +@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public class TFE_MonitoringStringGauge4 extends Pointer { + /** Empty constructor. Calls {@code super((Pointer)null)}. */ + public TFE_MonitoringStringGauge4() { super((Pointer)null); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TFE_MonitoringStringGauge4(Pointer p) { super(p); } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_OpAttrs.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_OpAttrs.java index 30398899f83..89b72d78844 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_OpAttrs.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TFE_OpAttrs.java @@ -8,8 +8,15 @@ import static org.tensorflow.internal.c_api.global.tensorflow.*; -// Parsed from tensorflow/c/eager/c_api_experimental.h +// APIs for generically dealing with op attributes (e.g. when forwarding them +// through custom device implementations). +// +// TODO(allenl): Currently these are black boxes, but we should have some way to +// inspect values. This would let people e.g. copy over most attributes and then +// modify some based on their values. + +// A reference to an op's name -> attribute mapping @Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public class TFE_OpAttrs extends Pointer { /** Empty constructor. Calls {@code super((Pointer)null)}. */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_AttrBuilder.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_AttrBuilder.java new file mode 100644 index 00000000000..69b0906860a --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_AttrBuilder.java @@ -0,0 +1,21 @@ +// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE + +package org.tensorflow.internal.c_api; + +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + + +// TF_NewAttrBuilder() returns an object that you can set attributes on as +// though it were an op. This allows querying properties of that op for +// type-checking purposes like if the op will run on a particular device type. +@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public class TF_AttrBuilder extends Pointer { + /** Empty constructor. Calls {@code super((Pointer)null)}. */ + public TF_AttrBuilder() { super((Pointer)null); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TF_AttrBuilder(Pointer p) { super(p); } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_CheckpointReader.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_CheckpointReader.java new file mode 100644 index 00000000000..8f44b6399e2 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_CheckpointReader.java @@ -0,0 +1,20 @@ +// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE + +package org.tensorflow.internal.c_api; + +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + + +// TF_NewCheckpointReader() return the CheckpointReader that can be use to +// investigate or load the variable from the checkpoint file +@Opaque @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public class TF_CheckpointReader extends Pointer { + /** Empty constructor. Calls {@code super((Pointer)null)}. */ + public TF_CheckpointReader() { super((Pointer)null); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TF_CheckpointReader(Pointer p) { super(p); } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_ShapeAndType.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_ShapeAndType.java new file mode 100644 index 00000000000..e12ed8c27d0 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_ShapeAndType.java @@ -0,0 +1,37 @@ +// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE + +package org.tensorflow.internal.c_api; + +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + + +// Information about the shape of a Tensor and its type. +@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public class TF_ShapeAndType extends Pointer { + static { Loader.load(); } + /** Default native constructor. */ + public TF_ShapeAndType() { super((Pointer)null); allocate(); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public TF_ShapeAndType(long size) { super((Pointer)null); allocateArray(size); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TF_ShapeAndType(Pointer p) { super(p); } + private native void allocate(); + private native void allocateArray(long size); + @Override public TF_ShapeAndType position(long position) { + return (TF_ShapeAndType)super.position(position); + } + @Override public TF_ShapeAndType getPointer(long i) { + return new TF_ShapeAndType((Pointer)this).offsetAddress(i); + } + + // Number of dimensions. -1 indicates unknown rank. + public native int num_dims(); public native TF_ShapeAndType num_dims(int setter); + // Array of dimensions. -1 indicates unknown dim. + public native @Cast("int64_t*") LongPointer dims(); public native TF_ShapeAndType dims(LongPointer setter); + // The data type. May be 0 to denote unknown type. + public native @Cast("TF_DataType") int dtype(); public native TF_ShapeAndType dtype(int setter); +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_ShapeAndTypeList.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_ShapeAndTypeList.java new file mode 100644 index 00000000000..3a7b8652c97 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/TF_ShapeAndTypeList.java @@ -0,0 +1,33 @@ +// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE + +package org.tensorflow.internal.c_api; + +import java.nio.*; +import org.bytedeco.javacpp.*; +import org.bytedeco.javacpp.annotation.*; + +import static org.tensorflow.internal.c_api.global.tensorflow.*; + + +// A list of TF_ShapeAndType elements.. +@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) +public class TF_ShapeAndTypeList extends Pointer { + static { Loader.load(); } + /** Default native constructor. */ + public TF_ShapeAndTypeList() { super((Pointer)null); allocate(); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public TF_ShapeAndTypeList(long size) { super((Pointer)null); allocateArray(size); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public TF_ShapeAndTypeList(Pointer p) { super(p); } + private native void allocate(); + private native void allocateArray(long size); + @Override public TF_ShapeAndTypeList position(long position) { + return (TF_ShapeAndTypeList)super.position(position); + } + @Override public TF_ShapeAndTypeList getPointer(long i) { + return new TF_ShapeAndTypeList((Pointer)this).offsetAddress(i); + } + + public native int num_items(); public native TF_ShapeAndTypeList num_items(int setter); + public native TF_ShapeAndType items(); public native TF_ShapeAndTypeList items(TF_ShapeAndType setter); +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java index 0e153289dff..4656099ea7f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java @@ -3157,6 +3157,406 @@ public static native void TF_RegisterFilesystemPlugin( // #endif // TENSORFLOW_C_C_API_H_ +// Parsed from tensorflow/c/c_api_experimental.h + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// #ifndef TENSORFLOW_C_C_API_EXPERIMENTAL_H_ +// #define TENSORFLOW_C_C_API_EXPERIMENTAL_H_ + +// #include +// #include + +// #include "tensorflow/c/c_api.h" +// #include "tensorflow/c/eager/c_api.h" + +// -------------------------------------------------------------------------- +// Experimental C API for TensorFlow. +// +// The API here is subject to changes in the future. +// -------------------------------------------------------------------------- + +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes.$a +// #ifdef SWIG +// #define TF_CAPI_EXPORT +// #else +// #if defined(_WIN32) +// #ifdef TF_COMPILE_LIBRARY +// #define TF_CAPI_EXPORT __declspec(dllexport) +// #else +// #define TF_CAPI_EXPORT __declspec(dllimport) +// #endif // TF_COMPILE_LIBRARY +// #else +// #define TF_CAPI_EXPORT __attribute__((visibility("default"))) +// #endif // _WIN32 +// #endif // SWIG + +// #ifdef __cplusplus +// #endif + +// When `enable` is true, set +// tensorflow.ConfigProto.OptimizerOptions.global_jit_level to ON_1, and also +// set XLA flag values to prepare for XLA compilation. Otherwise set +// global_jit_level to OFF. +// +// This and the next API are syntax sugar over TF_SetConfig(), and is used by +// clients that cannot read/write the tensorflow.ConfigProto proto. +// TODO: Migrate to TF_CreateConfig() below. +public static native void TF_EnableXLACompilation(TF_SessionOptions options, + @Cast("unsigned char") byte enable); + +// Set XLA's internal BuildXlaOpsPassFlags.tf_xla_enable_lazy_compilation to the +// value of 'enabled'. Also returns the original value of that flag. +// +// Use in tests to allow XLA to fallback to TF classic. This has global effect. +public static native @Cast("unsigned char") byte TF_SetXlaEnableLazyCompilation( + @Cast("unsigned char") byte enable); +public static native @Cast("unsigned char") byte TF_SetTfXlaCpuGlobalJit(@Cast("unsigned char") byte enable); + +// Sets XLA's auto jit mode according to the specified string, which is parsed +// as if passed in XLA_FLAGS. This has global effect. +public static native void TF_SetXlaAutoJitMode(@Cast("const char*") BytePointer mode); +public static native void TF_SetXlaAutoJitMode(String mode); + +// Sets XLA's minimum cluster size. This has global effect. +public static native void TF_SetXlaMinClusterSize(int size); + +// Gets/Sets TF/XLA flag for whether(true) or not(false) to disable constant +// folding. This is for testing to ensure that XLA is being tested rather than +// Tensorflow's CPU implementation through constant folding. +public static native @Cast("unsigned char") byte TF_GetXlaConstantFoldingDisabled(); +public static native void TF_SetXlaConstantFoldingDisabled( + @Cast("unsigned char") byte should_enable); + +// Create a serialized tensorflow.ConfigProto proto, where: +// +// a) ConfigProto.optimizer_options.global_jit_level is set to ON_1 if +// `enable_xla_compilation` is non-zero, and OFF otherwise. +// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. +// c) ConfigProto.device_count is set to `num_cpu_devices`. +public static native TF_Buffer TF_CreateConfig( + @Cast("unsigned char") byte enable_xla_compilation, @Cast("unsigned char") byte gpu_memory_allow_growth, + @Cast("unsigned int") int num_cpu_devices); + +// Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level +// is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE +// otherwise. +public static native TF_Buffer TF_CreateRunOptions( + @Cast("unsigned char") byte enable_full_trace); + +// Returns the graph content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +public static native @Cast("const char*") BytePointer TF_GraphDebugString(TF_Graph graph, + @Cast("size_t*") SizeTPointer len); + +// Returns the function content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +// +// Do not return const char*, because some foreign language binding +// (e.g. swift) cannot then call free() on the returned pointer. +public static native @Cast("char*") BytePointer TF_FunctionDebugString(TF_Function func, + @Cast("size_t*") SizeTPointer len); + +// On success, dequeues a tensor from a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_dequeue_", to be executed by this API call. + +// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is +// empty, this call is blocked. +// +// Tensors are enqueued via the corresponding TF enqueue op. +// TODO(hongm): Add support for `timeout_ms`. +public static native TF_Tensor TF_DequeueNamedTensor(TF_Session session, + int tensor_id, + TF_Status status); + +// On success, enqueues `tensor` into a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_enqueue_", to be executed by this API call. It reads +// from a placeholder node "arg_tensor_enqueue_". +// +// `tensor` is still owned by the caller. This call will be blocked if the queue +// has reached its capacity, and will be unblocked when the queued tensors again +// drop below the capacity due to dequeuing. +// +// Tensors are dequeued via the corresponding TF dequeue op. +// TODO(hongm): Add support for `timeout_ms`. +public static native void TF_EnqueueNamedTensor(TF_Session session, + int tensor_id, + TF_Tensor tensor, + TF_Status status); +// Create a serialized tensorflow.ServerDef proto. +public static native TF_Buffer TFE_GetServerDef(@Cast("const char*") BytePointer text_proto, TF_Status status); +public static native TF_Buffer TFE_GetServerDef(String text_proto, TF_Status status); + +public static native void TF_MakeInternalErrorStatus(TF_Status status, + @Cast("const char*") BytePointer errMsg); +public static native void TF_MakeInternalErrorStatus(TF_Status status, + String errMsg); +// Targeting ../TF_CheckpointReader.java + + +public static native TF_CheckpointReader TF_NewCheckpointReader( + @Cast("const char*") BytePointer filename, TF_Status status); +public static native TF_CheckpointReader TF_NewCheckpointReader( + String filename, TF_Status status); +public static native void TF_DeleteCheckpointReader( + TF_CheckpointReader reader); +public static native int TF_CheckpointReaderHasTensor( + TF_CheckpointReader reader, @Cast("const char*") BytePointer name); +public static native int TF_CheckpointReaderHasTensor( + TF_CheckpointReader reader, String name); +// Get the variable name at the given index +public static native @Cast("const char*") BytePointer TF_CheckpointReaderGetVariable( + TF_CheckpointReader reader, int index); +// Get the number of variable in the checkpoint +public static native int TF_CheckpointReaderSize(TF_CheckpointReader reader); +// Get the DataType of a variable +public static native @Cast("TF_DataType") int TF_CheckpointReaderGetVariableDataType( + TF_CheckpointReader reader, @Cast("const char*") BytePointer name); +public static native @Cast("TF_DataType") int TF_CheckpointReaderGetVariableDataType( + TF_CheckpointReader reader, String name); +// Read the shape of a variable and write to `dims` +public static native void TF_CheckpointReaderGetVariableShape( + TF_CheckpointReader reader, @Cast("const char*") BytePointer name, @Cast("int64_t*") LongPointer dims, int num_dims, + TF_Status status); +public static native void TF_CheckpointReaderGetVariableShape( + TF_CheckpointReader reader, String name, @Cast("int64_t*") LongBuffer dims, int num_dims, + TF_Status status); +public static native void TF_CheckpointReaderGetVariableShape( + TF_CheckpointReader reader, @Cast("const char*") BytePointer name, @Cast("int64_t*") long[] dims, int num_dims, + TF_Status status); +public static native void TF_CheckpointReaderGetVariableShape( + TF_CheckpointReader reader, String name, @Cast("int64_t*") LongPointer dims, int num_dims, + TF_Status status); +public static native void TF_CheckpointReaderGetVariableShape( + TF_CheckpointReader reader, @Cast("const char*") BytePointer name, @Cast("int64_t*") LongBuffer dims, int num_dims, + TF_Status status); +public static native void TF_CheckpointReaderGetVariableShape( + TF_CheckpointReader reader, String name, @Cast("int64_t*") long[] dims, int num_dims, + TF_Status status); +// Get the number of dimension of a variable +public static native int TF_CheckpointReaderGetVariableNumDims( + TF_CheckpointReader reader, @Cast("const char*") BytePointer name); +public static native int TF_CheckpointReaderGetVariableNumDims( + TF_CheckpointReader reader, String name); +// Load the weight of a variable +public static native TF_Tensor TF_CheckpointReaderGetTensor( + TF_CheckpointReader reader, @Cast("const char*") BytePointer name, TF_Status status); +public static native TF_Tensor TF_CheckpointReaderGetTensor( + TF_CheckpointReader reader, String name, TF_Status status); +// Targeting ../TF_AttrBuilder.java + + +public static native TF_AttrBuilder TF_NewAttrBuilder(@Cast("const char*") BytePointer op_name); +public static native TF_AttrBuilder TF_NewAttrBuilder(String op_name); +public static native void TF_DeleteAttrBuilder(TF_AttrBuilder builder); +public static native void TF_AttrBuilderSetType(TF_AttrBuilder builder, + @Cast("const char*") BytePointer attr_name, + @Cast("TF_DataType") int value); +public static native void TF_AttrBuilderSetType(TF_AttrBuilder builder, + String attr_name, + @Cast("TF_DataType") int value); +public static native void TF_AttrBuilderSetTypeList(TF_AttrBuilder builder, + @Cast("const char*") BytePointer attr_name, + @Cast("const TF_DataType*") IntPointer values, + int num_values); +public static native void TF_AttrBuilderSetTypeList(TF_AttrBuilder builder, + String attr_name, + @Cast("const TF_DataType*") IntBuffer values, + int num_values); +public static native void TF_AttrBuilderSetTypeList(TF_AttrBuilder builder, + @Cast("const char*") BytePointer attr_name, + @Cast("const TF_DataType*") int[] values, + int num_values); +public static native void TF_AttrBuilderSetTypeList(TF_AttrBuilder builder, + String attr_name, + @Cast("const TF_DataType*") IntPointer values, + int num_values); +public static native void TF_AttrBuilderSetTypeList(TF_AttrBuilder builder, + @Cast("const char*") BytePointer attr_name, + @Cast("const TF_DataType*") IntBuffer values, + int num_values); +public static native void TF_AttrBuilderSetTypeList(TF_AttrBuilder builder, + String attr_name, + @Cast("const TF_DataType*") int[] values, + int num_values); + +// Checks the tensorflow::NodeDef built via the methods above to see if it can +// run on device_type. +public static native void TF_AttrBuilderCheckCanRunOnDevice( + TF_AttrBuilder builder, @Cast("const char*") BytePointer device_type, TF_Status status); +public static native void TF_AttrBuilderCheckCanRunOnDevice( + TF_AttrBuilder builder, String device_type, TF_Status status); + +// For argument number input_index, fetch the corresponding number_attr that +// needs to be updated with the argument length of the input list. +// Returns nullptr if there is any problem like op_name is not found, or the +// argument does not support this attribute type. +public static native @Cast("const char*") BytePointer TF_GetNumberAttrForOpListInput( + @Cast("const char*") BytePointer op_name, int input_index, TF_Status status); +public static native String TF_GetNumberAttrForOpListInput( + String op_name, int input_index, TF_Status status); + +// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined +// if the status is not ok. +public static native int TF_OpIsStateful(@Cast("const char*") BytePointer op_type, + TF_Status status); +public static native int TF_OpIsStateful(String op_type, + TF_Status status); + +// Platform specific initialization routine. Very few platforms actually require +// this to be called. +public static native void TF_InitMain(@Cast("const char*") BytePointer usage, IntPointer argc, @Cast("char***") @ByPtrPtr PointerPointer argv); +public static native void TF_InitMain(String usage, IntBuffer argc, @Cast("char***") @ByPtrPtr PointerPointer argv); +public static native void TF_InitMain(@Cast("const char*") BytePointer usage, int[] argc, @Cast("char***") @ByPtrPtr PointerPointer argv); +public static native void TF_InitMain(String usage, IntPointer argc, @Cast("char***") @ByPtrPtr PointerPointer argv); +public static native void TF_InitMain(@Cast("const char*") BytePointer usage, IntBuffer argc, @Cast("char***") @ByPtrPtr PointerPointer argv); +public static native void TF_InitMain(String usage, int[] argc, @Cast("char***") @ByPtrPtr PointerPointer argv); + +// Platform-specific implementation to return an unused port. (This should used +// in tests only.) +public static native int TF_PickUnusedPortOrDie(); + +// Fast path method that makes constructing a single scalar tensor require less +// overhead and copies. +public static native TFE_TensorHandle TFE_NewTensorHandleFromScalar( + @Cast("TF_DataType") int data_type, Pointer data, @Cast("size_t") long len, TF_Status status); + +// Specify the server_def that enables collective ops. +// This is different to the above function in that it doesn't create remote +// contexts, and remotely executing ops is not possible. It just enables +// communication for collective ops. +public static native void TFE_EnableCollectiveOps(TFE_Context ctx, + @Const Pointer proto, + @Cast("size_t") long proto_len, + TF_Status status); + +// Aborts all ongoing collectives with the specified status. After abortion, +// subsequent collectives will error with this status immediately. To reset the +// collectives, create a new EagerContext. +// +// This is intended to be used when a peer failure is detected. +public static native void TFE_AbortCollectiveOps(TFE_Context ctx, + TF_Status status); + +// Checks the health of collective ops peers. Explicit health check is needed in +// multi worker collective ops to detect failures in the cluster. If a peer is +// down, collective ops may hang. +public static native void TFE_CollectiveOpsCheckPeerHealth( + TFE_Context ctx, @Cast("const char*") BytePointer task, @Cast("int64_t") long timeout_in_ms, + TF_Status status); +public static native void TFE_CollectiveOpsCheckPeerHealth( + TFE_Context ctx, String task, @Cast("int64_t") long timeout_in_ms, + TF_Status status); +// Targeting ../TF_ShapeAndType.java + + +// Targeting ../TF_ShapeAndTypeList.java + + + +// API for manipulating TF_ShapeAndTypeList objects. +// +public static native TF_ShapeAndTypeList TF_NewShapeAndTypeList( + int num_shapes); +public static native void TF_ShapeAndTypeListSetShape( + TF_ShapeAndTypeList shape_list, int index, @Cast("const int64_t*") LongPointer dims, + int num_dims); +public static native void TF_ShapeAndTypeListSetShape( + TF_ShapeAndTypeList shape_list, int index, @Cast("const int64_t*") LongBuffer dims, + int num_dims); +public static native void TF_ShapeAndTypeListSetShape( + TF_ShapeAndTypeList shape_list, int index, @Cast("const int64_t*") long[] dims, + int num_dims); +public static native void TF_ShapeAndTypeListSetUnknownShape( + TF_ShapeAndTypeList shape_list, int index); +public static native void TF_ShapeAndTypeListSetDtype( + TF_ShapeAndTypeList shape_list, int index, @Cast("TF_DataType") int dtype); +public static native void TF_DeleteShapeAndTypeList( + TF_ShapeAndTypeList shape_list); +public static native void TF_DeleteShapeAndTypeListArray( + @Cast("TF_ShapeAndTypeList**") PointerPointer shape_list_array, int num_items); +public static native void TF_DeleteShapeAndTypeListArray( + @ByPtrPtr TF_ShapeAndTypeList shape_list_array, int num_items); + +// Infer shapes for the given `op`. The arguments mimic the arguments of the +// `shape_inference::InferenceContext` constructor. Note the following: +// - The inputs of the `op` are not used for shape inference. So, it is +// OK to not have the inputs properly set in `op`. See `input_tensors` +// if you want shape inference to consider the input tensors of the +// op for shape inference. +// - The types need not be set in `input_shapes` as it is not used. +// - The number of `input_tensors` should be the same as the number of items +// in `input_shapes`. +// +// The results are returned in `output_shapes` and +// `output_resource_shapes_and_types`. The caller is responsible for freeing the +// memory in these buffers by calling `TF_DeleteShapeAndTypeList`. +public static native void TFE_InferShapes( + TFE_Op op, TF_ShapeAndTypeList input_shapes, @Cast("TF_Tensor**") PointerPointer input_tensors, + TF_ShapeAndTypeList input_tensor_as_shapes, + @Cast("TF_ShapeAndTypeList**") PointerPointer input_resource_shapes_and_types, + @Cast("TF_ShapeAndTypeList**") PointerPointer output_shapes, + @Cast("TF_ShapeAndTypeList***") @ByPtrPtr PointerPointer output_resource_shapes_and_types, TF_Status status); +public static native void TFE_InferShapes( + TFE_Op op, TF_ShapeAndTypeList input_shapes, @ByPtrPtr TF_Tensor input_tensors, + TF_ShapeAndTypeList input_tensor_as_shapes, + @ByPtrPtr TF_ShapeAndTypeList input_resource_shapes_and_types, + @ByPtrPtr TF_ShapeAndTypeList output_shapes, + @Cast("TF_ShapeAndTypeList***") @ByPtrPtr PointerPointer output_resource_shapes_and_types, TF_Status status); + +public static native void TF_ImportGraphDefOptionsSetValidateColocationConstraints( + TF_ImportGraphDefOptions opts, @Cast("unsigned char") byte enable); + +// Load the library specified by library_filename and register the pluggable +// device and related kernels present in that library. This function is not +// supported on embedded on mobile and embedded platforms and will fail if +// called. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The rules for determining the exact location of the +// library are platform-specific and are not documented here. +// +// On success, returns the newly created library handle and places OK in status. +// The caller owns the library handle. +// +// On failure, returns nullptr and places an error status in status. +public static native TF_Library TF_LoadPluggableDeviceLibrary( + @Cast("const char*") BytePointer library_filename, TF_Status status); +public static native TF_Library TF_LoadPluggableDeviceLibrary( + String library_filename, TF_Status status); + +// Frees the memory associated with the library handle. +// Does NOT unload the library. +public static native void TF_DeletePluggableDeviceLibraryHandle( + TF_Library lib_handle); + +// #ifdef __cplusplus /* end extern "C" */ +// #endif + +// #endif // TENSORFLOW_C_C_API_EXPERIMENTAL_H_ + + // Parsed from tensorflow/c/kernels.h /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. @@ -4831,6 +5231,311 @@ public static native void TFE_ContextExportRunMetadata(TFE_Context ctx, // #endif // TENSORFLOW_C_EAGER_C_API_H_ +// Parsed from tensorflow/c/eager/c_api_experimental.h + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// #ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ +// #define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ + +// #include "tensorflow/c/c_api.h" +// #include "tensorflow/c/eager/c_api.h" + +// #ifdef __cplusplus +// #endif + +// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This +// is for performance optimization by reusing an exiting unused op rather than +// creating a new op every time. If `raw_device_name` is `NULL` or empty, it +// does not set the device name. If it's not `NULL`, then it attempts to parse +// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster +// than separately calling it because if the existing op has the same +// `raw_device_name`, it skips parsing and just leave as it is. +public static native void TFE_OpReset(TFE_Op op_to_reset, + @Cast("const char*") BytePointer op_or_function_name, + @Cast("const char*") BytePointer raw_device_name, + TF_Status status); +public static native void TFE_OpReset(TFE_Op op_to_reset, + String op_or_function_name, + String raw_device_name, + TF_Status status); + +// Enables only graph collection in RunMetadata on the functions executed from +// this context. +public static native void TFE_ContextEnableGraphCollection(TFE_Context ctx); + +// Disables only graph collection in RunMetadata on the functions executed from +// this context. +public static native void TFE_ContextDisableGraphCollection(TFE_Context ctx); + +// TODO(fishx): Move these monitoring APIs into a separate file. +// ----------------------------------------------------------------------------- +// Monitoring Counter APIs. +// These APIs de-templated monitoring Counter for swig. + +// Atomically increments the value of the cell. The value must be non-negative. + +// Retrieves the current value of the cell. + +// APIs for Counter without label. +// Returns a new Counter metric object. The caller should manage lifetime of +// the object. Using duplicate metric name will crash the program with fatal +// error. +// Deletes the Counter object. +// Retrieves the cell from the Counter object. The Counter object will manage +// lifetime of the cell. + +// APIs for Counter with 1 label. + +// APIs for Counter with 2 labels. + +// ----------------------------------------------------------------------------- +// Monitoring Gauge APIs. +// These APIs de-templated monitoring Gauge for swig. + +// Atomically set the value of the cell. + +// Retrieves the current value of the cell. + +// APIs for Int Gauge without label. + +// APIs for Int Gauge with 1 label. + +// APIs for Int Gauge with 2 label. +// Retrieves the string value and saves it in buffer. + +// APIs for String Gauge without label. + +// APIs for String Gauge with 1 label. + +// APIs for String Gauge with 2 label. +// Targeting ../TFE_MonitoringStringGauge3.java + + +public static native TFE_MonitoringStringGauge3 TFE_MonitoringNewStringGauge3( + @Cast("const char*") BytePointer name, TF_Status out_status, @Cast("const char*") BytePointer description, + @Cast("const char*") BytePointer label1, @Cast("const char*") BytePointer label2, @Cast("const char*") BytePointer label3); +public static native TFE_MonitoringStringGauge3 TFE_MonitoringNewStringGauge3( + String name, TF_Status out_status, String description, + String label1, String label2, String label3); +public static native void TFE_MonitoringDeleteStringGauge3( + TFE_MonitoringStringGauge3 gauge); +// Targeting ../TFE_MonitoringStringGauge4.java + + +public static native TFE_MonitoringStringGauge4 TFE_MonitoringNewStringGauge4( + @Cast("const char*") BytePointer name, TF_Status out_status, @Cast("const char*") BytePointer description, + @Cast("const char*") BytePointer label1, @Cast("const char*") BytePointer label2, @Cast("const char*") BytePointer label3, + @Cast("const char*") BytePointer label4); +public static native TFE_MonitoringStringGauge4 TFE_MonitoringNewStringGauge4( + String name, TF_Status out_status, String description, + String label1, String label2, String label3, + String label4); +public static native void TFE_MonitoringDeleteStringGauge4( + TFE_MonitoringStringGauge4 gauge); + +// APIs for Bool Gauge without label. + +// APIs for Bool Gauge with 1 label. + +// APIs for Bool Gauge with 2 label. + +// ----------------------------------------------------------------------------- +// Monitoring Sampler APIs. +// These APIs de-templated monitoring Sampler for swig. + +// Atomically add the value of the cell. + +// Retrieves the current value of the cell. The return value is a HistogramProto +// saved in buffer. +// Targeting ../TFE_MonitoringBuckets.java + + +public static native TFE_MonitoringBuckets TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor, + int bucket_count); +public static native void TFE_MonitoringDeleteBuckets( + TFE_MonitoringBuckets buckets); + +// APIs for Sampler without label. + +// APIs for Sampler with 1 label. + +// APIs for Sampler with 2 label. + +// Sets whether to use TFRT +public static native void TFE_ContextOptionsSetTfrt(TFE_ContextOptions arg0, + @Cast("bool") boolean use_tfrt); + +// Sets whether to use TFRT distributed runtime +public static native void TFE_ContextOptionsSetTfrtDistributedRuntime( + TFE_ContextOptions options, @Cast("bool") boolean use_tfrt_distributed_runtime); + +// Returns the context_id from the EagerContext which is used by the +// EagerService to maintain consistency between client and worker. The +// context_id is initialized with a dummy value and is later set when the worker +// is initialized (either locally or remotely). The context_id can change during +// the process lifetime although this should cause the worker to be +// reinitialized (e.g. cleared caches) as well. +public static native @Cast("uint64_t") long TFE_GetContextId(TFE_Context ctx); +// Targeting ../TFE_CancellationManager.java + + +public static native TFE_CancellationManager TFE_NewCancellationManager(); +public static native @Cast("bool") boolean TFE_CancellationManagerIsCancelled( + TFE_CancellationManager arg0); +public static native void TFE_CancellationManagerStartCancel( + TFE_CancellationManager arg0); +public static native void TFE_DeleteCancellationManager( + TFE_CancellationManager arg0); + +// Associates the given `cancellation_manager` with `op`, so that invoking +// `TFE_CancellationManagerStartCancel(cancellation_manager)` will cancel the +// execution of `op`. +public static native void TFE_OpSetCancellationManager( + TFE_Op op, TFE_CancellationManager cancellation_manager, + TF_Status status); +// Targeting ../TFE_Executor.java + + + +// Creates a new eager Executor. Nodes in one executor are guaranteed to be +// executed in sequence. Assigning nodes to different executors allows executing +// nodes in parallel. +public static native TFE_Executor TFE_NewExecutor(@Cast("bool") boolean is_async); + +// Deletes the eager Executor without waiting for enqueued nodes. Please call +// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to +// make sure all nodes are finished. +public static native void TFE_DeleteExecutor(TFE_Executor arg0); + +// Returns true if the executor is in async mode. +public static native @Cast("bool") boolean TFE_ExecutorIsAsync(TFE_Executor arg0); + +// Causes the calling thread to block till all ops dispatched in this executor +// have been executed. Note that "execution" here refers to kernel execution / +// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee +// that lower level device queues (like GPU streams) have been flushed. +// +// This call may not block for execution of ops enqueued concurrently with this +// call. +public static native void TFE_ExecutorWaitForAllPendingNodes( + TFE_Executor arg0, TF_Status status); + +// When an error happens, any pending operations are discarded and newly issued +// ops return an error. This call clears the error state and re-enables +// execution of newly issued ops. +// +// Note that outputs of discarded ops remain in a corrupt state and should not +// be used for future calls. +// TODO(agarwal): mark the affected handles and raise errors if they are used. +public static native void TFE_ExecutorClearError(TFE_Executor arg0); + +// Sets a custom Executor for current thread. All nodes created by this thread +// will be added to this Executor. It will override current executor. +public static native void TFE_ContextSetExecutorForThread(TFE_Context arg0, + TFE_Executor arg1); + +// Returns the Executor for current thread. +public static native TFE_Executor TFE_ContextGetExecutorForThread( + TFE_Context arg0); + +// ----------------------------------------------------------------------------- +// Dynamic cluster API. + +// Update an existing context with a new set of servers defined in a ServerDef +// proto. Servers can be added to and removed from the list of remote workers +// in the context. New set of servers identified by the ServerDef must be up +// when the context is updated. +// +// This API is for experimental usage and may be subject to change. +public static native void TFE_ContextUpdateServerDef(TFE_Context ctx, + int keep_alive_secs, + @Const Pointer proto, + @Cast("size_t") long proto_len, + TF_Status status); + +// Checks whether a remote worker is alive or not. This will return true even if +// the context doesn't exist on the remote worker. +public static native @Cast("bool") boolean TFE_ContextCheckAlive(TFE_Context ctx, + @Cast("const char*") BytePointer worker_name, + TF_Status status); +public static native @Cast("bool") boolean TFE_ContextCheckAlive(TFE_Context ctx, + String worker_name, + TF_Status status); + +// Sync pending nodes in local executors (including the context default executor +// and thread executors) and streaming requests to remote executors, and get the +// combined status. +public static native void TFE_ContextAsyncWait(TFE_Context ctx, + TF_Status status); + +// This function will block till the operation that produces `h` has +// completed. This is only valid on local TFE_TensorHandles. The pointer +// returned will be on the device in which the TFE_TensorHandle resides (so e.g. +// for a GPU tensor this will return a pointer to GPU memory). The pointer is +// only guaranteed to be valid until TFE_DeleteTensorHandle is called on this +// TensorHandle. Only supports POD data types. +public static native Pointer TFE_TensorHandleDevicePointer(TFE_TensorHandle arg0, + TF_Status arg1); + +// This function will block till the operation that produces `h` has +// completed. This is only valid on local TFE_TensorHandles. Returns the size in +// bytes of the memory pointed to by the device pointer returned above. +public static native @Cast("size_t") long TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle arg0, + TF_Status arg1); + +// Creates a new TensorHandle from memory residing in the physical device +// device_name. Takes ownership of the memory, and will call deleter to release +// it after TF no longer needs it or in case of error. +// +// Custom devices must use TFE_NewCustomDeviceTensorHandle instead. +public static native TFE_TensorHandle TFE_NewTensorHandleFromDeviceMemory( + TFE_Context ctx, @Cast("const char*") BytePointer device_name, @Cast("TF_DataType") int arg2, @Cast("const int64_t*") LongPointer dims, + int num_dims, Pointer data, @Cast("size_t") long len, + Deallocator_Pointer_long_Pointer deallocator, + Pointer deallocator_arg, TF_Status status); +public static native TFE_TensorHandle TFE_NewTensorHandleFromDeviceMemory( + TFE_Context ctx, String device_name, @Cast("TF_DataType") int arg2, @Cast("const int64_t*") LongBuffer dims, + int num_dims, Pointer data, @Cast("size_t") long len, + Deallocator_Pointer_long_Pointer deallocator, + Pointer deallocator_arg, TF_Status status); +public static native TFE_TensorHandle TFE_NewTensorHandleFromDeviceMemory( + TFE_Context ctx, @Cast("const char*") BytePointer device_name, @Cast("TF_DataType") int arg2, @Cast("const int64_t*") long[] dims, + int num_dims, Pointer data, @Cast("size_t") long len, + Deallocator_Pointer_long_Pointer deallocator, + Pointer deallocator_arg, TF_Status status); +public static native TFE_TensorHandle TFE_NewTensorHandleFromDeviceMemory( + TFE_Context ctx, String device_name, @Cast("TF_DataType") int arg2, @Cast("const int64_t*") LongPointer dims, + int num_dims, Pointer data, @Cast("size_t") long len, + Deallocator_Pointer_long_Pointer deallocator, + Pointer deallocator_arg, TF_Status status); +public static native TFE_TensorHandle TFE_NewTensorHandleFromDeviceMemory( + TFE_Context ctx, @Cast("const char*") BytePointer device_name, @Cast("TF_DataType") int arg2, @Cast("const int64_t*") LongBuffer dims, + int num_dims, Pointer data, @Cast("size_t") long len, + Deallocator_Pointer_long_Pointer deallocator, + Pointer deallocator_arg, TF_Status status); +public static native TFE_TensorHandle TFE_NewTensorHandleFromDeviceMemory( + TFE_Context ctx, String device_name, @Cast("TF_DataType") int arg2, @Cast("const int64_t*") long[] dims, + int num_dims, Pointer data, @Cast("size_t") long len, + Deallocator_Pointer_long_Pointer deallocator, + Pointer deallocator_arg, TF_Status status); + +// Retrieves the address space (i.e. job, replia, task) of the local host and +// saves it in the buffer. +public static native void TFE_HostAddressSpace(TFE_Context ctx, + TF_Buffer buf); // Targeting ../TFE_OpAttrs.java @@ -4869,6 +5574,205 @@ public static native void TFE_OpSetAttrValueProto(@Const TFE_Op op, public static final int TFE_CUSTOM_DEVICE_VERSION = 4; +// Struct to be filled in. Functions are required except where indicated. + +// Registers a custom device for use with eager execution. +// +// Eager operations may be placed on this device, e.g. `with +// tf.device("CUSTOM"):` from Python if `device_name` for this call is +// "/job:localhost/replica:0/task:0/device:CUSTOM:0". +// +// The custom device defines copy operations for moving TensorHandles on and +// off, and an execution operation for named operations. Often execution will +// simply wrap op execution on one or more physical devices. +// +// device_info is an opaque caller-defined type stored with the custom device +// which is passed to the functions referenced in the TFE_CustomDevice struct +// `device` (execute, delete_device, etc.). It can for example contain the +// names of wrapped devices. +// +// There are currently no graph semantics implemented for registered custom +// devices, so executing tf.functions which contain operations placed on custom +// devices will fail. +// +// `device_name` must not name an existing physical or custom device. It must +// follow the format: +// +// /job:/replica:/task:/device:: +// +// If the device is successfully registered, `status` is set to TF_OK. Otherwise +// the device is not usable. In case of a bad status, `device.delete_device` is +// still called on `device_info` (i.e. the caller does not retain ownership). +// +// This API is highly experimental, and in particular is expected to change when +// it starts supporting operations with attributes and when tf.function support +// is added. + +// Struct to be filled in to define a custom device tensor handle. Fields are +// required except where indicated. + +// Creates a new TensorHandle from memory residing in a custom device. Takes +// ownership of the memory pointed to by `tensor_handle_data`, and calls +// `methods.deallocator` to release it after TF no longer needs it or in case of +// an error. +// +// This call is similar to `TFE_NewTensorHandleFromDeviceMemory`, but supports +// custom devices instead of physical devices and does not require blocking +// waiting for exact shapes. + +public static native void TFE_ContextGetFunctionDef(TFE_Context ctx, + @Cast("const char*") BytePointer function_name, + TF_Buffer buf, + TF_Status status); +public static native void TFE_ContextGetFunctionDef(TFE_Context ctx, + String function_name, + TF_Buffer buf, + TF_Status status); + +// Allocate and return a new Tensor on the host. +// +// The caller must set the Tensor values by writing them to the pointer returned +// by TF_TensorData with length TF_TensorByteSize. +public static native TF_Tensor TFE_AllocateHostTensor(TFE_Context ctx, + @Cast("TF_DataType") int dtype, + @Cast("const int64_t*") LongPointer dims, + int num_dims, + TF_Status status); +public static native TF_Tensor TFE_AllocateHostTensor(TFE_Context ctx, + @Cast("TF_DataType") int dtype, + @Cast("const int64_t*") LongBuffer dims, + int num_dims, + TF_Status status); +public static native TF_Tensor TFE_AllocateHostTensor(TFE_Context ctx, + @Cast("TF_DataType") int dtype, + @Cast("const int64_t*") long[] dims, + int num_dims, + TF_Status status); + +// Given a Tensor, wrap it with a TensorHandle +// +// Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context. +// The context should be identical to that of the Tensor. +public static native TFE_TensorHandle TFE_NewTensorHandleFromTensor( + TFE_Context ctx, TF_Tensor t, TF_Status status); + +// Create a packed TensorHandle with the given list of TensorHandles. +// If `handles` are on the same device, assign the same device to the packed +// handle; if `handles` are on different deivces, assign a CompositeDevice to +// it. +public static native TFE_TensorHandle TFE_CreatePackedTensorHandle( + TFE_Context ctx, @Cast("TFE_TensorHandle**") PointerPointer handles, IntPointer num_handles, + TF_Status status); +public static native TFE_TensorHandle TFE_CreatePackedTensorHandle( + TFE_Context ctx, @ByPtrPtr TFE_TensorHandle handles, IntPointer num_handles, + TF_Status status); +public static native TFE_TensorHandle TFE_CreatePackedTensorHandle( + TFE_Context ctx, @ByPtrPtr TFE_TensorHandle handles, IntBuffer num_handles, + TF_Status status); +public static native TFE_TensorHandle TFE_CreatePackedTensorHandle( + TFE_Context ctx, @ByPtrPtr TFE_TensorHandle handles, int[] num_handles, + TF_Status status); + +// Configure soft device placement policy for the eager executor. Note this +// policy is applied to any subsequent op executions. +public static native void TFE_ContextSetSoftDevicePlacement(TFE_Context ctx, + @Cast("unsigned char") byte enable, + TF_Status status); + +// Configure device placement policy logging for the eager executor. Note this +// policy is applied to any subsequent op executions. +public static native void TFE_ContextSetLogDevicePlacement(TFE_Context ctx, + @Cast("unsigned char") byte enable, + TF_Status status); + +// Returns the device type of the operation that produced `h`. +public static native @Cast("const char*") BytePointer TFE_TensorHandleDeviceType( + TFE_TensorHandle h, TF_Status status); + +// Returns the device ID of the operation that produced `h`. +public static native int TFE_TensorHandleDeviceID(TFE_TensorHandle h, + TF_Status status); + +// Returns the status for the tensor handle. In TFRT, a tensor handle can carry +// error info if error happens. If so, status will be set with the error info. +// If not, status will be set as OK. +public static native void TFE_TensorHandleGetStatus(TFE_TensorHandle h, + TF_Status status); + +// Get a comma-separated list of op names executed in graph functions dispatched +// to `ctx`. This feature is currently only enabled for TFRT debug builds, for +// performance and simplicity reasons. +public static native void TFE_GetExecutedOpNames(TFE_Context ctx, + TF_Buffer buf, + TF_Status status); + +// Set logical devices to the context's device manager. +// If logical devices are already configured at context initialization +// through TFE_ContextOptions, this method should not be called. +public static native void TFE_SetLogicalCpuDevices(TFE_Context ctx, + int num_cpus, + @Cast("const char*") BytePointer prefix, + TF_Status status); +public static native void TFE_SetLogicalCpuDevices(TFE_Context ctx, + int num_cpus, + String prefix, + TF_Status status); + +// Set configuration key and value using coordination service. +// If coordination service is enabled, the key-value will be stored on the +// leader and become accessible to all workers in the cluster. +// Currently, a config key can only be set with one value, and subsequently +// setting the same key will lead to errors. +// +// Note that the key-values are only expected to be used for cluster +// configuration data, and should not be used for storing large amount of data +// or being accessed very frequently. +public static native void TFE_InsertConfigKeyValue(TFE_Context ctx, + @Cast("const char*") BytePointer key, + @Cast("const char*") BytePointer value, + TF_Status status); +public static native void TFE_InsertConfigKeyValue(TFE_Context ctx, + String key, + String value, + TF_Status status); + +// Get configuration key and value using coordination service. +// The config key must be set before getting its value. Getting value of +// non-existing config keys will result in errors. +public static native void TFE_GetConfigKeyValue(TFE_Context ctx, + @Cast("const char*") BytePointer key, + TF_Buffer value_buf, + TF_Status status); +public static native void TFE_GetConfigKeyValue(TFE_Context ctx, + String key, + TF_Buffer value_buf, + TF_Status status); + +// Delete configuration key-value. If `key` is a directory, recursively clean up +// all key-values under the path specified by `key`. +public static native void TFE_DeleteConfigKeyValue(TFE_Context ctx, + @Cast("const char*") BytePointer key, + TF_Status status); +public static native void TFE_DeleteConfigKeyValue(TFE_Context ctx, + String key, + TF_Status status); + +// Report error (specified by error_code and error_message) to other tasks in +// the cluster. +public static native void TFE_ReportErrorToCluster(TFE_Context ctx, + int error_code, + @Cast("const char*") BytePointer error_message, + TF_Status status); +public static native void TFE_ReportErrorToCluster(TFE_Context ctx, + int error_code, + String error_message, + TF_Status status); + +// #ifdef __cplusplus /* end extern "C" */ +// #endif + +// #endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ + // Parsed from tensorflow/cc/framework/scope.h diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java index 9b3258fb08c..c5983d554dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java @@ -51,6 +51,7 @@ "tensorflow/c/tf_tensor.h", "tensorflow/c/tf_tstring.h", "tensorflow/c/c_api.h", + "tensorflow/c/c_api_experimental.h", // "tensorflow/c/env.h", "tensorflow/c/kernels.h", "tensorflow/c/ops.h", @@ -333,11 +334,6 @@ public void map(InfoMap infoMap) { "static TF_Operation\\* TF_FinishOperationLocked\\(TF_OperationDescription\\* desc,", "\\}")) .put(new Info("OutputTensor", "TensorId", "tensorflow::AttrValue").skip()) - .put( - new Info("c_api_experimental.h") - .linePatterns( - "typedef struct TFE_OpAttrs TFE_OpAttrs;", - "#define TFE_CUSTOM_DEVICE_VERSION 4")) .put( new Info("TF_CAPI_EXPORT", "TF_Bool", "TF_GUARDED_BY", "TF_MUST_USE_RESULT") .cppTypes()