Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extension ordering #242

Merged
merged 14 commits into from
Oct 3, 2023
9 changes: 6 additions & 3 deletions src/rocker/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .core import get_rocker_version
from .core import RockerExtensionManager
from .core import DependencyMissing
from .core import ExtensionError

from .os_detector import detect_os

Expand Down Expand Up @@ -54,9 +55,11 @@ def main():
args_dict['mode'] = OPERATIONS_DRY_RUN
print('DEPRECATION Warning: --noexecute is deprecated for --mode dry-run please switch your usage by December 2020')

active_extensions = extension_manager.get_active_extensions(args_dict)
# Force user to end if present otherwise it will break other extensions
active_extensions.sort(key=lambda e:e.get_name().startswith('user'))
try:
active_extensions = extension_manager.get_active_extensions(args_dict)
except ExtensionError as e:
print(f"ERROR! {str(e)}")
return 1
print("Active extensions %s" % [e.get_name() for e in active_extensions])

base_image = args.image
Expand Down
85 changes: 81 additions & 4 deletions src/rocker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import signal
import struct
import termios
import typing

SYS_STDOUT = sys.stdout

Expand All @@ -45,6 +46,10 @@ class DependencyMissing(RuntimeError):
pass


class ExtensionError(RuntimeError):
pass


class RockerExtension(object):
"""The base class for Rocker extension points"""

