From 68eab0bb18073e52d549955645d8df723b2d1aa7 Mon Sep 17 00:00:00 2001 From: OlgaOvcharenko Date: Thu, 22 Aug 2024 12:26:57 +0200 Subject: [PATCH] Initial commit --- scripts/nn/optim/shampoo.dml | 44 ++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 scripts/nn/optim/shampoo.dml diff --git a/scripts/nn/optim/shampoo.dml b/scripts/nn/optim/shampoo.dml new file mode 100644 index 00000000000..e8b73c6de1b --- /dev/null +++ b/scripts/nn/optim/shampoo.dml @@ -0,0 +1,44 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +/* + * Shampoo optimizer. + */ + +update = function(matrix[double] X, matrix[double] dX, matrix[double] L, matrix[double] R, double lr) + return (matrix[double] X) { + /* + * Performs a vanilla SGD update. + * + * Inputs: + * - X: Parameters to update, of shape (any, any). + * - dX: Gradient wrt `X` of a loss function being optimized, of same shape as `X`. + * - L: Left second-moment information of the accumulated gradients. + * - R: Right second-moment information of the accumulated gradients. + * - lr: Learning rate. + * + * Outputs: + * - X: Updated parameters `X`, of same shape as input `X`. + */ + L = L + dX %*% t(dX) + R = R + t(dX) %*% dX + X = X – lr * pow(L, -1/4) %*% dX %*% pow(R, -1/4)) +} \ No newline at end of file