diff --git a/Orange/data/io.py b/Orange/data/io.py index 07592d72bc0..0959bb725c2 100644 --- a/Orange/data/io.py +++ b/Orange/data/io.py @@ -24,7 +24,7 @@ import xlsxwriter import openpyxl -from Orange.data import _io, Table, Domain, ContinuousVariable +from Orange.data import _io, Table, Domain, ContinuousVariable, update_origin from Orange.data import Compression, open_compressed, detect_encoding, \ isnastr, guess_data_type, sanitize_variable from Orange.data.io_base import FileFormatBase, Flags, DataTableMixin, PICKLE_PROTOCOL @@ -164,14 +164,7 @@ def read(self): skipinitialspace=True, ) data = self.data_table(reader) - - # TODO: Name can be set unconditionally when/if - # self.filename will always be a string with the file name. - # Currently, some tests pass StringIO instead of - # the file name to a reader. - if isinstance(self.filename, str): - data.name = path.splitext( - path.split(self.filename)[-1])[0] + data.name = path.splitext(path.split(self.filename)[-1])[0] if error and isinstance(error, UnicodeDecodeError): pos, endpos = error.args[2], error.args[3] warning = ('Skipped invalid byte(s) in position ' @@ -179,6 +172,7 @@ def read(self): ('-' + str(endpos)) if (endpos - pos) > 1 else '') warnings.warn(warning) self.set_table_metadata(self.filename, data) + update_origin(data, self.filename) return data except Exception as e: error = e @@ -215,6 +209,7 @@ def read(self): if not isinstance(table, Table): raise TypeError("file does not contain a data table") else: + update_origin(table, self.filename) return table @classmethod @@ -264,6 +259,7 @@ def read(self): try: cells = self.get_cells() table = self.data_table(cells) + update_origin(table, self.filename) table.name = path.splitext(path.split(self.filename)[-1])[0] if self.sheet and len(self.sheets) > 1: table.name = '-'.join((table.name, self.sheet)) diff --git a/Orange/data/io_util.py b/Orange/data/io_util.py index 163c0354abe..c1c8de402c5 100644 --- a/Orange/data/io_util.py +++ b/Orange/data/io_util.py @@ -1,17 +1,27 @@ +import os.path import subprocess from collections import defaultdict +from typing import Tuple, Optional import numpy as np +import pandas as pd from chardet.universaldetector import UniversalDetector from Orange.data import ( is_discrete_values, MISSING_VALUES, Variable, - DiscreteVariable, StringVariable, ContinuousVariable, TimeVariable, + DiscreteVariable, StringVariable, ContinuousVariable, TimeVariable, Table, ) from Orange.misc.collections import natural_sorted -__all__ = ["Compression", "open_compressed", "detect_encoding", "isnastr", - "guess_data_type", "sanitize_variable"] +__all__ = [ + "Compression", + "open_compressed", + "detect_encoding", + "isnastr", + "guess_data_type", + "sanitize_variable", + "update_origin", +] class Compression: @@ -207,3 +217,69 @@ def mapvalues(arr): values = [_var.parse(i) for i in orig_values] return values, var + + +def _extract_new_origin(attr: Variable, table: Table, lookup_dirs: Tuple[str]) -> Optional[str]: + # origin exists + if os.path.exists(attr.attributes["origin"]): + return attr.attributes["origin"] + + # last dir of origin in lookup dirs + dir_ = os.path.basename(os.path.normpath(attr.attributes["origin"])) + for ld in lookup_dirs: + new_dir = os.path.join(ld, dir_) + if os.path.isdir(new_dir): + return new_dir + + # all column paths in lookup dirs + for ld in lookup_dirs: + if all( + os.path.exists(os.path.join(ld, attr.str_val(v))) + for v in table.get_column(attr) + if v and not pd.isna(v) + ): + return ld + + return None + + +def update_origin(table: Table, file_path: str): + """ + When a dataset with file paths in the column is moved to another computer, + the absolute path may not be correct. This function updates the path for all + columns with an "origin" attribute. + + The process consists of two steps. First, we identify directories to search + for files, and in the second step, we check if paths exist. + + Lookup directories: + 1. The directory where the file from file_path is placed + 2. The parent directory of 1. The situation when the user places dataset + file in the directory with files (for example, workflow in a directory + with images) + + Possible situations for file search: + 1. The last directory of origin (basedir) is in one of the lookup directories + 2. Origin doesn't exist in any lookup directories, but paths in a column can + be found in one of the lookup directories. This is usually a situation + when paths in a column are complex (e.g. a/b/c/d/file.txt). + + Note: This function updates the existing table + + Parameters + ---------- + table + Orange Table to be updated if origin exits in any column + file_path + Path of the loaded dataset for reference. Only paths inside datasets + directory or its parent directory will be considered for new origin. + """ + file_dir = os.path.dirname(file_path) + parent_dir = os.path.dirname(file_dir) + # if file_dir already root file_dir == parent_dir + lookup_dirs = tuple({file_dir: 0, parent_dir: 0}) + for attr in table.domain.metas: + if "origin" in attr.attributes and (attr.is_string or attr.is_discrete): + new_orig = _extract_new_origin(attr, table, lookup_dirs) + if new_orig: + attr.attributes["origin"] = new_orig diff --git a/Orange/data/tests/test_io_util.py b/Orange/data/tests/test_io_util.py index 683132da8c5..8d6ec273768 100644 --- a/Orange/data/tests/test_io_util.py +++ b/Orange/data/tests/test_io_util.py @@ -1,6 +1,18 @@ +import os.path import unittest +from tempfile import TemporaryDirectory -from Orange.data import ContinuousVariable, guess_data_type +import numpy as np + +from Orange.data import ( + ContinuousVariable, + guess_data_type, + Table, + Domain, + StringVariable, + DiscreteVariable, +) +from Orange.data.io_util import update_origin class TestIoUtil(unittest.TestCase): @@ -10,5 +22,115 @@ def test_guess_continuous_w_nans(self): ContinuousVariable) +class TestUpdateOrigin(unittest.TestCase): + FILE_NAMES = ["file1.txt", "file2.txt", "file3.txt"] + + def setUp(self) -> None: + self.alt_dir = TemporaryDirectory() # pylint: disable=consider-using-with + + self.var_string = var = StringVariable("Files") + files = self.FILE_NAMES + [var.Unknown] + self.table_string = Table.from_list( + Domain([], metas=[var]), np.array(files).reshape((-1, 1)) + ) + self.var_discrete = var = DiscreteVariable("Files", values=self.FILE_NAMES) + files = self.FILE_NAMES + [var.Unknown] + self.table_discrete = Table.from_list( + Domain([], metas=[var]), np.array(files).reshape((-1, 1)) + ) + + def tearDown(self) -> None: + self.alt_dir.cleanup() + + def __create_files(self): + for f in self.FILE_NAMES: + f = os.path.join(self.alt_dir.name, f) + with open(f, "w", encoding="utf8"): + pass + self.assertTrue(os.path.exists(f)) + + def test_origin_not_changed(self): + """ + Origin exist; keep it unchanged, even though dataset path also includes + files from column. + """ + with TemporaryDirectory() as dir_name: + self.var_string.attributes["origin"] = dir_name + update_origin(self.table_string, self.alt_dir.name) + self.assertEqual( + self.table_string.domain[self.var_string].attributes["origin"], dir_name + ) + + def test_origin_subdir(self): + """ + Origin is wrong but last dir in origin exit in the dataset file's path + """ + images_dir = os.path.join(self.alt_dir.name, "subdir") + os.mkdir(images_dir) + + self.var_string.attributes["origin"] = "/a/b/subdir" + update_origin(self.table_string, os.path.join(self.alt_dir.name, "data.csv")) + self.assertEqual( + self.table_string.domain[self.var_string].attributes["origin"], images_dir + ) + + def test_origin_parents_subdir(self): + """ + Origin is wrong but last dir in origin exit in the dataset file + parent's directory + """ + # make the dir where dataset is placed + images_dir = os.path.join(self.alt_dir.name, "subdir") + os.mkdir(images_dir) + + self.var_string.attributes["origin"] = "/a/b/subdir" + update_origin(self.table_string, os.path.join(images_dir, "data.csv")) + self.assertEqual( + self.table_string.domain[self.var_string].attributes["origin"], images_dir + ) + + def test_column_paths_subdir(self): + """ + Origin dir not exiting but paths from column exist in dataset's dir + """ + self.__create_files() + + self.var_string.attributes["origin"] = "/a/b/non-exiting-dir" + update_origin(self.table_string, os.path.join(self.alt_dir.name, "data.csv")) + self.assertEqual( + self.table_string.domain[self.var_string].attributes["origin"], + self.alt_dir.name, + ) + + self.var_discrete.attributes["origin"] = "/a/b/non-exiting-dir" + update_origin(self.table_discrete, os.path.join(self.alt_dir.name, "data.csv")) + self.assertEqual( + self.table_discrete.domain[self.var_discrete].attributes["origin"], + self.alt_dir.name, + ) + + def test_column_paths_parents_subdir(self): + """ + Origin dir not exiting but paths from column exist in dataset parent's dir + """ + # make the dir where dataset is placed + dataset_dir = os.path.join(self.alt_dir.name, "subdir") + self.__create_files() + + self.var_string.attributes["origin"] = "/a/b/non-exiting-dir" + update_origin(self.table_string, os.path.join(dataset_dir, "data.csv")) + self.assertEqual( + self.table_string.domain[self.var_string].attributes["origin"], + self.alt_dir.name, + ) + + self.var_discrete.attributes["origin"] = "/a/b/non-exiting-dir" + update_origin(self.table_discrete, os.path.join(dataset_dir, "data.csv")) + self.assertEqual( + self.table_discrete.domain[self.var_discrete].attributes["origin"], + self.alt_dir.name, + ) + + if __name__ == '__main__': unittest.main() diff --git a/Orange/data/tests/test_variable.py b/Orange/data/tests/test_variable.py index c8d9b027cd9..4021acafcd4 100644 --- a/Orange/data/tests/test_variable.py +++ b/Orange/data/tests/test_variable.py @@ -1,6 +1,7 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring # pylint: disable=protected-access +import csv import os import sys import math @@ -10,7 +11,7 @@ import warnings from datetime import datetime, timezone -from io import StringIO +from tempfile import NamedTemporaryFile, TemporaryDirectory import numpy as np import pandas as pd @@ -714,27 +715,35 @@ def test_no_date_no_time(self): self.assertEqual(TimeVariable('relative time').repr_val(1.6), '1.6') def test_readwrite_timevariable(self): - output_csv = StringIO() - input_csv = StringIO("""\ -Date,Feature -time,continuous -, -1920-12-12,1.0 -1920-12-13,3.0 -1920-12-14,5.5 -""") - for stream in (output_csv, input_csv): - stream.close = lambda: None # HACK: Prevent closing of streams - - table = CSVReader(input_csv).read() - self.assertIsInstance(table.domain['Date'], TimeVariable) - self.assertEqual(table[0, 'Date'], '1920-12-12') + content = [ + ("Date", "Feature"), + ("time", "continuous"), + ("", ""), + ("1920-12-12", 1.0), + ("1920-12-13", 3.0), + ("1920-12-14", 5.5), + ] + with NamedTemporaryFile( + mode="w", delete=False, newline="", encoding="utf-8" + ) as input_csv: + csv.writer(input_csv, delimiter=",").writerows(content) + + table = CSVReader(input_csv.name).read() + self.assertIsInstance(table.domain["Date"], TimeVariable) + self.assertEqual(table[0, "Date"], "1920-12-12") # Dates before 1970 are negative - self.assertTrue(all(inst['Date'] < 0 for inst in table)) + self.assertTrue(all(inst["Date"] < 0 for inst in table)) - CSVReader.write_file(output_csv, table) - self.assertEqual(input_csv.getvalue().splitlines(), - output_csv.getvalue().splitlines()) + with NamedTemporaryFile(mode="w", delete=False) as output_csv: + pass + CSVReader.write_file(output_csv.name, table) + + with open(input_csv.name, encoding="utf-8") as in_f: + with open(output_csv.name, encoding="utf-8") as out_f: + self.assertEqual(in_f.read(), out_f.read()) + + os.unlink(input_csv.name) + os.unlink(output_csv.name) def test_repr_value(self): # https://github.com/biolab/orange3/pull/1760 diff --git a/Orange/tests/test_io.py b/Orange/tests/test_io.py index 3bcb71d9155..557f411d87c 100644 --- a/Orange/tests/test_io.py +++ b/Orange/tests/test_io.py @@ -1,7 +1,5 @@ # Test methods with long descriptive names can omit docstrings # pylint: disable=missing-docstring - -import io import os import pickle import shutil @@ -12,10 +10,10 @@ from Orange import data -from Orange.data.io import FileFormat, TabReader, CSVReader, PickleReader +from Orange.data.io import FileFormat, TabReader, CSVReader, PickleReader, ExcelReader from Orange.data.io_base import PICKLE_PROTOCOL from Orange.data.table import get_sample_datasets_dir -from Orange.data import Table +from Orange.data import Table, StringVariable, Domain from Orange.tests import test_dirname from Orange.util import OrangeDeprecationWarning @@ -124,12 +122,13 @@ def test_empty_columns(self): 1, 0, 1, 2, """ - c = io.StringIO(samplefile) + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp: + tmp.write(samplefile) with self.assertWarns(UserWarning) as cm: - table = CSVReader(c).read() + table = CSVReader(tmp.name).read() + os.unlink(tmp.name) self.assertEqual(len(table.domain.attributes), 2) - self.assertEqual(cm.warning.args[0], - "Columns with no headers were removed.") + self.assertEqual(cm.warning.args[0], "Columns with no headers were removed.") def test_type_annotations(self): class FooFormat(FileFormat): @@ -206,6 +205,32 @@ def test_pickle_version(self): # we should not use a version that is not supported self.assertLessEqual(PICKLE_PROTOCOL, pickle.HIGHEST_PROTOCOL) + def test_update_origin(self): + """ + Test if origin attributes is changed if path doesn't exist. For example + when file moved to another computer. It tested only one scenario + all other scenarios are tested as part of update_origin function tests. + """ + with tempfile.TemporaryDirectory() as dir_name: + os.mkdir(os.path.join(dir_name, "subdir")) + + var = StringVariable("Files") + var.attributes["origin"] = "/a/b/c/d/subdir" + table = Table.from_list(Domain([], metas=[var]), ["f1", "f2"]) + + for reader in (CSVReader, TabReader, PickleReader, ExcelReader): + dataset = os.path.join(dir_name, f"dataset{reader.EXTENSIONS[0]}") + if reader is PickleReader: + reader.write_file(dataset, table) + else: + reader.write_file(dataset, table, with_annotations=True) + + table = Table.from_file(dataset) + self.assertEqual( + os.path.join(dir_name, "subdir"), + table.domain["Files"].attributes["origin"], + ) + if __name__ == "__main__": unittest.main() diff --git a/Orange/tests/test_tab_reader.py b/Orange/tests/test_tab_reader.py index 85c1a8aa846..f0e98333b5d 100644 --- a/Orange/tests/test_tab_reader.py +++ b/Orange/tests/test_tab_reader.py @@ -2,12 +2,13 @@ # pylint: disable=missing-docstring import io +import os from os import path, remove import unittest -import tempfile import shutil import time from collections import OrderedDict +from tempfile import NamedTemporaryFile, mkdtemp import numpy as np @@ -34,8 +35,10 @@ def test_read_easy(self): 2.0 \tM \t4 \t """ - file = io.StringIO(simplefile) - table = read_tab_file(file) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(simplefile) + table = read_tab_file(tmp.name) + os.unlink(tmp.name) f1, f2, c1, c2 = table.domain.variables self.assertIsInstance(f1, DiscreteVariable) @@ -60,15 +63,24 @@ def test_read_save_quoted(self): """c\td"""\tk ''' expected = ['"a"', '"b"', '"c\td"'] - f = io.StringIO(quoted) - table = read_tab_file(f) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(quoted) + table = read_tab_file(tmp.name) + os.unlink(tmp.name) self.assertSequenceEqual(table.metas[:, 0].tolist(), expected) - f = io.StringIO() - f.close = lambda: None - TabReader.write_file(f, table) - saved = f.getvalue() - table1 = read_tab_file(io.StringIO(saved)) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + pass + TabReader.write_file(tmp.name, table) + with open(tmp.name, encoding="utf-8") as f: + saved = f.read() + os.unlink(tmp.name) + + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(saved) + table1 = read_tab_file(tmp.name) + os.unlink(tmp.name) + self.assertSequenceEqual(table1.metas[:, 0].tolist(), expected) def test_read_and_save_attributes(self): @@ -78,8 +90,10 @@ def test_read_and_save_attributes(self): \ta=1 b=2 \tclass x=a\\ longer\\ string \tclass 1.0 \tM \t5 \trich """ - file = io.StringIO(samplefile) - table = read_tab_file(file) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(samplefile) + table = read_tab_file(tmp.name) + os.unlink(tmp.name) _, f2, c1, _ = table.domain.variables self.assertIsInstance(f2, DiscreteVariable) @@ -89,13 +103,18 @@ def test_read_and_save_attributes(self): self.assertIsInstance(c1, DiscreteVariable) self.assertEqual(c1.name, "Class 1") self.assertEqual(c1.attributes, {'x': 'a longer string'}) - outf = io.StringIO() - outf.close = lambda: None - TabReader.write_file(outf, table) - saved = outf.getvalue() - file = io.StringIO(saved) - table = read_tab_file(file) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + pass + TabReader.write_file(tmp.name, table) + with open(tmp.name, encoding="utf-8") as f: + saved = f.read() + os.unlink(tmp.name) + + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(saved) + table = read_tab_file(tmp.name) + os.unlink(tmp.name) _, f2, c1, _ = table.domain.variables self.assertIsInstance(f2, DiscreteVariable) @@ -108,12 +127,11 @@ def test_read_and_save_attributes(self): spath = "/path/to/somewhere" c1.attributes["path"] = spath - outf = io.StringIO() - outf.close = lambda: None - TabReader.write_file(outf, table) - outf.seek(0) - - table = read_tab_file(outf) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + pass + TabReader.write_file(tmp.name, table) + table = read_tab_file(tmp.name) + os.unlink(tmp.name) _, _, c1, _ = table.domain.variables self.assertEqual(c1.attributes["path"], spath) @@ -123,8 +141,10 @@ def test_read_data_oneline_header(self): 0.1\t0.2\t0.3 1.1\t1.2\t1.5 """ - file = io.StringIO(samplefile) - table = read_tab_file(file) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(samplefile) + table = read_tab_file(tmp.name) + os.unlink(tmp.name) self.assertEqual(len(table), 2) self.assertEqual(len(table.domain.variables), 3) @@ -135,8 +155,10 @@ def test_read_data_no_header(self): 0.1\t0.2\t0.3 1.1\t1.2\t1.5 """ - file = io.StringIO(samplefile) - table = read_tab_file(file) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(samplefile) + table = read_tab_file(tmp.name) + os.unlink(tmp.name) self.assertEqual(len(table), 2) self.assertEqual(len(table.domain.variables), 3) @@ -148,10 +170,14 @@ def test_read_data_no_header_feature_reuse(self): 0.1\t0.2\t0.3 1.1\t1.2\t1.5 """ - file = io.StringIO(samplefile) - t1 = read_tab_file(file) - file = io.StringIO(samplefile) - t2 = read_tab_file(file) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(samplefile) + t1 = read_tab_file(tmp.name) + os.unlink(tmp.name) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(samplefile) + t2 = read_tab_file(tmp.name) + os.unlink(tmp.name) self.assertEqual(t1.domain[0], t2.domain[0]) def test_renaming(self): @@ -160,7 +186,7 @@ def test_renaming(self): c\t c\t c\t c\t c\t c\t c\t c\t c \t \t \t \t class\t class\t \t \t meta 0\t 0\t 0\t 0\t 0\t 0\t 0\t 0 """ - file = tempfile.NamedTemporaryFile("wt", delete=False, suffix=".tab") + file = NamedTemporaryFile("wt", delete=False, suffix=".tab") filename = file.name try: file.write(simplefile) @@ -198,7 +224,7 @@ def test_sheets(self): self.assertEqual(reader.sheets, []) def test_attributes_saving(self): - tempdir = tempfile.mkdtemp() + tempdir = mkdtemp() try: self.assertEqual(self.data.attributes, {}) self.data.attributes[1] = "test" @@ -209,7 +235,7 @@ def test_attributes_saving(self): shutil.rmtree(tempdir) def test_attributes_saving_as_txt(self): - tempdir = tempfile.mkdtemp() + tempdir = mkdtemp() try: self.data.attributes = OrderedDict() self.data.attributes["a"] = "aa" @@ -229,7 +255,7 @@ def test_data_name(self): self.assertEqual(table2.name, 'iris') def test_metadata(self): - tempdir = tempfile.mkdtemp() + tempdir = mkdtemp() try: self.data.attributes = OrderedDict() self.data.attributes["a"] = "aa" @@ -241,7 +267,7 @@ def test_metadata(self): shutil.rmtree(tempdir) def test_no_metadata(self): - tempdir = tempfile.mkdtemp() + tempdir = mkdtemp() try: self.data.attributes = OrderedDict() fname = path.join(tempdir, "out.tab") @@ -251,7 +277,7 @@ def test_no_metadata(self): shutil.rmtree(tempdir) def test_had_metadata_now_there_is_none(self): - tempdir = tempfile.mkdtemp() + tempdir = mkdtemp() try: self.data.attributes["a"] = "aa" fname = path.join(tempdir, "out.tab") @@ -275,11 +301,16 @@ def test_number_of_decimals(self): @staticmethod def test_many_discrete(): - b = io.StringIO() - b.write("Poser\nd\n\n") - b.writelines("K" + str(i) + "\n" for i in range(30000)) + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write("Poser\nd\n\n") + tmp.writelines("K" + str(i) + "\n" for i in range(30000)) start = time.time() - _ = TabReader(b).read() + _ = TabReader(tmp.name).read() elapsed = time.time() - start + os.unlink(tmp.name) if elapsed > 2: raise AssertionError() + + +if __name__ == "__main__": + unittest.main() diff --git a/Orange/tests/test_txt_reader.py b/Orange/tests/test_txt_reader.py index dcd0dfc451d..2111143ee77 100644 --- a/Orange/tests/test_txt_reader.py +++ b/Orange/tests/test_txt_reader.py @@ -4,7 +4,6 @@ import unittest from tempfile import NamedTemporaryFile import os -import io import warnings from Orange.data import Table, ContinuousVariable, DiscreteVariable @@ -80,8 +79,11 @@ def test_read_csv(self): self.read_easy(csv_file_nh, "Feature ") def test_read_csv_with_na(self): - c = io.StringIO(csv_file_missing) - table = CSVReader(c).read() + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(csv_file_missing) + + table = CSVReader(tmp.name).read() + os.unlink(tmp.name) f1, f2 = table.domain.variables self.assertIsInstance(f1, ContinuousVariable) self.assertIsInstance(f2, DiscreteVariable) @@ -130,3 +132,7 @@ def test_csv_sniffer(self): data = reader.read() self.assertEqual(len(data), 8) self.assertEqual(len(data.domain.variables) + len(data.domain.metas), 15) + + +if __name__ == "__main__": + unittest.main() diff --git a/Orange/widgets/evaluate/tests/test_owpredictions.py b/Orange/widgets/evaluate/tests/test_owpredictions.py index f48a5eb884b..762f4d2068b 100644 --- a/Orange/widgets/evaluate/tests/test_owpredictions.py +++ b/Orange/widgets/evaluate/tests/test_owpredictions.py @@ -1,8 +1,9 @@ """Tests for OWPredictions""" # pylint: disable=protected-access -import io +import os import unittest from functools import partial +from tempfile import NamedTemporaryFile from typing import Optional from unittest.mock import Mock, patch @@ -206,8 +207,10 @@ def test_bad_data(self): child\tmale\tyes child\tfemale\tyes """ - file1 = io.StringIO(filestr1) - table = TabReader(file1).read() + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(filestr1) + table = TabReader(tmp.name).read() + os.unlink(tmp.name) learner = TreeLearner() tree = learner(table) @@ -220,9 +223,11 @@ def test_bad_data(self): child\tmale\tyes child\tfemale\tunknown """ - file2 = io.StringIO(filestr2) - bad_table = TabReader(file2).read() + with NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(filestr2) + bad_table = TabReader(tmp.name).read() + os.unlink(tmp.name) self.send_signal(self.widget.Inputs.predictors, tree, 1) with excepthook_catch():