diff --git a/fseval/callbacks/wandb.py b/fseval/callbacks/wandb.py index 2896713..b2ce823 100644 --- a/fseval/callbacks/wandb.py +++ b/fseval/callbacks/wandb.py @@ -1,8 +1,9 @@ import copy import sys import time +from collections import UserDict from logging import Logger, getLogger -from typing import Dict, Optional, cast +from typing import Dict, List, Optional, cast from fseval.types import Callback from fseval.utils.dict_utils import dict_flatten, dict_merge @@ -12,6 +13,40 @@ from wandb.viz import CustomChart, custom_chart_panel_config +class QueryField(dict): + """Wrapper class for query field panel configs. See `PanelConfig`.""" + + def __init__(self, *fields): + self["name"] = "runSets" + self["args"] = [{"name": "runSets", "value": r"${runSets}"}] + self["fields"] = [ + {"name": "id", "fields": []}, + {"name": "name", "fields": []}, + {"name": "_defaultColorIndex", "fields": []}, + *fields, + ] + + def add_field(self, field: Dict): + self["fields"].append(field) + + +class PanelConfig(dict): + """Wrapper class for panel configs. Is basically what `custom_chart_panel_config` + in wandb.viz does, but allows more dynamic configuring. Skips the construction + of a CustomChart class. + + see https://github.com/wandb/client/blob/master/wandb/viz.py""" + + def __init__(self, viz_id: str, fields: Dict = {}, string_fields: Dict = {}): + self["userQuery"] = {"queryFields": []} + self["panelDefId"] = viz_id + self["fieldSettings"] = fields + self["stringSettings"] = string_fields + + def add_query_field(self, query_field: QueryField): + self["userQuery"]["queryFields"].append(query_field) + + class WandbCallback(Callback): def __init__(self, **kwargs): super(WandbCallback, self).__init__() @@ -92,15 +127,66 @@ def upload_table(self, df, name): logs[name] = table wandb.log(logs) - def add_panel( + def create_panel_config( self, viz_id: str, - panel_name: str, - table_key: str, + table_key: Optional[str] = None, + summary_fields: List = [], + config_fields: List = [], fields: Dict = {}, string_fields: Dict = {}, - panel_config_callback=lambda panel_config: panel_config, - ): + ) -> PanelConfig: + """Construct a `PanelConfig`, which is to be passed to the `add_panel_to_run` + function. Allows uploading custom charts right to the wandb run. Allows + configuring summary fields, config and tables.""" + + # create panel config + panel_config = PanelConfig( + viz_id=viz_id, fields=fields, string_fields=string_fields + ) + + ### Add query fields + query_field = QueryField() + + # add table + if table_key: + panel_config["transform"] = {"name": "tableWithLeafColNames"} + query_field.add_field( + { + "name": "summaryTable", + "args": [{"name": "tableKey", "value": table_key}], + "fields": [], + } + ) + + # add summary fields + if summary_fields: + query_field.add_field( + { + "name": "summary", + "args": [{"name": "keys", "value": summary_fields}], + "fields": [], + } + ) + + # add config fields + if config_fields: + query_field.add_field( + { + "name": "config", + "args": [{"name": "keys", "value": config_fields}], + "fields": [], + } + ) + + # add query field to panel config + panel_config.add_query_field(query_field) + + return panel_config + + def add_panel_to_run( + self, panel_name: str, panel_config: PanelConfig, panel_type: str = "Vega2" + ) -> None: """Adds a custom chart panel to the current wandb run. This function uses internal wandb functions, so might be prone to changes in their code. The function is a mixup of the following modules / functions: @@ -114,23 +200,16 @@ def add_panel( assert wandb.run is not None, "no wandb run in progress. wandb.run is None." - # create custom chart. is just a data holder class for its attributes. - custom_chart = CustomChart( - viz_id=viz_id, - table=None, - fields=fields, - string_fields=string_fields, - ) - - # create custom chart config. - # Function `custom_chart_panel_config(custom_chart, key, table_key)` has a - # useless attribute, `key`. - panel_config = custom_chart_panel_config(custom_chart, None, table_key) - panel_config = panel_config_callback(panel_config) - # add chart to current run. - wandb.run._add_panel(panel_name, "Vega2", panel_config) + wandb.run._add_panel(panel_name, panel_type, panel_config) # "publish" chart to backend if wandb.run._backend: wandb.run._backend.interface.publish_history({}, wandb.run.step) + + def add_panel(self, panel_name: str, viz_id: str, **panel_config) -> None: + """Convience method to create panel config, and add a panel to the run with the + config right away. See `add_panel_to_run()`.""" + + panel_config = self.create_panel_config(viz_id, **panel_config) + self.add_panel_to_run(panel_name, panel_config) diff --git a/fseval/pipelines/rank_and_validate/rank_and_validate.py b/fseval/pipelines/rank_and_validate/rank_and_validate.py index 6939376..cc7070f 100644 --- a/fseval/pipelines/rank_and_validate/rank_and_validate.py +++ b/fseval/pipelines/rank_and_validate/rank_and_validate.py @@ -255,11 +255,183 @@ def score(self, X, y, **kwargs): if wandb_callback: self.logger.info(f"Tables uploaded {tc.green('✓')}") - # Upload charts - # create chart: + ##### Upload charts + # has ground truth + rank_and_validate_estimator = self.estimators[0] + ranking_validator_estimator = rank_and_validate_estimator.ranking_validator + X_importances = ranking_validator_estimator.X_importances + has_ground_truth = X_importances is not None + + ### Aggregated charts if wandb_callback: - # use wandb_callback.add_panel - # "dunnkers/fseval/feature-importances-all-bootstraps-with-ticks" - ... + # mean validation scores + wandb_callback.add_panel( + panel_name="mean_validation_score", + viz_id="dunnkers/fseval/datasets-vs-rankers", + table_key="validation_scores_mean", + config_fields=[ + "ranker/name", + "dataset/name", + "validator/name", + "dataset/task", + ], + fields={ + "validator": "validator/name", + "score": "score", + "x": "ranker/name", + "y": "dataset/name", + "task": "dataset/task", + }, + string_fields={ + "title": "Feature ranker performance", + "subtitle": "→ Mean validation score over all bootstraps", + "ylabel": "accuracy or r2-score", + "aggregation_op": "mean", + "scale_type": "pow", + "color_exponent": "5", + "opacity_exponent": "15", + "scale_string": "x^5", + "color_scheme": "redyellowblue", + "reverse_colorscheme": "", + "text_threshold": "0.95", + }, + ) + + # stability scores + wandb_callback.add_panel( + panel_name="feature_importance_stability", + viz_id="dunnkers/fseval/datasets-vs-rankers", + table_key="feature_importances", + config_fields=[ + "ranker/name", + "dataset/name", + "validator/name", + "dataset/task", + ], + fields={ + "validator": "validator/name", + "score": "feature_importances", + "x": "ranker/name", + "y": "dataset/name", + "task": "dataset/task", + }, + string_fields={ + "title": "Algorithm Stability", + "subtitle": "→ Mean stdev of feature importances. Lower is better.", + "aggregation_op": "stdev", + "scale_type": "quantile", + "color_exponent": "5", + "opacity_exponent": "0", + "scale_string": "qnt", + "color_scheme": "reds", + "reverse_colorscheme": "", + "text_threshold": "0.95", + }, + ) + + # fitting time + wandb_callback.add_panel( + panel_name="fitting_times", + viz_id="dunnkers/fseval/datasets-vs-rankers", + table_key="ranking_scores", + config_fields=[ + "ranker/name", + "dataset/name", + "validator/name", + "dataset/task", + ], + fields={ + "validator": "validator/name", + "score": "fit_time", + "x": "ranker/name", + "y": "dataset/name", + "task": "dataset/task", + }, + string_fields={ + "title": "Fitting time (seconds)", + "subtitle": "→ As mean over all bootstraps", + "aggregation_op": "mean", + "scale_type": "pow", + "color_exponent": "0.1", + "opacity_exponent": "15", + "scale_string": "x^0.1", + "color_scheme": "redyellowblue", + "reverse_colorscheme": "true", + "text_threshold": "0.95", + }, + ) + + ### Individual charts + if wandb_callback: + # validation scores bootstraps + wandb_callback.add_panel( + panel_name="validation_score_bootstraps", + viz_id="dunnkers/fseval/validation-score-bootstraps", + table_key="validation_scores", + config_fields=["validator/name", "ranker/name"], + string_fields={ + "title": "Classification accuracy vs. Subset size", + "subtitle": "→ for all bootstraps", + "ylabel": "accuracy or r2-score", + }, + ) + + # validation scores + wandb_callback.add_panel( + panel_name="validation_score", + viz_id="dunnkers/fseval/validation-score", + table_key="validation_scores", + config_fields=["ranker/name"], + fields={ + "hue": "ranker/name", + }, + string_fields={ + "title": "Classification accuracy vs. Subset size", + "subtitle": "→ as the mean over all bootstraps", + "ylabel": "accuracy or r2-score", + }, + ) + + # mean feature importance + wandb_callback.add_panel( + panel_name="feature_importances_mean", + viz_id="wandb/bar/v0", + table_key="feature_importances", + fields={ + "label": "feature_index", + "value": "feature_importances", + }, + string_fields={"title": "Feature importance per feature"}, + ) + + # feature importance & stability + wandb_callback.add_panel( + panel_name="feature_importances_stability", + viz_id="dunnkers/fseval/feature-importances-stability", + table_key="feature_importances", + fields={ + "x": "feature_index", + "y": "feature_importances", + }, + string_fields={ + "title": "Feature importance & Stability", + "subtitle": "→ a smaller stdev means more stability", + }, + ) + + # feature importance vs feature index + wandb_callback.add_panel( + panel_name="feature_importances_all_bootstraps", + viz_id="dunnkers/fseval/feature-importances-all-bootstraps-with-ticks", + table_key="feature_importances", + fields={ + "x": "feature_index", + "y": "feature_importances", + }, + string_fields={ + "title": "Feature importance vs. Feature index", + "subtitle": "→ estimated feature importance per feature", + }, + ) return summary