Automatic differentiation is a crucial component of DMFF and plays a significant role in optimizing neural networks. This technique computes the derivatives of output with respect to input using backpropagation, so parameters optimization can be conducted using gradient descent algorithms. With its efficiency in optimizing high-dimensional parameters, this technique is not limited to training neural networks but is also suitable for any physical model optimization (i.e., molecular force field in the case of DMFF). A typical optimization recipe needs two key ingradients: 1. gradient evaluation, which can be done easily using JAX; and 2. an optimizer that takes gradient as inputs, and update parameters following certain optimization algorithm. To help the users building optimization workflows, DMFF provides an wrapper API for optimizers implemented in Optax, which is introduced here.
Function periodic_move
:
- Creates a function to perform a periodic update on parameters. If the update causes the parameters to exceed a given range, they are wrapped around in a periodic manner (like an angle that wraps around after 360 degrees).
Function genOptimizer
:
- It's a function to generate an optimizer based on user preferences.
- Depending on the arguments, it can produce various optimization schemes, such as SGD, Nesterov, Adam, etc.
- Supports learning rate schedules like exponential decay and warmup exponential decay.
- The optimizer can be further augmented with features like gradient clipping, periodic parameter wrapping, and keeping parameters non-negative.
Function label_iter
, mark_iter
, and label2trans_iter
:
- These are utility functions used for tree-like structured data (common with JAX's pytree concept).
label_iter
recursively labels the tree nodes.mark_iter marks
each node in the tree with a False.label2trans_iter
checks and updates the mask tree based on whether the label exists in the transformations. If not, it sets that transformation to zero.
Class MultiTransform
:
- Manages multiple transformations simultaneously on tree-structured data.
- Maintains a record of transformations, labels, and masks.
finalize
method ensures that every label has a corresponding transformation, setting any missing transformations to zero.
To set up an optimization, you should follow these steps:
- Initialize MultiTransform with Parameter Tree:
multiTrans = MultiTransform(your_parameter_tree)
-
Define Optimizers for Desired Labels
-
For each part of the parameter tree you want to optimize differently, set an optimizer. For example:
multiTrans["Force1"]["Parameter1"] = genOptimizer(learning_rate=lr1, clip=clip1)
multiTrans["Force1"]["Parameter2"] = genOptimizer(learning_rate=lr2, clip=clip2)
-
Finalize MultiTransform
-
Before using it, always finalize the MultiTransform:
multiTrans.finalize()
- Create a Combined Gradient Transformation:
traj = md.load("init.dcd", top="box.pdb")[50:]
- Create a sample using the loaded trajectory and the previously defined state name:
grad_transform = optax.multi_transform(multiTrans.transforms, multiTrans.labels)
- Mask Parameters (If Needed)
If you have parameters in your tree that shouldn't be updated, create a mask and then mask your transformation:
grad_transform = optax.masked(grad_transform, hamiltonian.getParameters().mask)
- Initialize Optimization State:
opt_state = grad_transform.init(hamiltonian.getParameters().parameters)
By following the above steps, you'll have a grad_transform
that can handle complex parameter trees and an optimization state opt_state
ready for your optimization routine.