From d2534498afef777f73422f9936a62e4899ba8516 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Mon, 22 May 2023 13:58:42 +0200 Subject: [PATCH 01/15] changed loss derivative to be a vector --- znnl/training_recording/jax_recording.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 90b01d6..68f47b3 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -548,8 +548,7 @@ def _update_loss_derivative(self, parsed_data): vector_loss_derivative = self._loss_derivative_fn.calculate( parsed_data["predictions"], self._data_set["targets"] ) - loss_derivative = calculate_l_pq_norm(vector_loss_derivative) - self._loss_derivative_array.append(loss_derivative) + self._loss_derivative_array.append(vector_loss_derivative) def gather_recording(self, selected_properties: list = None) -> dataclass: """ @@ -593,6 +592,7 @@ def gather_recording(self, selected_properties: list = None) -> dataclass: db_data = self._data_storage.fetch_data(selected_properties) # Add db data to the selected data dict. for item, data in selected_data.items(): + print(item) selected_data[item] = onp.concatenate((db_data[item], data), axis=0) except FileNotFoundError: # There is no database. From 5e3f8f553d7300ac7f60a82cd85b0b190e1d2dd5 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 12:53:53 +0200 Subject: [PATCH 02/15] Implemented the recording of the fisher trace --- znnl/training_recording/jax_recording.py | 62 +++++++++++++++++++++--- 1 file changed, 54 insertions(+), 8 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 68f47b3..020a525 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -39,7 +39,6 @@ from znnl.models.jax_model import JaxModel from znnl.training_recording.data_storage import DataStorage from znnl.utils.matrix_utils import ( - calculate_l_pq_norm, compute_magnitude_density, normalize_gram_matrix, ) @@ -90,6 +89,10 @@ class JaxRecorder: loss_derivative : bool (default=False) If true, the derivative of the loss function with respect to the network output will be recorded. + fisher_trace : bool (default=False) + If true, the trace of the fisher matrix will be recorded. Requires the ntk + and the loss derivative to be calculated. + Warning, large overhead. update_rate : int (default=1) How often the values are updated. @@ -148,6 +151,10 @@ class JaxRecorder: loss_derivative: bool = False _loss_derivative_array: list = None + # Fisher trace + fisher_trace: bool = False + _fisher_trace_array: list = None + # Class helpers update_rate: int = 1 _loss_fn: SimpleLoss = None @@ -244,7 +251,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): if overwrite: self._index_count = 0 - # Check if we need an NTK computation and update the class accordingly + # Check if we need an NTK computation, update the class accordingly if any( [ "ntk" in self._selected_properties, @@ -255,10 +262,19 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): "covariance_entropy" in self._selected_properties, "eigenvalues" in self._selected_properties, "trace" in self._selected_properties, + "fisher_trace" in self._selected_properties, ] ): self._compute_ntk = True + # Check if we need a loss derivative computation, update the class accordingly + if any( + [ + "fisher_trace" in self._selected_properties, + ] + ): + self._compute_loss_derivative = True + if "loss_derivative" in self._selected_properties: self._loss_derivative_fn = LossDerivative(self._loss_fn) @@ -291,7 +307,7 @@ def update_recorder(self, epoch: int, model: JaxModel): predictions = predictions[0] parsed_data["predictions"] = predictions - # Compute ntk here to avoid repeated computation. + # Compute ntk and loss derivative here to avoid repeated computation. if self._compute_ntk: try: ntk = self._model.compute_ntk( @@ -311,6 +327,11 @@ def update_recorder(self, epoch: int, model: JaxModel): self.covariance_entropy = False self.eigenvalues = False self._read_selected_attributes() + if self._compute_loss_derivative: + vector_loss_derivative = self._loss_derivative_fn.calculate( + parsed_data["predictions"], self._data_set["targets"] + ) + parsed_data["loss_derivative"] = vector_loss_derivative for item in self._selected_properties: call_fn = getattr(self, f"_update_{item}") # get the callable function @@ -538,17 +559,42 @@ def _update_loss_derivative(self, parsed_data): """ Update the loss derivative array. - The loss derivative is normalized by the L_pq matrix norm. + Parameters + ---------- + parsed_data : dict + Data computed before the update to prevent repeated calculations. + """ + self._loss_derivative_array.append(parsed_data["loss_derivative"]) + + def _update_fisher_trace(self, parsed_data): + """ + Update the fisher trace array. Parameters ---------- parsed_data : dict Data computed before the update to prevent repeated calculations. """ - vector_loss_derivative = self._loss_derivative_fn.calculate( - parsed_data["predictions"], self._data_set["targets"] - ) - self._loss_derivative_array.append(vector_loss_derivative) + loss_derivative = parsed_data["loss_derivative"] + ntk = parsed_data["ntk"] + + try: + assert len(ntk.shape) == 4 + except (AssertionError): + raise TypeError( + "The ntk needs to have 4 dimensions for the fisher trace calculation." + "Maybe you have set the model to trace over the output dimensions?" + ) + + dataset_size = loss_derivative.shape[0] + dimensionality = loss_derivative.shape[1] + fisher_trace = 0 + for i in range(dataset_size): + for l1 in range(dimensionality): + for l2 in range(dimensionality): + fisher_trace += loss_derivative[i, l1] * loss_derivative[i, l2] * \ + ntk[i, i, l1, l2] + self._fisher_trace_array.append(fisher_trace / dataset_size) def gather_recording(self, selected_properties: list = None) -> dataclass: """ From 928fbc7b56699a58cd7673a7cd5308e6520c09e3 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 12:53:53 +0200 Subject: [PATCH 03/15] Implemented the recording of the fisher trace --- znnl/training_recording/jax_recording.py | 67 +++++++++++++++++++----- 1 file changed, 55 insertions(+), 12 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 68f47b3..b98e58a 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -38,11 +38,7 @@ from znnl.loss_functions import SimpleLoss from znnl.models.jax_model import JaxModel from znnl.training_recording.data_storage import DataStorage -from znnl.utils.matrix_utils import ( - calculate_l_pq_norm, - compute_magnitude_density, - normalize_gram_matrix, -) +from znnl.utils.matrix_utils import compute_magnitude_density, normalize_gram_matrix logger = logging.getLogger(__name__) @@ -90,6 +86,10 @@ class JaxRecorder: loss_derivative : bool (default=False) If true, the derivative of the loss function with respect to the network output will be recorded. + fisher_trace : bool (default=False) + If true, the trace of the fisher matrix will be recorded. Requires the ntk + and the loss derivative to be calculated. + Warning, large overhead. update_rate : int (default=1) How often the values are updated. @@ -148,6 +148,10 @@ class JaxRecorder: loss_derivative: bool = False _loss_derivative_array: list = None + # Fisher trace + fisher_trace: bool = False + _fisher_trace_array: list = None + # Class helpers update_rate: int = 1 _loss_fn: SimpleLoss = None @@ -244,7 +248,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): if overwrite: self._index_count = 0 - # Check if we need an NTK computation and update the class accordingly + # Check if we need an NTK computation, update the class accordingly if any( [ "ntk" in self._selected_properties, @@ -255,10 +259,19 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): "covariance_entropy" in self._selected_properties, "eigenvalues" in self._selected_properties, "trace" in self._selected_properties, + "fisher_trace" in self._selected_properties, ] ): self._compute_ntk = True + # Check if we need a loss derivative computation, update the class accordingly + if any( + [ + "fisher_trace" in self._selected_properties, + ] + ): + self._compute_loss_derivative = True + if "loss_derivative" in self._selected_properties: self._loss_derivative_fn = LossDerivative(self._loss_fn) @@ -291,7 +304,7 @@ def update_recorder(self, epoch: int, model: JaxModel): predictions = predictions[0] parsed_data["predictions"] = predictions - # Compute ntk here to avoid repeated computation. + # Compute ntk and loss derivative here to avoid repeated computation. if self._compute_ntk: try: ntk = self._model.compute_ntk( @@ -311,6 +324,11 @@ def update_recorder(self, epoch: int, model: JaxModel): self.covariance_entropy = False self.eigenvalues = False self._read_selected_attributes() + if self._compute_loss_derivative: + vector_loss_derivative = self._loss_derivative_fn.calculate( + parsed_data["predictions"], self._data_set["targets"] + ) + parsed_data["loss_derivative"] = vector_loss_derivative for item in self._selected_properties: call_fn = getattr(self, f"_update_{item}") # get the callable function @@ -538,17 +556,42 @@ def _update_loss_derivative(self, parsed_data): """ Update the loss derivative array. - The loss derivative is normalized by the L_pq matrix norm. + Parameters + ---------- + parsed_data : dict + Data computed before the update to prevent repeated calculations. + """ + self._loss_derivative_array.append(parsed_data["loss_derivative"]) + + def _update_fisher_trace(self, parsed_data): + """ + Update the fisher trace array. Parameters ---------- parsed_data : dict Data computed before the update to prevent repeated calculations. """ - vector_loss_derivative = self._loss_derivative_fn.calculate( - parsed_data["predictions"], self._data_set["targets"] - ) - self._loss_derivative_array.append(vector_loss_derivative) + loss_derivative = parsed_data["loss_derivative"] + ntk = parsed_data["ntk"] + + try: + assert len(ntk.shape) == 4 + except (AssertionError): + raise TypeError( + "The ntk needs to have 4 dimensions for the fisher trace calculation." + "Maybe you have set the model to trace over the output dimensions?" + ) + + dataset_size = loss_derivative.shape[0] + dimensionality = loss_derivative.shape[1] + fisher_trace = 0 + for i in range(dataset_size): + for l1 in range(dimensionality): + for l2 in range(dimensionality): + fisher_trace += loss_derivative[i, l1] * loss_derivative[i, l2] * \ + ntk[i, i, l1, l2] + self._fisher_trace_array.append(fisher_trace / dataset_size) def gather_recording(self, selected_properties: list = None) -> dataclass: """ From 1f6888d53285d927460922d36689d7fcaea1d02c Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 13:20:04 +0200 Subject: [PATCH 04/15] fixing black formatting --- znnl/training_recording/jax_recording.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index bf5a6ee..816b06c 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -586,16 +586,6 @@ def _update_loss_derivative(self, parsed_data): """ self._loss_derivative_array.append(parsed_data["loss_derivative"]) - def _update_fisher_trace(self, parsed_data): - """ - Update the fisher trace array. - Parameters - ---------- - parsed_data : dict - Data computed before the update to prevent repeated calculations. - """ - self._loss_derivative_array.append(parsed_data["loss_derivative"]) - def _update_fisher_trace(self, parsed_data): """ Update the fisher trace array. @@ -610,7 +600,7 @@ def _update_fisher_trace(self, parsed_data): try: assert len(ntk.shape) == 4 - except (AssertionError): + except AssertionError: raise TypeError( "The ntk needs to have 4 dimensions for the fisher trace calculation." "Maybe you have set the model to trace over the output dimensions?" @@ -622,8 +612,11 @@ def _update_fisher_trace(self, parsed_data): for i in range(dataset_size): for l1 in range(dimensionality): for l2 in range(dimensionality): - fisher_trace += loss_derivative[i, l1] * loss_derivative[i, l2] * \ - ntk[i, i, l1, l2] + fisher_trace += ( + loss_derivative[i, l1] + * loss_derivative[i, l2] + * ntk[i, i, l1, l2] + ) self._fisher_trace_array.append(fisher_trace / dataset_size) def gather_recording(self, selected_properties: list = None) -> dataclass: From 1b386abd545f3daa1cbe8b8afbdb77874025fb0f Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 13:34:30 +0200 Subject: [PATCH 05/15] somehow i did every change twice, this commit removed the doubles --- znnl/training_recording/jax_recording.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 816b06c..e8c17e4 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -90,10 +90,6 @@ class JaxRecorder: If true, the trace of the fisher matrix will be recorded. Requires the ntk and the loss derivative to be calculated. Warning, large overhead. - fisher_trace : bool (default=False) - If true, the trace of the fisher matrix will be recorded. Requires the ntk - and the loss derivative to be calculated. - Warning, large overhead. update_rate : int (default=1) How often the values are updated. @@ -156,10 +152,6 @@ class JaxRecorder: fisher_trace: bool = False _fisher_trace_array: list = None - # Fisher trace - fisher_trace: bool = False - _fisher_trace_array: list = None - # Class helpers update_rate: int = 1 _loss_fn: SimpleLoss = None @@ -256,7 +248,6 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): if overwrite: self._index_count = 0 - # Check if we need an NTK computation, update the class accordingly # Check if we need an NTK computation, update the class accordingly if any( [ @@ -281,14 +272,6 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): ): self._compute_loss_derivative = True - # Check if we need a loss derivative computation, update the class accordingly - if any( - [ - "fisher_trace" in self._selected_properties, - ] - ): - self._compute_loss_derivative = True - if "loss_derivative" in self._selected_properties: self._loss_derivative_fn = LossDerivative(self._loss_fn) @@ -321,7 +304,6 @@ def update_recorder(self, epoch: int, model: JaxModel): predictions = predictions[0] parsed_data["predictions"] = predictions - # Compute ntk and loss derivative here to avoid repeated computation. # Compute ntk and loss derivative here to avoid repeated computation. if self._compute_ntk: try: @@ -347,11 +329,6 @@ def update_recorder(self, epoch: int, model: JaxModel): parsed_data["predictions"], self._data_set["targets"] ) parsed_data["loss_derivative"] = vector_loss_derivative - if self._compute_loss_derivative: - vector_loss_derivative = self._loss_derivative_fn.calculate( - parsed_data["predictions"], self._data_set["targets"] - ) - parsed_data["loss_derivative"] = vector_loss_derivative for item in self._selected_properties: call_fn = getattr(self, f"_update_{item}") # get the callable function From e553e680a54c183ca5fd46fa3ee600b7981eaf39 Mon Sep 17 00:00:00 2001 From: SamTov Date: Tue, 23 May 2023 14:23:10 +0200 Subject: [PATCH 06/15] changes --- znnl/training_recording/jax_recording.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index e8c17e4..95ccc7d 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -29,6 +29,8 @@ from os import path from pathlib import Path +import jax + import numpy as onp from znnl.accuracy_functions.accuracy_function import AccuracyFunction @@ -248,7 +250,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): if overwrite: self._index_count = 0 - # Check if we need an NTK computation, update the class accordingly + # Check if we need an NTK computation, update the class accordingly. if any( [ "ntk" in self._selected_properties, @@ -268,12 +270,11 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): if any( [ "fisher_trace" in self._selected_properties, + "loss_derivative" in self._selected_properties ] ): self._compute_loss_derivative = True - - if "loss_derivative" in self._selected_properties: - self._loss_derivative_fn = LossDerivative(self._loss_fn) + self._loss_derivative_fn = LossDerivative(self._loss_fn) def update_recorder(self, epoch: int, model: JaxModel): """ @@ -360,7 +361,7 @@ def visualize_recorder(self): ------- """ - raise NotImplementedError("Not yet available in ZnRND.") + raise NotImplementedError("Not yet available in ZnNL.") @property def loss_fn(self): @@ -579,9 +580,20 @@ def _update_fisher_trace(self, parsed_data): assert len(ntk.shape) == 4 except AssertionError: raise TypeError( - "The ntk needs to have 4 dimensions for the fisher trace calculation." + "The ntk needs to be rank 4 for the fisher trace calculation." "Maybe you have set the model to trace over the output dimensions?" ) + + def _inner_fn(a, b, c): + + return a * b * c + + map_1 = jax.vmap(_inner_fn, in_axes=(None, 0, 0)) + map_2 = jax.vmap(map_1, in_axes=(0, None, 0)) + map_3 = jax.vmap(map_2, in_axes=(0, 0, 0)) + + fisher_trace = onp.sum(map_3(loss_derivative, loss_derivative, )) + dataset_size = loss_derivative.shape[0] dimensionality = loss_derivative.shape[1] From 19843821df266ef6cf405015a1f724c98498571f Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 14:49:17 +0200 Subject: [PATCH 07/15] optimized fisher trace calculation --- znnl/training_recording/jax_recording.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 95ccc7d..0376176 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -30,7 +30,6 @@ from pathlib import Path import jax - import numpy as onp from znnl.accuracy_functions.accuracy_function import AccuracyFunction @@ -274,7 +273,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): ] ): self._compute_loss_derivative = True - self._loss_derivative_fn = LossDerivative(self._loss_fn) + self._loss_derivative_fn = LossDerivative(self._loss_fn) def update_recorder(self, epoch: int, model: JaxModel): """ @@ -583,29 +582,20 @@ def _update_fisher_trace(self, parsed_data): "The ntk needs to be rank 4 for the fisher trace calculation." "Maybe you have set the model to trace over the output dimensions?" ) - + def _inner_fn(a, b, c): return a * b * c - + map_1 = jax.vmap(_inner_fn, in_axes=(None, 0, 0)) map_2 = jax.vmap(map_1, in_axes=(0, None, 0)) map_3 = jax.vmap(map_2, in_axes=(0, 0, 0)) - fisher_trace = onp.sum(map_3(loss_derivative, loss_derivative, )) - - dataset_size = loss_derivative.shape[0] - dimensionality = loss_derivative.shape[1] - fisher_trace = 0 - for i in range(dataset_size): - for l1 in range(dimensionality): - for l2 in range(dimensionality): - fisher_trace += ( - loss_derivative[i, l1] - * loss_derivative[i, l2] - * ntk[i, i, l1, l2] - ) + indices = onp.arange(dataset_size) + fisher_trace = onp.sum(map_3(loss_derivative, loss_derivative, + ntk[indices, indices, :, :])) + self._fisher_trace_array.append(fisher_trace / dataset_size) def gather_recording(self, selected_properties: list = None) -> dataclass: From fac41d8523bed55538de2062c53b33e81f0a9a12 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 14:58:42 +0200 Subject: [PATCH 08/15] black formatting again --- znnl/training_recording/jax_recording.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 0376176..cfc110a 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -269,7 +269,7 @@ def instantiate_recorder(self, data_set: dict = None, overwrite: bool = False): if any( [ "fisher_trace" in self._selected_properties, - "loss_derivative" in self._selected_properties + "loss_derivative" in self._selected_properties, ] ): self._compute_loss_derivative = True @@ -593,8 +593,9 @@ def _inner_fn(a, b, c): dataset_size = loss_derivative.shape[0] indices = onp.arange(dataset_size) - fisher_trace = onp.sum(map_3(loss_derivative, loss_derivative, - ntk[indices, indices, :, :])) + fisher_trace = onp.sum( + map_3(loss_derivative, loss_derivative, ntk[indices, indices, :, :]) + ) self._fisher_trace_array.append(fisher_trace / dataset_size) From f3c64b5503235b825f0d4fd3ca15023ee5b5d310 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 14:59:58 +0200 Subject: [PATCH 09/15] and again --- znnl/training_recording/jax_recording.py | 1 - 1 file changed, 1 deletion(-) diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index cfc110a..91ce769 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -584,7 +584,6 @@ def _update_fisher_trace(self, parsed_data): ) def _inner_fn(a, b, c): - return a * b * c map_1 = jax.vmap(_inner_fn, in_axes=(None, 0, 0)) From 49a33e0bbe1cfe059cdf1ef7792acf1a9871fd75 Mon Sep 17 00:00:00 2001 From: SamTov Date: Tue, 23 May 2023 15:55:46 +0200 Subject: [PATCH 10/15] Move fisher computation to own module. --- znnl/observables/__init__.py | 30 ++++++++++ znnl/observables/fisher_trace.py | 72 ++++++++++++++++++++++++ znnl/training_recording/jax_recording.py | 24 ++------ 3 files changed, 106 insertions(+), 20 deletions(-) create mode 100644 znnl/observables/__init__.py create mode 100644 znnl/observables/fisher_trace.py diff --git a/znnl/observables/__init__.py b/znnl/observables/__init__.py new file mode 100644 index 0000000..180060c --- /dev/null +++ b/znnl/observables/__init__.py @@ -0,0 +1,30 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for the observables. +""" +from znnl.observables.fisher_trace import compute_fisher_trace + +__all__ = [compute_fisher_trace.__name__] diff --git a/znnl/observables/fisher_trace.py b/znnl/observables/fisher_trace.py new file mode 100644 index 0000000..3094ee7 --- /dev/null +++ b/znnl/observables/fisher_trace.py @@ -0,0 +1,72 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for the computation of the Fisher trace. +""" +import jax.numpy as np +import jax + + +def compute_fisher_trace(loss_derivative: np.ndarray, ntk: np.ndarray) -> float: + """ + Compute the Fisher matrix trace from the NTK. + + Parameters + ---------- + loss_derivative : np.ndarray (n_data_points, network_output) + Loss derivative to use in the computation. + ntk : np.ndarray + NTK of the network in one state. + + Returns + ------- + fisher_trace : float + Trace of the Fisher matrix corresponding to the NTK. + """ + try: + assert len(ntk.shape) == 4 + except AssertionError: + raise TypeError( + "The ntk needs to be rank 4 for the fisher trace calculation." + "Maybe you have set the model to trace over the output dimensions?" + ) + + def _inner_fn(a, b, c): + """ + Function to be mapped over. + """ + return a * b * c + + map_1 = jax.vmap(_inner_fn, in_axes=(None, 0, 0)) + map_2 = jax.vmap(map_1, in_axes=(0, None, 0)) + map_3 = jax.vmap(map_2, in_axes=(0, 0, 0)) + + dataset_size = loss_derivative.shape[0] + indices = np.arange(dataset_size) + fisher_trace = np.sum( + map_3(loss_derivative, loss_derivative, ntk[indices, indices, :, :]) + ) + + return fisher_trace / dataset_size diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 91ce769..76a4d2b 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -39,6 +39,7 @@ from znnl.loss_functions import SimpleLoss from znnl.models.jax_model import JaxModel from znnl.training_recording.data_storage import DataStorage +from znnl.observables.fisher_trace import compute_fisher_trace from znnl.utils.matrix_utils import compute_magnitude_density, normalize_gram_matrix logger = logging.getLogger(__name__) @@ -575,28 +576,11 @@ def _update_fisher_trace(self, parsed_data): loss_derivative = parsed_data["loss_derivative"] ntk = parsed_data["ntk"] - try: - assert len(ntk.shape) == 4 - except AssertionError: - raise TypeError( - "The ntk needs to be rank 4 for the fisher trace calculation." - "Maybe you have set the model to trace over the output dimensions?" + fisher_trace = compute_fisher_trace( + loss_derivative=loss_derivative, ntk=ntk ) - def _inner_fn(a, b, c): - return a * b * c - - map_1 = jax.vmap(_inner_fn, in_axes=(None, 0, 0)) - map_2 = jax.vmap(map_1, in_axes=(0, None, 0)) - map_3 = jax.vmap(map_2, in_axes=(0, 0, 0)) - - dataset_size = loss_derivative.shape[0] - indices = onp.arange(dataset_size) - fisher_trace = onp.sum( - map_3(loss_derivative, loss_derivative, ntk[indices, indices, :, :]) - ) - - self._fisher_trace_array.append(fisher_trace / dataset_size) + self._fisher_trace_array.append(fisher_trace) def gather_recording(self, selected_properties: list = None) -> dataclass: """ From 14a8e7ffabd28b066509233676ac83c5dbc3a5a6 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 17:28:45 +0200 Subject: [PATCH 11/15] added test for the fisher_trace calculation --- .../test_fisher_trace_calculation.py | 54 +++++++++++++++++++ znnl/training_recording/jax_recording.py | 5 +- 2 files changed, 55 insertions(+), 4 deletions(-) create mode 100644 CI/unit_tests/observables/test_fisher_trace_calculation.py diff --git a/CI/unit_tests/observables/test_fisher_trace_calculation.py b/CI/unit_tests/observables/test_fisher_trace_calculation.py new file mode 100644 index 0000000..44c9e0e --- /dev/null +++ b/CI/unit_tests/observables/test_fisher_trace_calculation.py @@ -0,0 +1,54 @@ +""" +ZnNL: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +""" + +from znnl.observables.fisher_trace import compute_fisher_trace +import numpy as np + +ntk = np.array( + [ + [ + [[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], + np.random.rand(3, 3) + ], + [ + np.random.rand(3, 3), + [[2, 1, 3], + [1, 2, 3], + [3, 2, 1]] + ] + ] +) +loss_derivative = np.array( + [ + [5, 4, 3], + [2, 1, 0] + ] +) + +assert compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk) == 638 / 2 diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index 76a4d2b..c4e0eb9 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -29,7 +29,6 @@ from os import path from pathlib import Path -import jax import numpy as onp from znnl.accuracy_functions.accuracy_function import AccuracyFunction @@ -576,9 +575,7 @@ def _update_fisher_trace(self, parsed_data): loss_derivative = parsed_data["loss_derivative"] ntk = parsed_data["ntk"] - fisher_trace = compute_fisher_trace( - loss_derivative=loss_derivative, ntk=ntk - ) + fisher_trace = compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk) self._fisher_trace_array.append(fisher_trace) From 87ea89de834a0d85460557a962b4d557e590235b Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 17:30:48 +0200 Subject: [PATCH 12/15] changes for isort --- CI/unit_tests/observables/test_fisher_trace_calculation.py | 3 ++- znnl/observables/fisher_trace.py | 2 +- znnl/training_recording/jax_recording.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CI/unit_tests/observables/test_fisher_trace_calculation.py b/CI/unit_tests/observables/test_fisher_trace_calculation.py index 44c9e0e..3b3f72f 100644 --- a/CI/unit_tests/observables/test_fisher_trace_calculation.py +++ b/CI/unit_tests/observables/test_fisher_trace_calculation.py @@ -25,9 +25,10 @@ ------- """ -from znnl.observables.fisher_trace import compute_fisher_trace import numpy as np +from znnl.observables.fisher_trace import compute_fisher_trace + ntk = np.array( [ [ diff --git a/znnl/observables/fisher_trace.py b/znnl/observables/fisher_trace.py index 3094ee7..08c3abb 100644 --- a/znnl/observables/fisher_trace.py +++ b/znnl/observables/fisher_trace.py @@ -25,8 +25,8 @@ ------- Module for the computation of the Fisher trace. """ -import jax.numpy as np import jax +import jax.numpy as np def compute_fisher_trace(loss_derivative: np.ndarray, ntk: np.ndarray) -> float: diff --git a/znnl/training_recording/jax_recording.py b/znnl/training_recording/jax_recording.py index c4e0eb9..ee49bc7 100644 --- a/znnl/training_recording/jax_recording.py +++ b/znnl/training_recording/jax_recording.py @@ -37,8 +37,8 @@ from znnl.analysis.loss_fn_derivative import LossDerivative from znnl.loss_functions import SimpleLoss from znnl.models.jax_model import JaxModel -from znnl.training_recording.data_storage import DataStorage from znnl.observables.fisher_trace import compute_fisher_trace +from znnl.training_recording.data_storage import DataStorage from znnl.utils.matrix_utils import compute_magnitude_density, normalize_gram_matrix logger = logging.getLogger(__name__) From 1fad2f22d31a4572f98b74186960a26281dac148 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 17:33:16 +0200 Subject: [PATCH 13/15] black formatter changes --- .../test_fisher_trace_calculation.py | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/CI/unit_tests/observables/test_fisher_trace_calculation.py b/CI/unit_tests/observables/test_fisher_trace_calculation.py index 3b3f72f..63077ab 100644 --- a/CI/unit_tests/observables/test_fisher_trace_calculation.py +++ b/CI/unit_tests/observables/test_fisher_trace_calculation.py @@ -31,25 +31,10 @@ ntk = np.array( [ - [ - [[1, 2, 3], - [4, 5, 6], - [7, 8, 9]], - np.random.rand(3, 3) - ], - [ - np.random.rand(3, 3), - [[2, 1, 3], - [1, 2, 3], - [3, 2, 1]] - ] - ] -) -loss_derivative = np.array( - [ - [5, 4, 3], - [2, 1, 0] + [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], np.random.rand(3, 3)], + [np.random.rand(3, 3), [[2, 1, 3], [1, 2, 3], [3, 2, 1]]], ] ) +loss_derivative = np.array([[5, 4, 3], [2, 1, 0]]) assert compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk) == 638 / 2 From 83a72897b39bdb2b8799271af734021a23f7717b Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 17:43:48 +0200 Subject: [PATCH 14/15] refining the fisher trace test --- .../test_fisher_trace_calculation.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/CI/unit_tests/observables/test_fisher_trace_calculation.py b/CI/unit_tests/observables/test_fisher_trace_calculation.py index 63077ab..bbe09fe 100644 --- a/CI/unit_tests/observables/test_fisher_trace_calculation.py +++ b/CI/unit_tests/observables/test_fisher_trace_calculation.py @@ -23,18 +23,37 @@ Summary ------- +This module tests the implementation of the fisher trace computation module. """ import numpy as np from znnl.observables.fisher_trace import compute_fisher_trace -ntk = np.array( - [ - [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], np.random.rand(3, 3)], - [np.random.rand(3, 3), [[2, 1, 3], [1, 2, 3], [3, 2, 1]]], - ] -) -loss_derivative = np.array([[5, 4, 3], [2, 1, 0]]) -assert compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk) == 638 / 2 +class TestFisherTrace: + """ + Class for testing the implementation of the fisher trace calculation + """ + + def test_fisher_trace_computation(self): + """ + Function tests if the fisher trace computation works correctly for an + example which was calculated by hand before. + + Returns + ------- + Asserts the calculated fisher trace for the manually defined inputs + is what it should be. + """ + + ntk = np.array( + [ + [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], np.random.rand(3, 3)], + [np.random.rand(3, 3), [[2, 1, 3], [1, 2, 3], [3, 2, 1]]], + ] + ) + loss_derivative = np.array([[5, 4, 3], [2, 1, 0]]) + + trace = compute_fisher_trace(loss_derivative=loss_derivative, ntk=ntk) + assert trace == 638 / 2 From 68c6ab7c5f58d65c26e0615c557070b563b47322 Mon Sep 17 00:00:00 2001 From: Marc Sauter Date: Tue, 23 May 2023 17:47:39 +0200 Subject: [PATCH 15/15] Added some info to a warning --- znnl/observables/fisher_trace.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/znnl/observables/fisher_trace.py b/znnl/observables/fisher_trace.py index 08c3abb..398a1b8 100644 --- a/znnl/observables/fisher_trace.py +++ b/znnl/observables/fisher_trace.py @@ -37,7 +37,7 @@ def compute_fisher_trace(loss_derivative: np.ndarray, ntk: np.ndarray) -> float: ---------- loss_derivative : np.ndarray (n_data_points, network_output) Loss derivative to use in the computation. - ntk : np.ndarray + ntk : np.ndarray (n_data_points, n_data_points, network_output, network_output) NTK of the network in one state. Returns @@ -51,6 +51,7 @@ def compute_fisher_trace(loss_derivative: np.ndarray, ntk: np.ndarray) -> float: raise TypeError( "The ntk needs to be rank 4 for the fisher trace calculation." "Maybe you have set the model to trace over the output dimensions?" + "Try adding trace_axes=() to the models parameters." ) def _inner_fn(a, b, c):