diff --git a/docs/todo_list.md b/docs/todo_list.md index 10f309e..8bfaf95 100644 --- a/docs/todo_list.md +++ b/docs/todo_list.md @@ -2,8 +2,6 @@ The list of expected/future improvements: -0. Extend columns in function '_check_csv_header' - 1. Add ability to select type of output format (CSV or parquet) in commandline - current solution generate both formats without ability to choose preferences diff --git a/tests/test_generator.py b/tests/test_generator.py index 9925d7f..bef8915 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -128,10 +128,11 @@ def test_generate_bigbulk_repeat(self): self.assertTrue(os.path.exists(path.join(dir, f"{basic_contact.BasicContact.NAME}.csv"))) - def _check_csv_header(self, filename, key_text): + def _check_csv_header(self, filename, key_texts: list): if os.path.exists(filename): - df = pd.read_csv(filename) - self.assertTrue(df.to_string().find(key_text) >= 0) + content = pd.read_csv(filename).to_string() + for key_text in key_texts: + self.assertTrue(content.find(key_text) >= 0) def test_csv_structure(self): """All csv have header""" @@ -142,12 +143,19 @@ def test_csv_structure(self): dir = path.join(TestGenerator.OUTPUT_ADR, lbl) self.assertTrue(os.path.exists(dir)) - self._check_csv_header(path.join(dir, f"{basic_party.BasicParty.NAME}.csv"), "party_id") - self._check_csv_header(path.join(dir, f"{basic_contact.BasicContact.NAME}.csv"), "party_id") - self._check_csv_header(path.join(dir, f"{basic_relation.BasicRelation.NAME}.csv"), "party_id") - self._check_csv_header(path.join(dir, f"{basic_account.BasicAccount.NAME}.csv"), "party_id") - self._check_csv_header(path.join(dir, f"{basic_transaction.BasicTransaction.NAME}.csv"), "account_id") - self._check_csv_header(path.join(dir, f"{basic_event.BasicEvent.NAME}.csv"), "party_id") - self._check_csv_header(path.join(dir, f"{basic_communication.BasicCommunication.NAME}.csv"), "party_id") + self._check_csv_header(path.join(dir, f"{basic_party.BasicParty.NAME}.csv"), + ["party_id", "party_gender"]) + self._check_csv_header(path.join(dir, f"{basic_contact.BasicContact.NAME}.csv"), + ["party_id", "contact_id", "contact_state"]) + self._check_csv_header(path.join(dir, f"{basic_relation.BasicRelation.NAME}.csv"), + ["party_id", "relation_id", "relation_type"]) + self._check_csv_header(path.join(dir, f"{basic_account.BasicAccount.NAME}.csv"), + ["party_id", "account_id", "account_state"]) + self._check_csv_header(path.join(dir, f"{basic_transaction.BasicTransaction.NAME}.csv"), + ["account_id", "transaction_id", "transaction_direction"]) + self._check_csv_header(path.join(dir, f"{basic_event.BasicEvent.NAME}.csv"), + ["party_id", "event_id", "session_id"]) + self._check_csv_header(path.join(dir, f"{basic_communication.BasicCommunication.NAME}.csv"), + ["party_id", "communication_id", "content", "content_sentiment"]) # TODO: Add batch size under limit, it will generate wrong dataset \ No newline at end of file