Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DataFrames and Numpy Arrays for log_plot() #754

Merged
merged 12 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ dependencies = [
[project.optional-dependencies]
image = ["numpy", "pillow"]
sklearn = ["scikit-learn"]
plots = ["scikit-learn"]
plots = ["scikit-learn", "pandas", "numpy"]
markdown = ["matplotlib"]
tests = [
"pytest>=7.2.0,<8.0",
Expand Down
13 changes: 11 additions & 2 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import glob
import json
import logging
Expand All @@ -6,7 +7,11 @@
import shutil
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union, TYPE_CHECKING

if TYPE_CHECKING:
import numpy as np
import pandas as pd

from dvc.exceptions import DvcException
from funcy import set_in
Expand Down Expand Up @@ -35,6 +40,7 @@
StrPath,
catch_and_warn,
clean_and_copy_into,
convert_datapoints_to_list_of_dicts,
env2bool,
inside_notebook,
matplotlib_installed,
Expand Down Expand Up @@ -391,14 +397,17 @@ def log_image(self, name: str, val):
def log_plot(
self,
name: str,
datapoints: List[Dict],
datapoints: pd.DataFrame | np.ndarray | List[Dict],
x: str,
y: str,
template: Optional[str] = None,
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
):
# Convert the given datapoints to List[Dict]
datapoints = convert_datapoints_to_list_of_dicts(datapoints=datapoints)

if not CustomPlot.could_log(datapoints):
raise InvalidDataTypeError(name, type(datapoints))

Expand Down
56 changes: 54 additions & 2 deletions src/dvclive/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
from __future__ import annotations
import csv
import json
import os
import re
import shutil
import webbrowser
from pathlib import Path
from platform import uname
from typing import Union
from typing import Union, List, Dict, TYPE_CHECKING
import webbrowser

if TYPE_CHECKING:
import numpy as np
import pandas as pd
else:
try:
import pandas as pd
except ImportError:
pd = None

try:
import numpy as np
except ImportError:
np = None

from .error import (
InvalidDataTypeError,
)
mnrozhkov marked this conversation as resolved.
Show resolved Hide resolved

StrPath = Union[str, Path]

Expand Down Expand Up @@ -194,3 +213,36 @@ def read_history(live, metric):
def read_latest(live, metric_name):
_, latest = parse_metrics(live)
return latest["step"], latest[metric_name]


def convert_datapoints_to_list_of_dicts(
datapoints: List[Dict] | pd.DataFrame | np.ndarray,
) -> List[Dict]:
"""
Convert the given datapoints to a list of dictionaries.

Args:
datapoints: The input datapoints to be converted.

Returns:
A list of dictionaries representing the datapoints.

Raises:
TypeError: `datapoints` must be pd.DataFrame, np.ndarray, or List[Dict]
"""
if isinstance(datapoints, list):
return datapoints

if pd and isinstance(datapoints, pd.DataFrame):
return datapoints.to_dict(orient="records")

if np and isinstance(datapoints, np.ndarray):
# This is a structured array
if datapoints.dtype.names is not None:
return [dict(zip(datapoints.dtype.names, row)) for row in datapoints]

# This is a regular array
return [dict(enumerate(row)) for row in datapoints]

# Raise an error if the input is not a supported type
raise InvalidDataTypeError("datapoints", type(datapoints))
33 changes: 32 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import pandas as pd
import pytest

from dvclive.utils import standardize_metric_name
from dvclive.utils import standardize_metric_name, convert_datapoints_to_list_of_dicts
from dvclive.error import InvalidDataTypeError


@pytest.mark.parametrize(
Expand All @@ -15,3 +18,31 @@
)
def test_standardize_metric_name(framework, logged, standardized):
assert standardize_metric_name(logged, framework) == standardized


class TestConvertDatapointsToListOfDicts:
mnrozhkov marked this conversation as resolved.
Show resolved Hide resolved
def test_dataframe(self):
df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}]
assert convert_datapoints_to_list_of_dicts(df) == expected_output

def test_ndarray(self):
arr = np.array([[1, 3], [2, 4]])
expected_output = [{0: 1, 1: 3}, {0: 2, 1: 4}]
assert convert_datapoints_to_list_of_dicts(arr) == expected_output

def test_structured_array(self):
dtype = [("A", "i4"), ("B", "i4")]
structured_array = np.array([(1, 3), (2, 4)], dtype=dtype)
expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}]
assert convert_datapoints_to_list_of_dicts(structured_array) == expected_output

def test_list_of_dicts(self):
list_of_dicts = [{"A": 1, "B": 3}, {"A": 2, "B": 4}]
assert convert_datapoints_to_list_of_dicts(list_of_dicts) == list_of_dicts

def test_unsupported_format(self):
with pytest.raises(InvalidDataTypeError) as exc_info:
convert_datapoints_to_list_of_dicts("unsupported data format")

assert "not supported type" in str(exc_info.value)
Loading