Skip to content

Commit 037a2ca

Browse files
committed
Add support for Tensor NIO
1 parent 02d0403 commit 037a2ca

File tree

1,131 files changed

+25181
-10149
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,131 files changed

+25181
-10149
lines changed

pom.xml

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,40 @@
3030

3131
<modules>
3232
<module>tensorflow-core</module>
33-
<!--module>tensorflow-utils</module-->
33+
<module>tensorflow-utils</module>
3434
<!--module>tensorflow-frameworks</module> TODO -->
3535
<!--module>tensorflow-starters</module> TODO -->
3636
</modules>
3737

38+
<properties>
39+
<maven.compiler.source>1.8</maven.compiler.source>
40+
<maven.compiler.target>1.8</maven.compiler.target>
41+
<junit.version>4.12</junit.version>
42+
<jmh.version>1.21</jmh.version>
43+
</properties>
44+
45+
<dependencyManagement>
46+
<dependencies>
47+
<dependency>
48+
<groupId>junit</groupId>
49+
<artifactId>junit</artifactId>
50+
<version>${junit.version}</version>
51+
</dependency>
52+
<dependency>
53+
<groupId>org.openjdk.jmh</groupId>
54+
<artifactId>jmh-core</artifactId>
55+
<version>${jmh.version}</version>
56+
<scope>test</scope>
57+
</dependency>
58+
<dependency>
59+
<groupId>org.openjdk.jmh</groupId>
60+
<artifactId>jmh-generator-annprocess</artifactId>
61+
<version>${jmh.version}</version>
62+
<scope>test</scope>
63+
</dependency>
64+
</dependencies>
65+
</dependencyManagement>
66+
3867
<!-- Two profiles are used:
3968
ossrh - deploys to ossrh/maven central
4069
bintray - deploys to bintray/jcenter. -->
@@ -64,6 +93,7 @@
6493
</distributionManagement>
6594
</profile>
6695
</profiles>
96+
6797
<!-- http://central.sonatype.org/pages/requirements.html#developer-information -->
6898
<developers>
6999
<developer>
@@ -72,6 +102,7 @@
72102
<organizationUrl>http://www.tensorflow.org</organizationUrl>
73103
</developer>
74104
</developers>
105+
75106
<build>
76107
<plugins>
77108
<!-- GPG signed components: http://central.sonatype.org/pages/apache-maven.html#gpg-signed-components -->

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
<version>${project.version}</version>
3232
<optional>true</optional> <!-- for compilation only -->
3333
</dependency>
34+
<dependency>
35+
<groupId>org.tensorflow</groupId>
36+
<artifactId>nio-utils</artifactId>
37+
<version>${project.version}</version>
38+
</dependency>
3439
<dependency>
3540
<groupId>junit</groupId>
3641
<artifactId>junit</artifactId>
@@ -76,10 +81,6 @@
7681
<plugin>
7782
<artifactId>maven-compiler-plugin</artifactId>
7883
<version>3.8.0</version>
79-
<configuration>
80-
<source>1.7</source>
81-
<target>1.7</target>
82-
</configuration>
8384
<executions>
8485
<execution>
8586
<id>default-compile</id>
@@ -99,7 +100,7 @@
99100
</goals>
100101
<configuration>
101102
<includes>
102-
<include>org/tensorflow/c_api/presets/*.java</include>
103+
<include>org/tensorflow/internal/c_api/presets/*.java</include>
103104
</includes>
104105
</configuration>
105106
</execution>
@@ -150,6 +151,7 @@
150151
<compilerOption>${project.basedir}/src/main/native/server_jni.cc</compilerOption>
151152
<compilerOption>${project.basedir}/src/main/native/session_jni.cc</compilerOption>
152153
<compilerOption>${project.basedir}/src/main/native/tensorflow_jni.cc</compilerOption>
154+
<compilerOption>${project.basedir}/src/main/native/tensor_buffers_jni.cc</compilerOption>
153155
<compilerOption>${project.basedir}/src/main/native/tensor_jni.cc</compilerOption>
154156
<compilerOption>${project.basedir}/src/main/native/utils_jni.cc</compilerOption>
155157
</compilerOptions>
@@ -201,7 +203,7 @@
201203
<configuration>
202204
<skip>${javacpp.parser.skip}</skip>
203205
<outputDirectory>${project.basedir}/src/gen/java</outputDirectory>
204-
<classOrPackageName>org.tensorflow.c_api.presets.*</classOrPackageName>
206+
<classOrPackageName>org.tensorflow.internal.c_api.presets.*</classOrPackageName>
205207
</configuration>
206208
</execution>
207209
<execution>
@@ -211,9 +213,9 @@
211213
<goal>build</goal>
212214
</goals>
213215
<configuration>
214-
<outputDirectory>${project.build.directory}/native/org/tensorflow/c_api/${javacpp.platform}${javacpp.platform.extension}/</outputDirectory>
216+
<outputDirectory>${project.build.directory}/native/org/tensorflow/internal/c_api/${javacpp.platform}${javacpp.platform.extension}/</outputDirectory>
215217
<skip>${javacpp.compiler.skip}</skip>
216-
<classOrPackageName>org.tensorflow.c_api.**</classOrPackageName>
218+
<classOrPackageName>org.tensorflow.internal.c_api.**</classOrPackageName>
217219
<copyLibs>true</copyLibs>
218220
<copyResources>true</copyResources>
219221
</configuration>
@@ -224,6 +226,9 @@
224226
<artifactId>maven-surefire-plugin</artifactId>
225227
<version>2.22.0</version>
226228
<configuration>
229+
<argLine>
230+
-Djava.library.path=${project.build.directory}/native/org/tensorflow/internal/c_api/${javacpp.platform}${javacpp.platform.extension}
231+
</argLine>
227232
<additionalClasspathElements>${project.build.directory}/native/</additionalClasspathElements>
228233
</configuration>
229234
</plugin>
@@ -256,16 +261,16 @@
256261
<!-- In case of successive builds for multiple platforms
257262
without cleaning, ensures we only include files for
258263
this platform. -->
259-
<include>org/tensorflow/c_api/${javacpp.platform}${javacpp.platform.extension}/</include>
264+
<include>org/tensorflow/internal/c_api/${javacpp.platform}${javacpp.platform.extension}/</include>
260265
</includes>
261266
<classesDirectory>${project.build.directory}/native</classesDirectory>
262267
<excludes>
263-
<exclude>org/tensorflow/c_api/${javacpp.platform}${javacpp.platform.extension}/*.exp</exclude>
264-
<exclude>org/tensorflow/c_api/${javacpp.platform}${javacpp.platform.extension}/*.lib</exclude>
265-
<exclude>org/tensorflow/c_api/${javacpp.platform}${javacpp.platform.extension}/*.obj</exclude>
266-
<exclude>org/tensorflow/c_api/${javacpp.platform}${javacpp.platform.extension}/*mklml*</exclude>
267-
<exclude>org/tensorflow/c_api/${javacpp.platform}${javacpp.platform.extension}/*iomp5*</exclude>
268-
<exclude>org/tensorflow/c_api/${javacpp.platform}${javacpp.platform.extension}/*msvcr120*</exclude>
268+
<exclude>org/tensorflow/internal/c_api/${javacpp.platform}${javacpp.platform.extension}/*.exp</exclude>
269+
<exclude>org/tensorflow/internal/c_api/${javacpp.platform}${javacpp.platform.extension}/*.lib</exclude>
270+
<exclude>org/tensorflow/internal/c_api/${javacpp.platform}${javacpp.platform.extension}/*.obj</exclude>
271+
<exclude>org/tensorflow/internal/c_api/${javacpp.platform}${javacpp.platform.extension}/*mklml*</exclude>
272+
<exclude>org/tensorflow/internal/c_api/${javacpp.platform}${javacpp.platform.extension}/*iomp5*</exclude>
273+
<exclude>org/tensorflow/internal/c_api/${javacpp.platform}${javacpp.platform.extension}/*msvcr120*</exclude>
269274
</excludes>
270275
</configuration>
271276
</execution>
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
op {
2+
graph_op_name: "KafkaDataset"
3+
endpoint {
4+
name: "data.KafkaDataset"
5+
}
6+
}

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/java_defs.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,22 +97,25 @@ class Type {
9797
static Type IterableOf(const Type& type) {
9898
return Interface("Iterable").add_parameter(type);
9999
}
100+
static Type DataTypeOf(const Type& type) {
101+
return Class("DataType", "org.tensorflow").add_parameter(type);
102+
}
100103
static Type ForDataType(DataType data_type) {
101104
switch (data_type) {
102105
case DataType::DT_BOOL:
103-
return Class("Boolean");
106+
return Class("TBool", "org.tensorflow.types");
104107
case DataType::DT_STRING:
105-
return Class("String");
108+
return Class("TString", "org.tensorflow.types");
106109
case DataType::DT_FLOAT:
107-
return Class("Float");
110+
return Class("TFloat", "org.tensorflow.types");
108111
case DataType::DT_DOUBLE:
109-
return Class("Double");
112+
return Class("TDouble", "org.tensorflow.types");
110113
case DataType::DT_UINT8:
111-
return Class("UInt8", "org.tensorflow.types");
114+
return Class("TUInt8", "org.tensorflow.types");
112115
case DataType::DT_INT32:
113-
return Class("Integer");
116+
return Class("TInt32", "org.tensorflow.types");
114117
case DataType::DT_INT64:
115-
return Class("Long");
118+
return Class("TInt64", "org.tensorflow.types");
116119
case DataType::DT_RESOURCE:
117120
// TODO(karllessard) create a Resource utility class that could be
118121
// used to store a resource and its type (passed in a second argument).

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -124,24 +124,16 @@ void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
124124
.EndLine()
125125
.BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
126126
.Append(array_name + "[i] = ");
127-
if (attr.type().kind() == Type::GENERIC) {
128-
writer->Append("DataType.fromClass(" + var_name + ".get(i));");
129-
} else {
130-
writer->Append(var_name + ".get(i);");
131-
}
127+
writer->Append(var_name + ".get(i);");
132128
writer->EndLine()
133129
.EndBlock()
134130
.Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
135131
.Append(array_name + ");")
136132
.EndLine();
137133
} else {
138-
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
139-
if (attr.var().type().name() == "Class") {
140-
writer->Append("DataType.fromClass(" + var_name + "));");
141-
} else {
142-
writer->Append(var_name + ");");
143-
}
144-
writer->EndLine();
134+
writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
135+
.Append(var_name + ");")
136+
.EndLine();
145137
}
146138
}
147139

@@ -179,7 +171,7 @@ void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
179171
if (attr.type().kind() == Type::GENERIC &&
180172
default_types.find(attr.type().name()) != default_types.end()) {
181173
factory_statement << default_types.at(attr.type().name()).name()
182-
<< ".class";
174+
<< ".DTYPE";
183175
} else {
184176
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
185177
factory_statement << attr.var().name();

tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_specs.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
148148
types = MakeTypePair(Type::Class("Boolean"), Type::Boolean());
149149

150150
} else if (attr_type == "shape") {
151-
types = MakeTypePair(Type::Class("Shape", "org.tensorflow"));
151+
types = MakeTypePair(Type::Class("Shape", "org.tensorflow.nio.nd"));
152152

153153
} else if (attr_type == "tensor") {
154154
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
@@ -157,7 +157,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
157157
} else if (attr_type == "type") {
158158
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
159159
if (IsRealNumbers(attr_def.allowed_values())) {
160-
type.add_supertype(Type::Class("Number"));
160+
type.add_supertype(Type::Class("TNumber", "org.tensorflow.types.family"));
161161
}
162162
types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
163163

@@ -305,7 +305,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
305305
bool iterable = false;
306306
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
307307
Type var_type = types.first.kind() == Type::GENERIC
308-
? Type::ClassOf(types.first)
308+
? Type::DataTypeOf(types.first)
309309
: types.first;
310310
if (iterable) {
311311
var_type = Type::ListOf(var_type);

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/AudioOps.java

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import org.tensorflow.op.audio.DecodeWav;
66
import org.tensorflow.op.audio.EncodeWav;
77
import org.tensorflow.op.audio.Mfcc;
8+
import org.tensorflow.types.TFloat;
9+
import org.tensorflow.types.TInt32;
10+
import org.tensorflow.types.TString;
811

912
/**
1013
* An API for building {@code audio} operations as {@link Op Op}s
@@ -18,33 +21,6 @@ public final class AudioOps {
1821
this.scope = scope;
1922
}
2023

21-
/**
22-
* Builds an {@link AudioSpectrogram} operation
23-
*
24-
* @param input Float representation of audio data.
25-
* @param windowSize How wide the input window is in samples. For the highest efficiency
26-
* @param stride How widely apart the center of adjacent sample windows should be.
27-
* @param options carries optional attributes values
28-
* @return a new instance of AudioSpectrogram
29-
* @see org.tensorflow.op.audio.AudioSpectrogram
30-
*/
31-
public AudioSpectrogram audioSpectrogram(Operand<Float> input, Long windowSize, Long stride,
32-
AudioSpectrogram.Options... options) {
33-
return AudioSpectrogram.create(scope, input, windowSize, stride, options);
34-
}
35-
36-
/**
37-
* Builds an {@link EncodeWav} operation
38-
*
39-
* @param audio 2-D with shape `[length, channels]`.
40-
* @param sampleRate Scalar containing the sample frequency.
41-
* @return a new instance of EncodeWav
42-
* @see org.tensorflow.op.audio.EncodeWav
43-
*/
44-
public EncodeWav encodeWav(Operand<Float> audio, Operand<Integer> sampleRate) {
45-
return EncodeWav.create(scope, audio, sampleRate);
46-
}
47-
4824
/**
4925
* Builds an {@link DecodeWav} operation
5026
*
@@ -53,7 +29,7 @@ public EncodeWav encodeWav(Operand<Float> audio, Operand<Integer> sampleRate) {
5329
* @return a new instance of DecodeWav
5430
* @see org.tensorflow.op.audio.DecodeWav
5531
*/
56-
public DecodeWav decodeWav(Operand<String> contents, DecodeWav.Options... options) {
32+
public DecodeWav decodeWav(Operand<TString> contents, DecodeWav.Options... options) {
5733
return DecodeWav.create(scope, contents, options);
5834
}
5935

@@ -66,8 +42,35 @@ public DecodeWav decodeWav(Operand<String> contents, DecodeWav.Options... option
6642
* @return a new instance of Mfcc
6743
* @see org.tensorflow.op.audio.Mfcc
6844
*/
69-
public Mfcc mfcc(Operand<Float> spectrogram, Operand<Integer> sampleRate,
45+
public Mfcc mfcc(Operand<TFloat> spectrogram, Operand<TInt32> sampleRate,
7046
Mfcc.Options... options) {
7147
return Mfcc.create(scope, spectrogram, sampleRate, options);
7248
}
49+
50+
/**
51+
* Builds an {@link EncodeWav} operation
52+
*
53+
* @param audio 2-D with shape `[length, channels]`.
54+
* @param sampleRate Scalar containing the sample frequency.
55+
* @return a new instance of EncodeWav
56+
* @see org.tensorflow.op.audio.EncodeWav
57+
*/
58+
public EncodeWav encodeWav(Operand<TFloat> audio, Operand<TInt32> sampleRate) {
59+
return EncodeWav.create(scope, audio, sampleRate);
60+
}
61+
62+
/**
63+
* Builds an {@link AudioSpectrogram} operation
64+
*
65+
* @param input Float representation of audio data.
66+
* @param windowSize How wide the input window is in samples. For the highest efficiency
67+
* @param stride How widely apart the center of adjacent sample windows should be.
68+
* @param options carries optional attributes values
69+
* @return a new instance of AudioSpectrogram
70+
* @see org.tensorflow.op.audio.AudioSpectrogram
71+
*/
72+
public AudioSpectrogram audioSpectrogram(Operand<TFloat> input, Long windowSize, Long stride,
73+
AudioSpectrogram.Options... options) {
74+
return AudioSpectrogram.create(scope, input, windowSize, stride, options);
75+
}
7376
}

0 commit comments

Comments
 (0)