From c7f486755ace92b45494a58963dc085fb01dbc30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Mon, 11 Nov 2024 15:11:19 +0100 Subject: [PATCH 01/13] Implementation of aw_rescaling --- src/anemoi/training/data/scaling.py | 55 +++++++++++++++++++++++++ src/anemoi/training/train/forecaster.py | 16 ++++++- 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/anemoi/training/data/scaling.py b/src/anemoi/training/data/scaling.py index 83419a88..dcacd3ae 100644 --- a/src/anemoi/training/data/scaling.py +++ b/src/anemoi/training/data/scaling.py @@ -14,6 +14,11 @@ import numpy as np +import torch +from torch_geometric.data import HeteroData +from scipy.spatial import SphericalVoronoi +from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian + LOGGER = logging.getLogger(__name__) @@ -77,3 +82,53 @@ def scaler(plev: float) -> np.ndarray: del plev # unused # no scaling, always return 1.0 return 1.0 + +class BaseAreaWeights: + """Method for overwriting the area-weights stored in the graph object.""" + + def __init__(self, target_nodes: str): + """Initialize area weights with target nodes. + + Parameters + ---------- + target_nodes : str + Name of the set of nodes to be rescaled (defined when creating the graph). + """ + self.target = target_nodes + + def area_weights(self, graph_data) -> torch.Tensor: + return torch.from_numpy(self.global_area_weights(graph_data)) + + def global_area_weights(self, graph_data: HeteroData) -> np.ndarray: + lats, lons = graph_data[self.target].x[:,0], graph_data[self.target].x[:,1] + points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons))) + sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0]) + area_weights = sv.calculate_areas() + + return area_weights / np.max(area_weights) + +class StretchedGridCutoutAreaWeights(BaseAreaWeights): + """Rescale area weight of nodes inside the cutout area by setting their sum to a fraction of the sum of global nodes area weight.""" + + def __init__(self, target_nodes: str, cutout_weight_frac_of_global: float): + """Initialize area weights with target nodes and scaling factor for the cutout nodes area weight. + + Parameters + ---------- + target_nodes : str + Name of the set of nodes to be rescaled (defined when creating the graph). + cutout_weight_frac_of_global: float + Scaling factor for the cutout nodes area weight - sum of cutout nodes area weight set to a fraction of the sum of global nodes area weight. + """ + super().__init__(target_nodes=target_nodes) + self.fraction = cutout_weight_frac_of_global + + def area_weights(self, graph_data: HeteroData) -> torch.Tensor: + area_weights = self.global_area_weights(graph_data) + mask = graph_data[self.target]["cutout"].squeeze().bool() + + global_sum = np.sum(area_weights[~mask]) + weight_per_cutout_node = self.fraction * global_sum / sum(mask) + area_weights[mask] = weight_per_cutout_node + + return torch.from_numpy(area_weights) \ No newline at end of file diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index ebba3c6d..3b2ccb5b 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -86,7 +86,7 @@ def __init__( self.save_hyperparameters() self.latlons_data = graph_data[config.graph.data].x - self.node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() + self.node_weights = self.get_node_weights(config, graph_data) if config.model.get("output_mask", None) is not None: self.output_mask = Boolean1DMask(graph_data[config.graph.data][config.model.output_mask]) @@ -289,6 +289,20 @@ def get_feature_weights( LOGGER.debug("Parameter %s was not scaled.", key) return torch.from_numpy(loss_scaling) + + @staticmethod + def get_node_weights( + config: DictConfig, + graph_data: HeteroData + ) -> torch.Tensor: + node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() + + if "spatial" in config.training.loss_scaling: + spatial_loss_scaler = instantiate(config.training.loss_scaling.spatial) + node_weights = spatial_loss_scaler.area_weights(graph_data) + LOGGER.info("Rescaling area weights") + + return node_weights def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: LOGGER.debug("set_model_comm_group: %s", model_comm_group) From e99a5a7abeef088508307698beed25f72f6a2847 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Mon, 11 Nov 2024 15:24:01 +0100 Subject: [PATCH 02/13] Pre-commit --- src/anemoi/training/data/scaling.py | 23 ++++++++++++++--------- src/anemoi/training/train/forecaster.py | 9 +++------ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/anemoi/training/data/scaling.py b/src/anemoi/training/data/scaling.py index dcacd3ae..f890a1f1 100644 --- a/src/anemoi/training/data/scaling.py +++ b/src/anemoi/training/data/scaling.py @@ -13,11 +13,10 @@ from abc import abstractmethod import numpy as np - import torch -from torch_geometric.data import HeteroData -from scipy.spatial import SphericalVoronoi from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian +from scipy.spatial import SphericalVoronoi +from torch_geometric.data import HeteroData LOGGER = logging.getLogger(__name__) @@ -83,6 +82,7 @@ def scaler(plev: float) -> np.ndarray: # no scaling, always return 1.0 return 1.0 + class BaseAreaWeights: """Method for overwriting the area-weights stored in the graph object.""" @@ -96,29 +96,34 @@ def __init__(self, target_nodes: str): """ self.target = target_nodes - def area_weights(self, graph_data) -> torch.Tensor: + def area_weights(self, graph_data: HeteroData) -> torch.Tensor: return torch.from_numpy(self.global_area_weights(graph_data)) def global_area_weights(self, graph_data: HeteroData) -> np.ndarray: - lats, lons = graph_data[self.target].x[:,0], graph_data[self.target].x[:,1] + lats, lons = graph_data[self.target].x[:, 0], graph_data[self.target].x[:, 1] points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons))) sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0]) area_weights = sv.calculate_areas() return area_weights / np.max(area_weights) + class StretchedGridCutoutAreaWeights(BaseAreaWeights): - """Rescale area weight of nodes inside the cutout area by setting their sum to a fraction of the sum of global nodes area weight.""" + """Rescale area weight of nodes inside the cutout area. + + Sum of the area weight of cutout nodes set to a specified fraction of sum of global nodes. + """ def __init__(self, target_nodes: str, cutout_weight_frac_of_global: float): """Initialize area weights with target nodes and scaling factor for the cutout nodes area weight. - + Parameters ---------- target_nodes : str Name of the set of nodes to be rescaled (defined when creating the graph). cutout_weight_frac_of_global: float - Scaling factor for the cutout nodes area weight - sum of cutout nodes area weight set to a fraction of the sum of global nodes area weight. + Scaling factor for the cutout nodes area weight - sum of cutout nodes area weight set to a fraction of + the sum of global nodes area weight. """ super().__init__(target_nodes=target_nodes) self.fraction = cutout_weight_frac_of_global @@ -131,4 +136,4 @@ def area_weights(self, graph_data: HeteroData) -> torch.Tensor: weight_per_cutout_node = self.fraction * global_sum / sum(mask) area_weights[mask] = weight_per_cutout_node - return torch.from_numpy(area_weights) \ No newline at end of file + return torch.from_numpy(area_weights) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 3b2ccb5b..94caa435 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -289,19 +289,16 @@ def get_feature_weights( LOGGER.debug("Parameter %s was not scaled.", key) return torch.from_numpy(loss_scaling) - + @staticmethod - def get_node_weights( - config: DictConfig, - graph_data: HeteroData - ) -> torch.Tensor: + def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor: node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() if "spatial" in config.training.loss_scaling: spatial_loss_scaler = instantiate(config.training.loss_scaling.spatial) node_weights = spatial_loss_scaler.area_weights(graph_data) LOGGER.info("Rescaling area weights") - + return node_weights def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: From 8dc5e11795b5a75233c867cd7f6a5eeb06951e88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Thu, 14 Nov 2024 10:34:16 +0100 Subject: [PATCH 03/13] Updated implementation based on feedback --- .../training/config/training/default.yaml | 5 ++ src/anemoi/training/data/scaling.py | 60 ------------- src/anemoi/training/losses/nodeweigths.py | 89 +++++++++++++++++++ src/anemoi/training/train/forecaster.py | 9 +- 4 files changed, 96 insertions(+), 67 deletions(-) create mode 100644 src/anemoi/training/losses/nodeweigths.py diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index e1826d2c..41f64323 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -112,3 +112,8 @@ pressure_level_scaler: _target_: anemoi.training.data.scaling.ReluPressureLevelScaler minimum: 0.2 slope: 0.001 + +node_loss_weights: + _target_: anemoi.traininig.losses.nodeweights.GraphNodeAttribute + target_nodes: ${graph.data} + node_attribute: area_weight diff --git a/src/anemoi/training/data/scaling.py b/src/anemoi/training/data/scaling.py index f890a1f1..83419a88 100644 --- a/src/anemoi/training/data/scaling.py +++ b/src/anemoi/training/data/scaling.py @@ -13,10 +13,6 @@ from abc import abstractmethod import numpy as np -import torch -from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian -from scipy.spatial import SphericalVoronoi -from torch_geometric.data import HeteroData LOGGER = logging.getLogger(__name__) @@ -81,59 +77,3 @@ def scaler(plev: float) -> np.ndarray: del plev # unused # no scaling, always return 1.0 return 1.0 - - -class BaseAreaWeights: - """Method for overwriting the area-weights stored in the graph object.""" - - def __init__(self, target_nodes: str): - """Initialize area weights with target nodes. - - Parameters - ---------- - target_nodes : str - Name of the set of nodes to be rescaled (defined when creating the graph). - """ - self.target = target_nodes - - def area_weights(self, graph_data: HeteroData) -> torch.Tensor: - return torch.from_numpy(self.global_area_weights(graph_data)) - - def global_area_weights(self, graph_data: HeteroData) -> np.ndarray: - lats, lons = graph_data[self.target].x[:, 0], graph_data[self.target].x[:, 1] - points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons))) - sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0]) - area_weights = sv.calculate_areas() - - return area_weights / np.max(area_weights) - - -class StretchedGridCutoutAreaWeights(BaseAreaWeights): - """Rescale area weight of nodes inside the cutout area. - - Sum of the area weight of cutout nodes set to a specified fraction of sum of global nodes. - """ - - def __init__(self, target_nodes: str, cutout_weight_frac_of_global: float): - """Initialize area weights with target nodes and scaling factor for the cutout nodes area weight. - - Parameters - ---------- - target_nodes : str - Name of the set of nodes to be rescaled (defined when creating the graph). - cutout_weight_frac_of_global: float - Scaling factor for the cutout nodes area weight - sum of cutout nodes area weight set to a fraction of - the sum of global nodes area weight. - """ - super().__init__(target_nodes=target_nodes) - self.fraction = cutout_weight_frac_of_global - - def area_weights(self, graph_data: HeteroData) -> torch.Tensor: - area_weights = self.global_area_weights(graph_data) - mask = graph_data[self.target]["cutout"].squeeze().bool() - - global_sum = np.sum(area_weights[~mask]) - weight_per_cutout_node = self.fraction * global_sum / sum(mask) - area_weights[mask] = weight_per_cutout_node - - return torch.from_numpy(area_weights) diff --git a/src/anemoi/training/losses/nodeweigths.py b/src/anemoi/training/losses/nodeweigths.py new file mode 100644 index 00000000..72a17fd3 --- /dev/null +++ b/src/anemoi/training/losses/nodeweigths.py @@ -0,0 +1,89 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging + +import numpy as np +import torch +from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian +from scipy.spatial import SphericalVoronoi +from torch_geometric.data import HeteroData + +LOGGER = logging.getLogger(__name__) + + +class GraphNodeAttribute: + """Method to load and optionally change the weighting of node attributes in the graph.""" + + def __init__(self, target_nodes: str, node_attribute: str): + self.target = target_nodes + self.node_attribute = node_attribute + + def area_weights(self, graph_data: HeteroData) -> np.ndarray: + lats, lons = graph_data[self.target].x[:, 0], graph_data[self.target].x[:, 1] + points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons))) + sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0]) + area_weights = sv.calculate_areas() + + return area_weights / np.max(area_weights) + + def weights(self, graph_data: HeteroData) -> torch.Tensor: + try: + attr_weight = graph_data[self.target][self.node_attribute].squeeze() + + LOGGER.info("Loading node attribute %s from the graph", self.node_attribute) + except KeyError: + attr_weight = torch.from_numpy(self.global_area_weights(graph_data)) + + LOGGER.info( + "Node attribute %s not found in graph. Default area weighting will be used", + self.node_attribute, + ) + + return attr_weight + + +class ReweightedGraphNodeAttribute(GraphNodeAttribute): + """Method to reweight a subset of the target nodes defined by scaled_attributes. + + Subset nodes will be scaled such that their weight sum equals weight_frac_of_total of the sum + over all nodes. + """ + + def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str, weight_frac_of_total: float): + super().__init__(target_nodes=target_nodes, node_attribute=node_attribute) + self.scaled_attribute = scaled_attribute + self.fraction = weight_frac_of_total + + def weights(self, graph_data: HeteroData) -> torch.Tensor: + try: + attr_weight = graph_data[self.target][self.node_attribute].squeeze() + + LOGGER.info("Loading node attribute %s from the graph", self.node_attribute) + except KeyError: + attr_weight = torch.from_numpy(self.global_area_weights(graph_data)) + + LOGGER.info( + "Node attribute %s not found in graph. Default area weighting will be used", + self.node_attribute, + ) + + mask = graph_data[self.target][self.scaled_attribute].squeeze().bool() + + unmasked_sum = torch.sum(attr_weight[~mask]) + weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask) + attr_weight[mask] = weight_per_masked_node + LOGGER.info( + "Weight of nodes in %s rescaled such that their sum equals %.3f of the sum over all nodes", + self.node_attribute, + self.fraction, + ) + + return attr_weight diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 94caa435..9dba8316 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -292,14 +292,9 @@ def get_feature_weights( @staticmethod def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor: - node_weights = graph_data[config.graph.data][config.model.node_loss_weight].squeeze() + node_weighting = instantiate(config.training.node_loss_weights) - if "spatial" in config.training.loss_scaling: - spatial_loss_scaler = instantiate(config.training.loss_scaling.spatial) - node_weights = spatial_loss_scaler.area_weights(graph_data) - LOGGER.info("Rescaling area weights") - - return node_weights + return node_weighting.weights(graph_data) def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: LOGGER.debug("set_model_comm_group: %s", model_comm_group) From cc4f38b4cf666d1af5a6a254934ccc951b69cc7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Thu, 14 Nov 2024 16:20:20 +0100 Subject: [PATCH 04/13] Small fixes - training now worked for all cases --- .../training/config/training/default.yaml | 2 +- .../losses/{nodeweigths.py => nodeweights.py} | 19 ++++--------------- 2 files changed, 5 insertions(+), 16 deletions(-) rename src/anemoi/training/losses/{nodeweigths.py => nodeweights.py} (82%) diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 41f64323..4da69b34 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -114,6 +114,6 @@ pressure_level_scaler: slope: 0.001 node_loss_weights: - _target_: anemoi.traininig.losses.nodeweights.GraphNodeAttribute + _target_: anemoi.training.losses.nodeweights.GraphNodeAttribute target_nodes: ${graph.data} node_attribute: area_weight diff --git a/src/anemoi/training/losses/nodeweigths.py b/src/anemoi/training/losses/nodeweights.py similarity index 82% rename from src/anemoi/training/losses/nodeweigths.py rename to src/anemoi/training/losses/nodeweights.py index 72a17fd3..2845e563 100644 --- a/src/anemoi/training/losses/nodeweigths.py +++ b/src/anemoi/training/losses/nodeweights.py @@ -40,7 +40,7 @@ def weights(self, graph_data: HeteroData) -> torch.Tensor: LOGGER.info("Loading node attribute %s from the graph", self.node_attribute) except KeyError: - attr_weight = torch.from_numpy(self.global_area_weights(graph_data)) + attr_weight = torch.from_numpy(self.area_weights(graph_data)) LOGGER.info( "Node attribute %s not found in graph. Default area weighting will be used", @@ -51,7 +51,7 @@ def weights(self, graph_data: HeteroData) -> torch.Tensor: class ReweightedGraphNodeAttribute(GraphNodeAttribute): - """Method to reweight a subset of the target nodes defined by scaled_attributes. + """Method to reweight a subset of the target nodes defined by scaled_attribute. Subset nodes will be scaled such that their weight sum equals weight_frac_of_total of the sum over all nodes. @@ -63,26 +63,15 @@ def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str self.fraction = weight_frac_of_total def weights(self, graph_data: HeteroData) -> torch.Tensor: - try: - attr_weight = graph_data[self.target][self.node_attribute].squeeze() - - LOGGER.info("Loading node attribute %s from the graph", self.node_attribute) - except KeyError: - attr_weight = torch.from_numpy(self.global_area_weights(graph_data)) - - LOGGER.info( - "Node attribute %s not found in graph. Default area weighting will be used", - self.node_attribute, - ) + attr_weight = super().weights(graph_data) mask = graph_data[self.target][self.scaled_attribute].squeeze().bool() - unmasked_sum = torch.sum(attr_weight[~mask]) weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask) attr_weight[mask] = weight_per_masked_node LOGGER.info( "Weight of nodes in %s rescaled such that their sum equals %.3f of the sum over all nodes", - self.node_attribute, + self.scaled_attribute, self.fraction, ) From d0d2b5755f928f2149c9f40fc17e19d86f56315a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Fri, 15 Nov 2024 09:17:03 +0100 Subject: [PATCH 05/13] Docstrings GraphNodeAttributes, minor fixes --- src/anemoi/training/losses/nodeweights.py | 57 +++++++++++++++++++++-- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/src/anemoi/training/losses/nodeweights.py b/src/anemoi/training/losses/nodeweights.py index 2845e563..909de1f2 100644 --- a/src/anemoi/training/losses/nodeweights.py +++ b/src/anemoi/training/losses/nodeweights.py @@ -20,13 +20,48 @@ class GraphNodeAttribute: - """Method to load and optionally change the weighting of node attributes in the graph.""" + """Base class to load and optionally change the weight attribute of nodes in the graph. + + Attributes + ---------- + target: str + name of target nodes, key in HeteroData graph object + node_attribute: str + name of node weight attribute, key in HeteroData graph object + + Methods + ------- + weights(self, graph_data) + Load node weight attribute. Compute area weights if they can not be found in graph + object. + """ def __init__(self, target_nodes: str, node_attribute: str): + """Initialize graph node attribute with target nodes and node attribute. + + Parameters + ---------- + target_nodes: str + name of nodes, key in HeteroData graph object + node_attribute: str + name of node weight attribute, key in HeteroData graph object + """ self.target = target_nodes self.node_attribute = node_attribute def area_weights(self, graph_data: HeteroData) -> np.ndarray: + """Nodes weighted by the size of the geographical area they represent. + + Parameters + ---------- + graph_data: HeteroData + graph object + + Returns + ------- + np.ndarray + area weights of the target nodes + """ lats, lons = graph_data[self.target].x[:, 0], graph_data[self.target].x[:, 1] points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons))) sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0]) @@ -35,11 +70,26 @@ def area_weights(self, graph_data: HeteroData) -> np.ndarray: return area_weights / np.max(area_weights) def weights(self, graph_data: HeteroData) -> torch.Tensor: - try: + """Returns weight of type self.node_attribute for nodes self.target. + + Attempts to load from graph_data and calculates area weights for the target + nodes if they do not exist. + + Parameters + ---------- + graph_data: HeteroData + graph object + + Returns + ------- + torch.Tensor + weight of target nodes + """ + if self.node_attribute in graph_data[self.target]: attr_weight = graph_data[self.target][self.node_attribute].squeeze() LOGGER.info("Loading node attribute %s from the graph", self.node_attribute) - except KeyError: + else: attr_weight = torch.from_numpy(self.area_weights(graph_data)) LOGGER.info( @@ -69,6 +119,7 @@ def weights(self, graph_data: HeteroData) -> torch.Tensor: unmasked_sum = torch.sum(attr_weight[~mask]) weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask) attr_weight[mask] = weight_per_masked_node + LOGGER.info( "Weight of nodes in %s rescaled such that their sum equals %.3f of the sum over all nodes", self.scaled_attribute, From 8502ebcb55e2f1b13b4c3b693e3fe4501ec9a30d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Fri, 15 Nov 2024 10:48:03 +0100 Subject: [PATCH 06/13] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 287d76ea..9deebe47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ Keep it human-readable, your future self will thank you! - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) +- Functionality to change the weight attribute of nodes in the graph at the start of training without re-generating the graph. [#136] (https://github.com/ecmwf/anemoi-training/pull/136) ### Changed - Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118) From bb4969b7eb491e72a80dea9438504805874c79b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Fri, 15 Nov 2024 10:48:52 +0100 Subject: [PATCH 07/13] Removed obsolete config options --- src/anemoi/training/config/model/gnn.yaml | 2 -- src/anemoi/training/config/model/graphtransformer.yaml | 2 -- src/anemoi/training/config/model/transformer.yaml | 2 -- 3 files changed, 6 deletions(-) diff --git a/src/anemoi/training/config/model/gnn.yaml b/src/anemoi/training/config/model/gnn.yaml index 4f4c176c..92a17fd4 100644 --- a/src/anemoi/training/config/model/gnn.yaml +++ b/src/anemoi/training/config/model/gnn.yaml @@ -45,8 +45,6 @@ attributes: - edge_dirs nodes: [] -node_loss_weight: area_weight - # Bounding configuration bounding: #These are applied in order diff --git a/src/anemoi/training/config/model/graphtransformer.yaml b/src/anemoi/training/config/model/graphtransformer.yaml index 5c2e819a..9c48967b 100644 --- a/src/anemoi/training/config/model/graphtransformer.yaml +++ b/src/anemoi/training/config/model/graphtransformer.yaml @@ -50,8 +50,6 @@ attributes: - edge_dirs nodes: [] -node_loss_weight: area_weight - # Bounding configuration bounding: #These are applied in order diff --git a/src/anemoi/training/config/model/transformer.yaml b/src/anemoi/training/config/model/transformer.yaml index b26c9ecc..cd6a1e7b 100644 --- a/src/anemoi/training/config/model/transformer.yaml +++ b/src/anemoi/training/config/model/transformer.yaml @@ -49,8 +49,6 @@ attributes: - edge_dirs nodes: [] -node_loss_weight: area_weight - # Bounding configuration bounding: #These are applied in order From 569316f916b9e4e4bad1e52076d804244130bb39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Fri, 15 Nov 2024 13:04:41 +0100 Subject: [PATCH 08/13] Docstrings --- src/anemoi/training/losses/nodeweights.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/anemoi/training/losses/nodeweights.py b/src/anemoi/training/losses/nodeweights.py index 909de1f2..e229b71b 100644 --- a/src/anemoi/training/losses/nodeweights.py +++ b/src/anemoi/training/losses/nodeweights.py @@ -108,6 +108,20 @@ class ReweightedGraphNodeAttribute(GraphNodeAttribute): """ def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str, weight_frac_of_total: float): + """Initialize reweighted graph node attribute. + + Parameters + ---------- + target_nodes: str + name of nodes, key in HeteroData graph object + node_attribute: str + name of node weight attribute, key in HeteroData graph object + scaled_attribute: str + name of node attribute defining the subset of nodes to be scaled, key in HeteroData graph object + weight_frac_of_total: float + sum of weight of subset nodes as a fraction of sum of weight of all nodes after rescaling + + """ super().__init__(target_nodes=target_nodes, node_attribute=node_attribute) self.scaled_attribute = scaled_attribute self.fraction = weight_frac_of_total From 5e850ce925906d3b7995493117231c3e1f4b5350 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Tue, 19 Nov 2024 13:35:29 +0100 Subject: [PATCH 09/13] Unit testing --- tests/train/test_nodeweights.py | 98 +++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 tests/train/test_nodeweights.py diff --git a/tests/train/test_nodeweights.py b/tests/train/test_nodeweights.py new file mode 100644 index 00000000..422e516b --- /dev/null +++ b/tests/train/test_nodeweights.py @@ -0,0 +1,98 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import numpy as np +import pytest +import torch +from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian +from scipy.spatial import SphericalVoronoi +from torch_geometric.data import HeteroData + +from anemoi.training.losses.nodeweights import GraphNodeAttribute +from anemoi.training.losses.nodeweights import ReweightedGraphNodeAttribute + + +def fake_graph() -> HeteroData: + hdata = HeteroData() + lons = torch.tensor([1.56, 3.12, 4.68, 6.24]) + lats = torch.tensor([-3.12, -1.56, 1.56, 3.12]) + cutout_mask = torch.tensor([False, True, False, False]).unsqueeze(1) + area_weights = torch.ones(cutout_mask.shape) + hdata["data"]["x"] = torch.stack((lats, lons), dim=1) + hdata["data"]["cutout"] = cutout_mask + hdata["data"]["area_weight"] = area_weights + + return hdata + + +def fake_sv_area_weights() -> torch.Tensor: + lats, lons = fake_graph()["data"]["x"][:, 0], fake_graph()["data"]["x"][:, 1] + points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons))) + sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0]) + area_weights = sv.calculate_areas() + + return torch.from_numpy(area_weights / np.max(area_weights)) + + +def fake_reweighted_sv_area_weights(frac: float) -> torch.Tensor: + weights = fake_sv_area_weights().unsqueeze(1) + cutout_mask = fake_graph()["data"]["cutout"] + unmasked_sum = torch.sum(weights[~cutout_mask]) + weight_per_masked_node = frac / (1.0 - frac) * unmasked_sum / sum(cutout_mask) + weights[cutout_mask] = weight_per_masked_node + + return weights.squeeze() + + +@pytest.mark.parametrize( + ("target_nodes", "node_attribute", "fake_graph", "expected_weights"), + [ + ("data", "area_weight", fake_graph(), fake_graph()["data"]["area_weight"]), + ("data", "non_existent_attr", fake_graph(), fake_sv_area_weights()), + ], +) +def test_grap_node_attributes( + target_nodes: str, + node_attribute: str, + fake_graph: HeteroData, + expected_weights: torch.Tensor, +) -> None: + weights = GraphNodeAttribute(target_nodes=target_nodes, node_attribute=node_attribute).weights(fake_graph) + assert isinstance(weights, torch.Tensor) + assert torch.allclose(weights, expected_weights) + + +@pytest.mark.parametrize( + ("target_nodes", "node_attribute", "scaled_attribute", "weight_frac_of_total", "fake_graph", "expected_weights"), + [ + ("data", "area_weight", "cutout", 0.0, fake_graph(), torch.tensor([1.0, 0.0, 1.0, 1.0])), + ("data", "area_weight", "cutout", 0.5, fake_graph(), torch.tensor([1.0, 3.0, 1.0, 1.0])), + ("data", "area_weight", "cutout", 0.97, fake_graph(), torch.tensor([1.0, 97.0, 1.0, 1.0])), + ("data", "non_existent_attr", "cutout", 0.0, fake_graph(), fake_reweighted_sv_area_weights(0.0)), + ("data", "non_existent_attr", "cutout", 0.5, fake_graph(), fake_reweighted_sv_area_weights(0.5)), + ("data", "non_existent_attr", "cutout", 0.99, fake_graph(), fake_reweighted_sv_area_weights(0.99)), + ], +) +def test_graph_node_attributes( + target_nodes: str, + node_attribute: str, + scaled_attribute: str, + weight_frac_of_total: float, + fake_graph: HeteroData, + expected_weights: torch.Tensor, +) -> None: + weights = ReweightedGraphNodeAttribute( + target_nodes=target_nodes, + node_attribute=node_attribute, + scaled_attribute=scaled_attribute, + weight_frac_of_total=weight_frac_of_total, + ).weights(graph_data=fake_graph) + assert isinstance(weights, torch.Tensor) + assert torch.allclose(weights, expected_weights) From d2fec0b5c55bc61493f12d54633d9f62538be3ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Thu, 21 Nov 2024 12:34:27 +0100 Subject: [PATCH 10/13] Updated documentation --- docs/user-guide/training.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/user-guide/training.rst b/docs/user-guide/training.rst index 5be08222..f05fdc48 100644 --- a/docs/user-guide/training.rst +++ b/docs/user-guide/training.rst @@ -183,6 +183,28 @@ levels nearer to the surface). By default anemoi-training uses a ReLU Pressure Level scaler with a minimum weighting of 0.2 (i.e. no pressure level has a weighting less than 0.2). +The loss is also scaled by assigning a weight to each node on the output +grid. These weights are calculated during graph-creation and stored as +an attribute in the graph object. The configuration option +``config.training.node_loss_weights`` is used to specify the node +attribute used as weights in the loss function. By default +anemoi-training uses area weighting, where each node is weighted +according to the size of the geographical area it represents. + +It is also possible to rescale the weight of a subset of nodes after +they are loaded from the graph. For instance, for a stretched grid setup +we can rescale the weight of nodes in the limited area such that their +sum equals 0.25 of the sum of all node weights with the following config +setup + +.. code:: yaml + + node_loss_weights: + _target_: anemoi.training.losses.nodeweights.ReweightedGraphNodeAttribute + target_nodes: data + scaled_attribute: cutout + weight_frac_of_total: 0.25 + *************** Learning rate *************** From 9c0ac296d325ba381fa16c9eec1c19243b29bb50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Fri, 22 Nov 2024 10:15:29 +0100 Subject: [PATCH 11/13] area_weights uses AreaWeights from anemoi-graphs --- src/anemoi/training/losses/nodeweights.py | 17 +++++------------ tests/train/test_nodeweights.py | 11 ++--------- 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/src/anemoi/training/losses/nodeweights.py b/src/anemoi/training/losses/nodeweights.py index e229b71b..1b8fbf9a 100644 --- a/src/anemoi/training/losses/nodeweights.py +++ b/src/anemoi/training/losses/nodeweights.py @@ -10,10 +10,8 @@ import logging -import numpy as np import torch -from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian -from scipy.spatial import SphericalVoronoi +from anemoi.graphs.nodes.attributes import AreaWeights from torch_geometric.data import HeteroData LOGGER = logging.getLogger(__name__) @@ -49,7 +47,7 @@ def __init__(self, target_nodes: str, node_attribute: str): self.target = target_nodes self.node_attribute = node_attribute - def area_weights(self, graph_data: HeteroData) -> np.ndarray: + def area_weights(self, graph_data: HeteroData) -> torch.Tensor: """Nodes weighted by the size of the geographical area they represent. Parameters @@ -59,15 +57,10 @@ def area_weights(self, graph_data: HeteroData) -> np.ndarray: Returns ------- - np.ndarray + torch.Tensor area weights of the target nodes """ - lats, lons = graph_data[self.target].x[:, 0], graph_data[self.target].x[:, 1] - points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons))) - sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0]) - area_weights = sv.calculate_areas() - - return area_weights / np.max(area_weights) + return AreaWeights(norm="unit-max", fill_value=0).compute(graph_data, self.target) def weights(self, graph_data: HeteroData) -> torch.Tensor: """Returns weight of type self.node_attribute for nodes self.target. @@ -90,7 +83,7 @@ def weights(self, graph_data: HeteroData) -> torch.Tensor: LOGGER.info("Loading node attribute %s from the graph", self.node_attribute) else: - attr_weight = torch.from_numpy(self.area_weights(graph_data)) + attr_weight = self.area_weights(graph_data).squeeze() LOGGER.info( "Node attribute %s not found in graph. Default area weighting will be used", diff --git a/tests/train/test_nodeweights.py b/tests/train/test_nodeweights.py index 422e516b..00b1f41d 100644 --- a/tests/train/test_nodeweights.py +++ b/tests/train/test_nodeweights.py @@ -8,11 +8,9 @@ # nor does it submit to any jurisdiction. -import numpy as np import pytest import torch -from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian -from scipy.spatial import SphericalVoronoi +from anemoi.graphs.nodes.attributes import AreaWeights from torch_geometric.data import HeteroData from anemoi.training.losses.nodeweights import GraphNodeAttribute @@ -33,12 +31,7 @@ def fake_graph() -> HeteroData: def fake_sv_area_weights() -> torch.Tensor: - lats, lons = fake_graph()["data"]["x"][:, 0], fake_graph()["data"]["x"][:, 1] - points = latlon_rad_to_cartesian((np.asarray(lats), np.asarray(lons))) - sv = SphericalVoronoi(points, radius=1.0, center=[0.0, 0.0, 0.0]) - area_weights = sv.calculate_areas() - - return torch.from_numpy(area_weights / np.max(area_weights)) + return AreaWeights(norm="unit-max", fill_value=0).compute(fake_graph(), "data").squeeze() def fake_reweighted_sv_area_weights(frac: float) -> torch.Tensor: From f4bf9c0728cc7ef215844887716529aac8dc80d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Mon, 25 Nov 2024 15:50:42 +0100 Subject: [PATCH 12/13] pre-commit --- src/anemoi/training/train/forecaster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index d2067d68..da6c8d3f 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -319,7 +319,7 @@ def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor node_weighting = instantiate(config.training.node_loss_weights) return node_weighting.weights(graph_data) - + def set_model_comm_group( self, model_comm_group: ProcessGroup, From 87262ee3777190e71be38839e9bbfe985fa502d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5vard=20Homleid=20Haugen?= Date: Tue, 26 Nov 2024 09:34:03 +0100 Subject: [PATCH 13/13] if test to check for scaled_attribute --- src/anemoi/training/losses/nodeweights.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/anemoi/training/losses/nodeweights.py b/src/anemoi/training/losses/nodeweights.py index 1b8fbf9a..ed4afaf4 100644 --- a/src/anemoi/training/losses/nodeweights.py +++ b/src/anemoi/training/losses/nodeweights.py @@ -122,7 +122,12 @@ def __init__(self, target_nodes: str, node_attribute: str, scaled_attribute: str def weights(self, graph_data: HeteroData) -> torch.Tensor: attr_weight = super().weights(graph_data) - mask = graph_data[self.target][self.scaled_attribute].squeeze().bool() + if self.scaled_attribute in graph_data[self.target]: + mask = graph_data[self.target][self.scaled_attribute].squeeze().bool() + else: + error_msg = f"scaled_attribute {self.scaled_attribute} not found in graph_object" + raise KeyError(error_msg) + unmasked_sum = torch.sum(attr_weight[~mask]) weight_per_masked_node = self.fraction / (1 - self.fraction) * unmasked_sum / sum(mask) attr_weight[mask] = weight_per_masked_node