Skip to content

Commit

Permalink
mob next [ci-skip] [ci skip] [skip ci]
Browse files Browse the repository at this point in the history
lastFile:tests/unit/test_project.py
  • Loading branch information
augustebaum committed Sep 11, 2024
1 parent 33f3ef9 commit 6b4f6f0
Showing 1 changed file with 59 additions and 70 deletions.
129 changes: 59 additions & 70 deletions tests/unit/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,87 +205,76 @@ def savefig(*args, **kwargs):
# Add a scikit-learn model
model = RandomForestClassifier()
model.fit(numpy.array([[1, 2], [3, 4]]), [0, 1])
model_dump = skops.io.dumps(model)
serialized_model = base64.b64encode(model_dump).decode("ascii")
project.put("rf_model", model) # ScikitLearnModelItem

breakpoint()

assert project.storage.content == {
"string_item": {
"item_type": str(ItemType.JSON),
"serialized": '"Hello, World!"',
"media_type": None,
},
"int_item": {
"item_type": str(ItemType.JSON),
"serialized": "42",
"media_type": None,
},
"float_item": {
"item_type": str(ItemType.JSON),
"serialized": "3.14",
"media_type": None,
},
"bool_item": {
"item_type": str(ItemType.JSON),
"serialized": "true",
"media_type": None,
},
"list_item": {
"item_type": str(ItemType.JSON),
"serialized": "[1, 2, 3]",
"media_type": None,
},
"dict_item": {
"item_type": str(ItemType.JSON),
"serialized": '{"key": "value"}',
"media_type": None,
},
"pandas_df": {
"item_type": str(ItemType.PANDAS_DATAFRAME),
"serialized": '{"columns":["A","B"],"index":[0,1,2],"data":[[1,4],[2,5],[3,6]]}',
"media_type": None,
},
"numpy_array": {
"item_type": str(ItemType.NUMPY_ARRAY),
"serialized": "[1, 2, 3, 4, 5]",
"media_type": None,
},
"mpl_figure": {
"item_type": str(ItemType.MEDIA),
"serialized": "",
"media_type": "image/svg+xml",
},
"rf_model": {
"item_type": str(ItemType.SKLEARN_BASE_ESTIMATOR),
"serialized": f"{'skops': '{serialized_model}', 'html': ''}",
"media_type": "text/html",
},
"vega_chart": {
"item_type": str(ItemType.MEDIA),
"serialized": json.dumps(altair_chart.to_dict()),
"media_type": "application/vnd.vega.v5+json",
},
"pil_image": {
"item_type": str(ItemType.MEDIA),
"serialized": pil_image_str,
"media_type": "image/jpeg",
},
assert project.storage.content["string_item"]== {
"item_type": str(ItemType.JSON),
"serialized": '"Hello, World!"',
"media_type": None,
}
assert project.storage.content["int_item"]== {
"item_type": str(ItemType.JSON),
"serialized": "42",
"media_type": None,
}
assert project.storage.content["float_item"]== {
"item_type": str(ItemType.JSON),
"serialized": "3.14",
"media_type": None,
}
assert project.storage.content["bool_item"]== {
"item_type": str(ItemType.JSON),
"serialized": "true",
"media_type": None,
}
assert project.storage.content["list_item"]== {
"item_type": str(ItemType.JSON),
"serialized": "[1, 2, 3]",
"media_type": None,
}
assert project.storage.content["dict_item"]== {
"item_type": str(ItemType.JSON),
"serialized": '{"key": "value"}',
"media_type": None,
}
assert project.storage.content["pandas_df"]== {
"item_type": str(ItemType.PANDAS_DATAFRAME),
"serialized": '{"columns":["A","B"],"index":[0,1,2],"data":[[1,4],[2,5],[3,6]]}',
"media_type": None,
}
assert project.storage.content["numpy_array"]== {
"item_type": str(ItemType.NUMPY_ARRAY),
"serialized": "[1, 2, 3, 4, 5]",
"media_type": None,
}
assert project.storage.content["mpl_figure"]== {
"item_type": str(ItemType.MEDIA),
"serialized": "",
"media_type": "image/svg+xml",
}
assert project.storage.content["vega_chart"]== {
"item_type": str(ItemType.MEDIA),
"serialized": json.dumps(altair_chart.to_dict()),
"media_type": "application/vnd.vega.v5+json",
}
assert project.storage.content["pil_image"]== {
"item_type": str(ItemType.MEDIA),
"serialized": pil_image_str,
"media_type": "image/jpeg",
}

assert project.get("string_item") == "Hello, World!"
assert project.get("int_item") == 42
assert project.get("float_item") == 3.14
assert project.get("bool_item")
assert project.get("bool_item") is True
assert project.get("list_item") == [1, 2, 3]
assert project.get("dict_item") == {"key": "value"}
pandas.testing.assert_frame_equal(project.get("pandas_df"), df)
numpy.testing.assert_array_equal(project.get("numpy_array"), arr)
assert project.get("mpl_figure") == None
assert project.get("rf_model") == None
assert project.get("vega_chart") == None
assert project.get("pil_image") == None
assert isinstance(project.get("rf_model"), RandomForestClassifier)
assert project.get("mpl_figure") is None
assert project.get("vega_chart") is None
assert project.get("pil_image") is None


def test_api_get_items():
Expand Down

0 comments on commit 6b4f6f0

Please sign in to comment.