Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding charting function for Sankey diagram #95

Merged
merged 9 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added dune_client/viz/__init__.py
Empty file.
83 changes: 83 additions & 0 deletions dune_client/viz/graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Functions you can call to make different graphs
"""

import plotly.graph_objects as go
import colorlover as cl
import pandas as pd
bh2smith marked this conversation as resolved.
Show resolved Hide resolved

# function to create Sankey diagram
def create_sankey(
query_result: pd.DataFrame,
predefined_colors: dict,
columns: dict,
viz_config: dict,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the test suggests that these could be

Suggested change
predefined_colors: dict,
columns: dict,
viz_config: dict,
predefined_colors: dict[str, str],
columns: dict[str, str],
viz_config: dict[str, int | float],

Although it also kinda looks like columns could be list[str].

Note that earlier python versions might require Union[int, float] (if you decide to go with this suggestion).

title: str = "unnamed",
):
"""
Creates a Sankey diagram based on input query_result
, which must contain source, target, value columns
agaperste marked this conversation as resolved.
Show resolved Hide resolved
"""
# Check if the dataframe contains required columns
required_columns = [columns["source"], columns["target"], columns["value"]]
for col in required_columns:
if col not in query_result.columns:
raise ValueError(f"Error: The dataframe is missing the '{col}' column")

# Check if 'value' column is numeric
if not pd.api.types.is_numeric_dtype(query_result[columns["value"]]):
raise ValueError("Error: The 'value' column must be numeric")

# preprocess query result dataframe
all_nodes = list(
pd.concat(
[query_result[columns["source"]], query_result[columns["target"]]]
).unique()
)
# In Sankey, 'source' and 'target' must be indices. Thus, you need to map projects to indices.
query_result["source_idx"] = query_result[columns["source"]].map(all_nodes.index)
query_result["target_idx"] = query_result[columns["target"]].map(all_nodes.index)

# create color map for Sankey
colors = cl.scales["12"]["qual"]["Set3"] # default color
color_map = {}
for node in all_nodes:
for name, color in predefined_colors.items():
if name.lower() in node.lower(): # check if name exists in the node name
color_map[node] = color
break
else:
color_map[node] = colors[
len(color_map) % len(colors)
] # default color assignment

fig = go.Figure(
go.Sankey(
node=dict(
pad=viz_config["node_pad"],
thickness=viz_config["node_thickness"],
line=dict(color="black", width=viz_config["node_line_width"]),
label=all_nodes,
color=[
color_map.get(node, "blue") for node in all_nodes
], # customize node color
),
link=dict(
source=query_result["source_idx"],
target=query_result["target_idx"],
value=query_result[columns["value"]],
color=[
color_map.get(query_result[columns["source"]].iloc[i], "black")
for i in range(len(query_result))
], # customize link color
),
)
)
fig.update_layout(
title_text=title,
font_size=viz_config["font_size"],
height=viz_config["figure_height"],
width=viz_config["figure_width"],
)

return fig
139 changes: 139 additions & 0 deletions tests/unit/test_viz_sankey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import unittest
from unittest.mock import patch
import pandas as pd
from dune_client.viz.graphs import create_sankey


class TestCreateSankey(unittest.TestCase):

# Setting up a dataframe for testing
def setUp(self):
self.df = pd.DataFrame(
bh2smith marked this conversation as resolved.
Show resolved Hide resolved
{
"source": [
"WBTC",
"USDC",
"USDC",
"USDC",
"USDC",
"COMP",
"DAI",
"DAI",
"USDT",
"WBTC",
"DAI",
"DAI",
"USDC",
"MKR",
"DAI",
"USDT",
"UNI",
"USDT",
"USDT",
"WBTC",
"USDC",
],
"target": [
"WETH",
"WBTC",
"COMP",
"MKR",
"DAI",
"WETH",
"COMP",
"WETH",
"MKR",
"USDT",
"MKR",
"USDT",
"USDT",
"WETH",
"WBTC",
"DAI",
"WETH",
"WETH",
"WBTC",
"DAI",
"UNI",
],
"value": [
2184,
2076,
447,
158,
4294,
519,
72,
4070,
123,
99,
85,
188,
4675,
352,
281,
230,
59,
4482,
103,
171,
54,
],
}
)

self.predefined_colors = {
"USDC": "rgb(38, 112, 196)",
"USDT": "rgb(0, 143, 142)",
"WETH": "rgb(144, 144, 144)",
"WBTC": "rgb(247, 150, 38)",
"COMP": "rgb(32, 217, 152)",
"DAI": "rgb(254, 175, 48)",
"MKR": "rgb(38, 173, 158)",
"UNI": "rgb(255, 21, 126)",
}

self.columns = {"source": "source", "target": "target", "value": "value"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the columns mappings of themselves to themselves? Is there an example where they aren't just the same values mapping to themselves or does this even need to be a dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of giving folks the option to name their columns differently, which ultimately has to be mapped to be source, target, value.

Copy link
Collaborator

@bh2smith bh2smith Sep 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this fact can be included somewhere in the doc strings. And also here in the test you could use targets like "renamed_column" to make things extra clear.

The test should serve as simple example that demonstrates all the way the function can be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it makes sense, changed!

self.viz_config: dict = {
"node_pad": 15,
"node_thickness": 20,
"node_line_width": 0.5,
"font_size": 10,
"figure_height": 1000,
"figure_width": 1500,
}

def test_missing_column(self):
# Remove a required column from dataframe
df_without_target = self.df.drop(columns=["target"])
with self.assertRaises(ValueError):
create_sankey(
df_without_target, self.predefined_colors, self.columns, self.viz_config
)

def test_value_column_not_numeric(self):
# Change the 'value' column to a non-numeric type
df_with_str_values = self.df.copy()
df_with_str_values["value"] = ["10"] * len(df_with_str_values)
agaperste marked this conversation as resolved.
Show resolved Hide resolved
with self.assertRaises(ValueError):
create_sankey(
df_with_str_values,
self.predefined_colors,
self.columns,
self.viz_config,
)

# Mocking the visualization creation and just testing the processing logic
@patch("plotly.graph_objects.Figure")
def test_mocked_visualization(self, MockFigure):

result = create_sankey(
self.df, self.predefined_colors, self.columns, self.viz_config, "test"
)

# Ensuring our mocked Figure was called with the correct parameters
MockFigure.assert_called_once()


if __name__ == "__main__":
unittest.main()
Loading