diff --git a/docs/conf.py b/docs/conf.py index 59f6672e..da88c1ca 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -7,8 +7,8 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = 'DSI' -copyright = '2023, Terry Turton' -author = 'Terry Turton' +copyright = '2023, Triad National Security, LLC' +author = 'Triad National Security, LLC' release = '0.0.0' # -- General configuration --------------------------------------------------- diff --git a/docs/index.rst b/docs/index.rst index 5b3df85d..ed68746f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,7 +16,7 @@ The Data Science Infrastructure Project (DSI) core plugins drivers - + permissions Indices and tables ================== diff --git a/docs/permissions.rst b/docs/permissions.rst new file mode 100644 index 00000000..5a2cbfbf --- /dev/null +++ b/docs/permissions.rst @@ -0,0 +1,55 @@ +Permissions +=================== +DSI is capable of consuming information from files, environments, and in-situ processes which may or may not have the same permissions authority. To track this information for the purposes of returning user queries into DSI storage, we utilize a permissions handler. The permissions handler bundles the authority by which information is read and adds this to each column data structure. Most relational database systems require that types are enforced by column, and DSI extends this idea to require that permissions are enforced by column. By tracking the permissions associated with each column, DSI can save files using the same POSIX permissions authority that initially granted access to the information, therefore preserving POSIX permssions as files are saved. + +By default, DSI will stop users from saving any metadata if the length of the union of the set of column permissions is greater than one. This prevents users from saving files that might have complex security implications. If a user enables the ``allow_multiple_permissions`` parameter of the ``PermissionsManager``, then the number of files that will be saved is equal to the length of the union of the set of column permissions in the middleware data structures being written (an example of this behavior follows). There will be one file for each set of columns read by the same permissions authority. + +By default, DSI will always respect the POSIX permissions authority by which information was read. If the user wishes to override this behavior and write all of their metadata to the same file with a unified UID and GID, they can enable the ``squash_permissions`` parameter of the ``PermissionsManager``. The user should be very certain that the information they are writing is protected appropriately in this case. + +An example helps illustrate these scenarios: + ++----------+----------+----------+ +| Col A | Col B | Col C | ++----------+----------+----------+ +| *Perm D* | *Perm D* | *Perm F* | ++----------+----------+----------+ +| Row A1 | Row B1 | Row C1 | ++----------+----------+----------+ +| Row A2 | Row B2 | Row C2 | ++----------+----------+----------+ + +By default, DSI will refuse to write this data structure to disk because ``len(union({D,D,F})) > 1`` + +If a user enables the ``allow_multiple_permissions`` parameter, two files will be saved: + +>>> $ cat file1 +>>> | Col A | Col B | +>>> =================== +>>> | Perm D | Perm D | +>>> | Row A1 | Row B1 | +>>> | Row A2 | Row B2 | +>>> $ get_perms(file1) +>>> Perm D +>>> $ cat file2 +>>> | Col C | +>>> ========== +>>> | Perm F | +>>> | Row C1 | +>>> | Row C2 | +>>> $ get_perms(file2) +>>> Perm F + +If a user enables ``allow_multiple_permissions`` and ``squash_permissions``, then a single file will be written with the user's UID and effective GID and 660 access: + +>>> $ cat file +>>> | Col A | Col B | Col C | +>>> ============================ +>>> | Perm D | Perm D | Perm F | +>>> | Row A1 | Row B1 | Row C1 | +>>> | Row A2 | Row B2 | Row C2 | +>>> $ get_perms(file) +>>> My UID and Effective GID, with 660 access controls. + + +.. automodule:: dsi.permissions.permissions + :members: diff --git a/dsi/core.py b/dsi/core.py index bceb025c..6bd72cfc 100644 --- a/dsi/core.py +++ b/dsi/core.py @@ -3,8 +3,10 @@ from collections import OrderedDict from itertools import product +from dsi.permissions.permissions import PermissionsManager -class Terminal(): + +class Terminal: """ An instantiated Terminal is the DSI human/machine interface. @@ -12,41 +14,49 @@ class Terminal(): front-ends or back-ends. Plugins may be producers or consumers. See documentation for more information. """ - DRIVER_PREFIX = ['dsi.drivers'] - DRIVER_IMPLEMENTATIONS = ['gufi', 'sqlite', 'parquet'] - PLUGIN_PREFIX = ['dsi.plugins'] - PLUGIN_IMPLEMENTATIONS = ['env', 'file_consumer'] - VALID_PLUGINS = ['Hostname', 'SystemKernel', 'GitInfo', 'Bueno', 'Csv'] - VALID_DRIVERS = ['Gufi', 'Sqlite', 'Parquet'] + DRIVER_PREFIX = ["dsi.drivers"] + DRIVER_IMPLEMENTATIONS = ["gufi", "sqlite", "parquet"] + PLUGIN_PREFIX = ["dsi.plugins"] + PLUGIN_IMPLEMENTATIONS = ["env", "file_consumer"] + VALID_PLUGINS = ["Hostname", "SystemKernel", "Bueno", "Csv", "GitInfo"] + VALID_DRIVERS = ["Gufi", "Sqlite", "Parquet"] VALID_MODULES = VALID_PLUGINS + VALID_DRIVERS - VALID_MODULE_FUNCTIONS = {'plugin': [ - 'producer', 'consumer'], 'driver': ['front-end', 'back-end']} - VALID_ARTIFACT_INTERACTION_TYPES = ['get', 'set', 'put', 'inspect'] + VALID_MODULE_FUNCTIONS = { + "plugin": ["producer", "consumer"], + "driver": ["front-end", "back-end"], + } + VALID_ARTIFACT_INTERACTION_TYPES = ["get", "set", "put", "inspect"] + + def __init__(self, allow_multiple_permissions=False, squash_permissions=False): + self.allow_multiple_permissions = allow_multiple_permissions + self.squash_permissions = squash_permissions - def __init__(self): # Helper function to get parent module names. def static_munge(prefix, implementations): - return (['.'.join(i) for i in product(prefix, implementations)]) + return [".".join(i) for i in product(prefix, implementations)] self.module_collection = {} - driver_modules = static_munge( - self.DRIVER_PREFIX, self.DRIVER_IMPLEMENTATIONS) - self.module_collection['driver'] = {} + driver_modules = static_munge(self.DRIVER_PREFIX, self.DRIVER_IMPLEMENTATIONS) + self.module_collection["driver"] = {} for module in driver_modules: - self.module_collection['driver'][module] = import_module(module) + self.module_collection["driver"][module] = import_module(module) - plugin_modules = static_munge( - self.PLUGIN_PREFIX, self.PLUGIN_IMPLEMENTATIONS) - self.module_collection['plugin'] = {} + plugin_modules = static_munge(self.PLUGIN_PREFIX, self.PLUGIN_IMPLEMENTATIONS) + self.module_collection["plugin"] = {} for module in plugin_modules: - self.module_collection['plugin'][module] = import_module(module) + self.module_collection["plugin"][module] = import_module(module) self.active_modules = {} - valid_module_functions_flattened = self.VALID_MODULE_FUNCTIONS['plugin'] + \ - self.VALID_MODULE_FUNCTIONS['driver'] + valid_module_functions_flattened = ( + self.VALID_MODULE_FUNCTIONS["plugin"] + + self.VALID_MODULE_FUNCTIONS["driver"] + ) for valid_function in valid_module_functions_flattened: self.active_modules[valid_function] = [] self.active_metadata = OrderedDict() + self.perms_manager = PermissionsManager( + allow_multiple_permissions, squash_permissions + ) self.transload_lock = False def list_available_modules(self, mod_type): @@ -64,8 +74,9 @@ def list_available_modules(self, mod_type): for python_module, classlist in self.module_collection[mod_type].items(): # In the next line, both "class" and VALID_MODULES refer to DSI modules. class_collector.extend( - [x for x in dir(classlist) if x in self.VALID_MODULES]) - return (class_collector) + [x for x in dir(classlist) if x in self.VALID_MODULES] + ) + return class_collector def load_module(self, mod_type, mod_name, mod_function, **kwargs): """ @@ -76,16 +87,24 @@ def load_module(self, mod_type, mod_name, mod_function, **kwargs): We expect most users will work with module implementations rather than templates, but but all high level class abstractions are accessible with this method. """ - if self.transload_lock and mod_type == 'plugin': - print('Plugin module loading is prohibited after transload. No action taken.') + if self.transload_lock and mod_type == "plugin": + print( + "Plugin module loading is prohibited after transload. No action taken." + ) return if mod_function not in self.VALID_MODULE_FUNCTIONS[mod_type]: print( - 'Hint: Did you declare your Module Function in the Terminal Global vars?') + "Hint: Did you declare your Module Function in the Terminal Global vars?" + ) raise NotImplementedError - if mod_name in [obj.__class__.__name__ for obj in self.active_modules[mod_function]]: - print('{} {} already loaded as {}. Nothing to do.'.format( - mod_name, mod_type, mod_function)) + if mod_name in [ + obj.__class__.__name__ for obj in self.active_modules[mod_function] + ]: + print( + "{} {} already loaded as {}. Nothing to do.".format( + mod_name, mod_type, mod_function + ) + ) return # DSI Modules are Python classes. class_name = mod_name @@ -94,47 +113,56 @@ def load_module(self, mod_type, mod_name, mod_function, **kwargs): try: this_module = import_module(python_module) class_ = getattr(this_module, class_name) - self.active_modules[mod_function].append(class_(**kwargs)) + instance = class_(perms_manager=self.perms_manager, **kwargs) + self.active_modules[mod_function].append(instance) load_success = True except AttributeError: continue if load_success: - print('{} {} {} loaded successfully.'.format( - mod_name, mod_type, mod_function)) + print( + "{} {} {} loaded successfully.".format(mod_name, mod_type, mod_function) + ) else: - print('Hint: Did you declare your Plugin/Driver in the Terminal Global vars?') + print( + "Hint: Did you declare your Plugin/Driver in the Terminal Global vars?" + ) raise NotImplementedError def unload_module(self, mod_type, mod_name, mod_function): """ Unloads a DSI module from the active_modules collection """ - if self.transload_lock and mod_type == 'plugin': + if self.transload_lock and mod_type == "plugin": print( - 'Plugin module unloading is prohibited after transload. No action taken.') + "Plugin module unloading is prohibited after transload. No action taken." + ) return for i, mod in enumerate(self.active_modules[mod_function]): if mod.__class__.__name__ == mod_name: self.active_modules[mod_function].pop(i) - print("{} {} {} unloaded successfully.".format( - mod_name, mod_type, mod_function)) + print( + "{} {} {} unloaded successfully.".format( + mod_name, mod_type, mod_function + ) + ) return - print("{} {} {} could not be found in active_modules. No action taken.".format( - mod_name, mod_type, mod_function)) + print( + "{} {} {} could not be found in active_modules. No action taken.".format( + mod_name, mod_type, mod_function + ) + ) def add_external_python_module(self, mod_type, mod_name, mod_path): """ - Adds an external, meaning not from the DSI repo, Python module to the module_collection. + Adds a given external, meaning not from the DSI repo, Python module to the module_collection Afterwards, load_module can be used to load a DSI module from the added Python module. - Note: mod_type is needed because each Python module only implements plugins or drivers. + Note: mod_type is needed because each Python module should only implement plugins or drivers For example, term = Terminal() - term.add_external_python_module('plugin', 'my_python_file', - - '/the/path/to/my_python_file.py') + term.add_external_python_module('plugin', 'my_python_file', '/the/path/to/my_python_file.py') term.load_module('plugin', 'MyPlugin', 'consumer') @@ -149,7 +177,7 @@ def list_loaded_modules(self): These Plugins and Drivers are active or ready to execute a post-processing task. """ - return (self.active_modules) + return self.active_modules def transload(self, **kwargs): """ @@ -160,7 +188,8 @@ def transload(self, **kwargs): data sources to a single DSI Core Middleware data structure. """ selected_function_modules = dict( - (k, self.active_modules[k]) for k in ('producer', 'consumer')) + (k, self.active_modules[k]) for k in ("producer", "consumer") + ) # Note this transload supports plugin.env Environment types now. for module_type, objs in selected_function_modules.items(): for obj in objs: @@ -177,52 +206,115 @@ def transload(self, **kwargs): max_len = max([len(col) for col in self.active_metadata.values()]) for colname, coldata in self.active_metadata.items(): if len(coldata) != max_len: - self.active_metadata[colname].extend( # add None's until reaching max_len - [None] * (max_len - len(coldata))) + self.active_metadata[ + colname + ].extend( # add None's until reaching max_len + [None] * (max_len - len(coldata)) + ) - assert all([len(col) == max_len for col in self.active_metadata.values( - )]), "All columns must have the same number of rows" + assert all( + [len(col) == max_len for col in self.active_metadata.values()] + ), "All columns must have the same number of rows" self.transload_lock = True - def artifact_handler(self, interaction_type, **kwargs): + def artifact_handler(self, interaction_type, **kwargs) -> bool: """ Store or retrieve using all loaded DSI Drivers with back-end functionality. A DSI Core Terminal may load zero or more Drivers with back-end storage functionality. Calling artifact_handler will execute all back-end functionality currently loaded, given the provided ``interaction_type``. + + Returns whether the interaction was successful or not. """ if interaction_type not in self.VALID_ARTIFACT_INTERACTION_TYPES: print( - 'Hint: Did you declare your artifact interaction type in the Terminal Global vars?') + "Hint: Did you declare your artifact interaction type in the Terminal Global vars?" + ) raise NotImplementedError + + if interaction_type == "put" or interaction_type == "set": + should_continue = self.handle_permissions_interactions() + if not should_continue: + return False + operation_success = False # Perform artifact movement first, because inspect implementation may rely on # self.active_metadata or some stored artifact. selected_function_modules = dict( - (k, self.active_modules[k]) for k in (['back-end'])) + (k, self.active_modules[k]) for k in (["back-end"]) + ) for module_type, objs in selected_function_modules.items(): for obj in objs: - if interaction_type == 'put' or interaction_type == 'set': - obj.put_artifacts( - collection=self.active_metadata, **kwargs) + if interaction_type == "put" or interaction_type == "set": + obj.put_artifacts(collection=self.active_metadata, **kwargs) operation_success = True - elif interaction_type == 'get': - self.active_metadata = obj.get_artifacts(**kwargs) + elif interaction_type == "get": + self.active_metadata.update(obj.get_artifacts(**kwargs)) operation_success = True - if interaction_type == 'inspect': + if interaction_type == "inspect": for module_type, objs in selected_function_modules.items(): for obj in objs: - obj.put_artifacts( - collection=self.active_metadata, **kwargs) + obj.put_artifacts(collection=self.active_metadata, **kwargs) self.active_metadata = obj.inspect_artifacts( - collection=self.active_metadata, **kwargs) + collection=self.active_metadata, **kwargs + ) operation_success = True if operation_success: - return + return operation_success else: print( - 'Hint: Did you implement a case for your artifact interaction in the \ - artifact_handler loop?') + "Hint: Did you implement a case for your artifact interaction in the \ + artifact_handler loop?" + ) raise NotImplementedError + + def handle_permissions_interactions(self) -> bool: + """ + Presents the user with information on how permissions are being handled + and recieves consent to the operations through input. + Returns whether or not the user wants the operation to be carried out. + """ + if ( + self.perms_manager.has_multiple_permissions() + and not self.allow_multiple_permissions + ): + print( + "Data has multiple permissions as shown here: \n" + + self.put_report() + + "However, allow_multiple_permissions is not true.\n" + + "No action taken." + ) + return False + elif self.squash_permissions: + msg = ( + "WARNING: One file will be written, throwing out all " + + "permissions attached as shown below:\n" + + self.put_report() + + "THIS SHOULD BE DONE WITH EXTREME CAUTION! Continue? (y/n): " + ) + if input(msg).lower().strip() != "y": + print("No action taken.") + return False + elif self.allow_multiple_permissions: + msg = ( + "WARNING: One file will be written for each POSIX permission " + + "present in read files as shown below:\n" + + self.put_report() + + "Continue? (y/n)" + ) + if input(msg).lower().strip() != "y": + print("No action taken.") + return False + return True + + def put_report(self) -> str: + """ + Generates a report of which columns are registered to which permissions + ex. col1: 444-444-0o660 + """ + report = "" + for col, perm in self.perms_manager.column_perms.items(): + report += f"{col}: {perm}\n" + return report diff --git a/dsi/drivers/filesystem.py b/dsi/drivers/filesystem.py index a2bd161d..64b2caf7 100644 --- a/dsi/drivers/filesystem.py +++ b/dsi/drivers/filesystem.py @@ -1,4 +1,6 @@ from abc import ABCMeta, abstractmethod +from collections.abc import Callable +from dsi.permissions.permissions import PermissionsManager class Driver(metaclass=ABCMeta): @@ -12,7 +14,7 @@ def git_commit_sha(self): pass @abstractmethod - def put_artifacts(self, artifacts, kwargs) -> None: + def put_artifacts(self, artifacts, **kwargs) -> None: pass @abstractmethod @@ -25,7 +27,7 @@ def inspect_artifacts(self): class Filesystem(Driver): - git_commit_sha = '5d79e08d4a6c1570ceb47cdd61d2259505c05de9' + git_commit_sha: str = "5d79e08d4a6c1570ceb47cdd61d2259505c05de9" # Declare named types DOUBLE = "DOUBLE" STRING = "VARCHAR" @@ -42,10 +44,11 @@ class Filesystem(Driver): LT = "<" EQ = "=" - def __init__(self, filename) -> None: - pass + def __init__(self, filename: str, perms_manager: PermissionsManager) -> None: + self.filename = filename + self.perms_manager = perms_manager - def put_artifacts(self, artifacts, kwargs) -> None: + def put_artifacts(self, artifacts, **kwargs) -> None: pass def get_artifacts(self, query): @@ -53,3 +56,34 @@ def get_artifacts(self, query): def inspect_artifacts(self): pass + + def write_files( + self, + collection, + write_func: Callable[[dict[str, list], str], None], + f_basename: str, + f_ext: str, + ) -> None: + """ + Write out a given collection to multiple files, one per + unique permission. File are written with `write_func` + and those files are set to their corresponding permission. + """ + f_map = self.get_output_file_mapping(f_basename, f_ext) + for f, cols in f_map.items(): # Write one file for each unique permission + col_to_data = {col: collection[col] for col in cols} + write_func(col_to_data, f) + self.perms_manager.set_file_permissions(f_map) + + def get_output_file_mapping( + self, base_filename: str, file_ext: str + ) -> dict[str, list[str]]: + """ + Given a base filename and extention, returns a mapping from filename + to a list of corresponding columns. Each filename encodes permissions. + """ + perms_to_cols = self.perms_manager.get_permission_columnlist_mapping() + return { + (base_filename + "-" + str(perm) + file_ext): cols + for perm, cols in perms_to_cols.items() + } diff --git a/dsi/drivers/gufi.py b/dsi/drivers/gufi.py index ff6bdbfb..ecddcdcf 100644 --- a/dsi/drivers/gufi.py +++ b/dsi/drivers/gufi.py @@ -31,7 +31,7 @@ class Gufi(Filesystem): column: column name from the DSI db to join on """ - def __init__(self, prefix, index, dbfile, table, column, verbose=False): + def __init__(self, prefix, index, dbfile, table, column, verbose=False, **kwargs): ''' prefix: prefix to GUFI commands index: directory with GUFI indexes @@ -41,7 +41,7 @@ def __init__(self, prefix, index, dbfile, table, column, verbose=False): verbose: print debugging statements or not ''' - super().__init__(dbfile) + super().__init__(dbfile, **kwargs) # prefix is the prefix to the GUFI installation self.prefix = prefix # index is the directory where the GUFI indexes are stored @@ -65,6 +65,7 @@ def get_artifacts(self, query): resout = self._run_gufi_query(query) if self.isVerbose: print(resout) + # TODO: Needs to register permissions for columns return resout diff --git a/dsi/drivers/parquet.py b/dsi/drivers/parquet.py index 71ab62c9..798a6240 100644 --- a/dsi/drivers/parquet.py +++ b/dsi/drivers/parquet.py @@ -14,8 +14,7 @@ class Parquet(Filesystem): """ def __init__(self, filename, **kwargs): - super().__init__(filename=filename) - self.filename = filename + super().__init__(filename=filename, **kwargs) try: self.compression = kwargs['compression'] except KeyError: @@ -25,12 +24,18 @@ def get_artifacts(self): """Get Parquet data from filename.""" table = pq.read_table(self.filename) resout = table.to_pydict() + self.perms_manager.register_columns_with_file( + list(resout.keys()), self.filename) return resout def put_artifacts(self, collection): """Put artifacts into file at filename path.""" - table = pa.table(collection) - pq.write_table(table, self.filename, compression=self.compression) + def write_dict(collection, fname): + table = pa.table(collection) + pq.write_table(table, fname, compression=self.compression) + + self.write_files(collection, write_func=write_dict, + f_basename=self.filename[:-3], f_ext='.pq') @staticmethod def get_cmd_output(cmd: list) -> str: diff --git a/dsi/drivers/sqlite.py b/dsi/drivers/sqlite.py index 1597c6b9..b8ee9f42 100644 --- a/dsi/drivers/sqlite.py +++ b/dsi/drivers/sqlite.py @@ -34,7 +34,8 @@ class Sqlite(Filesystem): con = None cur = None - def __init__(self, filename): + def __init__(self, filename, **kwargs): + super().__init__(filename, **kwargs) self.filename = filename self.con = sqlite3.connect(filename) self.cur = self.con.cursor() @@ -191,7 +192,8 @@ def get_artifact_list(self, isVerbose=False): # Returns reference from query def get_artifacts(self, query): - self.get_artifacts_list() + # TODO: Needs to register permissions by column + return self.get_artifacts_list() # Closes connection to server def close(self): diff --git a/dsi/drivers/tests/test_gufi.py b/dsi/drivers/tests/test_gufi.py index 83033be3..78ab9f91 100644 --- a/dsi/drivers/tests/test_gufi.py +++ b/dsi/drivers/tests/test_gufi.py @@ -1,4 +1,5 @@ from dsi.drivers.gufi import Gufi +from dsi.permissions.permissions import PermissionsManager isVerbose = False @@ -9,7 +10,9 @@ def test_artifact_query(): prefix = "/usr/local/bin" table = "sample" column = "sample_col" - store = Gufi(prefix, index, dbpath, table, column, isVerbose) + mock_pm = PermissionsManager() + store = Gufi(prefix, index, dbpath, table, column, + isVerbose, perms_manager=mock_pm) sqlstr = "select * from dsi_entries" rows = store.get_artifacts(sqlstr) store.close() diff --git a/dsi/drivers/tests/test_sqlite.py b/dsi/drivers/tests/test_sqlite.py index 14ae0f5b..7450ff70 100644 --- a/dsi/drivers/tests/test_sqlite.py +++ b/dsi/drivers/tests/test_sqlite.py @@ -1,6 +1,7 @@ import git from dsi.drivers.sqlite import Sqlite, DataType +from dsi.permissions.permissions import PermissionsManager isVerbose = True @@ -13,7 +14,8 @@ def get_git_root(path): def test_wildfire_data_sql_artifact(): dbpath = "wildfire.db" - store = Sqlite(dbpath) + mock_pm = PermissionsManager() + store = Sqlite(dbpath, perms_manager=mock_pm) store.close() # No error implies success assert True @@ -22,7 +24,8 @@ def test_wildfire_data_sql_artifact(): def test_wildfire_data_csv_artifact(): csvpath = '/'.join([get_git_root('.'), 'dsi/data/wildfiredata.csv']) dbpath = "wildfire.db" - store = Sqlite(dbpath) + mock_pm = PermissionsManager() + store = Sqlite(dbpath, perms_manager=mock_pm) store.put_artifacts_csv(csvpath, "simulation", isVerbose=isVerbose) store.close() # No error implies success @@ -32,7 +35,8 @@ def test_wildfire_data_csv_artifact(): def test_yosemite_data_csv_artifact(): csvpath = '/'.join([get_git_root('.'), 'dsi/data/yosemite5.csv']) dbpath = "yosemite.db" - store = Sqlite(dbpath) + mock_pm = PermissionsManager() + store = Sqlite(dbpath, perms_manager=mock_pm) store.put_artifacts_csv(csvpath, "vision", isVerbose=isVerbose) store.close() # No error implies success @@ -41,7 +45,8 @@ def test_yosemite_data_csv_artifact(): def test_artifact_query(): dbpath = "wildfire.db" - store = Sqlite(dbpath) + mock_pm = PermissionsManager() + store = Sqlite(dbpath, perms_manager=mock_pm) _ = store.get_artifact_list(isVerbose=isVerbose) data_type = DataType() data_type.name = "simulation" diff --git a/dsi/permissions/permissions.py b/dsi/permissions/permissions.py new file mode 100644 index 00000000..48907f93 --- /dev/null +++ b/dsi/permissions/permissions.py @@ -0,0 +1,130 @@ +from collections import defaultdict +from dataclasses import dataclass +from stat import S_IRWXU, S_IRWXG, S_IRWXO +from os import stat, getuid, getgid, chown, chmod + + +@dataclass(eq=True, frozen=True) +class Permission: + """A simple dataclass to represent POSIX file permissions""" + + uid: int + gid: int + settings: str + + def __iter__(self): + """enables conversion to tuple, list, etc.""" + for v in (self.uid, self.gid, self.settings): + yield v + + def __str__(self): + return f"{self.uid}-{self.gid}-{self.settings}" + + +class PermissionsManager: + """ + A class to handle and register the mapping from columns + to their permissions. Uses flyweights so that each unique + permission is shared and only stored in memory once. + """ + + def __init__(self, allow_multiple_permissions=False, squash_permissions=False): + self.perms_collection = {} + self.column_perms = {} + self.allow_multiple_permissions = allow_multiple_permissions + self.squash_permissions = squash_permissions + + def get_perm(self, uid, gid, settings) -> Permission: + """If a perm already exists, return it. Else create it.""" + if (uid, gid, settings) in self.perms_collection: + return self.perms_collection[(uid, gid, settings)] + perm = ( + Permission(uid, gid, settings) + if not self.squash_permissions + else Permission(*self.get_process_permissions()) + ) + self.perms_collection[(uid, gid, settings)] = perm + return perm + + def register_columns(self, keys: list[str], perm: Permission) -> None: + """Links a list of column names to a given permission""" + if tuple(perm) not in self.perms_collection: + raise PermissionNotFoundError(perm) + for key in keys: + self.column_perms[key] = perm + + def register_columns_with_file(self, keys: list[str], fp: str) -> None: + """Gets a file's Permission and links it to the given columns""" + uid, gid, settings = ( + self.get_process_permissions() + if fp is None or self.squash_permissions + else self.get_file_permissions(fp) + ) + perm = self.get_perm(uid, gid, settings) + self.register_columns(keys, perm) + + def get_permission_columnlist_mapping(self) -> dict[Permission, list[str]]: + """ + Returns a mapping from unique Permission to list of columns. + """ + mapping = defaultdict(list) + for col, perm in self.column_perms.items(): + mapping[perm].append(col) + return mapping + + def get_column_perms(self, key: str) -> Permission: + """Get the Permission of a given column""" + try: + return self.column_perms[key] + except KeyError: + raise ColumnNotRegisteredError(key) + + def get_file_permissions(self, fpath: str) -> tuple[int, int, str]: + """Given a filepath, returns (uid, gid, settings)""" + st = stat(fpath) # includes info on filetype, perms, etc. + uid = st.st_uid + gid = st.st_gid + perm_mask = S_IRWXU | S_IRWXG | S_IRWXO # user | group | other + settings = oct(st.st_mode & perm_mask) # select perm bits from st_mode + return (uid, gid, settings) + + def set_file_permissions(self, file_mapping: dict[str, list[str]]) -> None: + """ + Given a mapping from filename to list of columns, set each file + to its column's permissions. (All columns of a file share perms) + """ + for filename, cols in file_mapping.items(): + f_perm = self.get_column_perms(cols[0]) # cols share same perms + uid, gid, settings = tuple(f_perm) + chown(filename, uid, gid) # type: ignore + chmod(filename, int(settings, base=8)) # type: ignore + + def get_process_permissions(self) -> tuple[int, int, str]: + """ + In the event of data not coming from a file, + default to (uid, egid, 660) + """ + uid = getuid() + egid = getgid() + return (uid, egid, "0o660") + + def has_multiple_permissions(self) -> bool: + """ + Returns whether or not the collection has multiple permissions. + """ + return len(self.perms_collection.keys()) > 1 + + +class PermissionNotFoundError(Exception): + def __init__(self, perm: Permission) -> None: + self.msg = ( + f"Permission {perm} not found. Make sure to use get_perm instead of " + + "manually instantiating a Permission to register." + ) + super().__init__(self.msg) + + +class ColumnNotRegisteredError(Exception): + def __init__(self, key: str) -> None: + self.msg = f"Permission for column {key} not registered. Be sure to `register_columns`." + super().__init__(self.msg) diff --git a/dsi/permissions/tests/test_permissions.py b/dsi/permissions/tests/test_permissions.py new file mode 100644 index 00000000..3f13dcca --- /dev/null +++ b/dsi/permissions/tests/test_permissions.py @@ -0,0 +1,115 @@ +import git +import os +from glob import glob + +from dsi.core import Terminal +from dsi.permissions.permissions import PermissionsManager + + +def get_git_root(path): + git_repo = git.Repo(path, search_parent_directories=True) + git_root = git_repo.git.rev_parse("--show-toplevel") + return git_root + + +def print_and_y(s): + print(s) + return "y" + + +def test_multiple_perms_fails_by_default(monkeypatch): + monkeypatch.setattr("builtins.input", print_and_y) # mock input + term = Terminal() + bueno_path = "/".join([get_git_root("."), "dsi/data", "bueno1.data"]) + os.chmod(bueno_path, 0o664) + term.load_module("plugin", "Bueno", "consumer", filenames=bueno_path) + term.load_module("plugin", "Hostname", "consumer") + term.transload() + pq_path = "/".join([get_git_root("."), "dsi/data", "dummy_data.pq"]) + term.load_module("driver", "Parquet", "back-end", filename=pq_path) + assert not term.artifact_handler(interaction_type="put") + + +def test_multiple_permissions_register_correctly(monkeypatch): + monkeypatch.setattr("builtins.input", print_and_y) # mock input + term = Terminal(allow_multiple_permissions=True) + bueno_path = "/".join([get_git_root("."), "dsi/data", "bueno1.data"]) + os.chmod(bueno_path, 0o664) + term.load_module("plugin", "Bueno", "consumer", filenames=bueno_path) + term.load_module("plugin", "Hostname", "consumer") + + term.transload() + + for env_col in ("uid", "effective_gid", "moniker", "gid_list"): + uid, gid, settings = tuple(term.perms_manager.column_perms[env_col]) + assert isinstance(uid, int) + assert isinstance(gid, int) + assert settings == "0o660" + + for bueno_col in ("foo", "bar", "baz"): + uid, gid, settings = tuple(term.perms_manager.column_perms[bueno_col]) + assert isinstance(uid, int) + assert isinstance(gid, int) + assert settings == "0o664" + + pq_path = "/".join([get_git_root("."), "dsi/data", "dummy_data.pq"]) + term.load_module("driver", "Parquet", "back-end", filename=pq_path) + assert term.artifact_handler(interaction_type="put") + + +def test_squash_permissions_register_correctly(monkeypatch): + monkeypatch.setattr("builtins.input", print_and_y) # mock input + term = Terminal(squash_permissions=True) + bueno_path = "/".join([get_git_root("."), "dsi/data", "bueno1.data"]) + os.chmod(bueno_path, 0o664) + term.load_module("plugin", "Bueno", "consumer", filenames=bueno_path) + term.load_module("plugin", "Hostname", "consumer") + + term.transload() + + for env_col in ("uid", "effective_gid", "moniker", "gid_list"): + uid, gid, settings = tuple(term.perms_manager.column_perms[env_col]) + assert isinstance(uid, int) + assert isinstance(gid, int) + assert settings == "0o660" + + for bueno_col in ("foo", "bar", "baz"): + uid, gid, settings = tuple(term.perms_manager.column_perms[bueno_col]) + assert isinstance(uid, int) + assert isinstance(gid, int) + assert settings == "0o660" + + pq_path = "/".join([get_git_root("."), "dsi/data", "dummy_data.pq"]) + term.load_module("driver", "Parquet", "back-end", filename=pq_path) + assert term.artifact_handler(interaction_type="put") + + +def test_permissions_output_correctly(monkeypatch): + monkeypatch.setattr("builtins.input", print_and_y) # mock input + term = Terminal(allow_multiple_permissions=True) + bueno_path = "/".join([get_git_root("."), "dsi/data", "bueno1.data"]) + os.chmod(bueno_path, 0o664) + + term.load_module("plugin", "Bueno", "consumer", filenames=bueno_path) + term.load_module("plugin", "Hostname", "consumer") + term.transload() + + pq_path = "/".join([get_git_root("."), "dsi/data", "dummy_data.pq"]) + term.load_module("driver", "Parquet", "back-end", filename=pq_path) + + assert term.artifact_handler(interaction_type="put") + + pm = PermissionsManager() + written_paths = glob("/".join([get_git_root("."), "dsi/data"]) + "/dummy_data*.pq") + for path in written_paths: + uid, gid, settings = pm.get_file_permissions(path) + if settings == "0o664": # the bueno file + old_uid, old_gid, old_settings = pm.get_file_permissions(bueno_path) + assert uid == old_uid + assert gid == old_gid + assert settings == old_settings + assert ( + path.find(str(uid)) != -1 + and path.find(str(gid)) != -1 + and path.find(settings) != -1 + ) diff --git a/dsi/plugins/env.py b/dsi/plugins/env.py index 90d349a4..27af35d7 100644 --- a/dsi/plugins/env.py +++ b/dsi/plugins/env.py @@ -8,9 +8,7 @@ from json import dumps from dsi.plugins.metadata import StructuredMetadata -from dsi.plugins.plugin_models import ( - GitInfoModel, HostnameModel, SystemKernelModel -) +from dsi.plugins.plugin_models import GitInfoModel, HostnameModel, SystemKernelModel class Environment(StructuredMetadata): @@ -21,16 +19,16 @@ class Environment(StructuredMetadata): information. """ - def __init__(self): - super().__init__() + def __init__(self, **kwargs): + super().__init__(**kwargs) # Get POSIX info self.posix_info = OrderedDict() - self.posix_info['uid'] = os.getuid() - self.posix_info['effective_gid'] = os.getgid() - egid = self.posix_info['effective_gid'] - self.posix_info['moniker'] = getuser() - moniker = self.posix_info['moniker'] - self.posix_info['gid_list'] = os.getgrouplist(moniker, egid) + self.posix_info["uid"] = os.getuid() + self.posix_info["effective_gid"] = os.getgid() + egid = self.posix_info["effective_gid"] + self.posix_info["moniker"] = getuser() + moniker = self.posix_info["moniker"] + self.posix_info["gid_list"] = os.getgrouplist(moniker, egid) class Hostname(Environment): @@ -42,12 +40,13 @@ class Hostname(Environment): """ def __init__(self, **kwargs) -> None: - super().__init__() + super().__init__(**kwargs) def pack_header(self) -> None: """Set schema with keys of prov_info.""" column_names = list(self.posix_info.keys()) + ["hostname"] - self.set_schema(column_names, validation_model=HostnameModel) + self.set_schema(column_names) + self.perms_manager.register_columns_with_file(column_names, None) def add_rows(self) -> None: """Parses environment provenance data and adds the row.""" @@ -65,36 +64,39 @@ class GitInfo(Environment): Adds the current git remote and git commit to metadata. """ - def __init__(self, git_repo_path='./') -> None: - """ Initializes the git repo in the given directory and access to git commands """ - super().__init__() + def __init__(self, git_repo_path="./", **kwargs) -> None: + """Initializes the git repo in the given directory and access to git commands""" + super().__init__(**kwargs) try: self.repo = Repo(git_repo_path, search_parent_directories=True) except git.exc.InvalidGitRepositoryError: - raise Exception(f"Git could not find .git/ in {git_repo_path}, " + - "GitInfo Plugin must be given a repo base path " + - "(default is working dir)") + raise Exception( + f"Git could not find .git/ in {git_repo_path}, " + + "GitInfo Plugin must be given a repo base path " + + "(default is working dir)" + ) self.git_info = { "git_remote": lambda: self.repo.git.remote("-v"), - "git_commit": lambda: self.repo.git.rev_parse("--short", "HEAD") + "git_commit": lambda: self.repo.git.rev_parse("--short", "HEAD"), } def pack_header(self) -> None: - """ Set schema with POSIX and Git columns """ - column_names = list(self.posix_info.keys()) + \ - list(self.git_info.keys()) + """Set schema with POSIX and Git columns""" + column_names = list(self.posix_info.keys()) + list(self.git_info.keys()) self.set_schema(column_names, validation_model=GitInfoModel) + self.perms_manager.register_columns_with_file(column_names, None) def add_rows(self) -> None: - """ Adds a row to the output with POSIX info, git remote, and git commit """ + """Adds a row to the output with POSIX info, git remote, and git commit""" if not self.schema_is_set(): self.pack_header() - row = list(self.posix_info.values()) + \ - [self.git_info["git_remote"](), self.git_info["git_commit"]()] + row = list(self.posix_info.values()) + [ + self.git_info["git_remote"](), + self.git_info["git_commit"](), + ] self.add_to_output(row) - class SystemKernel(Environment): """ Plugin for reading environment provenance data. @@ -108,9 +110,9 @@ class SystemKernel(Environment): 6. Container information, if containerized """ - def __init__(self) -> None: + def __init__(self, **kwargs) -> None: """Initialize SystemKernel with inital provenance info.""" - super().__init__() + super().__init__(**kwargs) self.prov_info = self.get_prov_info() self.column_names = ["kernel_info"] @@ -118,6 +120,7 @@ def pack_header(self) -> None: """Set schema with keys of prov_info.""" column_names = list(self.posix_info.keys()) + self.column_names self.set_schema(column_names, validation_model=SystemKernelModel) + self.perms_manager.register_columns_with_file(column_names, None) def add_rows(self) -> None: """Parses environment provenance data and adds the row.""" @@ -139,7 +142,7 @@ def get_prov_info(self) -> str: return blob def get_kernel_version(self) -> dict: - """Kernel version is obtained by the "uname -r" command, returns it in a dict. """ + """Kernel version is obtained by the "uname -r" command, returns it in a dict.""" return {"kernel version": self.get_cmd_output(["uname -r"])} def get_kernel_ct_config(self) -> dict: @@ -183,8 +186,7 @@ def get_kernel_rt_config(self) -> dict: rt_config = {} for line in lines: if "=" in line: # if the line is not permission denied - option, value = line.split( - " = ", maxsplit=1) # note the spaces + option, value = line.split(" = ", maxsplit=1) # note the spaces rt_config[option] = value # If line is permission denied, ignore it. return rt_config @@ -199,12 +201,15 @@ def get_kernel_mod_config(self) -> dict: modules = self.get_cmd_output(lsmod_command).split("\n") sep = "END MODINFO" - modinfo_command = ["/sbin/modinfo $(lsmod | tail -n +2 | awk '{print $1}' | \ + modinfo_command = [ + "/sbin/modinfo $(lsmod | tail -n +2 | awk '{print $1}' | \ sed 's/nvidia_/nvidia-current-/g' | \ sed 's/^nvidia$/nvidia-current/g') | " - f"sed -e 's/filename:/{sep}filename:/g'"] - modinfos = self.get_cmd_output( - modinfo_command, ignore_stderr=True).split("\n" + sep) + f"sed -e 's/filename:/{sep}filename:/g'" + ] + modinfos = self.get_cmd_output(modinfo_command, ignore_stderr=True).split( + "\n" + sep + ) mod_configs = {} for mod, info in zip(modules, modinfos): diff --git a/dsi/plugins/file_consumer.py b/dsi/plugins/file_consumer.py index 6645f663..110c1136 100644 --- a/dsi/plugins/file_consumer.py +++ b/dsi/plugins/file_consumer.py @@ -24,7 +24,7 @@ def __init__(self, filenames, **kwargs): raise TypeError self.file_info = {} for filename in self.filenames: - sha = sha1(open(filename, 'rb').read()) + sha = sha1(open(filename, "rb").read()) self.file_info[abspath(filename)] = sha.hexdigest() @@ -42,13 +42,13 @@ def __init__(self, filenames, **kwargs): self.csv_data = {} def pack_header(self) -> None: - """ Set schema based on the CSV columns """ + """Set schema based on the CSV columns""" column_names = list(self.file_info.keys()) + list(self.csv_data.keys()) self.set_schema(column_names) def add_rows(self) -> None: - """ Adds a list containing one or more rows of the CSV along with file_info to output. """ + """Adds a list containing one or more rows of the CSV along with file_info to output.""" if not self.schema_is_set(): # use Pandas to append all CSVs together as a @@ -65,8 +65,14 @@ def add_rows(self) -> None: temp_df = read_csv(filename) # raise exception if schemas do not match if any([set(temp_df.columns) != set(df.columns) for df in dfs]): - print('Error: Strict schema mode is on. Schemas do not match.') + print( + "Error: Strict schema mode is on. Schemas do not match." + ) raise TypeError + + self.perms_manager.register_columns_with_file( + temp_df.columns, filename + ) dfs.append(temp_df) total_df = concat([total_df, temp_df]) @@ -75,14 +81,19 @@ def add_rows(self) -> None: total_df = DataFrame() for filename in self.filenames: temp_df = read_csv(filename) + self.perms_manager.register_columns_with_file( + temp_df.columns, filename + ) total_df = concat([total_df, temp_df]) # Columns are present in the middleware already (schema_is_set==True). # TODO: Can this go under the else block at line #79? - self.csv_data = total_df.to_dict('list') + self.csv_data = total_df.to_dict("list") for col, coldata in self.csv_data.items(): # replace NaNs with None - self.csv_data[col] = [None if type(item) == float and isnan(item) else item - for item in coldata] + self.csv_data[col] = [ + None if type(item) == float and isnan(item) else item + for item in coldata + ] self.pack_header() total_length = len(self.csv_data[list(self.csv_data.keys())[0]]) @@ -113,16 +124,18 @@ def add_rows(self) -> None: """Parses Bueno data and adds a list containing 1 or more rows.""" if not self.schema_is_set(): for idx, filename in enumerate(self.filenames): - with open(filename, 'r') as fh: + with open(filename, "r") as fh: file_content = json.load(fh) for key, val in file_content.items(): # Check if column already exists if key not in self.bueno_data: # Initialize empty column if first time seeing it - self.bueno_data[key] = [None] \ - * len(self.filenames) + self.bueno_data[key] = [None] * len(self.filenames) # Set the appropriate row index value for this keyval_pair self.bueno_data[key][idx] = val + self.perms_manager.register_columns_with_file( + list(self.bueno_data.keys()), self.filenames[0] + ) # TODO: Each row comes from a different file, not sure what to do self.pack_header() rows = list(self.bueno_data.values()) diff --git a/dsi/plugins/metadata.py b/dsi/plugins/metadata.py index 5dd41451..779251bd 100644 --- a/dsi/plugins/metadata.py +++ b/dsi/plugins/metadata.py @@ -40,6 +40,7 @@ def __init__(self, **kwargs): """ self.output_collector = OrderedDict() self.column_cnt = None # schema not set until pack_header + self.perms_manager = kwargs['perms_manager'] self.validation_model = None # optional pydantic Model # Check for strict_mode option if 'strict_mode' in kwargs: @@ -57,6 +58,7 @@ def set_schema(self, column_names: list, validation_model=None) -> None: """ Initializes columns in the output_collector and column_cnt. Useful in a plugin's pack_header method. + Also registers column permissions if filename is set. """ # Strict mode | SMLock | relation diff --git a/dsi/plugins/tests/test_env.py b/dsi/plugins/tests/test_env.py index 4e131c75..880ebe1f 100644 --- a/dsi/plugins/tests/test_env.py +++ b/dsi/plugins/tests/test_env.py @@ -1,6 +1,7 @@ import collections from dsi.plugins.env import Hostname, SystemKernel, GitInfo +from dsi.permissions.permissions import PermissionsManager import git from json import loads @@ -8,18 +9,20 @@ def get_git_root(path): git_repo = git.Repo(path, search_parent_directories=True) git_root = git_repo.git.rev_parse("--show-toplevel") - return (git_root) + return git_root def test_hostname_plugin_type(): - a = Hostname() + mock_pm = PermissionsManager() + a = Hostname(perms_manager=mock_pm) a.add_rows() a.add_rows() assert type(a.output_collector) == collections.OrderedDict def test_hostname_plugin_col_shape(): - a = Hostname() + mock_pm = PermissionsManager() + a = Hostname(perms_manager=mock_pm) a.add_rows() a.add_rows() assert len(a.output_collector.keys()) == len(a.output_collector.values()) @@ -27,7 +30,8 @@ def test_hostname_plugin_col_shape(): def test_hostname_plugin_row_shape(): for row_cnt in range(1, 10): - a = Hostname() + mock_pm = PermissionsManager() + a = Hostname(perms_manager=mock_pm) for _ in range(row_cnt): a.add_rows() column_values = list(a.output_collector.values()) @@ -36,13 +40,15 @@ def test_hostname_plugin_row_shape(): assert len(col) == row_shape == row_cnt -def test_systemkernel_plugin_type(): - plug = SystemKernel() +def test_envprov_plugin_type(): + mock_pm = PermissionsManager() + plug = SystemKernel(perms_manager=mock_pm) assert type(plug.output_collector) == collections.OrderedDict -def test_systemkernel_plugin_adds_rows(): - plug = SystemKernel() +def test_envprov_plugin_adds_rows(): + mock_pm = PermissionsManager() + plug = SystemKernel(perms_manager=mock_pm) plug.add_rows() plug.add_rows() @@ -54,26 +60,32 @@ def test_systemkernel_plugin_adds_rows(): def test_systemkernel_plugin_blob_is_big(): - plug = SystemKernel() + mock_pm = PermissionsManager() + plug = SystemKernel(perms_manager=mock_pm) plug.add_rows() blob = plug.output_collector["kernel_info"][0] info_dict = loads(blob) + # 1 systemkernel col + 4 inherited Env cols + assert len(plug.output_collector.keys()) == 5 + # dict should have more than 1000 (~7000) keys assert len(info_dict.keys()) > 1000 def test_git_plugin_type(): - root = get_git_root('.') - plug = GitInfo(git_repo_path=root) + mock_pm = PermissionsManager() + root = get_git_root(".") + plug = GitInfo(git_repo_path=root, perms_manager=mock_pm) plug.add_rows() assert type(plug.output_collector) == collections.OrderedDict def test_git_plugin_adds_rows(): - root = get_git_root('.') - plug = GitInfo(git_repo_path=root) + mock_pm = PermissionsManager() + root = get_git_root(".") + plug = GitInfo(git_repo_path=root, perms_manager=mock_pm) plug.add_rows() plug.add_rows() @@ -85,8 +97,9 @@ def test_git_plugin_adds_rows(): def test_git_plugin_infos_are_str(): - root = get_git_root('.') - plug = GitInfo(git_repo_path=root) + mock_pm = PermissionsManager() + root = get_git_root(".") + plug = GitInfo(git_repo_path=root, perms_manager=mock_pm) plug.add_rows() assert type(plug.output_collector["git_remote"][0]) == str diff --git a/dsi/plugins/tests/test_file_consumer.py b/dsi/plugins/tests/test_file_consumer.py index 5a33e8f1..40e9fc0c 100644 --- a/dsi/plugins/tests/test_file_consumer.py +++ b/dsi/plugins/tests/test_file_consumer.py @@ -3,25 +3,28 @@ import git from dsi.plugins.file_consumer import Bueno, Csv +from dsi.permissions.permissions import PermissionsManager def get_git_root(path): git_repo = git.Repo(path, search_parent_directories=True) git_root = git_repo.git.rev_parse("--show-toplevel") - return (git_root) + return git_root def test_bueno_plugin_type(): - path = '/'.join([get_git_root('.'), 'dsi/data', 'bueno1.data']) - plug = Bueno(filenames=path) + mock_pm = PermissionsManager() + path = "/".join([get_git_root("."), "dsi/data", "bueno1.data"]) + plug = Bueno(filenames=path, perms_manager=mock_pm) plug.add_rows() assert type(plug.output_collector) == OrderedDict def test_bueno_plugin_adds_rows(): - path1 = '/'.join([get_git_root('.'), 'dsi/data', 'bueno1.data']) - path2 = '/'.join([get_git_root('.'), 'dsi/data', 'bueno2.data']) - plug = Bueno(filenames=[path1, path2]) + mock_pm = PermissionsManager() + path1 = "/".join([get_git_root("."), "dsi/data", "bueno1.data"]) + path2 = "/".join([get_git_root("."), "dsi/data", "bueno2.data"]) + plug = Bueno(filenames=[path1, path2], perms_manager=mock_pm) plug.add_rows() plug.add_rows() @@ -33,15 +36,17 @@ def test_bueno_plugin_adds_rows(): def test_csv_plugin_type(): - path = '/'.join([get_git_root('.'), 'dsi/data', 'wildfiredata.csv']) - plug = Csv(filenames=path) + mock_pm = PermissionsManager() + path = "/".join([get_git_root("."), "dsi/data", "wildfiredata.csv"]) + plug = Csv(filenames=path, perms_manager=mock_pm) plug.add_rows() assert type(plug.output_collector) == OrderedDict def test_csv_plugin_adds_rows(): - path = '/'.join([get_git_root('.'), 'dsi/data', 'wildfiredata.csv']) - plug = Csv(filenames=path) + mock_pm = PermissionsManager() + path = "/".join([get_git_root("."), "dsi/data", "wildfiredata.csv"]) + plug = Csv(filenames=path, perms_manager=mock_pm) plug.add_rows() for key, val in plug.output_collector.items(): @@ -52,10 +57,11 @@ def test_csv_plugin_adds_rows(): def test_csv_plugin_adds_rows_multiple_files(): - path1 = '/'.join([get_git_root('.'), 'dsi/data', 'wildfiredata.csv']) - path2 = '/'.join([get_git_root('.'), 'dsi/data', 'yosemite5.csv']) + mock_pm = PermissionsManager() + path1 = "/".join([get_git_root("."), "dsi/data", "wildfiredata.csv"]) + path2 = "/".join([get_git_root("."), "dsi/data", "yosemite5.csv"]) - plug = Csv(filenames=[path1, path2]) + plug = Csv(filenames=[path1, path2], perms_manager=mock_pm) plug.add_rows() for key, val in plug.output_collector.items(): @@ -66,10 +72,11 @@ def test_csv_plugin_adds_rows_multiple_files(): def test_csv_plugin_adds_rows_multiple_files_strict_mode(): - path1 = '/'.join([get_git_root('.'), 'dsi/data', 'wildfiredata.csv']) - path2 = '/'.join([get_git_root('.'), 'dsi/data', 'yosemite5.csv']) + mock_pm = PermissionsManager() + path1 = "/".join([get_git_root("."), "dsi/data", "wildfiredata.csv"]) + path2 = "/".join([get_git_root("."), "dsi/data", "yosemite5.csv"]) - plug = Csv(filenames=[path1, path2], strict_mode=True) + plug = Csv(filenames=[path1, path2], strict_mode=True, perms_manager=mock_pm) try: plug.add_rows() except TypeError: @@ -78,13 +85,12 @@ def test_csv_plugin_adds_rows_multiple_files_strict_mode(): def test_csv_plugin_leaves_active_metadata_wellformed(): - path = '/'.join([get_git_root('.'), 'dsi/data', 'wildfiredata.csv']) + path = "/".join([get_git_root("."), "dsi/data", "wildfiredata.csv"]) term = Terminal() - term.load_module('plugin', 'Csv', 'consumer', filenames=[path]) - term.load_module('plugin', 'Hostname', 'producer') + term.load_module("plugin", "Csv", "consumer", filenames=[path]) + term.load_module("plugin", "Hostname", "producer") term.transload() columns = list(term.active_metadata.values()) - assert all([len(columns[0]) == len(col) - for col in columns]) # all same length + assert all([len(columns[0]) == len(col) for col in columns]) # all same length diff --git a/dsi/tests/test_core.py b/dsi/tests/test_core.py index 9e1e6d5b..13b1af80 100644 --- a/dsi/tests/test_core.py +++ b/dsi/tests/test_core.py @@ -1,6 +1,13 @@ +import git from dsi.core import Terminal +def get_git_root(path): + git_repo = git.Repo(path, search_parent_directories=True) + git_root = git_repo.git.rev_parse("--show-toplevel") + return (git_root) + + def test_terminal_module_getter(): a = Terminal() plugins = a.list_available_modules('plugin') @@ -10,7 +17,8 @@ def test_terminal_module_getter(): def test_unload_module(): a = Terminal() - a.load_module('plugin', 'GitInfo', 'producer') + a.load_module('plugin', 'GitInfo', 'producer', + git_repo_path=get_git_root('.')) assert len(a.list_loaded_modules()['producer']) == 1 a.unload_module('plugin', 'GitInfo', 'producer') assert len(a.list_loaded_modules()['producer']) == 0