Skip to content

Commit

Permalink
replace cudf assert_eq (#1693)
Browse files Browse the repository at this point in the history
replace `cudf assert_eq` by `assert_frame_equal` which is more stable

Authors:
  - Joseph Nke (https://github.com/jnke2016)

Approvers:
  - Brad Rees (https://github.com/BradReesWork)

URL: #1693
  • Loading branch information
jnke2016 authored Jul 13, 2021
1 parent 04f73b8 commit 2b7d02f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 30 deletions.
4 changes: 2 additions & 2 deletions python/cugraph/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import scipy
import cudf
from cudf.testing._utils import assert_eq
from cudf.testing.testing import assert_frame_equal
import cugraph
from cugraph.tests import utils

Expand Down Expand Up @@ -327,7 +327,7 @@ def test_edges_for_Graph(graph_file):
else:
edges.append([edge[0], edge[1]])
nx_edge_list = cudf.DataFrame(list(edges), columns=['src', 'dst'])
assert_eq(
assert_frame_equal(
nx_edge_list.sort_values(by=['src', 'dst']).reset_index(drop=True),
cu_edge_list.sort_values(by=['src', 'dst']).reset_index(drop=True),
check_dtype=False
Expand Down
51 changes: 23 additions & 28 deletions python/cugraph/tests/test_hypergraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
import pandas as pd
import pytest
import cudf
from cudf.testing._utils import assert_eq

from cudf.testing.testing import assert_frame_equal
import cugraph


Expand Down Expand Up @@ -109,11 +108,10 @@ def test_hyperedges(categorical_metadata):
h = cugraph.hypergraph(simple_df,
categorical_metadata=categorical_metadata)

assert_eq(
len(h.keys()), len(["entities", "nodes", "edges", "events", "graph"])
)
assert len(h.keys()) == len(
["entities", "nodes", "edges", "events", "graph"])

edges = pd.DataFrame({
edges = cudf.from_pandas(pd.DataFrame({
"event_id": [
"event_id::0",
"event_id::1",
Expand Down Expand Up @@ -160,25 +158,24 @@ def test_hyperedges(categorical_metadata):
"a1": [1, 2, 3] * 4,
"a2": ["red", "blue", "green"] * 4,
"🙈": ["æski ēˈmōjē", "😋", "s"] * 4,
})
}))

if categorical_metadata:
edges = edges.astype({"edge_type": "category"})

assert_eq(edges, h["edges"])

assert_frame_equal(edges, h["edges"], check_dtype=False)
for (k, v) in [
("entities", 12), ("nodes", 15), ("edges", 12), ("events", 3)
]:
assert_eq(len(h[k]), v)
assert len(h[k]) == v


def test_hyperedges_direct():

h = cugraph.hypergraph(hyper_df, direct=True)

assert_eq(len(h["edges"]), 9)
assert_eq(len(h["nodes"]), 9)
assert len(h["edges"]) == 9
assert len(h["nodes"]) == 9


def test_hyperedges_direct_categories():
Expand All @@ -193,8 +190,8 @@ def test_hyperedges_direct_categories():
},
)

assert_eq(len(h["edges"]), 9)
assert_eq(len(h["nodes"]), 6)
assert len(h["edges"]) == 9
assert len(h["nodes"]) == 6


def test_hyperedges_direct_manual_shaping():
Expand All @@ -204,14 +201,14 @@ def test_hyperedges_direct_manual_shaping():
direct=True,
EDGES={"aa": ["cc"], "cc": ["cc"]},
)
assert_eq(len(h1["edges"]), 6)
assert len(h1["edges"]) == 6

h2 = cugraph.hypergraph(
hyper_df,
direct=True,
EDGES={"aa": ["cc", "bb", "aa"], "cc": ["cc"]},
)
assert_eq(len(h2["edges"]), 12)
assert len(h2["edges"]) == 12


@pytest.mark.parametrize("categorical_metadata", [False, True])
Expand All @@ -222,9 +219,8 @@ def test_drop_edge_attrs(categorical_metadata):
drop_edge_attrs=True,
categorical_metadata=categorical_metadata)

assert_eq(
len(h.keys()), len(["entities", "nodes", "edges", "events", "graph"])
)
assert len(h.keys()) == len(
["entities", "nodes", "edges", "events", "graph"])

edges = cudf.DataFrame.from_pandas(pd.DataFrame({
"event_id": [
Expand Down Expand Up @@ -257,12 +253,12 @@ def test_drop_edge_attrs(categorical_metadata):
if categorical_metadata:
edges = edges.astype({"edge_type": "category"})

assert_eq(edges, h["edges"])
assert_frame_equal(edges, h["edges"], check_dtype=False)

for (k, v) in [
("entities", 9), ("nodes", 12), ("edges", 9), ("events", 3)
]:
assert_eq(len(h[k]), v)
assert len(h[k]) == v


@pytest.mark.parametrize("categorical_metadata", [False, True])
Expand All @@ -277,9 +273,8 @@ def test_drop_edge_attrs_direct(categorical_metadata):
categorical_metadata=categorical_metadata,
)

assert_eq(
len(h.keys()), len(["entities", "nodes", "edges", "events", "graph"])
)
assert len(h.keys()) == len(
["entities", "nodes", "edges", "events", "graph"])

edges = cudf.DataFrame.from_pandas(pd.DataFrame({
"event_id": [
Expand All @@ -300,10 +295,10 @@ def test_drop_edge_attrs_direct(categorical_metadata):
if categorical_metadata:
edges = edges.astype({"edge_type": "category"})

assert_eq(edges, h["edges"])
assert_frame_equal(edges, h["edges"], check_dtype=False)

for (k, v) in [("entities", 9), ("nodes", 9), ("edges", 6), ("events", 0)]:
assert_eq(len(h[k]), v)
assert len(h[k]) == v


def test_skip_hyper():
Expand Down Expand Up @@ -399,10 +394,10 @@ def test_skip_na_hyperedge():
nans_df, drop_edge_attrs=True
)["edges"]

assert_eq(len(skip_attr_h_edges), len(expected_hits))
assert len(skip_attr_h_edges) == len(expected_hits)

default_h_edges = cugraph.hypergraph(nans_df)["edges"]
assert_eq(len(default_h_edges), len(expected_hits))
assert len(default_h_edges) == len(expected_hits)


def test_hyper_to_pa_vanilla():
Expand Down

0 comments on commit 2b7d02f

Please sign in to comment.