diff --git a/dune_client/api/extensions.py b/dune_client/api/extensions.py index c3e6424..918f183 100644 --- a/dune_client/api/extensions.py +++ b/dune_client/api/extensions.py @@ -77,7 +77,7 @@ def run_query_dataframe( This is a convenience method that uses run_query_csv() + pandas.read_csv() underneath """ try: - import pandas # type: ignore # pylint: disable=import-outside-toplevel + import pandas # pylint: disable=import-outside-toplevel except ImportError as exc: raise ImportError( "dependency failure, pandas is required but missing" diff --git a/dune_client/client_async.py b/dune_client/client_async.py index 0b042e4..7a44c8b 100644 --- a/dune_client/client_async.py +++ b/dune_client/client_async.py @@ -271,7 +271,7 @@ async def refresh_into_dataframe( This is a convenience method that uses refresh_csv underneath """ try: - import pandas # type: ignore # pylint: disable=import-outside-toplevel + import pandas # pylint: disable=import-outside-toplevel except ImportError as exc: raise ImportError( "dependency failure, pandas is required but missing" diff --git a/dune_client/viz/__init__.py b/dune_client/viz/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dune_client/viz/graphs.py b/dune_client/viz/graphs.py new file mode 100644 index 0000000..d7a8eb6 --- /dev/null +++ b/dune_client/viz/graphs.py @@ -0,0 +1,90 @@ +""" +Functions you can call to make different graphs +""" + +from typing import Dict, Union + +# https://github.com/plotly/colorlover/issues/35 +import colorlover as cl # type: ignore[import] +import pandas as pd +import plotly.graph_objects as go # type: ignore[import] +from plotly.graph_objs import Figure # type: ignore[import] + + +# function to create Sankey diagram +def create_sankey( + query_result: pd.DataFrame, + predefined_colors: Dict[str, str], + columns: Dict[str, str], + viz_config: Dict[str, Union[int, float]], + title: str = "unnamed", +) -> Figure: + """ + Creates a Sankey diagram based on input query_result, + which must contain source, target, value columns. + Column names don't have to be exact same but there must be + these three columns conceptually and value column must be numeric. + """ + # Check if the dataframe contains required columns + required_columns = [columns["source"], columns["target"], columns["value"]] + for col in required_columns: + if col not in query_result.columns: + raise ValueError(f"Error: The dataframe is missing the '{col}' column") + + # Check if 'value' column is numeric + if not pd.api.types.is_numeric_dtype(query_result[columns["value"]]): + raise ValueError("Error: The 'value' column must be numeric") + + # preprocess query result dataframe + all_nodes = list( + pd.concat( + [query_result[columns["source"]], query_result[columns["target"]]] + ).unique() + ) + # In Sankey, 'source' and 'target' must be indices. Thus, you need to map projects to indices. + query_result["source_idx"] = query_result[columns["source"]].map(all_nodes.index) + query_result["target_idx"] = query_result[columns["target"]].map(all_nodes.index) + + # create color map for Sankey + colors = cl.scales["12"]["qual"]["Set3"] # default color + color_map = {} + for node in all_nodes: + for name, color in predefined_colors.items(): + if name.lower() in node.lower(): # check if name exists in the node name + color_map[node] = color + break + else: + color_map[node] = colors[ + len(color_map) % len(colors) + ] # default color assignment + + fig = go.Figure( + go.Sankey( + node={ + "pad": viz_config["node_pad"], + "thickness": viz_config["node_thickness"], + "line": {"color": "black", "width": viz_config["node_line_width"]}, + "label": all_nodes, + "color": [ + color_map.get(node, "blue") for node in all_nodes + ], # customize node color + }, + link={ + "source": query_result["source_idx"], + "target": query_result["target_idx"], + "value": query_result[columns["value"]], + "color": [ + color_map.get(query_result[columns["source"]].iloc[i], "black") + for i in range(len(query_result)) + ], # customize link color + }, + ) + ) + fig.update_layout( + title_text=title, + font_size=viz_config["font_size"], + height=viz_config["figure_height"], + width=viz_config["figure_width"], + ) + + return fig diff --git a/requirements/dev.txt b/requirements/dev.txt index a3583c5..ede131a 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,8 +1,11 @@ -r prod.txt black>=23.7.0 pandas>=1.0.0 +pandas-stubs>=1.0.0 pylint>=2.17.5 pytest>=7.4.1 python-dotenv>=1.0.0 mypy>=1.5.1 aiounittest>=1.4.2 +colorlover>=0.3.0 +plotly>=5.9.0 diff --git a/tests/unit/test_viz_sankey.py b/tests/unit/test_viz_sankey.py new file mode 100644 index 0000000..b5d78dc --- /dev/null +++ b/tests/unit/test_viz_sankey.py @@ -0,0 +1,69 @@ +import unittest +from unittest.mock import patch +import pandas as pd +from dune_client.viz.graphs import create_sankey + + +class TestCreateSankey(unittest.TestCase): + # Setting up a dataframe for testing + def setUp(self): + self.df = pd.DataFrame( + { + "source_col": ["WBTC", "USDC"], + "target_col": ["USDC", "WBTC"], + "value_col": [2184, 2076], + } + ) + + self.predefined_colors = { + "USDC": "rgb(38, 112, 196)", + "WBTC": "rgb(247, 150, 38)", + } + + self.columns = { + "source": "source_col", + "target": "target_col", + "value": "value_col", + } + self.viz_config: dict = { + "node_pad": 15, + "node_thickness": 20, + "node_line_width": 0.5, + "font_size": 10, + "figure_height": 1000, + "figure_width": 1500, + } + + def test_missing_column(self): + # Remove a required column from dataframe + df_without_target = self.df.drop(columns=["target_col"]) + with self.assertRaises(ValueError): + create_sankey( + df_without_target, self.predefined_colors, self.columns, self.viz_config + ) + + def test_value_column_not_numeric(self): + # Change the 'value' column to a non-numeric type + df_with_str_values = self.df.copy() + df_with_str_values["value_col"] = ["10", "11"] + with self.assertRaises(ValueError): + create_sankey( + df_with_str_values, + self.predefined_colors, + self.columns, + self.viz_config, + ) + + # Mocking the visualization creation and just testing the processing logic + @patch("plotly.graph_objects.Figure") + def test_mocked_visualization(self, MockFigure): + result = create_sankey( + self.df, self.predefined_colors, self.columns, self.viz_config, "test" + ) + + # Ensuring our mocked Figure was called with the correct parameters + MockFigure.assert_called_once() + + +if __name__ == "__main__": + unittest.main()