From 06c1b7687b3788652c48334d0c08e69f5f2b594e Mon Sep 17 00:00:00 2001 From: Samuel Audet Date: Thu, 16 May 2019 20:03:10 +0900 Subject: [PATCH] * Overload `Tensor.create()` factory methods for TensorFlow with handy `long... shape` (issue bytedeco/javacpp#301) --- CHANGELOG.md | 1 + .../org/bytedeco/tensorflow/CollectiveContext.java | 4 ++-- .../bytedeco/tensorflow/FunctionLibraryRuntime.java | 2 +- .../java/org/bytedeco/tensorflow/OpKernelContext.java | 10 +++++----- .../gen/java/org/bytedeco/tensorflow/SessionState.java | 2 +- .../bytedeco/tensorflow/TF_ImportGraphDefResults.java | 2 +- .../java/org/bytedeco/tensorflow/TF_WhileParams.java | 2 +- .../java/org/bytedeco/tensorflow/TensorSliceSet.java | 2 +- .../java/org/bytedeco/tensorflow/AbstractTensor.java | 8 ++++++++ 9 files changed, 21 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bc60c939ec..8f9d9b0c2ec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ + * Overload `Tensor.create()` factory methods for TensorFlow with handy `long... shape` ([issue bytedeco/javacpp#301](https://github.com/bytedeco/javacpp/issues/301)) * Add build for `linux-arm64` to presets for OpenBLAS ([pull #726](https://github.com/bytedeco/javacpp-presets/pull/726)) * Bundle complete binary packages of CPython itself for convenience ([issue #712](https://github.com/bytedeco/javacpp-presets/issues/712)) * Fix and refine mapping for `HoughLines`, `HoughLinesP`, and `HoughCircles` ([issue #717](https://github.com/bytedeco/javacpp-presets/issues/717)) diff --git a/tensorflow/src/gen/java/org/bytedeco/tensorflow/CollectiveContext.java b/tensorflow/src/gen/java/org/bytedeco/tensorflow/CollectiveContext.java index 0608c02f928..26c13a5c5ec 100644 --- a/tensorflow/src/gen/java/org/bytedeco/tensorflow/CollectiveContext.java +++ b/tensorflow/src/gen/java/org/bytedeco/tensorflow/CollectiveContext.java @@ -34,13 +34,13 @@ private native void allocate(CollectiveExecutor col_exec, @Const DeviceMgr dev_m @Cast("tensorflow::int64") long step_id, @Const Tensor input, Tensor output); public native CollectiveExecutor col_exec(); public native CollectiveContext col_exec(CollectiveExecutor setter); // Not owned - @MemberGetter public native @Const DeviceMgr dev_mgr(); // Not owned + public native @Const DeviceMgr dev_mgr(); public native CollectiveContext dev_mgr(DeviceMgr setter); // Not owned public native OpKernelContext op_ctx(); public native CollectiveContext op_ctx(OpKernelContext setter); // Not owned public native OpKernelContext.Params op_params(); public native CollectiveContext op_params(OpKernelContext.Params setter); // Not owned @MemberGetter public native @Const @ByRef CollectiveParams col_params(); @MemberGetter public native @StdString BytePointer exec_key(); @MemberGetter public native @Cast("const tensorflow::int64") long step_id(); - @MemberGetter public native @Const Tensor input(); // Not owned + public native @Const Tensor input(); public native CollectiveContext input(Tensor setter); // Not owned public native Tensor output(); public native CollectiveContext output(Tensor setter); // Not owned public native Device device(); public native CollectiveContext device(Device setter); // The device for which this instance labors @MemberGetter public native @StdString BytePointer device_name(); diff --git a/tensorflow/src/gen/java/org/bytedeco/tensorflow/FunctionLibraryRuntime.java b/tensorflow/src/gen/java/org/bytedeco/tensorflow/FunctionLibraryRuntime.java index 74c107008b6..d9362ebf007 100644 --- a/tensorflow/src/gen/java/org/bytedeco/tensorflow/FunctionLibraryRuntime.java +++ b/tensorflow/src/gen/java/org/bytedeco/tensorflow/FunctionLibraryRuntime.java @@ -53,7 +53,7 @@ public static class InstantiateOptions extends Pointer { // between a set of libraries (e.g. by allowing a // `FunctionLibraryDefinition` to store an `outer_scope` pointer // and implementing name resolution across libraries). - @MemberGetter public native @Const FunctionLibraryDefinition overlay_lib(); + public native @Const FunctionLibraryDefinition overlay_lib(); public native InstantiateOptions overlay_lib(FunctionLibraryDefinition setter); // This interface is EXPERIMENTAL and subject to change. // diff --git a/tensorflow/src/gen/java/org/bytedeco/tensorflow/OpKernelContext.java b/tensorflow/src/gen/java/org/bytedeco/tensorflow/OpKernelContext.java index 2b81696eb84..a3476e28f54 100644 --- a/tensorflow/src/gen/java/org/bytedeco/tensorflow/OpKernelContext.java +++ b/tensorflow/src/gen/java/org/bytedeco/tensorflow/OpKernelContext.java @@ -66,7 +66,7 @@ public static class Params extends Pointer { public native @Cast("bool") boolean record_tensor_accesses(); public native Params record_tensor_accesses(boolean setter); // Array indexed by output number for this node - @MemberGetter public native @Const AllocatorAttributes output_attr_array(); + public native @Const AllocatorAttributes output_attr_array(); public native Params output_attr_array(AllocatorAttributes setter); // Shared resources accessible by this op kernel invocation. public native ResourceMgr resource_manager(); public native Params resource_manager(ResourceMgr setter); @@ -94,13 +94,13 @@ public static class Params extends Pointer { public native CancellationManager cancellation_manager(); public native Params cancellation_manager(CancellationManager setter); // Inputs to this op kernel. - @MemberGetter public native @Const TensorValueVector inputs(); + public native @Const TensorValueVector inputs(); public native Params inputs(TensorValueVector setter); public native @Cast("bool") boolean is_input_dead(); public native Params is_input_dead(boolean setter); - @MemberGetter public native @Const AllocatorAttributesVector input_alloc_attrs(); + public native @Const AllocatorAttributesVector input_alloc_attrs(); public native Params input_alloc_attrs(AllocatorAttributesVector setter); // Device contexts. - @MemberGetter public native @Const DeviceContextInlinedVector input_device_contexts(); + public native @Const DeviceContextInlinedVector input_device_contexts(); public native Params input_device_contexts(DeviceContextInlinedVector setter); public native DeviceContext op_device_context(); public native Params op_device_context(DeviceContext setter); // Control-flow op supports. @@ -122,7 +122,7 @@ public static class Params extends Pointer { @MemberGetter public static native int kNoReservation(); public static final int kNoReservation = kNoReservation(); // Values in [0,...) represent reservations for the indexed output. - @MemberGetter public native @Const IntPointer forward_from_array(); + public native @Const IntPointer forward_from_array(); public native Params forward_from_array(IntPointer setter); } // params must outlive the OpKernelContext. diff --git a/tensorflow/src/gen/java/org/bytedeco/tensorflow/SessionState.java b/tensorflow/src/gen/java/org/bytedeco/tensorflow/SessionState.java index ce4c6e88aa8..954cb317bc4 100644 --- a/tensorflow/src/gen/java/org/bytedeco/tensorflow/SessionState.java +++ b/tensorflow/src/gen/java/org/bytedeco/tensorflow/SessionState.java @@ -41,5 +41,5 @@ public class SessionState extends Pointer { public native @Cast("tensorflow::int64") long GetNewId(); - @MemberGetter public static native @Cast("const char*") BytePointer kTensorHandleResourceTypeName(); + public static native @Cast("const char*") BytePointer kTensorHandleResourceTypeName(); public static native void kTensorHandleResourceTypeName(BytePointer setter); } diff --git a/tensorflow/src/gen/java/org/bytedeco/tensorflow/TF_ImportGraphDefResults.java b/tensorflow/src/gen/java/org/bytedeco/tensorflow/TF_ImportGraphDefResults.java index c5157327a14..b64bcd4ce95 100644 --- a/tensorflow/src/gen/java/org/bytedeco/tensorflow/TF_ImportGraphDefResults.java +++ b/tensorflow/src/gen/java/org/bytedeco/tensorflow/TF_ImportGraphDefResults.java @@ -27,7 +27,7 @@ public class TF_ImportGraphDefResults extends Pointer { public native @StdVector TF_Output return_tensors(); public native TF_ImportGraphDefResults return_tensors(TF_Output setter); public native @Cast("TF_Operation**") @StdVector PointerPointer return_nodes(); public native TF_ImportGraphDefResults return_nodes(PointerPointer setter); - @MemberGetter public native @Cast("const char**") @StdVector PointerPointer missing_unused_key_names(); + public native @Cast("const char**") @StdVector PointerPointer missing_unused_key_names(); public native TF_ImportGraphDefResults missing_unused_key_names(PointerPointer setter); public native @StdVector IntPointer missing_unused_key_indexes(); public native TF_ImportGraphDefResults missing_unused_key_indexes(IntPointer setter); // Backing memory for missing_unused_key_names values. diff --git a/tensorflow/src/gen/java/org/bytedeco/tensorflow/TF_WhileParams.java b/tensorflow/src/gen/java/org/bytedeco/tensorflow/TF_WhileParams.java index c71d33a6b56..3226f56d594 100644 --- a/tensorflow/src/gen/java/org/bytedeco/tensorflow/TF_WhileParams.java +++ b/tensorflow/src/gen/java/org/bytedeco/tensorflow/TF_WhileParams.java @@ -34,5 +34,5 @@ public class TF_WhileParams extends Pointer { // Unique null-terminated name for this while loop. This is used as a prefix // for created operations. - @MemberGetter public native @Cast("const char*") BytePointer name(); + public native @Cast("const char*") BytePointer name(); public native TF_WhileParams name(BytePointer setter); } diff --git a/tensorflow/src/gen/java/org/bytedeco/tensorflow/TensorSliceSet.java b/tensorflow/src/gen/java/org/bytedeco/tensorflow/TensorSliceSet.java index 113193aade2..3d7acc864ac 100644 --- a/tensorflow/src/gen/java/org/bytedeco/tensorflow/TensorSliceSet.java +++ b/tensorflow/src/gen/java/org/bytedeco/tensorflow/TensorSliceSet.java @@ -74,7 +74,7 @@ public static class SliceInfo extends Pointer { public native @ByRef TensorSlice slice(); public native SliceInfo slice(TensorSlice setter); public native @StdString BytePointer tag(); public native SliceInfo tag(BytePointer setter); - @MemberGetter public native @Const FloatPointer data(); + public native @Const FloatPointer data(); public native SliceInfo data(FloatPointer setter); public native @Cast("tensorflow::int64") long num_floats(); public native SliceInfo num_floats(long setter); } diff --git a/tensorflow/src/main/java/org/bytedeco/tensorflow/AbstractTensor.java b/tensorflow/src/main/java/org/bytedeco/tensorflow/AbstractTensor.java index 069b8196814..90bbb03745c 100644 --- a/tensorflow/src/main/java/org/bytedeco/tensorflow/AbstractTensor.java +++ b/tensorflow/src/main/java/org/bytedeco/tensorflow/AbstractTensor.java @@ -34,6 +34,14 @@ public abstract class AbstractTensor extends Pointer implements Indexable { static { Loader.load(); } public AbstractTensor(Pointer p) { super(p); } + public static Tensor create(float[] data, long... shape) { return create(data, new TensorShape(shape)); } + public static Tensor create(double[] data, long... shape) { return create(data, new TensorShape(shape)); } + public static Tensor create(int[] data, long... shape) { return create(data, new TensorShape(shape)); } + public static Tensor create(short[] data, long... shape) { return create(data, new TensorShape(shape)); } + public static Tensor create(byte[] data, long... shape) { return create(data, new TensorShape(shape)); } + public static Tensor create(long[] data, long... shape) { return create(data, new TensorShape(shape)); } + public static Tensor create(String[] data, long... shape) { return create(data, new TensorShape(shape)); } + public static Tensor create(float[] data, TensorShape shape) { Tensor t = new Tensor(DT_FLOAT, shape); FloatBuffer b = t.createBuffer(); b.put(data); return t; } public static Tensor create(double[] data, TensorShape shape) { Tensor t = new Tensor(DT_DOUBLE, shape); DoubleBuffer b = t.createBuffer(); b.put(data); return t; } public static Tensor create(int[] data, TensorShape shape) { Tensor t = new Tensor(DT_INT32, shape); IntBuffer b = t.createBuffer(); b.put(data); return t; }