-
Notifications
You must be signed in to change notification settings - Fork 219
Description
In woking with Model training, an issue on Optimizer has shown its head.
Currently, when calling minimize(loss) on the Optimizer instance, the Optimizer code walks the entire Graph and pulls out all the defined Variables in the graph. The idea is when you call minimize(loss), the Optimizer builds gradients based on all the variables. However, when working with Model, this "all variables approach" breaks down, because some variables are not referenced in the loss operand execution path. This produces the following error:
org.tensorflow.exceptions.TFInvalidArgumentException: Cannot compute the partial derivative for node 'model/mse_total' as it's unreachable from the output node(s).
This specific error is because the MSE metric's internal variables are not within the loss execution path. This pattern of "non-trainable variables (weights)" is in most Metric classes, and in the Model itself, so it is wide spread. What we need is a way to distinguish between trainable and non-trainable variables. Trainable variables would then be used to calculate the gradient values in the Optimizer.
In Python tensorflow, the Keras Layers track the trainable variables as an attribute list, the Model then passes the collected lists to the Optimizer's minimize method.
There are a couple of options here:
- Mimic TF Keras, and have each
Layeridentify its trainable variables, Then, pass the trainable variables as aList<Variable<?> listusing a call like,Optimizer.minimize(loss, trainableVariables), then have theOptimizerminimizeroutine calladdGradientswith this variable list, rather than walk the whole Graph, to compute the gradients. - Within
Optimzier.minimize(loss), walk thelossoperand execution path to locate any variables contributing to the loss calculation, then pass these toaddGradients. A solution based on this option may be facilitated using Add graph walking functions to Graph and GraphOperation #232, "Add graph walking functions to Graph and GraphOperation".