Skip to content

Commit

Permalink
adding charting function for Sankey diagram
Browse files Browse the repository at this point in the history
  • Loading branch information
agaperste committed Sep 29, 2023
1 parent b510c97 commit 16a88c3
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
Empty file added dune_client/viz/__init__.py
Empty file.
83 changes: 83 additions & 0 deletions dune_client/viz/graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Functions you can call to make different graphs
"""

import plotly.graph_objects as go
import colorlover as cl
import pandas as pd

# function to create Sankey diagram
def create_sankey(
query_result: pd.DataFrame,
predefined_colors: {},
source: str = "source",
target: str = "target",
value: str = "value",
title: str = "unnamed",
node_pad: int = 15,
node_thickness: int = 20,
node_line_width: int = 0.5,
font_size: int = 10,
figure_height: int = 1000,
figure_width: int = 1500,
):
"""
Creates a Sankey diagram based on input query_result
, which must contain source, target, value columns
"""
# Check if the dataframe contains required columns
required_columns = [source, target, value]
for col in required_columns:
if col not in query_result.columns:
raise ValueError(f"Error: The dataframe is missing the '{col}' column")

# Check if 'value' column is numeric
if not pd.api.types.is_numeric_dtype(query_result[value]):
raise ValueError("Error: The 'value' column must be numeric")

# preprocess query result dataframe
all_nodes = list(pd.concat([query_result[source], query_result[target]]).unique())
# In Sankey, 'source' and 'target' must be indices. Thus, you need to map projects to indices.
query_result["source_idx"] = query_result[source].map(all_nodes.index)
query_result["target_idx"] = query_result[target].map(all_nodes.index)

# create color map for Sankey
colors = cl.scales["12"]["qual"]["Set3"] # default color
color_map = {}
for node in all_nodes:
for name, color in predefined_colors.items():
if name.lower() in node.lower(): # check if name exists in the node name
color_map[node] = color
break
else:
color_map[node] = colors[
len(color_map) % len(colors)
] # default color assignment

fig = go.Figure(
go.Sankey(
node=dict(
pad=node_pad,
thickness=node_thickness,
line=dict(color="black", width=node_line_width),
label=all_nodes,
color=[
color_map.get(node, "blue") for node in all_nodes
], # customize node color
),
link=dict(
source=query_result["source_idx"],
target=query_result["target_idx"],
value=query_result[value],
color=[
color_map.get(query_result[source].iloc[i], "black")
for i in range(len(query_result))
], # customize link color
),
)
)
fig.update_layout(
title_text=title, font_size=font_size, height=figure_height, width=figure_width
)

return fig
49 changes: 49 additions & 0 deletions tests/unit/test_viz_sankey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest
from unittest.mock import patch
import pandas as pd
from dune_client.viz.graphs import create_sankey

class TestCreateSankey(unittest.TestCase):

# Setting up a dataframe for testing
def setUp(self):
self.df = pd.DataFrame({
'source': ['WBTC', 'USDC', 'USDC', 'USDC', 'USDC', 'COMP', 'DAI', 'DAI', 'USDT', 'WBTC', 'DAI', 'DAI', 'USDC', 'MKR', 'DAI', 'USDT', 'UNI', 'USDT', 'USDT', 'WBTC', 'USDC'],
'target': ['WETH', 'WBTC', 'COMP', 'MKR', 'DAI', 'WETH', 'COMP', 'WETH', 'MKR', 'USDT', 'MKR', 'USDT', 'USDT', 'WETH', 'WBTC', 'DAI', 'WETH', 'WETH', 'WBTC', 'DAI', 'UNI'],
'value': [2184, 2076, 447, 158, 4294, 519, 72, 4070, 123, 99, 85, 188, 4675, 352, 281, 230, 59, 4482, 103, 171, 54]
})

self.predefined_colors = {
"USDC": "rgb(38, 112, 196)",
"USDT": "rgb(0, 143, 142)",
"WETH": "rgb(144, 144, 144)",
"WBTC": "rgb(247, 150, 38)",
"COMP": "rgb(32, 217, 152)",
"DAI": "rgb(254, 175, 48)",
"MKR": "rgb(38, 173, 158)",
"UNI": "rgb(255, 21, 126)",
}

def test_missing_column(self):
# Remove a required column from dataframe
df_without_target = self.df.drop(columns=['target'])
with self.assertRaises(ValueError):
create_sankey(df_without_target, self.predefined_colors)

def test_value_column_not_numeric(self):
# Change the 'value' column to a non-numeric type
df_with_str_values = self.df.copy()
df_with_str_values['value'] = ['10'] * len(df_with_str_values)
with self.assertRaises(ValueError):
create_sankey(df_with_str_values, self.predefined_colors)

# Mocking the visualization creation and just testing the processing logic
@patch('plotly.graph_objects.Figure')
def test_mocked_visualization(self, MockFigure):
result = create_sankey(self.df, self.predefined_colors)

# Ensuring our mocked Figure was called with the correct parameters
MockFigure.assert_called_once()

if __name__ == '__main__':
unittest.main()

0 comments on commit 16a88c3

Please sign in to comment.