Skip to content

Commit

Permalink
Add charts as part of the pipeline 💎 closes #32
Browse files Browse the repository at this point in the history
  • Loading branch information
dunnkers committed Jun 23, 2021
1 parent c5f0eec commit f050a87
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 26 deletions.
121 changes: 100 additions & 21 deletions fseval/callbacks/wandb.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import sys
import time
from collections import UserDict
from logging import Logger, getLogger
from typing import Dict, Optional, cast
from typing import Dict, List, Optional, cast

from fseval.types import Callback
from fseval.utils.dict_utils import dict_flatten, dict_merge
Expand All @@ -12,6 +13,40 @@
from wandb.viz import CustomChart, custom_chart_panel_config


class QueryField(dict):
"""Wrapper class for query field panel configs. See `PanelConfig`."""

def __init__(self, *fields):
self["name"] = "runSets"
self["args"] = [{"name": "runSets", "value": r"${runSets}"}]
self["fields"] = [
{"name": "id", "fields": []},
{"name": "name", "fields": []},
{"name": "_defaultColorIndex", "fields": []},
*fields,
]

def add_field(self, field: Dict):
self["fields"].append(field)


class PanelConfig(dict):
"""Wrapper class for panel configs. Is basically what `custom_chart_panel_config`
in wandb.viz does, but allows more dynamic configuring. Skips the construction
of a CustomChart class.
see https://github.com/wandb/client/blob/master/wandb/viz.py"""

def __init__(self, viz_id: str, fields: Dict = {}, string_fields: Dict = {}):
self["userQuery"] = {"queryFields": []}
self["panelDefId"] = viz_id
self["fieldSettings"] = fields
self["stringSettings"] = string_fields

def add_query_field(self, query_field: QueryField):
self["userQuery"]["queryFields"].append(query_field)


class WandbCallback(Callback):
def __init__(self, **kwargs):
super(WandbCallback, self).__init__()
Expand Down Expand Up @@ -92,15 +127,66 @@ def upload_table(self, df, name):
logs[name] = table
wandb.log(logs)

def add_panel(
def create_panel_config(
self,
viz_id: str,
panel_name: str,
table_key: str,
table_key: Optional[str] = None,
summary_fields: List = [],
config_fields: List = [],
fields: Dict = {},
string_fields: Dict = {},
panel_config_callback=lambda panel_config: panel_config,
):
) -> PanelConfig:
"""Construct a `PanelConfig`, which is to be passed to the `add_panel_to_run`
function. Allows uploading custom charts right to the wandb run. Allows
configuring summary fields, config and tables."""

# create panel config
panel_config = PanelConfig(
viz_id=viz_id, fields=fields, string_fields=string_fields
)

### Add query fields
query_field = QueryField()

# add table
if table_key:
panel_config["transform"] = {"name": "tableWithLeafColNames"}
query_field.add_field(
{
"name": "summaryTable",
"args": [{"name": "tableKey", "value": table_key}],
"fields": [],
}
)

# add summary fields
if summary_fields:
query_field.add_field(
{
"name": "summary",
"args": [{"name": "keys", "value": summary_fields}],
"fields": [],
}
)

# add config fields
if config_fields:
query_field.add_field(
{
"name": "config",
"args": [{"name": "keys", "value": config_fields}],
"fields": [],
}
)

# add query field to panel config
panel_config.add_query_field(query_field)

return panel_config

