Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DataFrames and Numpy Arrays for log_plot() #754

Merged
merged 12 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import json
import logging
import math
import numpy as np
import os
import pandas as pd
mnrozhkov marked this conversation as resolved.
Show resolved Hide resolved
import shutil
import tempfile
from pathlib import Path
Expand Down Expand Up @@ -35,6 +37,7 @@
StrPath,
catch_and_warn,
clean_and_copy_into,
convert_datapoints_to_list_of_dicts,
env2bool,
inside_notebook,
matplotlib_installed,
Expand Down Expand Up @@ -391,14 +394,20 @@ def log_image(self, name: str, val):
def log_plot(
self,
name: str,
datapoints: List[Dict],
datapoints: pd.DataFrame | np.ndarray | List[Dict],
x: str,
y: str,
template: Optional[str] = None,
title: Optional[str] = None,
x_label: Optional[str] = None,
y_label: Optional[str] = None,
columns: Optional[List[str]] = None,
mnrozhkov marked this conversation as resolved.
Show resolved Hide resolved
):
# Convert the given datapoints to List[Dict]
datapoints = convert_datapoints_to_list_of_dicts(
datapoints=datapoints, columns=columns
)
mnrozhkov marked this conversation as resolved.
Show resolved Hide resolved

if not CustomPlot.could_log(datapoints):
raise InvalidDataTypeError(name, type(datapoints))

Expand Down
31 changes: 29 additions & 2 deletions src/dvclive/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import csv
import json
import numpy as np
import os
import pandas as pd
mnrozhkov marked this conversation as resolved.
Show resolved Hide resolved
import re
import shutil
import webbrowser
from pathlib import Path
from platform import uname
from typing import Union
from typing import Union, List, Dict, Optional
import webbrowser

StrPath = Union[str, Path]

Expand Down Expand Up @@ -194,3 +196,28 @@ def read_history(live, metric):
def read_latest(live, metric_name):
_, latest = parse_metrics(live)
return latest["step"], latest[metric_name]


def convert_datapoints_to_list_of_dicts(
datapoints: pd.DataFrame | np.ndarray | List[Dict],
columns: Optional[List[str]] = None,
) -> List[Dict]:
"""
Convert the given datapoints to a list of dictionaries.

Args:
datapoints: The input datapoints to be converted.
columns: The column columns for the datapoints. Applied only for np.ndarray inputs.

Returns:
A list of dictionaries representing the datapoints.

Raises:
None
"""
if isinstance(datapoints, pd.DataFrame):
return datapoints.to_dict(orient="records")
elif isinstance(datapoints, np.ndarray):
return pd.DataFrame(datapoints, columns=columns).to_dict(orient="records")
else:
return datapoints
26 changes: 25 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import pandas as pd
import pytest

from dvclive.utils import standardize_metric_name
from dvclive.utils import standardize_metric_name, convert_datapoints_to_list_of_dicts


@pytest.mark.parametrize(
Expand All @@ -15,3 +17,25 @@
)
def test_standardize_metric_name(framework, logged, standardized):
assert standardize_metric_name(logged, framework) == standardized


class TestConvertDatapointsToListOfDicts:
mnrozhkov marked this conversation as resolved.
Show resolved Hide resolved
def test_dataframe(self):
df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}]
assert convert_datapoints_to_list_of_dicts(df) == expected_output

def test_ndarray_with_columns(self):
arr = np.array([[1, 3], [2, 4]])
columns = ["A", "B"]
expected_output = [{"A": 1, "B": 3}, {"A": 2, "B": 4}]
assert convert_datapoints_to_list_of_dicts(arr, columns) == expected_output

def test_ndarray_without_columns(self):
arr = np.array([[1, 3], [2, 4]])
expected_output = [{0: 1, 1: 3}, {0: 2, 1: 4}]
assert convert_datapoints_to_list_of_dicts(arr) == expected_output

def test_list_of_dicts(self):
list_of_dicts = [{"A": 1, "B": 3}, {"A": 2, "B": 4}]
assert convert_datapoints_to_list_of_dicts(list_of_dicts) == list_of_dicts
Loading