Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding charting function for Sankey diagram #95

Merged
merged 9 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dune_client/api/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run_query_dataframe(
This is a convenience method that uses run_query_csv() + pandas.read_csv() underneath
"""
try:
import pandas # type: ignore # pylint: disable=import-outside-toplevel
import pandas # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError(
"dependency failure, pandas is required but missing"
Expand Down
2 changes: 1 addition & 1 deletion dune_client/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ async def refresh_into_dataframe(
This is a convenience method that uses refresh_csv underneath
"""
try:
import pandas # type: ignore # pylint: disable=import-outside-toplevel
import pandas # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError(
"dependency failure, pandas is required but missing"
Expand Down
Empty file added dune_client/viz/__init__.py
Empty file.
90 changes: 90 additions & 0 deletions dune_client/viz/graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Functions you can call to make different graphs
"""

from typing import Dict, Union

# https://github.com/plotly/colorlover/issues/35
import colorlover as cl # type: ignore[import]
import pandas as pd
import plotly.graph_objects as go # type: ignore[import]
from plotly.graph_objs import Figure # type: ignore[import]


# function to create Sankey diagram
def create_sankey(
query_result: pd.DataFrame,
predefined_colors: Dict[str, str],
columns: Dict[str, str],
viz_config: Dict[str, Union[int, float]],
title: str = "unnamed",
) -> Figure:
"""
Creates a Sankey diagram based on input query_result,
which must contain source, target, value columns.
Column names don't have to be exact same but there must be
these three columns conceptually and value column must be numeric.
"""
# Check if the dataframe contains required columns
required_columns = [columns["source"], columns["target"], columns["value"]]
for col in required_columns:
if col not in query_result.columns:
raise ValueError(f"Error: The dataframe is missing the '{col}' column")

# Check if 'value' column is numeric
if not pd.api.types.is_numeric_dtype(query_result[columns["value"]]):
raise ValueError("Error: The 'value' column must be numeric")

# preprocess query result dataframe
all_nodes = list(
pd.concat(
[query_result[columns["source"]], query_result[columns["target"]]]
).unique()
)
# In Sankey, 'source' and 'target' must be indices. Thus, you need to map projects to indices.
query_result["source_idx"] = query_result[columns["source"]].map(all_nodes.index)
query_result["target_idx"] = query_result[columns["target"]].map(all_nodes.index)

# create color map for Sankey
colors = cl.scales["12"]["qual"]["Set3"] # default color
color_map = {}
for node in all_nodes:
for name, color in predefined_colors.items():
if name.lower() in node.lower(): # check if name exists in the node name
color_map[node] = color
break
else:
color_map[node] = colors[
len(color_map) % len(colors)
] # default color assignment

fig = go.Figure(
go.Sankey(
node={
"pad": viz_config["node_pad"],
"thickness": viz_config["node_thickness"],
"line": {"color": "black", "width": viz_config["node_line_width"]},
"label": all_nodes,
"color": [
color_map.get(node, "blue") for node in all_nodes
], # customize node color
},
link={
"source": query_result["source_idx"],
"target": query_result["target_idx"],
"value": query_result[columns["value"]],
"color": [
color_map.get(query_result[columns["source"]].iloc[i], "black")
for i in range(len(query_result))
], # customize link color
},
)
)
fig.update_layout(
title_text=title,
font_size=viz_config["font_size"],
height=viz_config["figure_height"],
width=viz_config["figure_width"],
)

return fig
3 changes: 3 additions & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
-r prod.txt
black>=23.7.0
pandas>=1.0.0
pandas-stubs>=1.0.0
pylint>=2.17.5
pytest>=7.4.1
python-dotenv>=1.0.0
mypy>=1.5.1
aiounittest>=1.4.2
colorlover>=0.3.0
plotly>=5.9.0
69 changes: 69 additions & 0 deletions tests/unit/test_viz_sankey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest
from unittest.mock import patch
import pandas as pd
from dune_client.viz.graphs import create_sankey


class TestCreateSankey(unittest.TestCase):
# Setting up a dataframe for testing
def setUp(self):
self.df = pd.DataFrame(
bh2smith marked this conversation as resolved.
Show resolved Hide resolved
{
"source_col": ["WBTC", "USDC"],
"target_col": ["USDC", "WBTC"],
"value_col": [2184, 2076],
}
)

self.predefined_colors = {
"USDC": "rgb(38, 112, 196)",
"WBTC": "rgb(247, 150, 38)",
}

self.columns = {
"source": "source_col",
"target": "target_col",
"value": "value_col",
}
self.viz_config: dict = {
"node_pad": 15,
"node_thickness": 20,
"node_line_width": 0.5,
"font_size": 10,
"figure_height": 1000,
"figure_width": 1500,
}

def test_missing_column(self):
# Remove a required column from dataframe
df_without_target = self.df.drop(columns=["target_col"])
with self.assertRaises(ValueError):
create_sankey(
df_without_target, self.predefined_colors, self.columns, self.viz_config
)

def test_value_column_not_numeric(self):
# Change the 'value' column to a non-numeric type
df_with_str_values = self.df.copy()
df_with_str_values["value_col"] = ["10", "11"]
with self.assertRaises(ValueError):
create_sankey(
df_with_str_values,
self.predefined_colors,
self.columns,
self.viz_config,
)

# Mocking the visualization creation and just testing the processing logic
@patch("plotly.graph_objects.Figure")
def test_mocked_visualization(self, MockFigure):
result = create_sankey(
self.df, self.predefined_colors, self.columns, self.viz_config, "test"
)

# Ensuring our mocked Figure was called with the correct parameters
MockFigure.assert_called_once()


if __name__ == "__main__":
unittest.main()
Loading