Skip to content

Commit

Permalink
pandas dataframes supported
Browse files Browse the repository at this point in the history
  • Loading branch information
tschm committed Dec 22, 2023
1 parent 403b861 commit e11d4ed
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 23 deletions.
29 changes: 6 additions & 23 deletions cvx/bson/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@
from typing import Any, Dict, Union

import numpy.typing as npt
import pyarrow as pa
import pandas as pd

# see https://github.com/microsoft/pylance-release/issues/2019
from typing_extensions import TypeAlias

import bson
from cvx.bson.io import decode, encode

# FILE = Union[str, bytes, PathLike[str], PathLike[bytes], int]
FILE = Union[str, bytes, PathLike]
MATRIX: TypeAlias = npt.NDArray[Any]
MATRIX: TypeAlias = npt.NDArray[Any] | pd.DataFrame
MATRICES = Dict[str, MATRIX]


Expand Down Expand Up @@ -58,9 +58,8 @@ def write_bson(file: FILE, data: MATRICES) -> int:
file: file
data: dictionary of numpy arrays
"""
bson_str = to_bson(data=data)
with open(file=file, mode="wb") as bson_file:
return bson_file.write(bson_str)
return bson_file.write(to_bson(data=data))


def to_bson(data: MATRICES) -> bytes:
Expand All @@ -70,25 +69,9 @@ def to_bson(data: MATRICES) -> bytes:
Args:
data: dictionary of numpy arrays
"""

def _encode_tensor(tensor: pa.lib.Tensor) -> bytes:
buffer = pa.BufferOutputStream()
pa.ipc.write_tensor(tensor, buffer)
return bytes(buffer.getvalue().to_pybytes())

return bytes(
bson.dumps(
{
name: _encode_tensor(pa.Tensor.from_numpy(obj=matrix))
for name, matrix in data.items()
}
)
)
return bytes(bson.dumps({name: encode(matrix) for name, matrix in data.items()}))


def from_bson(bson_str: bytes) -> MATRICES:
"""Convert a bson string into a dictionary of numpy arrays"""
data = bson.loads(bson_str)

# for name, value in data.items():
return {name: pa.ipc.read_tensor(value).to_numpy() for name, value in data.items()}
return {name: decode(value) for name, value in bson.loads(bson_str).items()}
56 changes: 56 additions & 0 deletions cvx/bson/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2023 Stanford University Convex Optimization Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from io import BytesIO
from typing import Any

import numpy as np
import pandas as pd
import pyarrow as pa


def encode(data: np.ndarray | pd.DataFrame) -> Any:
"""
Encode a numpy array or a pandas DataFrame
Args:
data: The numpy array or pandas DataFrame
Returns: object converted into bytes
"""
if isinstance(data, np.ndarray):
tensor = pa.Tensor.from_numpy(obj=data)
buffer = pa.BufferOutputStream()
pa.ipc.write_tensor(tensor, buffer)
return bytes(buffer.getvalue().to_pybytes())

if isinstance(data, pd.DataFrame):
return data.to_parquet()

raise TypeError(f"Invalid Datatype {type(data)}")


def decode(data: bytes) -> np.ndarray | pd.DataFrame:
"""
Decode the bytes back into numpy array or pandas DataFrame
Args:
data: bytes
Returns:
The array or the frame
"""
try:
return pa.ipc.read_tensor(data).to_numpy()
except pa.ArrowInvalid:
return pd.read_parquet(BytesIO(data))
50 changes: 50 additions & 0 deletions tests/test_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
import pandas as pd
import pytest

from cvx.bson.file import from_bson, to_bson


@pytest.fixture()
def data():
return {
"frame": pd.DataFrame(data=np.random.rand(5, 2)),
"numpy": np.random.rand(5, 2),
"frame_with_time": pd.DataFrame(
data=np.random.rand(2, 2),
index=[pd.Timestamp("2020-01-01"), pd.Timestamp("2022-01-01")],
),
}


def assert_equal(obj1, obj2):
assert type(obj1) == type(obj2)

if isinstance(obj1, pd.DataFrame):
pd.testing.assert_frame_equal(obj1, obj2)

if isinstance(obj1, np.ndarray):
np.testing.assert_array_equal(obj1, obj2)


def test_roundtrip(data):
"""
Testing the roundtrip
Args:
data: Fixture exposing a dictionary of data
"""
reproduced = from_bson(to_bson(data))
for key, value in reproduced.items():
assert_equal(value, data[key])


def test_file(data, tmp_path):
with open(file=tmp_path / "xxx.bson", mode="wb") as bson_file:
bson_file.write(to_bson(data))

with open(file=tmp_path / "xxx.bson", mode="rb") as bson_file:
reproduced = from_bson(bson_file.read())

for key, value in reproduced.items():
assert_equal(reproduced[key], data[key])

0 comments on commit e11d4ed

Please sign in to comment.