diff --git a/sh.py b/sh.py index fb0f5fe7..7fd47a6c 100644 --- a/sh.py +++ b/sh.py @@ -70,7 +70,7 @@ def callable(ob): IS_OSX = platform.system() == "Darwin" THIS_DIR = os.path.dirname(os.path.realpath(__file__)) -SH_LOGGER_NAME = "sh" +SH_LOGGER_NAME = __name__ import errno @@ -153,19 +153,19 @@ class ErrorReturnCode(Exception): derived classes with the format: ErrorReturnCode_NNN where NNN is the exit code number. the reason for this is it reduces boiler plate code when testing error return codes: - + try: some_cmd() except ErrorReturnCode_12: print("couldn't do X") - + vs: try: some_cmd() except ErrorReturnCode as e: if e.exit_code == 12: print("couldn't do X") - + it's not much of a savings, but i believe it makes the code easier to read """ truncate_cap = 750 @@ -350,7 +350,7 @@ class Logger(object): script is done. with sh, it's easy to create loggers with unique names if we want our loggers to include our command arguments. for example, these are all unique loggers: - + ls -l ls -l /tmp ls /tmp @@ -364,7 +364,7 @@ def __init__(self, name, context=None): self.name = name if context: context = context.replace("%", "%%") - self.context = context + self.context = context self.log = logging.getLogger("%s.%s" % (SH_LOGGER_NAME, name)) def _format_msg(self, msg, *args): @@ -660,7 +660,7 @@ class Command(object): represents the program itself (and not a running instance of it), it should hold very little state. in fact, the only state it does hold is baked arguments. - + when a Command object is called, the result that is returned is a RunningCommand object, which represents the Command put into an execution state. """ @@ -787,7 +787,7 @@ def __init__(self, path): if not found: raise CommandNotFound(path) - self._path = encode_to_py3bytes_or_py2str(found) + self._path = encode_to_py3bytes_or_py2str(found) self._partial = False self._partial_baked_args = [] @@ -1624,7 +1624,7 @@ class NotYetReadyToRead(Exception): pass def determine_how_to_read_input(input_obj): """ given some kind of input object, return a function that knows how to read chunks of that input object. - + each reader function should return a chunk and raise a DoneReadingForever exception, or return None, when there's no more data to read @@ -2280,31 +2280,86 @@ def __init__(self, self_module, baked_args={}): # but it seems to be the only way to make reload() behave # nicely. if i make these attributes dynamic lookups in # __getattr__, reload sometimes chokes in weird ways... - for attr in ["__builtins__", "__doc__", "__name__", "__package__"]: + for attr in ["__builtins__", "__doc__", "__file__", "__name__", "__package__"]: setattr(self, attr, getattr(self_module, attr, None)) # python 3.2 (2.7 and 3.3 work fine) breaks on osx (not ubuntu) # if we set this to None. and 3.3 needs a value for __path__ self.__path__ = [] self.__self_module = self_module - self.__env = Environment(globals(), baked_args) - - def __setattr__(self, name, value): - if hasattr(self, "__env"): - self.__env[name] = value - else: - ModuleType.__setattr__(self, name, value) + self.__env = Environment(globals(), baked_args=baked_args) def __getattr__(self, name): - if name == "__env": - raise AttributeError return self.__env[name] - # accept special keywords argument to define defaults for all operations - # that will be processed with given by return SelfWrapper def __call__(self, **kwargs): - return SelfWrapper(self.__self_module, kwargs) + "DO NOT store the new module returned by this function in a variable named 'sh'" + baked_args = self.__env.baked_args.copy() + baked_args.update(kwargs) + return self.__class__(self.__self_module, baked_args) + + +class ModuleImporterFromVariables(object): + """ + Implements the Importer protocol. + This hook allow to import modules from variables. + Example: + mod = build_a_module() + import mod + """ + def __init__(self, restrict_to=None, forbid_mod_names=None): + # We can only store class names here, not class objects, + # as SelfWrapper class definition will be overriden next time this module is loaded + self.restrict_to = restrict_to + self.forbid_mod_names = forbid_mod_names + + def register_if_not_active(self): + # For the same reason as above, 'isinstance' won't work here, + # because if there is an instance of ModuleImporterFromVariables in that list, + # its class correspond to a previous definition of ModuleImporterFromVariables, + # before this file was reloaded + if not any(getattr(module, 'restrict_to', None) == self.restrict_to for module in sys.meta_path): + sys.meta_path.insert(0, self) + + def find_module(self, mod_fullname, path=None): + if self.forbid_mod_names and mod_fullname in self.forbid_mod_names: + return None + parent_frame = inspect.currentframe().f_back + # This function is called from the frozen importlib, + # so we go back up frame per frame until we're "out" of the importlib code + while parent_frame.f_code.co_filename == '': + parent_frame = parent_frame.f_back + if mod_fullname not in parent_frame.f_locals and mod_fullname not in parent_frame.f_globals: + return None + return self + def load_module(self, mod_fullname, parent_frame=None): + "We intentionnally do zero caching through sys.modules" + if not parent_frame: + parent_frame = inspect.currentframe().f_back + # This function is called from the frozen importlib, + # so we go back up frame per frame until we're "out" of the importlib code + while parent_frame.f_code.co_filename == '': + parent_frame = parent_frame.f_back + if mod_fullname in parent_frame.f_locals: + module = parent_frame.f_locals[mod_fullname] + elif mod_fullname in parent_frame.f_globals: + module = parent_frame.f_globals[mod_fullname] + else: + raise ImportError("%s not found in scope" % (mod_fullname,)) + if self.restrict_to and not any(str(type(module)) == classname for classname in self.restrict_to): + raise ImportError("%s (%s) does not belong to the list of allowed modules: %s" % (mod_fullname, module, self.restrict_to)) + if self.forbid_mod_names and mod_fullname in self.forbid_mod_names: + raise ImportError("%s (%s) is a forbidden module name" % (mod_fullname, module)) + module.__loader__ = self + return module + + # Methods needed for Python 3 + def create_module(self, spec): + return self.load_module(spec.name, parent_frame=inspect.currentframe().f_back) + + def exec_module(self, module): + pass # we're being run as a stand-alone script @@ -2350,3 +2405,4 @@ def run_test(version, locale): else: self = sys.modules[__name__] sys.modules[__name__] = SelfWrapper(self) + ModuleImporterFromVariables(restrict_to=[""], forbid_mod_names='sh').register_if_not_active() diff --git a/test.py b/test.py index f5dc23be..c702d76e 100644 --- a/test.py +++ b/test.py @@ -4,8 +4,8 @@ from os.path import exists, join, realpath import unittest import tempfile +import sh as _sh # importing under another name to avoid skewing tests import sys -import sh import platform from functools import wraps @@ -14,12 +14,16 @@ tempdir = realpath(tempfile.gettempdir()) IS_OSX = platform.system() == "Darwin" IS_PY3 = sys.version_info[0] == 3 + if IS_PY3: unicode = str - python = sh.Command(sh.which("python%d.%d" % sys.version_info[:2])) + from io import StringIO + from io import BytesIO as cStringIO + python = _sh.Command(_sh.which("python%d.%d" % sys.version_info[:2])) else: - from sh import python - + from StringIO import StringIO + from cStringIO import StringIO as cStringIO + python = _sh.python THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -38,7 +42,7 @@ def skip(*args, **kwargs): return wrapper requires_posix = skipUnless(os.name == "posix", "Requires POSIX") -requires_utf8 = skipUnless(sh.DEFAULT_ENCODING == "UTF-8", "System encoding must be UTF-8") +requires_utf8 = skipUnless(_sh.DEFAULT_ENCODING == "UTF-8", "System encoding must be UTF-8") def create_tmp_test(code, prefix="tmp", delete=True): @@ -59,7 +63,6 @@ def create_tmp_test(code, prefix="tmp", delete=True): @requires_posix class FunctionalTests(unittest.TestCase): - def test_print_command(self): from sh import ls, which actual_location = which("ls") @@ -393,6 +396,7 @@ def test_doesnt_execute_directories(self): h.write(bunk_header) os.chmod(gcc_file2, int(0o755)) + import sh from sh import gcc if IS_PY3: self.assertEqual(gcc._path, @@ -932,6 +936,7 @@ def agg(line, stdin, process): process.terminate() return True + import sh caught_signal = False try: p = python(py.name, _out=agg, u=True, _bg=True) @@ -948,7 +953,6 @@ def agg(line, stdin, process): def test_stdout_callback_kill(self): import signal - import sh py = create_tmp_test(""" import sys @@ -968,6 +972,7 @@ def agg(line, stdin, process): process.kill() return True + import sh caught_signal = False try: p = python(py.name, _out=agg, u=True, _bg=True) @@ -1217,13 +1222,6 @@ def test_tty_output(self): def test_stringio_output(self): from sh import echo - if IS_PY3: - from io import StringIO - from io import BytesIO as cStringIO - else: - from StringIO import StringIO - from cStringIO import StringIO as cStringIO - out = StringIO() echo("-n", "testing 123", _out=out) self.assertEqual(out.getvalue(), "testing 123") @@ -1235,14 +1233,6 @@ def test_stringio_output(self): def test_stringio_input(self): from sh import cat - - if IS_PY3: - from io import StringIO - from io import BytesIO as cStringIO - else: - from StringIO import StringIO - from cStringIO import StringIO as cStringIO - input = StringIO() input.write("herpderp") input.seek(0) @@ -1317,7 +1307,7 @@ def test_encoding(self): def test_timeout(self): - from sh import sleep + import sh from time import time # check that a normal sleep is more or less how long the whole process @@ -1499,24 +1489,6 @@ def s(fn): str(fn()) self.assertEqual(p, "test") - def test_shared_secial_args(self): - import sh - - if IS_PY3: - from io import StringIO - from io import BytesIO as cStringIO - else: - from StringIO import StringIO - from cStringIO import StringIO as cStringIO - - out1 = sh.ls('.') - out2 = StringIO() - sh_new = sh(_out=out2) - sh_new.ls('.') - self.assertEqual(out1, out2.getvalue()) - out2.close() - - def test_signal_exception(self): from sh import SignalException_15 @@ -1589,7 +1561,7 @@ def test_file_output_isnt_buffered(self): def test_pushd(self): """ test that pushd is just a specialized form of sh.args """ - import os + import os, sh old_wd = os.getcwd() with sh.pushd(tempdir): new_wd = sh.pwd().strip() @@ -1602,7 +1574,7 @@ def test_pushd(self): def test_args_context(self): """ test that we can use the args with-context to temporarily override command settings """ - import os + import os, sh old_wd = os.getcwd() with sh.args(_cwd=tempdir): @@ -1951,6 +1923,99 @@ def test_chunk_buffered(self): self.assertEqual(b.flush(), b"e\n") +@requires_posix +class CustomizingDefaultsTests(unittest.TestCase): + def test_defaults(self): + import sh + out = StringIO() + _sh = sh(_out=out) + _sh.echo('-n', 'TEST') + self.assertEqual('TEST', out.getvalue()) + + def test_defaults_with_import_from_after(self): + import sh + out = StringIO() + _sh = sh(_out=out) + from _sh import echo + echo('-n', 'TEST') + self.assertEqual('TEST', out.getvalue()) + + out.seek(0); out.truncate(0) # Emptying the StringIO + + sh.echo('-n', 'KO') + self.assertEqual('', out.getvalue()) + + def test_defaults_with_import_from_before(self): + import sh + out = StringIO() + from sh import echo + _sh = sh(_out=out) + echo('-n', 'TEST') + self.assertEqual('', out.getvalue()) + + def test_defaults_imported_as_sh(self): + import sh + out = StringIO() + sh = sh(_out=out) # this is forbidden + from sh import echo # hence this import the default 'echo' + echo('-n', 'TEST') + self.assertEqual('', out.getvalue()) + + def test_defaults_not_set_in_other_modules(self): + import sh + out = StringIO() + _sh = sh(_out=out) + from _sh import echo + from test_module_that_import_echo import echo + echo('-n', 'TEST') + self.assertEqual('', out.getvalue()) + + def test_defaults_set_in_parent_function(self): + import sh + out = StringIO() + _sh = sh(_out=out) + def nested1(): + _sh.echo('-n', 'TEST1') + def nested2(): + import sh + sh.echo('-n', 'TEST2') + nested1() + nested2() + self.assertEqual('TEST1', out.getvalue()) + + def test_defaults_with_reimport(self): + import sh + out = StringIO() + _sh = sh(_out=out) + import _sh # this reimport '_sh' from the eponymous local variable + _sh.echo('-n', 'TEST') + self.assertEqual('TEST', out.getvalue()) + + def test_module_importer_from_variables_ok(self): + # Import from module in global scope + from _sh import echo + # Import from local module + _ = _sh + from _ import python + + def test_module_importer_from_variables_ko(self): + def unallowed_import(): + _os = os + from _os import path + self.assertRaises(ImportError, unallowed_import) + + def test_defaults_with_global_module(self): + global _sh + old_sh = _sh + try: + out = StringIO() + _sh = _sh(_out=out) + from _sh import echo + echo('-n', 'TEST') + self.assertEqual('TEST', out.getvalue()) + finally: + _sh = old_sh + if __name__ == "__main__": # if we're running a specific test, we can let unittest framework figure out diff --git a/test_module_that_import_echo.py b/test_module_that_import_echo.py new file mode 100644 index 00000000..76eb4db0 --- /dev/null +++ b/test_module_that_import_echo.py @@ -0,0 +1 @@ +from sh import echo