Skip to content

Commit

Permalink
* Add more samples for TensorFlow including a complete training exam…
Browse files Browse the repository at this point in the history
…ple (pull #563)
  • Loading branch information
Neiko2002 authored and saudet committed May 26, 2018
1 parent c67748c commit 07362ae
Show file tree
Hide file tree
Showing 9 changed files with 9,623 additions and 48 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

* Add more samples for TensorFlow including a complete training example ([pull #563](https://github.com/bytedeco/javacpp-presets/pull/563))
* Add helper for `PIX`, `FPIX`, and `DPIX` of Leptonica, facilitating access to image data of Tesseract ([issue #517](https://github.com/bytedeco/javacpp-presets/issues/517))
* Add presets for the NVBLAS, NVGRAPH, NVRTC, and NVML modules of CUDA ([issue deeplearning4j/nd4j#2895](https://github.com/deeplearning4j/nd4j/issues/2895))
* Link OpenBLAS with `-Wl,-z,noexecstack` on `linux-armhf` as required by the JDK ([issue deeplearning4j/libnd4j#700](https://github.com/deeplearning4j/libnd4j/issues/700))
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/samples/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>org.bytedeco.javacpp-presets</groupId>
<artifactId>tensorflow-samples</artifactId>
<version>1.8.0-1.4.2-SNAPSHOT</version>
<name>JavaCPP Presets Samples for TensorFlow</name>

<properties>
<maven.compiler.target>1.7</maven.compiler.target>
<maven.compiler.source>1.7</maven.compiler.source>
</properties>

<dependencies>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tensorflow-platform</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.bytedeco.javacpp.samples.tensorflow;

import static org.bytedeco.javacpp.tensorflow.Const;
import static org.bytedeco.javacpp.tensorflow.InitMain;
import static org.bytedeco.javacpp.tensorflow.TF_CHECK_OK;

import java.nio.IntBuffer;

import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.tensorflow.AddN;
import org.bytedeco.javacpp.tensorflow.GraphDef;
import org.bytedeco.javacpp.tensorflow.Input;
import org.bytedeco.javacpp.tensorflow.InputList;
import org.bytedeco.javacpp.tensorflow.L2Loss;
import org.bytedeco.javacpp.tensorflow.Output;
import org.bytedeco.javacpp.tensorflow.OutputVector;
import org.bytedeco.javacpp.tensorflow.Scope;
import org.bytedeco.javacpp.tensorflow.Session;
import org.bytedeco.javacpp.tensorflow.SessionOptions;
import org.bytedeco.javacpp.tensorflow.StringTensorPairVector;
import org.bytedeco.javacpp.tensorflow.StringVector;
import org.bytedeco.javacpp.tensorflow.Tensor;
import org.bytedeco.javacpp.tensorflow.TensorShape;
import org.bytedeco.javacpp.tensorflow.TensorVector;

/**
* Showcase the usage of OutputVector and the AddN operator.
*
* @author Nico Hezel
*/
public class AddNExample {

public static void main(String[] args) {

// Load all javacpp-preset classes and native libraries
Loader.load(org.bytedeco.javacpp.tensorflow.class);

// Platform-specific initialization routine
InitMain("trainer", (int[])null, null);

// Create a new empty graph
Scope scope = Scope.NewRootScope();

// (2,1) matrix of ones, sixes and tens
TensorShape shape = new TensorShape(2, 1);
Output ones = Const(scope.WithOpName("ones"), 1, shape);
Output sixes = Const(scope.WithOpName("sixes"), 6, shape);
Output tens = Const(scope.WithOpName("tens"), 10, shape);

// Adding all matrices element-wise
OutputVector ov = new OutputVector(ones, sixes, tens);
InputList inputList = new InputList(ov);
AddN add = new AddN(scope.WithOpName("add"), inputList);

// Build a graph definition object
GraphDef def = new GraphDef();
TF_CHECK_OK(scope.ToGraphDef(def));

// Creates a session.
SessionOptions options = new SessionOptions();
try(final Session session = new Session(options)) {

// Create the graph to be used for the session.
TF_CHECK_OK(session.Create(def));

// Input and output of a single session run.
StringTensorPairVector input_feed = new StringTensorPairVector();
StringVector output_tensor_name = new StringVector("add:0");
StringVector target_tensor_name = new StringVector();
TensorVector outputs = new TensorVector();

// Run the session once
TF_CHECK_OK(session.Run(input_feed, output_tensor_name, target_tensor_name, outputs));

// Print the add-output
for (Tensor output : outputs.get()) {
IntBuffer y_flat = output.createBuffer();
for (int i = 0; i < output.NumElements(); i++)
System.out.println(y_flat.get(i));
}
}
}
}
Loading

0 comments on commit 07362ae

Please sign in to comment.