Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyAdadelta;
import org.tensorflow.types.family.TType;
Expand Down Expand Up @@ -150,16 +151,16 @@ private <T extends TType> void createAdaDeltaSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> accumSlot = getSlot(variable, ACCUMULATOR).get();
Variable<T> accumUpdateSlot = getSlot(variable, ACCUMULATOR_UPDATE).get();
return tf.train.applyAdadelta(
return deps.train.applyAdadelta(
variable,
accumSlot,
accumUpdateSlot,
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
tf.dtypes.cast(tf.constant(rho), gradient.type()),
tf.dtypes.cast(tf.constant(epsilon), gradient.type()),
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
deps.dtypes.cast(deps.constant(rho), gradient.type()),
deps.dtypes.cast(deps.constant(epsilon), gradient.type()),
gradient,
ApplyAdadelta.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyAdagrad;
import org.tensorflow.types.family.TType;
Expand Down Expand Up @@ -140,10 +141,10 @@ private <T extends TType> void createAdaGradSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> slot = getSlot(variable, ACCUMULATOR).get();
return tf.train.applyAdagrad(
variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient, opts);
return deps.train.applyAdagrad(
variable, slot, deps.dtypes.cast(deps.constant(learningRate), gradient.type()), gradient, opts);
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.tensorflow.Output;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyAdagradDa;
import org.tensorflow.types.TInt64;
Expand Down Expand Up @@ -209,17 +210,17 @@ private <T extends TType> void createAdaGradDASlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> gradSlot = getSlot(variable, ACCUMULATOR).get();
Variable<T> gradSquaredSlot = getSlot(variable, SQUARED_ACCUMULATOR).get();
return tf.train.applyAdagradDa(
return deps.train.applyAdagradDa(
variable,
gradSlot,
gradSquaredSlot,
gradient,
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
tf.dtypes.cast(tf.constant(l1Strength), gradient.type()),
tf.dtypes.cast(tf.constant(l2Strength), gradient.type()),
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
deps.dtypes.cast(deps.constant(l1Strength), gradient.type()),
deps.dtypes.cast(deps.constant(l2Strength), gradient.type()),
globalStep,
ApplyAdagradDa.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.tensorflow.Output;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Endpoint;
import org.tensorflow.op.annotation.Operator;
Expand Down Expand Up @@ -223,19 +224,19 @@ private <T extends TType> void createAdamSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> firstMomentSlot = getSlot(variable, FIRST_MOMENT).get();
Variable<T> secondMomentSlot = getSlot(variable, SECOND_MOMENT).get();
return tf.train.applyAdam(
return deps.train.applyAdam(
variable,
firstMomentSlot,
secondMomentSlot,
tf.dtypes.cast(betaOnePower, gradient.type()),
tf.dtypes.cast(betaTwoPower, gradient.type()),
tf.dtypes.cast(learningRateConst, gradient.type()),
tf.dtypes.cast(betaOneConst, gradient.type()),
tf.dtypes.cast(betaTwoConst, gradient.type()),
tf.dtypes.cast(epsilonConst, gradient.type()),
deps.dtypes.cast(betaOnePower, gradient.type()),
deps.dtypes.cast(betaTwoPower, gradient.type()),
deps.dtypes.cast(learningRateConst, gradient.type()),
deps.dtypes.cast(betaOneConst, gradient.type()),
deps.dtypes.cast(betaTwoConst, gradient.type()),
deps.dtypes.cast(epsilonConst, gradient.type()),
gradient,
ApplyAdam.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.tensorflow.Output;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyAdaMax;
Expand Down Expand Up @@ -155,19 +156,19 @@ private <T extends TType> void createAdamaxSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> firstMomentSlot = getSlot(variable, FIRST_MOMENT).get();
Variable<T> secondMomentSlot = getSlot(variable, SECOND_MOMENT).get();
return ApplyAdaMax.create(
this.tf.scope(),
deps.scope(),
variable,
firstMomentSlot,
secondMomentSlot,
tf.dtypes.cast(betaOnePower, gradient.type()),
tf.dtypes.cast(learningRateConst, gradient.type()),
tf.dtypes.cast(betaOneConst, gradient.type()),
tf.dtypes.cast(betaTwoConst, gradient.type()),
tf.dtypes.cast(epsilonConst, gradient.type()),
deps.dtypes.cast(betaOnePower, gradient.type()),
deps.dtypes.cast(learningRateConst, gradient.type()),
deps.dtypes.cast(betaOneConst, gradient.type()),
deps.dtypes.cast(betaTwoConst, gradient.type()),
deps.dtypes.cast(epsilonConst, gradient.type()),
gradient,
ApplyAdaMax.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyFtrl;
import org.tensorflow.types.family.TType;
Expand Down Expand Up @@ -238,21 +239,21 @@ private <T extends TType> void createFtrlSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> accumSlot = getSlot(variable, ACCUMULATOR).get();
Variable<T> linearSlot = getSlot(variable, LINEAR_ACCUMULATOR).get();
ApplyFtrl.Options options = ApplyFtrl.useLocking(true);
return this.tf.train.applyFtrl(
return deps.train.applyFtrl(
variable,
accumSlot, // accum
linearSlot, // linear
gradient, // gradient
tf.dtypes.cast(tf.constant(learningRate), gradient.type()), // lr
tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.type()), // l1
tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.type()), // l2
tf.dtypes.cast(
tf.constant(l2ShrinkageRegularizationStrength), gradient.type()), // l2Shrinkage
tf.dtypes.cast(tf.constant(learningRatePower), gradient.type()), // lrPower
deps.dtypes.cast(deps.constant(learningRate), gradient.type()), // lr
deps.dtypes.cast(deps.constant(l1RegularizationStrength), gradient.type()), // l1
deps.dtypes.cast(deps.constant(l2RegularizationStrength), gradient.type()), // l2
deps.dtypes.cast(
deps.constant(l2ShrinkageRegularizationStrength), gradient.type()), // l2Shrinkage
deps.dtypes.cast(deps.constant(learningRatePower), gradient.type()), // lrPower
options);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.train.ApplyGradientDescent;
import org.tensorflow.types.family.TType;

Expand Down Expand Up @@ -65,10 +66,10 @@ public GradientDescent(Graph graph, String name, float learningRate) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
return tf.train.applyGradientDescent(
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
return deps.train.applyGradientDescent(
variable,
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
gradient,
ApplyGradientDescent.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.train.ApplyMomentum;
import org.tensorflow.types.family.TType;
Expand Down Expand Up @@ -130,14 +131,14 @@ private <T extends TType> void createMomentumSlot(Output<T> v) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Variable<T> slot = getSlot(variable, MOMENTUM).get();
return tf.train.applyMomentum(
return deps.train.applyMomentum(
variable,
slot,
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
deps.dtypes.cast(deps.constant(learningRate), gradient.type()),
gradient,
tf.dtypes.cast(tf.constant(momentum), gradient.type()),
deps.dtypes.cast(deps.constant(momentum), gradient.type()),
ApplyMomentum.useNesterov(useNesterov),
ApplyMomentum.useLocking(true));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.tensorflow.Output;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Variable;
Expand Down Expand Up @@ -224,53 +225,53 @@ protected Optional<Op> prepare(String scopeName) {

/** {@inheritDoc} */
@Override
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
protected <T extends TType> Op applyDense(Ops deps, Output<T> gradient, Output<T> variable) {
Class<T> type = gradient.type();
Variable<T> m = getSlot(variable, FIRST_MOMENT).get(); // first Moment
Variable<T> v = getSlot(variable, SECOND_MOMENT).get(); // Second Moment

// gPrime = grad / coefficients['oneMinusMScheduleNew']
Operand<T> gPrime = tf.math.div(gradient, tf.dtypes.cast(oneMinusMScheduleNew, type));
Operand<T> gPrime = deps.math.div(gradient, deps.dtypes.cast(oneMinusMScheduleNew, type));
// mT = (coefficients['beta_1_t'] * m + coefficients['one_minus_beta_1_t'] * grad)
Operand<T> mT =
tf.math.add(
tf.math.mul(tf.dtypes.cast(betaOneConst, type), m),
tf.math.mul(tf.dtypes.cast(oneMinusBeta1, type), gradient));
deps.math.add(
deps.math.mul(deps.dtypes.cast(betaOneConst, type), m),
deps.math.mul(deps.dtypes.cast(oneMinusBeta1, type), gradient));
// mT = state_ops.assign(m, mT, use_locking=self._use_locking)
// update m
mT = tf.assign(m, mT, Assign.useLocking(true));
mT = deps.assign(m, mT, Assign.useLocking(true));

// mTPrime = mT / coefficients['oneMinusMScheduleNext']
Operand<T> mTPrime = tf.math.div(mT, tf.dtypes.cast(oneMinusMScheduleNext, type));
Operand<T> mTPrime = deps.math.div(mT, deps.dtypes.cast(oneMinusMScheduleNext, type));

// vT = (coefficients['beta_2_t'] * v + coefficients['one_minus_beta_2_t'] *
// math_ops.square(grad))
Operand<T> vT =
tf.math.add(
tf.math.mul(tf.dtypes.cast(betaTwoConst, type), v),
tf.math.mul(tf.dtypes.cast(oneMinusBeta2, type), tf.math.square(gradient)));
deps.math.add(
deps.math.mul(deps.dtypes.cast(betaTwoConst, type), v),
deps.math.mul(deps.dtypes.cast(oneMinusBeta2, type), deps.math.square(gradient)));
// vT = state_ops.assign(v, vT, use_locking=self._use_locking)
// update v
vT = tf.assign(v, vT, Assign.useLocking(true));
vT = deps.assign(v, vT, Assign.useLocking(true));

// vTPrime = vT / coefficients['vTPrimeDenominator']
Operand<T> vTPrime = tf.math.div(vT, tf.dtypes.cast(vTPrimeDenominator, type));
Operand<T> vTPrime = deps.math.div(vT, deps.dtypes.cast(vTPrimeDenominator, type));

// m_t_bar = (coefficients['oneMinusMT'] * gPrime + coefficients['mT1'] * mTPrime)
Operand<T> m_t_bar =
tf.math.add(
tf.math.mul(tf.dtypes.cast(oneMinusMT, type), gPrime),
tf.math.mul(tf.dtypes.cast(mT1, type), mTPrime));
deps.math.add(
deps.math.mul(deps.dtypes.cast(oneMinusMT, type), gPrime),
deps.math.mul(deps.dtypes.cast(mT1, type), mTPrime));
// varT = var - coefficients['lr_t'] * m_t_bar / (math_ops.sqrt(vTPrime) +
// coefficients['epsilon'])
Operand<T> varT =
tf.math.sub(
deps.math.sub(
variable,
tf.math.div(
tf.math.mul(tf.dtypes.cast(learningRateConst, type), m_t_bar),
tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, type))));
deps.math.div(
deps.math.mul(deps.dtypes.cast(learningRateConst, type), m_t_bar),
deps.math.add(deps.math.sqrt(vTPrime), deps.dtypes.cast(epsilonConst, type))));

return tf.assign(variable, varT, Assign.useLocking(true));
return deps.assign(variable, varT, Assign.useLocking(true));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,16 @@ public Op applyGradients(List<GradAndVar<? extends TType>> gradsAndVars, String
gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList());

createSlots(variables);
List<Op> gradients = gradsAndVars.stream().map(GradAndVar::getGradient).filter(g -> !g.isClosed()).collect(Collectors.toList());
Ops tfOpsGrads = tf.withControlDependencies(gradients);

Optional<Op> prepOp = prepare(name + "/prepare");

List<Op> updateOps = new ArrayList<>();
prepOp.ifPresent(updateOps::add);
for (GradAndVar<? extends TType> pair : gradsAndVars) {
if (!pair.gradient.isClosed()) {
updateOps.add(applyDense(pair));
updateOps.add(applyDense(tfOpsGrads, pair));
}
}

Expand Down Expand Up @@ -261,8 +263,8 @@ protected void createSlots(List<Output<? extends TType>> variables) {}
* @param <T> the datatype of the gradients and variables.
* @return An operand which applies the desired optimizer update to the variable.
*/
private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) {
return applyDense(gradVarPair.getGradient(), gradVarPair.getVariable());
private <T extends TType> Op applyDense(Ops opDependencies, GradAndVar<T> gradVarPair) {
return applyDense(opDependencies, gradVarPair.getGradient(), gradVarPair.getVariable());
}

/**
Expand All @@ -273,7 +275,7 @@ private <T extends TType> Op applyDense(GradAndVar<T> gradVarPair) {
* @param <T> The type of the variable.
* @return An operand which applies the desired optimizer update to the variable.
*/
protected abstract <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable);
protected abstract <T extends TType> Op applyDense(Ops opDependencies, Output<T> gradient, Output<T> variable);

/**
* Gathers up the update operations into a single op that can be used as a run target.
Expand Down
Loading