def add_panel_to_run(
self, panel_name: str, panel_config: PanelConfig, panel_type: str = "Vega2"
) -> None:
"""Adds a custom chart panel to the current wandb run. This function uses
internal wandb functions, so might be prone to changes in their code. The
function is a mixup of the following modules / functions:
Expand All @@ -114,23 +200,16 @@ def add_panel(

assert wandb.run is not None, "no wandb run in progress. wandb.run is None."

# create custom chart. is just a data holder class for its attributes.
custom_chart = CustomChart(
viz_id=viz_id,
table=None,
fields=fields,
string_fields=string_fields,
)

# create custom chart config.
# Function `custom_chart_panel_config(custom_chart, key, table_key)` has a
# useless attribute, `key`.
panel_config = custom_chart_panel_config(custom_chart, None, table_key)
panel_config = panel_config_callback(panel_config)

# add chart to current run.
wandb.run._add_panel(panel_name, "Vega2", panel_config)
wandb.run._add_panel(panel_name, panel_type, panel_config)

# "publish" chart to backend
if wandb.run._backend:
wandb.run._backend.interface.publish_history({}, wandb.run.step)

def add_panel(self, panel_name: str, viz_id: str, **panel_config) -> None:
"""Convience method to create panel config, and add a panel to the run with the
config right away. See `add_panel_to_run()`."""

panel_config = self.create_panel_config(viz_id, **panel_config)
self.add_panel_to_run(panel_name, panel_config)
182 changes: 177 additions & 5 deletions fseval/pipelines/rank_and_validate/rank_and_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,183 @@ def score(self, X, y, **kwargs):
if wandb_callback:
self.logger.info(f"Tables uploaded {tc.green('✓')}")

# Upload charts
# create chart:
##### Upload charts
# has ground truth
rank_and_validate_estimator = self.estimators[0]
ranking_validator_estimator = rank_and_validate_estimator.ranking_validator
X_importances = ranking_validator_estimator.X_importances
has_ground_truth = X_importances is not None

### Aggregated charts
if wandb_callback:
# use wandb_callback.add_panel
# "dunnkers/fseval/feature-importances-all-bootstraps-with-ticks"
...
# mean validation scores
wandb_callback.add_panel(
panel_name="mean_validation_score",
viz_id="dunnkers/fseval/datasets-vs-rankers",
table_key="validation_scores_mean",
config_fields=[
"ranker/name",
"dataset/name",
"validator/name",
"dataset/task",
],
fields={
"validator": "validator/name",
"score": "score",
"x": "ranker/name",
"y": "dataset/name",
"task": "dataset/task",
},
string_fields={
"title": "Feature ranker performance",
"subtitle": "→ Mean validation score over all bootstraps",
"ylabel": "accuracy or r2-score",
"aggregation_op": "mean",
"scale_type": "pow",
"color_exponent": "5",
"opacity_exponent": "15",
"scale_string": "x^5",
"color_scheme": "redyellowblue",
"reverse_colorscheme": "",
"text_threshold": "0.95",
},
)

# stability scores
wandb_callback.add_panel(
panel_name="feature_importance_stability",
viz_id="dunnkers/fseval/datasets-vs-rankers",
table_key="feature_importances",
config_fields=[
"ranker/name",
"dataset/name",
"validator/name",
"dataset/task",
],
fields={
"validator": "validator/name",
"score": "feature_importances",
"x": "ranker/name",
"y": "dataset/name",
"task": "dataset/task",
},
string_fields={
"title": "Algorithm Stability",
"subtitle": "→ Mean stdev of feature importances. Lower is better.",
"aggregation_op": "stdev",
"scale_type": "quantile",
"color_exponent": "5",
"opacity_exponent": "0",
"scale_string": "qnt",
"color_scheme": "reds",
"reverse_colorscheme": "",
"text_threshold": "0.95",
},
)

# fitting time
wandb_callback.add_panel(
panel_name="fitting_times",
viz_id="dunnkers/fseval/datasets-vs-rankers",
table_key="ranking_scores",
config_fields=[
"ranker/name",
"dataset/name",
"validator/name",
"dataset/task",
],
fields={
"validator": "validator/name",
"score": "fit_time",
"x": "ranker/name",
"y": "dataset/name",
"task": "dataset/task",
},
string_fields={
"title": "Fitting time (seconds)",
"subtitle": "→ As mean over all bootstraps",
"aggregation_op": "mean",
"scale_type": "pow",
"color_exponent": "0.1",
"opacity_exponent": "15",
"scale_string": "x^0.1",
"color_scheme": "redyellowblue",
"reverse_colorscheme": "true",
"text_threshold": "0.95",
},
)

### Individual charts
if wandb_callback:
# validation scores bootstraps
wandb_callback.add_panel(
panel_name="validation_score_bootstraps",
viz_id="dunnkers/fseval/validation-score-bootstraps",
table_key="validation_scores",
config_fields=["validator/name", "ranker/name"],
string_fields={
"title": "Classification accuracy vs. Subset size",
"subtitle": "→ for all bootstraps",
"ylabel": "accuracy or r2-score",
},
)

# validation scores
wandb_callback.add_panel(
panel_name="validation_score",
viz_id="dunnkers/fseval/validation-score",
table_key="validation_scores",
config_fields=["ranker/name"],
fields={
"hue": "ranker/name",
},
string_fields={
"title": "Classification accuracy vs. Subset size",
"subtitle": "→ as the mean over all bootstraps",
"ylabel": "accuracy or r2-score",
},
)

# mean feature importance
wandb_callback.add_panel(
panel_name="feature_importances_mean",
viz_id="wandb/bar/v0",
table_key="feature_importances",
fields={
"label": "feature_index",
"value": "feature_importances",
},
string_fields={"title": "Feature importance per feature"},
)

# feature importance & stability
wandb_callback.add_panel(
panel_name="feature_importances_stability",
viz_id="dunnkers/fseval/feature-importances-stability",
table_key="feature_importances",
fields={
"x": "feature_index",
"y": "feature_importances",
},
string_fields={
"title": "Feature importance & Stability",
"subtitle": "→ a smaller stdev means more stability",
},
)

# feature importance vs feature index
wandb_callback.add_panel(
panel_name="feature_importances_all_bootstraps",
viz_id="dunnkers/fseval/feature-importances-all-bootstraps-with-ticks",
table_key="feature_importances",
fields={
"x": "feature_index",
"y": "feature_importances",
},
string_fields={
"title": "Feature importance vs. Feature index",
"subtitle": "→ estimated feature importance per feature",
},
)

return summary

0 comments on commit f050a87

Please sign in to comment.