-
Notifications
You must be signed in to change notification settings - Fork 165
/
Copy pathtest_data_profiler.py
127 lines (97 loc) · 3.94 KB
/
test_data_profiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import unittest
from unittest import mock
from dataprofiler import Data, Profiler
from . import test_utils
# This is taken from: https://github.com/rlworkgroup/dowel/pull/36/files
# undo when cpython#4800 is merged.
unittest.case._AssertWarnsContext.__enter__ = test_utils.patched_assert_warns
MODULE_PATH = os.path.dirname(os.path.abspath(__file__))
project_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
class TestDataProfiler(unittest.TestCase):
@classmethod
def setUpClass(cls):
test_dir = os.path.join(MODULE_PATH, "data")
cls.input_file_names = [
dict(
path=os.path.join(test_dir, "csv/aws_honeypot_marx_geo.csv"), type="csv"
),
]
def test_set_seed(self):
import dataprofiler as dp
self.assertEqual(dp.settings._seed, None)
dp.set_seed(5)
self.assertEqual(dp.settings._seed, 5)
with self.assertRaisesRegex(
ValueError, "Seed should be a non-negative integer."
):
dp.set_seed(-5)
with self.assertRaisesRegex(
ValueError, "Seed should be a non-negative integer."
):
dp.set_seed(5.2)
def test_data_import(self):
for file in self.input_file_names:
data = Data(file["path"])
self.assertEqual(data.data_type, file["type"])
def test_data_profiling(self):
for file in self.input_file_names:
data = Data(file["path"])
profile = Profiler(data)
self.assertIsNotNone(profile.profile)
self.assertIsNotNone(profile.report())
def test_no_snappy(self):
import importlib
import sys
import types
orig_import = __import__
# necessary for any wrapper around the library to test if snappy caught
# as an issue
def reload_data_profiler():
"""Recursively reload modules."""
sys_modules = sys.modules.copy()
for module_name, module in sys_modules.items():
# Only reload top level of the dataprofiler
if "dataprofiler" in module_name and len(module_name.split(".")) < 3:
if isinstance(module, types.ModuleType):
importlib.reload(module)
def import_mock(name, *args, **kwargs):
if name == "snappy":
raise ImportError("test")
return orig_import(name, *args, **kwargs)
with mock.patch("builtins.__import__", side_effect=import_mock):
with self.assertWarns(ImportWarning) as w:
import dataprofiler
reload_data_profiler()
self.assertEqual(
str(w.warning),
"Snappy must be installed to use parquet/avro datasets."
"\n\n"
"For macOS use Homebrew:\n"
"\t`brew install snappy`"
"\n\n"
"For linux use apt-get:\n`"
"\tsudo apt-get -y install libsnappy-dev`\n",
)
def test_no_tensorflow(self):
import sys
import pandas
orig_import = __import__
# necessary for any wrapper around the library to test if snappy caught
# as an issue
def import_mock(name, *args, **kwargs):
if name == "tensorflow":
raise ImportError("test")
return orig_import(name, *args, **kwargs)
with mock.patch("builtins.__import__", side_effect=import_mock):
with self.assertWarnsRegex(RuntimeWarning, "Partial Profiler Failure"):
modules_with_tf = [
"dataprofiler.labelers.character_level_cnn_model",
]
for module in modules_with_tf:
if module in sys.modules:
del sys.modules[module]
df = pandas.DataFrame([[1, 2.0], [1, 2.2], [-1, 3]])
profile = Profiler(df)
if __name__ == "__main__":
unittest.main()