diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java new file mode 100644 index 00000000000..5fce449ca96 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java @@ -0,0 +1,134 @@ +package org.tensorflow.framework.layers; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.tools.Shape; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +public abstract class Layer extends Module implements LayerFunction { + private final boolean trainable; + private final boolean dynamic; + private final DataType dtype; + + public List> inboundNodes; + public List> outboundNodes; + + protected boolean built; + + public Layer(Ops tf, String name, boolean trainable, boolean dynamic, DataType dtype) { + super(tf, name, dtype); + this.trainable = trainable; + this.dynamic = dynamic; + this.dtype = dtype; + } + + /** + * Builds this layer (add layer weights) NOTE: This method MUST set `built` to true + * + *

{@code this.built = true} + */ + public abstract void build(List inputShapes); + + public abstract List computeOutputShapes(List inputShapes); + + protected abstract List> call(List> inputs); + + @SafeVarargs + public final List> apply(Operand... inputs) { + return apply(Arrays.asList(inputs)); + } + + @Override + public final List> apply(List> inputs) { + if (!isBuilt()) throw new IllegalStateException("Cannot call a layer until it is built."); + + if (isDynamic() && tf.scope().env().isGraph()) + throw new IllegalStateException("Dynamic layers can only be used " + "in eager mode."); + + List expectedOutputShapes = computeOutputShapes(getShapes(inputs)); + List> outputs = call(inputs); + + for (int i = 0; i < inputs.size(); i++) { + if (expectedOutputShapes.get(i) != outputs.get(i).asOutput().shape()) { + throw new IllegalStateException( + "Shape " + + outputs.get(i).asOutput().shape() + + " at output " + + i + + "does not " + + "match expected shape " + + expectedOutputShapes.get(i)); + } + } + + return outputs; + } + + @Override + public Iterable> getDirectSubmodules() { + return Collections::emptyIterator; + } + + /** + * Returns a list of all trainable and non-trainable weights (in that order) + * + * @return all the weights of this layer (concatenation of getTrainableWeights() and + * getNonTrainableWeights()) + */ + public List> getWeights() { + List> weights = getTrainableWeights(); + weights.addAll(getNonTrainableWeights()); + return weights; + } + + /** + * List of variables to be included in backpropagation + * + * @return all trainable weights of this layer + */ + public List> getTrainableWeights() { + return getModuleWeights().stream() + .filter(w -> w.trainable) + .map(w -> w.variable) + .collect(Collectors.toList()); + } + + /** + * List of variables to be excluded from backpropagation + * + * @return all non-trainable weights of this layer + */ + public List> getNonTrainableWeights() { + return getModuleWeights().stream() + .filter(w -> !w.trainable) + .map(w -> w.variable) + .collect(Collectors.toList()); + } + + private List getShapes(List> operands) { + return operands.stream().map(op -> op.asOutput().shape()).collect(Collectors.toList()); + } + + public boolean isTrainable() { + return trainable; + } + + public boolean isDynamic() { + return dynamic; + } + + public boolean isBuilt() { + return built; + } + + public DataType getDtype() { + return dtype; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LayerFunction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LayerFunction.java new file mode 100644 index 00000000000..94d5dc4bf59 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LayerFunction.java @@ -0,0 +1,13 @@ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.types.family.TType; + +import java.util.List; +import java.util.function.Function; + +@FunctionalInterface +public interface LayerFunction + extends Function>, List>> { + List> apply(List> inputs); +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Module.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Module.java new file mode 100644 index 00000000000..4d8938b22e5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Module.java @@ -0,0 +1,66 @@ +package org.tensorflow.framework.layers; + +import org.tensorflow.DataType; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.tools.Shape; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +public abstract class Module { + protected final Ops tf; + private final String name; + private List> weights; + + public Module(Ops tf, String name, DataType dtype) { + this.tf = tf.withName(name); + this.name = name; + this.weights = new LinkedList<>(); + } + + public abstract Iterable> getDirectSubmodules(); + + public Iterable> getSubmodules(boolean recurse) { + if (!recurse) return getDirectSubmodules(); + + List> submodules = new ArrayList<>(); + + for (Module module : getDirectSubmodules()) + module.getSubmodules(true).forEach(submodules::add); + + return submodules; + } + + public Variable addWeight(String name, boolean trainable, Shape shape, DataType dtype) { + ModuleVariable moduleVariable = new ModuleVariable<>(name, tf.variable(shape, dtype), trainable); + this.weights.add(moduleVariable); + + return moduleVariable.variable; + } + + List> getModuleWeights() { + return StreamSupport.stream(getSubmodules(true).spliterator(), false) + .flatMap(module -> module.weights.stream()) + .collect(Collectors.toList()); + } + + public String getName() { + return name; + } +} + +class ModuleVariable { + String name; + boolean trainable; + Variable variable; + + public ModuleVariable(String name, Variable variable, boolean trainable) { + this.trainable = trainable; + this.variable = variable; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Node.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Node.java new file mode 100644 index 00000000000..9c8f0a0d7ef --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Node.java @@ -0,0 +1,28 @@ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.types.family.TType; + +import java.util.List; + +public class Node { + /** The Layer that takes input tensors and turns them into output tensors */ + private Layer outboundLayer; + + /** The layers from which input tensors originate */ + private List> inboundLayers; + + /** + * A list of integers, the same length as `inboundLayers`. `nodeIndices[i]` is the origin of + * inputTensors[i] + */ + private List nodeIndices; + + private Layer layer; + private List> outputs; + + public Node(Layer layer, List> outputs) { + this.layer = layer; + this.outputs = outputs; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Sequential.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Sequential.java new file mode 100644 index 00000000000..0bdcc48cff3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Sequential.java @@ -0,0 +1,55 @@ +package org.tensorflow.framework.layers; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.tools.Shape; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.List; + +public class Sequential extends Layer { + private final List> layers; + + @SafeVarargs + public Sequential(Ops tf, DataType dtype, Layer... layers) { + super(tf, "Sequential", true, true, dtype); + this.layers = Arrays.asList(layers); + } + + @Override + public void build(List inputShapes) { + List shapes = inputShapes; + + for (Layer layer : layers) { + layer.build(shapes); + shapes = layer.computeOutputShapes(shapes); + } + } + + @Override + public List computeOutputShapes(List inputShapes) { + List shapes = inputShapes; + for (Layer layer : layers) { + shapes = layer.computeOutputShapes(shapes); + } + + return shapes; + } + + @Override + protected List> call(List> inputs) { + List> outputs = inputs; + for (Layer layer : layers) { + outputs = layer.call(inputs); + } + + return outputs; + } + + @Override + public Iterable> getDirectSubmodules() { + return (List>) (List) layers; + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LayerTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LayerTest.java new file mode 100644 index 00000000000..127b10920ef --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LayerTest.java @@ -0,0 +1,46 @@ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.tools.Shape; +import org.tensorflow.types.TFloat32; + +import java.util.Collections; +import java.util.List; + +class Dense extends Layer { + private final int units; + public static Variable kernel; + public static Variable bias; + + public Dense(Ops tf, int units) { + super(tf, "dense", 1, true, false, TFloat32.DTYPE); + this.units = units; + } + + @Override + public void build(List inputShapes) { + kernel = addWeight("KERNEL", true, inputShapes.get(0), TFloat32.DTYPE); + bias = addWeight("BIAS", true, inputShapes.get(0), TFloat32.DTYPE); + this.built = true; + } + + @Override + public List computeOutputShapes(List inputShapes) { + return Collections.singletonList(inputShapes.get(0).replaceLast(units)); + } + + @Override + public List> call(List> inputs) { + Operand input = inputs.get(0); + return Collections.singletonList(tf.math.add(tf.linalg.matMul(input, kernel), bias)); + } + + @Override + public Iterable> getDirectSubmodules() { + return Collections::emptyIterator; + } +} + +public class LayerTest {} diff --git a/tensorflow-tools/src/main/java/org/tensorflow/tools/Shape.java b/tensorflow-tools/src/main/java/org/tensorflow/tools/Shape.java index be2a62a8b1e..3e031499597 100644 --- a/tensorflow-tools/src/main/java/org/tensorflow/tools/Shape.java +++ b/tensorflow-tools/src/main/java/org/tensorflow/tools/Shape.java @@ -148,6 +148,30 @@ public Shape prepend(long firstDimension) { return Shape.of(newDimensions); } + public boolean isKnown(int i) { + return dimensionSizes[i] != -1; + } + + public void assertKnown(int i) { + if (!isKnown(i)) { + throw new IllegalStateException("Dimension " + i + " in shape needs to be known."); + } + } + + public Shape replaceFirst(long dim) { + return replace(0, dim); + } + + public Shape replaceLast(long dim) { + return replace(dimensionSizes.length - 1, dim); + } + + public Shape replace(int i, long dim) { + Shape newShape = new Shape(Arrays.copyOf(dimensionSizes, dimensionSizes.length)); + newShape.dimensionSizes[i] = dim; + return newShape; + } + private static long computeSize(long[] dimensionSizes) { if (dimensionSizes == null) { return UNKNOWN_SIZE;