-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding charting function for Sankey diagram
- Loading branch information
Showing
3 changed files
with
132 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" | ||
Functions you can call to make different graphs | ||
""" | ||
|
||
import plotly.graph_objects as go | ||
import colorlover as cl | ||
import pandas as pd | ||
|
||
# function to create Sankey diagram | ||
def create_sankey( | ||
query_result: pd.DataFrame, | ||
predefined_colors: {}, | ||
source: str = "source", | ||
target: str = "target", | ||
value: str = "value", | ||
title: str = "unnamed", | ||
node_pad: int = 15, | ||
node_thickness: int = 20, | ||
node_line_width: int = 0.5, | ||
font_size: int = 10, | ||
figure_height: int = 1000, | ||
figure_width: int = 1500, | ||
): | ||
""" | ||
Creates a Sankey diagram based on input query_result | ||
, which must contain source, target, value columns | ||
""" | ||
# Check if the dataframe contains required columns | ||
required_columns = [source, target, value] | ||
for col in required_columns: | ||
if col not in query_result.columns: | ||
raise ValueError(f"Error: The dataframe is missing the '{col}' column") | ||
|
||
# Check if 'value' column is numeric | ||
if not pd.api.types.is_numeric_dtype(query_result[value]): | ||
raise ValueError("Error: The 'value' column must be numeric") | ||
|
||
# preprocess query result dataframe | ||
all_nodes = list(pd.concat([query_result[source], query_result[target]]).unique()) | ||
# In Sankey, 'source' and 'target' must be indices. Thus, you need to map projects to indices. | ||
query_result["source_idx"] = query_result[source].map(all_nodes.index) | ||
query_result["target_idx"] = query_result[target].map(all_nodes.index) | ||
|
||
# create color map for Sankey | ||
colors = cl.scales["12"]["qual"]["Set3"] # default color | ||
color_map = {} | ||
for node in all_nodes: | ||
for name, color in predefined_colors.items(): | ||
if name.lower() in node.lower(): # check if name exists in the node name | ||
color_map[node] = color | ||
break | ||
else: | ||
color_map[node] = colors[ | ||
len(color_map) % len(colors) | ||
] # default color assignment | ||
|
||
fig = go.Figure( | ||
go.Sankey( | ||
node=dict( | ||
pad=node_pad, | ||
thickness=node_thickness, | ||
line=dict(color="black", width=node_line_width), | ||
label=all_nodes, | ||
color=[ | ||
color_map.get(node, "blue") for node in all_nodes | ||
], # customize node color | ||
), | ||
link=dict( | ||
source=query_result["source_idx"], | ||
target=query_result["target_idx"], | ||
value=query_result[value], | ||
color=[ | ||
color_map.get(query_result[source].iloc[i], "black") | ||
for i in range(len(query_result)) | ||
], # customize link color | ||
), | ||
) | ||
) | ||
fig.update_layout( | ||
title_text=title, font_size=font_size, height=figure_height, width=figure_width | ||
) | ||
|
||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import unittest | ||
from unittest.mock import patch | ||
import pandas as pd | ||
from dune_client.viz.graphs import create_sankey | ||
|
||
class TestCreateSankey(unittest.TestCase): | ||
|
||
# Setting up a dataframe for testing | ||
def setUp(self): | ||
self.df = pd.DataFrame({ | ||
'source': ['WBTC', 'USDC', 'USDC', 'USDC', 'USDC', 'COMP', 'DAI', 'DAI', 'USDT', 'WBTC', 'DAI', 'DAI', 'USDC', 'MKR', 'DAI', 'USDT', 'UNI', 'USDT', 'USDT', 'WBTC', 'USDC'], | ||
'target': ['WETH', 'WBTC', 'COMP', 'MKR', 'DAI', 'WETH', 'COMP', 'WETH', 'MKR', 'USDT', 'MKR', 'USDT', 'USDT', 'WETH', 'WBTC', 'DAI', 'WETH', 'WETH', 'WBTC', 'DAI', 'UNI'], | ||
'value': [2184, 2076, 447, 158, 4294, 519, 72, 4070, 123, 99, 85, 188, 4675, 352, 281, 230, 59, 4482, 103, 171, 54] | ||
}) | ||
|
||
self.predefined_colors = { | ||
"USDC": "rgb(38, 112, 196)", | ||
"USDT": "rgb(0, 143, 142)", | ||
"WETH": "rgb(144, 144, 144)", | ||
"WBTC": "rgb(247, 150, 38)", | ||
"COMP": "rgb(32, 217, 152)", | ||
"DAI": "rgb(254, 175, 48)", | ||
"MKR": "rgb(38, 173, 158)", | ||
"UNI": "rgb(255, 21, 126)", | ||
} | ||
|
||
def test_missing_column(self): | ||
# Remove a required column from dataframe | ||
df_without_target = self.df.drop(columns=['target']) | ||
with self.assertRaises(ValueError): | ||
create_sankey(df_without_target, self.predefined_colors) | ||
|
||
def test_value_column_not_numeric(self): | ||
# Change the 'value' column to a non-numeric type | ||
df_with_str_values = self.df.copy() | ||
df_with_str_values['value'] = ['10'] * len(df_with_str_values) | ||
with self.assertRaises(ValueError): | ||
create_sankey(df_with_str_values, self.predefined_colors) | ||
|
||
# Mocking the visualization creation and just testing the processing logic | ||
@patch('plotly.graph_objects.Figure') | ||
def test_mocked_visualization(self, MockFigure): | ||
result = create_sankey(self.df, self.predefined_colors) | ||
|
||
# Ensuring our mocked Figure was called with the correct parameters | ||
MockFigure.assert_called_once() | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |