From 16a88c3a54c5cf911967ba88d119494e1aee22fb Mon Sep 17 00:00:00 2001 From: agaperste Date: Fri, 29 Sep 2023 13:23:22 -0400 Subject: [PATCH] adding charting function for Sankey diagram --- dune_client/viz/__init__.py | 0 dune_client/viz/graphs.py | 83 +++++++++++++++++++++++++++++++++++ tests/unit/test_viz_sankey.py | 49 +++++++++++++++++++++ 3 files changed, 132 insertions(+) create mode 100644 dune_client/viz/__init__.py create mode 100644 dune_client/viz/graphs.py create mode 100644 tests/unit/test_viz_sankey.py diff --git a/dune_client/viz/__init__.py b/dune_client/viz/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dune_client/viz/graphs.py b/dune_client/viz/graphs.py new file mode 100644 index 0000000..be6f922 --- /dev/null +++ b/dune_client/viz/graphs.py @@ -0,0 +1,83 @@ +""" +Functions you can call to make different graphs +""" + +import plotly.graph_objects as go +import colorlover as cl +import pandas as pd + +# function to create Sankey diagram +def create_sankey( + query_result: pd.DataFrame, + predefined_colors: {}, + source: str = "source", + target: str = "target", + value: str = "value", + title: str = "unnamed", + node_pad: int = 15, + node_thickness: int = 20, + node_line_width: int = 0.5, + font_size: int = 10, + figure_height: int = 1000, + figure_width: int = 1500, +): + """ + Creates a Sankey diagram based on input query_result + , which must contain source, target, value columns + """ + # Check if the dataframe contains required columns + required_columns = [source, target, value] + for col in required_columns: + if col not in query_result.columns: + raise ValueError(f"Error: The dataframe is missing the '{col}' column") + + # Check if 'value' column is numeric + if not pd.api.types.is_numeric_dtype(query_result[value]): + raise ValueError("Error: The 'value' column must be numeric") + + # preprocess query result dataframe + all_nodes = list(pd.concat([query_result[source], query_result[target]]).unique()) + # In Sankey, 'source' and 'target' must be indices. Thus, you need to map projects to indices. + query_result["source_idx"] = query_result[source].map(all_nodes.index) + query_result["target_idx"] = query_result[target].map(all_nodes.index) + + # create color map for Sankey + colors = cl.scales["12"]["qual"]["Set3"] # default color + color_map = {} + for node in all_nodes: + for name, color in predefined_colors.items(): + if name.lower() in node.lower(): # check if name exists in the node name + color_map[node] = color + break + else: + color_map[node] = colors[ + len(color_map) % len(colors) + ] # default color assignment + + fig = go.Figure( + go.Sankey( + node=dict( + pad=node_pad, + thickness=node_thickness, + line=dict(color="black", width=node_line_width), + label=all_nodes, + color=[ + color_map.get(node, "blue") for node in all_nodes + ], # customize node color + ), + link=dict( + source=query_result["source_idx"], + target=query_result["target_idx"], + value=query_result[value], + color=[ + color_map.get(query_result[source].iloc[i], "black") + for i in range(len(query_result)) + ], # customize link color + ), + ) + ) + fig.update_layout( + title_text=title, font_size=font_size, height=figure_height, width=figure_width + ) + + return fig diff --git a/tests/unit/test_viz_sankey.py b/tests/unit/test_viz_sankey.py new file mode 100644 index 0000000..879ee89 --- /dev/null +++ b/tests/unit/test_viz_sankey.py @@ -0,0 +1,49 @@ +import unittest +from unittest.mock import patch +import pandas as pd +from dune_client.viz.graphs import create_sankey + +class TestCreateSankey(unittest.TestCase): + + # Setting up a dataframe for testing + def setUp(self): + self.df = pd.DataFrame({ + 'source': ['WBTC', 'USDC', 'USDC', 'USDC', 'USDC', 'COMP', 'DAI', 'DAI', 'USDT', 'WBTC', 'DAI', 'DAI', 'USDC', 'MKR', 'DAI', 'USDT', 'UNI', 'USDT', 'USDT', 'WBTC', 'USDC'], + 'target': ['WETH', 'WBTC', 'COMP', 'MKR', 'DAI', 'WETH', 'COMP', 'WETH', 'MKR', 'USDT', 'MKR', 'USDT', 'USDT', 'WETH', 'WBTC', 'DAI', 'WETH', 'WETH', 'WBTC', 'DAI', 'UNI'], + 'value': [2184, 2076, 447, 158, 4294, 519, 72, 4070, 123, 99, 85, 188, 4675, 352, 281, 230, 59, 4482, 103, 171, 54] + }) + + self.predefined_colors = { + "USDC": "rgb(38, 112, 196)", + "USDT": "rgb(0, 143, 142)", + "WETH": "rgb(144, 144, 144)", + "WBTC": "rgb(247, 150, 38)", + "COMP": "rgb(32, 217, 152)", + "DAI": "rgb(254, 175, 48)", + "MKR": "rgb(38, 173, 158)", + "UNI": "rgb(255, 21, 126)", + } + + def test_missing_column(self): + # Remove a required column from dataframe + df_without_target = self.df.drop(columns=['target']) + with self.assertRaises(ValueError): + create_sankey(df_without_target, self.predefined_colors) + + def test_value_column_not_numeric(self): + # Change the 'value' column to a non-numeric type + df_with_str_values = self.df.copy() + df_with_str_values['value'] = ['10'] * len(df_with_str_values) + with self.assertRaises(ValueError): + create_sankey(df_with_str_values, self.predefined_colors) + + # Mocking the visualization creation and just testing the processing logic + @patch('plotly.graph_objects.Figure') + def test_mocked_visualization(self, MockFigure): + result = create_sankey(self.df, self.predefined_colors) + + # Ensuring our mocked Figure was called with the correct parameters + MockFigure.assert_called_once() + +if __name__ == '__main__': + unittest.main()