Expand All @@ -58,6 +63,22 @@ def validate_environment(self, cliargs):
necessary resources are available, like hardware."""
pass

def invoke_after(self, cliargs) -> typing.Set[str]:
"""
This extension should be loaded after the extensions in the returned
set. These extensions are not required to be present, but if they are,
they will be loaded before this extension.
"""
return set()

def required(self, cliargs) -> typing.Set[str]:
"""
Ensures the specified extensions are present and combined with
this extension. If the required extension should be loaded before
this extension, it should also be added to the `invoke_after` set.
"""
return set()

def get_preamble(self, cliargs):
return ''

Expand Down Expand Up @@ -106,13 +127,70 @@ def extend_cli_parser(self, parser, default_args={}):
parser.add_argument('--extension-blacklist', nargs='*',
default=[],
help='Prevent any of these extensions from being loaded.')
parser.add_argument('--strict-extension-selection', action='store_true',
help='When enabled, causes an error if required extensions are not explicitly '
'called out on the command line. Otherwise, the required extensions will '
'automatically be loaded if available.')


def get_active_extensions(self, cli_args):
active_extensions = [e() for e in self.available_plugins.values() if e.check_args_for_activation(cli_args) and e.get_name() not in cli_args['extension_blacklist']]
active_extensions.sort(key=lambda e:e.get_name().startswith('user'))
return active_extensions
"""
Checks for missing dependencies (specified by each extension's
required() method) and additionally sorts them.
"""
def sort_extensions(extensions, cli_args):

def topological_sort(source: typing.Dict[str, typing.Set[str]]) -> typing.List[str]:
"""Perform a topological sort on names and dependencies and returns the sorted list of names."""
names = set(source.keys())
# prune optional dependencies if they are not present (at this point the required check has already occurred)
pending = [(name, dependencies.intersection(names)) for name, dependencies in source.items()]
emitted = []
while pending:
next_pending = []
next_emitted = []
for entry in pending:
name, deps = entry
deps.difference_update(emitted) # remove dependencies already emitted
if deps: # still has dependencies? recheck during next pass
next_pending.append(entry)
else: # no more dependencies? time to emit
yield name
next_emitted.append(name) # remember what was emitted for difference_update()
if not next_emitted:
raise ExtensionError("Cyclic dependancy detected: %r" % (next_pending,))
pending = next_pending
emitted = next_emitted

extension_graph = {name: cls.invoke_after(cli_args) for name, cls in sorted(extensions.items())}
active_extension_list = [extensions[name] for name in topological_sort(extension_graph)]
return active_extension_list

active_extensions = {}
find_reqs = set([name for name, cls in self.available_plugins.items() if cls.check_args_for_activation(cli_args)])
while find_reqs:
name = find_reqs.pop()

if name in self.available_plugins.keys():
if name not in cli_args['extension_blacklist']:
ext = self.available_plugins[name]()
active_extensions[name] = ext
else:
raise ExtensionError(f"Extension '{name}' is blacklisted.")
else:
raise ExtensionError(f"Extension '{name}' not found. Is it installed?")

# add additional reqs for processing not already known about
known_reqs = set(active_extensions.keys()).union(find_reqs)
missing_reqs = ext.required(cli_args).difference(known_reqs)
if missing_reqs:
if cli_args['strict_extension_selection']:
raise ExtensionError(f"Extension '{name}' is missing required extension(s) {list(missing_reqs)}")
else:
print(f"Adding implicilty required extension(s) {list(missing_reqs)} required by extension '{name}'")
find_reqs = find_reqs.union(missing_reqs)

return sort_extensions(active_extensions, cli_args)

def get_docker_client():
"""Simple helper function for pre 2.0 imports"""
Expand Down Expand Up @@ -254,7 +332,6 @@ def get_operating_mode(self, args):
print("No tty detected for stdin forcing non-interactive")
return operating_mode


def generate_docker_cmd(self, command='', **kwargs):
docker_args = ''

Expand Down
82 changes: 78 additions & 4 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from rocker.core import list_plugins
from rocker.core import get_docker_client
from rocker.core import get_rocker_version
from rocker.core import RockerExtension
from rocker.core import RockerExtensionManager
from rocker.core import ExtensionError

class RockerCoreTest(unittest.TestCase):

Expand Down Expand Up @@ -128,9 +130,82 @@ def test_extension_manager(self):
self.assertIn('non-interactive', help_str)
self.assertIn('--extension-blacklist', help_str)

active_extensions = active_extensions = extension_manager.get_active_extensions({'user': True, 'ssh': True, 'extension_blacklist': ['ssh']})
self.assertEqual(len(active_extensions), 1)
self.assertEqual(active_extensions[0].get_name(), 'user')
self.assertRaises(ExtensionError,
extension_manager.get_active_extensions,
{'user': True, 'ssh': True, 'extension_blacklist': ['ssh']})

def test_strict_required_extensions(self):
class Foo(RockerExtension):
@classmethod
def get_name(cls):
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
return 'bar'

def required(self, cli_args):
return {'foo'}

extension_manager = RockerExtensionManager()
extension_manager.available_plugins = {'foo': Foo, 'bar': Bar}

correct_extensions_args = {'strict_extension_selection': True, 'bar': True, 'foo': True, 'extension_blacklist': []}
extension_manager.get_active_extensions(correct_extensions_args)

incorrect_extensions_args = {'strict_extension_selection': True, 'bar': True, 'extension_blacklist': []}
self.assertRaises(ExtensionError,
extension_manager.get_active_extensions, incorrect_extensions_args)

def test_implicit_required_extensions(self):
class Foo(RockerExtension):
@classmethod
def get_name(cls):
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
return 'bar'

def required(self, cli_args):
return {'foo'}

extension_manager = RockerExtensionManager()
extension_manager.available_plugins = {'foo': Foo, 'bar': Bar}

implicit_extensions_args = {'strict_extension_selection': False, 'bar': True, 'extension_blacklist': []}
active_extensions = extension_manager.get_active_extensions(implicit_extensions_args)
self.assertEqual(len(active_extensions), 2)
# required extensions are not ordered, just check to make sure they are both present
if active_extensions[0].get_name() == 'foo':
self.assertEqual(active_extensions[1].get_name(), 'bar')
else:
self.assertEqual(active_extensions[0].get_name(), 'bar')
self.assertEqual(active_extensions[1].get_name(), 'foo')

def test_extension_sorting(self):
class Foo(RockerExtension):
@classmethod
def get_name(cls):
return 'foo'

class Bar(RockerExtension):
@classmethod
def get_name(cls):
return 'bar'

def invoke_after(self, cli_args):
return {'foo', 'absent_extension'}

extension_manager = RockerExtensionManager()
extension_manager.available_plugins = {'foo': Foo, 'bar': Bar}

args = {'bar': True, 'foo': True, 'extension_blacklist': []}
active_extensions = extension_manager.get_active_extensions(args)
self.assertEqual(active_extensions[0].get_name(), 'foo')
self.assertEqual(active_extensions[1].get_name(), 'bar')

def test_docker_cmd_interactive(self):
dig = DockerImageGenerator([], {}, 'ubuntu:bionic')
Expand All @@ -148,7 +223,6 @@ def test_docker_cmd_interactive(self):

self.assertNotIn('-it', dig.generate_docker_cmd(mode='non-interactive'))


def test_docker_cmd_nocleanup(self):
dig = DockerImageGenerator([], {}, 'ubuntu:bionic')

Expand Down