From 5e12eeb8a2b3190bc511d7d4430cf3f55bd528b3 Mon Sep 17 00:00:00 2001 From: Mikhail Rozhkov Date: Tue, 12 Dec 2023 17:45:24 +0100 Subject: [PATCH 01/11] Add support for DataFrames and np.ndarray for log_plot() --- src/dvclive/live.py | 10 +++++++++- src/dvclive/utils.py | 32 ++++++++++++++++++++++++++++++-- tests/test_utils.py | 28 +++++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 1c9b725c..8f8dfb7f 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -2,7 +2,9 @@ import json import logging import math +import numpy as np import os +import pandas as pd import shutil import tempfile from pathlib import Path @@ -35,6 +37,7 @@ StrPath, catch_and_warn, clean_and_copy_into, + convert_datapoints_to_list_of_dicts, env2bool, inside_notebook, matplotlib_installed, @@ -391,14 +394,19 @@ def log_image(self, name: str, val): def log_plot( self, name: str, - datapoints: List[Dict], + datapoints: pd.DataFrame | np.ndarray | List[Dict], x: str, y: str, template: Optional[str] = None, title: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, + columns: Optional[List[str]] = None, ): + + # Convert the given datapoints to List[Dict] + datapoints = convert_datapoints_to_list_of_dicts(datapoints=datapoints, columns=columns) + if not CustomPlot.could_log(datapoints): raise InvalidDataTypeError(name, type(datapoints)) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 65f0266c..1b4d09d7 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -1,12 +1,14 @@ import csv import json +import numpy as np import os +import pandas as pd import re import shutil -import webbrowser from pathlib import Path from platform import uname -from typing import Union +from typing import Union, List, Dict, Optional +import webbrowser StrPath = Union[str, Path] @@ -194,3 +196,29 @@ def read_history(live, metric): def read_latest(live, metric_name): _, latest = parse_metrics(live) return latest["step"], latest[metric_name] + + +def convert_datapoints_to_list_of_dicts( + datapoints: pd.DataFrame | np.ndarray | List[Dict], + columns: Optional[List[str]] = None + ) -> List[Dict]: + """ + Convert the given datapoints to a list of dictionaries. + + Args: + datapoints: The input datapoints to be converted. + columns: The column columns for the datapoints. Applied only for np.ndarray inputs. + + Returns: + A list of dictionaries representing the datapoints. + + Raises: + None + """ + if isinstance(datapoints, pd.DataFrame): + return datapoints.to_dict(orient='records') + elif isinstance(datapoints, np.ndarray): + return pd.DataFrame(datapoints, columns=columns).to_dict(orient='records') + else: + return datapoints + diff --git a/tests/test_utils.py b/tests/test_utils.py index cb452125..d3b764c9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,9 @@ +import numpy as np +import pandas as pd import pytest +from typing import List, Dict, Optional -from dvclive.utils import standardize_metric_name +from dvclive.utils import standardize_metric_name, convert_datapoints_to_list_of_dicts @pytest.mark.parametrize( @@ -15,3 +18,26 @@ ) def test_standardize_metric_name(framework, logged, standardized): assert standardize_metric_name(logged, framework) == standardized + + + +class TestConvertDatapointsToListOfDicts: + def test_dataframe(self): + df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) + expected_output = [{'A': 1, 'B': 3}, {'A': 2, 'B': 4}] + assert convert_datapoints_to_list_of_dicts(df) == expected_output + + def test_ndarray_with_columns(self): + arr = np.array([[1, 3], [2, 4]]) + columns = ['A', 'B'] + expected_output = [{'A': 1, 'B': 3}, {'A': 2, 'B': 4}] + assert convert_datapoints_to_list_of_dicts(arr, columns) == expected_output + + def test_ndarray_without_columns(self): + arr = np.array([[1, 3], [2, 4]]) + expected_output = [{0: 1, 1: 3}, {0: 2, 1: 4}] + assert convert_datapoints_to_list_of_dicts(arr) == expected_output + + def test_list_of_dicts(self): + list_of_dicts = [{'A': 1, 'B': 3}, {'A': 2, 'B': 4}] + assert convert_datapoints_to_list_of_dicts(list_of_dicts) == list_of_dicts \ No newline at end of file From 6e48930e016f28d8dfa1c3a3da85cc15ee03e2f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 Dec 2023 17:20:06 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/dvclive/live.py | 5 +++-- src/dvclive/utils.py | 11 +++++------ tests/test_utils.py | 14 ++++++-------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 8f8dfb7f..2726588f 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -403,9 +403,10 @@ def log_plot( y_label: Optional[str] = None, columns: Optional[List[str]] = None, ): - # Convert the given datapoints to List[Dict] - datapoints = convert_datapoints_to_list_of_dicts(datapoints=datapoints, columns=columns) + datapoints = convert_datapoints_to_list_of_dicts( + datapoints=datapoints, columns=columns + ) if not CustomPlot.could_log(datapoints): raise InvalidDataTypeError(name, type(datapoints)) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 1b4d09d7..56d4ad5f 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -199,9 +199,9 @@ def read_latest(live, metric_name): def convert_datapoints_to_list_of_dicts( - datapoints: pd.DataFrame | np.ndarray | List[Dict], - columns: Optional[List[str]] = None - ) -> List[Dict]: + datapoints: pd.DataFrame | np.ndarray | List[Dict], + columns: Optional[List[str]] = None, +) -> List[Dict]: """ Convert the given datapoints to a list of dictionaries. @@ -216,9 +216,8 @@ def convert_datapoints_to_list_of_dicts( None """ if isinstance(datapoints, pd.DataFrame): - return datapoints.to_dict(orient='records') + return datapoints.to_dict(orient="records") elif isinstance(datapoints, np.ndarray): - return pd.DataFrame(datapoints, columns=columns).to_dict(orient='records') + return pd.DataFrame(datapoints, columns=columns).to_dict(orient="records") else: return datapoints - diff --git a/tests/test_utils.py b/tests/test_utils.py index d3b764c9..ae563294 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,6 @@ import numpy as np import pandas as pd import pytest -from typing import List, Dict, Optional from dvclive.utils import standardize_metric_name, convert_datapoints_to_list_of_dicts @@ -20,17 +19,16 @@ def test_standardize_metric_name(framework, logged, standardized): assert standardize_metric_name(logged, framework) == standardized - class TestConvertDatapointsToListOfDicts: def test_dataframe(self): - df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) - expected_output = [{'A': 1, 'B': 3}, {'A': 2, 'B': 4}] + df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] assert convert_datapoints_to_list_of_dicts(df) == expected_output def test_ndarray_with_columns(self): arr = np.array([[1, 3], [2, 4]]) - columns = ['A', 'B'] - expected_output = [{'A': 1, 'B': 3}, {'A': 2, 'B': 4}] + columns = ["A", "B"] + expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] assert convert_datapoints_to_list_of_dicts(arr, columns) == expected_output def test_ndarray_without_columns(self): @@ -39,5 +37,5 @@ def test_ndarray_without_columns(self): assert convert_datapoints_to_list_of_dicts(arr) == expected_output def test_list_of_dicts(self): - list_of_dicts = [{'A': 1, 'B': 3}, {'A': 2, 'B': 4}] - assert convert_datapoints_to_list_of_dicts(list_of_dicts) == list_of_dicts \ No newline at end of file + list_of_dicts = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] + assert convert_datapoints_to_list_of_dicts(list_of_dicts) == list_of_dicts From b12d1bbfa9568b10f56ce0073f0fee3937b0eeba Mon Sep 17 00:00:00 2001 From: Mikhail Rozhkov Date: Wed, 13 Dec 2023 15:34:08 +0100 Subject: [PATCH 03/11] Update convert_datapoints_to_list_of_dicts() and dependency imports --- pyproject.toml | 2 +- src/dvclive/live.py | 14 +++++++------- src/dvclive/utils.py | 34 ++++++++++++++++++++++------------ tests/test_utils.py | 20 +++++++++++++------- 4 files changed, 43 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 48a8ba52..56b89403 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ [project.optional-dependencies] image = ["numpy", "pillow"] sklearn = ["scikit-learn"] -plots = ["scikit-learn"] +plots = ["scikit-learn", "pandas", "numpy"] markdown = ["matplotlib"] tests = [ "pytest>=7.2.0,<8.0", diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 2726588f..0557d016 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -1,14 +1,17 @@ +from __future__ import annotations import glob import json import logging import math -import numpy as np import os -import pandas as pd import shutil import tempfile from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np + import pandas as pd from dvc.exceptions import DvcException from funcy import set_in @@ -401,12 +404,9 @@ def log_plot( title: Optional[str] = None, x_label: Optional[str] = None, y_label: Optional[str] = None, - columns: Optional[List[str]] = None, ): # Convert the given datapoints to List[Dict] - datapoints = convert_datapoints_to_list_of_dicts( - datapoints=datapoints, columns=columns - ) + datapoints = convert_datapoints_to_list_of_dicts(datapoints=datapoints) if not CustomPlot.could_log(datapoints): raise InvalidDataTypeError(name, type(datapoints)) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 56d4ad5f..442169b3 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -1,15 +1,18 @@ +from __future__ import annotations import csv import json -import numpy as np import os -import pandas as pd import re import shutil from pathlib import Path from platform import uname -from typing import Union, List, Dict, Optional +from typing import Union, List, Dict, Optional, TYPE_CHECKING import webbrowser +if TYPE_CHECKING: + import numpy as np + import pandas as pd + StrPath = Union[str, Path] @@ -199,25 +202,32 @@ def read_latest(live, metric_name): def convert_datapoints_to_list_of_dicts( - datapoints: pd.DataFrame | np.ndarray | List[Dict], - columns: Optional[List[str]] = None, -) -> List[Dict]: + datapoints: List[Dict] | pd.DataFrame | np.ndarray, + ) -> List[Dict]: """ Convert the given datapoints to a list of dictionaries. Args: datapoints: The input datapoints to be converted. - columns: The column columns for the datapoints. Applied only for np.ndarray inputs. Returns: A list of dictionaries representing the datapoints. Raises: - None + ValueError: If the `datapoints` argument is not of type pd.DataFrame, np.ndarray, or List[Dict]. """ + if isinstance(datapoints, list): + return datapoints + + import pandas as pd if isinstance(datapoints, pd.DataFrame): - return datapoints.to_dict(orient="records") - elif isinstance(datapoints, np.ndarray): - return pd.DataFrame(datapoints, columns=columns).to_dict(orient="records") + return datapoints.to_dict(orient='records') + + import numpy as np + if isinstance(datapoints, np.ndarray): + return pd.DataFrame(datapoints).to_dict(orient='records') else: - return datapoints + raise ValueError(""" + Unexpected format for `datapoints`. \ + Supported formats: pd.DataFrame, np.ndarray, or List[Dict]. + """) diff --git a/tests/test_utils.py b/tests/test_utils.py index ae563294..ded8785f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -25,17 +25,23 @@ def test_dataframe(self): expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] assert convert_datapoints_to_list_of_dicts(df) == expected_output - def test_ndarray_with_columns(self): - arr = np.array([[1, 3], [2, 4]]) - columns = ["A", "B"] - expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] - assert convert_datapoints_to_list_of_dicts(arr, columns) == expected_output - - def test_ndarray_without_columns(self): + def test_ndarray(self): arr = np.array([[1, 3], [2, 4]]) expected_output = [{0: 1, 1: 3}, {0: 2, 1: 4}] assert convert_datapoints_to_list_of_dicts(arr) == expected_output + def test_structured_array(self): + dtype = [('A', 'i4'), ('B', 'i4')] + structured_array = np.array([(1, 3), (2, 4)], dtype=dtype) + expected_output = [{'A': 1, 'B': 3}, {'A': 2, 'B': 4}] + assert convert_datapoints_to_list_of_dicts(structured_array) == expected_output + def test_list_of_dicts(self): list_of_dicts = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] assert convert_datapoints_to_list_of_dicts(list_of_dicts) == list_of_dicts + + def test_unsupported_format(self): + with pytest.raises(ValueError) as exc_info: + convert_datapoints_to_list_of_dicts("unsupported data format") + + assert "Unexpected format for `datapoints`" in str(exc_info.value) From 91adb248e82dbe591cfea4e609ee29b1a21f7bb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Dec 2023 14:43:06 +0000 Subject: [PATCH 04/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/dvclive/utils.py | 18 +++++++++++------- tests/test_utils.py | 4 ++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 442169b3..57df99b9 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -6,7 +6,7 @@ import shutil from pathlib import Path from platform import uname -from typing import Union, List, Dict, Optional, TYPE_CHECKING +from typing import Union, List, Dict, TYPE_CHECKING import webbrowser if TYPE_CHECKING: @@ -202,8 +202,8 @@ def read_latest(live, metric_name): def convert_datapoints_to_list_of_dicts( - datapoints: List[Dict] | pd.DataFrame | np.ndarray, - ) -> List[Dict]: + datapoints: List[Dict] | pd.DataFrame | np.ndarray, +) -> List[Dict]: """ Convert the given datapoints to a list of dictionaries. @@ -220,14 +220,18 @@ def convert_datapoints_to_list_of_dicts( return datapoints import pandas as pd + if isinstance(datapoints, pd.DataFrame): - return datapoints.to_dict(orient='records') + return datapoints.to_dict(orient="records") import numpy as np + if isinstance(datapoints, np.ndarray): - return pd.DataFrame(datapoints).to_dict(orient='records') + return pd.DataFrame(datapoints).to_dict(orient="records") else: - raise ValueError(""" + raise ValueError( + """ Unexpected format for `datapoints`. \ Supported formats: pd.DataFrame, np.ndarray, or List[Dict]. - """) + """ + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index ded8785f..d8610261 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -31,9 +31,9 @@ def test_ndarray(self): assert convert_datapoints_to_list_of_dicts(arr) == expected_output def test_structured_array(self): - dtype = [('A', 'i4'), ('B', 'i4')] + dtype = [("A", "i4"), ("B", "i4")] structured_array = np.array([(1, 3), (2, 4)], dtype=dtype) - expected_output = [{'A': 1, 'B': 3}, {'A': 2, 'B': 4}] + expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] assert convert_datapoints_to_list_of_dicts(structured_array) == expected_output def test_list_of_dicts(self): From 83848743ec7e0796176628926aec25935e180dd4 Mon Sep 17 00:00:00 2001 From: Mikhail Rozhkov Date: Wed, 13 Dec 2023 16:25:47 +0100 Subject: [PATCH 05/11] Improve code style and remove redundant exception in src/dvclive/utils.py --- src/dvclive/utils.py | 21 +++++++++++++-------- tests/test_utils.py | 6 ------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 57df99b9..0432e0d4 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -7,12 +7,17 @@ from pathlib import Path from platform import uname from typing import Union, List, Dict, TYPE_CHECKING +from typing import Union, List, Dict, TYPE_CHECKING import webbrowser if TYPE_CHECKING: import numpy as np import pandas as pd +from .error import ( + InvalidDataTypeError, +) + StrPath = Union[str, Path] @@ -203,6 +208,8 @@ def read_latest(live, metric_name): def convert_datapoints_to_list_of_dicts( datapoints: List[Dict] | pd.DataFrame | np.ndarray, +) -> List[Dict]: + datapoints: List[Dict] | pd.DataFrame | np.ndarray, ) -> List[Dict]: """ Convert the given datapoints to a list of dictionaries. @@ -214,24 +221,22 @@ def convert_datapoints_to_list_of_dicts( A list of dictionaries representing the datapoints. Raises: - ValueError: If the `datapoints` argument is not of type pd.DataFrame, np.ndarray, or List[Dict]. + TypeError: `datapoints` must be pd.DataFrame, np.ndarray, or List[Dict] """ if isinstance(datapoints, list): return datapoints import pandas as pd + if isinstance(datapoints, pd.DataFrame): return datapoints.to_dict(orient="records") + return datapoints.to_dict(orient="records") import numpy as np + if isinstance(datapoints, np.ndarray): return pd.DataFrame(datapoints).to_dict(orient="records") - else: - raise ValueError( - """ - Unexpected format for `datapoints`. \ - Supported formats: pd.DataFrame, np.ndarray, or List[Dict]. - """ - ) + + return datapoints diff --git a/tests/test_utils.py b/tests/test_utils.py index d8610261..31c321db 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -39,9 +39,3 @@ def test_structured_array(self): def test_list_of_dicts(self): list_of_dicts = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] assert convert_datapoints_to_list_of_dicts(list_of_dicts) == list_of_dicts - - def test_unsupported_format(self): - with pytest.raises(ValueError) as exc_info: - convert_datapoints_to_list_of_dicts("unsupported data format") - - assert "Unexpected format for `datapoints`" in str(exc_info.value) From baa23403cac62dff00ec957fb0046a52f454e9d3 Mon Sep 17 00:00:00 2001 From: Mikhail Rozhkov Date: Wed, 13 Dec 2023 18:25:51 +0100 Subject: [PATCH 06/11] Check instance types without importing the packages --- src/dvclive/utils.py | 21 +++++++++------------ tests/test_utils.py | 7 +++++++ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 0432e0d4..67335285 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -7,7 +7,6 @@ from pathlib import Path from platform import uname from typing import Union, List, Dict, TYPE_CHECKING -from typing import Union, List, Dict, TYPE_CHECKING import webbrowser if TYPE_CHECKING: @@ -208,8 +207,6 @@ def read_latest(live, metric_name): def convert_datapoints_to_list_of_dicts( datapoints: List[Dict] | pd.DataFrame | np.ndarray, -) -> List[Dict]: - datapoints: List[Dict] | pd.DataFrame | np.ndarray, ) -> List[Dict]: """ Convert the given datapoints to a list of dictionaries. @@ -226,17 +223,17 @@ def convert_datapoints_to_list_of_dicts( if isinstance(datapoints, list): return datapoints - import pandas as pd - - - if isinstance(datapoints, pd.DataFrame): + if isinstance_without_import(datapoints, "pandas.core.frame", "DataFrame"): return datapoints.to_dict(orient="records") return datapoints.to_dict(orient="records") - import numpy as np - + if isinstance_without_import(datapoints, "numpy", "ndarray"): + # This is a structured array + if datapoints.dtype.names is not None: + return [dict(zip(datapoints.dtype.names, row)) for row in datapoints] - if isinstance(datapoints, np.ndarray): - return pd.DataFrame(datapoints).to_dict(orient="records") + # This is a regular array + return [dict(enumerate(row)) for row in datapoints] - return datapoints + # Raise an error if the input is not a supported type + raise InvalidDataTypeError("datapoints", type(datapoints)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 31c321db..9e6eb3ce 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ import pytest from dvclive.utils import standardize_metric_name, convert_datapoints_to_list_of_dicts +from dvclive.error import InvalidDataTypeError @pytest.mark.parametrize( @@ -39,3 +40,9 @@ def test_structured_array(self): def test_list_of_dicts(self): list_of_dicts = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] assert convert_datapoints_to_list_of_dicts(list_of_dicts) == list_of_dicts + + def test_unsupported_format(self): + with pytest.raises(InvalidDataTypeError) as exc_info: + convert_datapoints_to_list_of_dicts("unsupported data format") + + assert "not supported type" in str(exc_info.value) From 3ce7afb83e343a554fb331e2fc89fd1cae68d749 Mon Sep 17 00:00:00 2001 From: Mikhail Rozhkov Date: Thu, 14 Dec 2023 11:58:16 +0100 Subject: [PATCH 07/11] Minor fix: remove duplicated line --- src/dvclive/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 67335285..0b412f30 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -225,7 +225,6 @@ def convert_datapoints_to_list_of_dicts( if isinstance_without_import(datapoints, "pandas.core.frame", "DataFrame"): return datapoints.to_dict(orient="records") - return datapoints.to_dict(orient="records") if isinstance_without_import(datapoints, "numpy", "ndarray"): # This is a structured array From 8da844ff8c51164dd5708dcb460041aa43fcb16a Mon Sep 17 00:00:00 2001 From: Mikhail Rozhkov Date: Thu, 14 Dec 2023 14:22:49 +0100 Subject: [PATCH 08/11] Update import to fix mypy errors --- src/dvclive/utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 0b412f30..919b7021 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -12,6 +12,16 @@ if TYPE_CHECKING: import numpy as np import pandas as pd +else: + try: + import pandas as pd + except ImportError: + pd = None + + try: + import numpy as np + except ImportError: + np = None from .error import ( InvalidDataTypeError, @@ -223,10 +233,10 @@ def convert_datapoints_to_list_of_dicts( if isinstance(datapoints, list): return datapoints - if isinstance_without_import(datapoints, "pandas.core.frame", "DataFrame"): + if pd and isinstance(datapoints, pd.DataFrame): return datapoints.to_dict(orient="records") - if isinstance_without_import(datapoints, "numpy", "ndarray"): + if np and isinstance(datapoints, np.ndarray): # This is a structured array if datapoints.dtype.names is not None: return [dict(zip(datapoints.dtype.names, row)) for row in datapoints] From 529352f5f1011aa4a93cda8e128ed4133a536730 Mon Sep 17 00:00:00 2001 From: Mikhail Rozhkov Date: Thu, 14 Dec 2023 15:37:02 +0100 Subject: [PATCH 09/11] Update imports: move InvalidDataTypeError before if TYPE_CHECKING --- src/dvclive/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index 919b7021..a06c3639 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -9,6 +9,8 @@ from typing import Union, List, Dict, TYPE_CHECKING import webbrowser +from .error import InvalidDataTypeError + if TYPE_CHECKING: import numpy as np import pandas as pd @@ -23,9 +25,7 @@ except ImportError: np = None -from .error import ( - InvalidDataTypeError, -) + StrPath = Union[str, Path] From 31b899d8b7a12192e9eda593b4b653579d5846fd Mon Sep 17 00:00:00 2001 From: Mikhail Rozhkov Date: Thu, 14 Dec 2023 16:35:54 +0100 Subject: [PATCH 10/11] Remove empty line --- src/dvclive/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dvclive/utils.py b/src/dvclive/utils.py index a06c3639..dee42911 100644 --- a/src/dvclive/utils.py +++ b/src/dvclive/utils.py @@ -26,7 +26,6 @@ np = None - StrPath = Union[str, Path] From c94c34f084feb2915c04852101a10d74d5bf9184 Mon Sep 17 00:00:00 2001 From: Mikhail Rozhkov Date: Thu, 14 Dec 2023 16:36:44 +0100 Subject: [PATCH 11/11] Parametrize tests for convert_datapoints_to_list_of_dicts() --- tests/test_utils.py | 51 ++++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9e6eb3ce..2b5d56ca 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,29 +20,28 @@ def test_standardize_metric_name(framework, logged, standardized): assert standardize_metric_name(logged, framework) == standardized -class TestConvertDatapointsToListOfDicts: - def test_dataframe(self): - df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) - expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] - assert convert_datapoints_to_list_of_dicts(df) == expected_output - - def test_ndarray(self): - arr = np.array([[1, 3], [2, 4]]) - expected_output = [{0: 1, 1: 3}, {0: 2, 1: 4}] - assert convert_datapoints_to_list_of_dicts(arr) == expected_output - - def test_structured_array(self): - dtype = [("A", "i4"), ("B", "i4")] - structured_array = np.array([(1, 3), (2, 4)], dtype=dtype) - expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] - assert convert_datapoints_to_list_of_dicts(structured_array) == expected_output - - def test_list_of_dicts(self): - list_of_dicts = [{"A": 1, "B": 3}, {"A": 2, "B": 4}] - assert convert_datapoints_to_list_of_dicts(list_of_dicts) == list_of_dicts - - def test_unsupported_format(self): - with pytest.raises(InvalidDataTypeError) as exc_info: - convert_datapoints_to_list_of_dicts("unsupported data format") - - assert "not supported type" in str(exc_info.value) +# Tests for convert_datapoints_to_list_of_dicts() +@pytest.mark.parametrize( + ("input_data", "expected_output"), + [ + ( + pd.DataFrame({"A": [1, 2], "B": [3, 4]}), + [{"A": 1, "B": 3}, {"A": 2, "B": 4}], + ), + (np.array([[1, 3], [2, 4]]), [{0: 1, 1: 3}, {0: 2, 1: 4}]), + ( + np.array([(1, 3), (2, 4)], dtype=[("A", "i4"), ("B", "i4")]), + [{"A": 1, "B": 3}, {"A": 2, "B": 4}], + ), + ([{"A": 1, "B": 3}, {"A": 2, "B": 4}], [{"A": 1, "B": 3}, {"A": 2, "B": 4}]), + ], +) +def test_convert_datapoints_to_list_of_dicts(input_data, expected_output): + assert convert_datapoints_to_list_of_dicts(input_data) == expected_output + + +def test_unsupported_format(): + with pytest.raises(InvalidDataTypeError) as exc_info: + convert_datapoints_to_list_of_dicts("unsupported data format") + + assert "not supported type" in str(exc_info.value)