Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 55 additions & 22 deletions bash_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,31 @@

from .display import (extract_contents, build_cmds)

# Special command patterns
su = re.compile("(sudo )? *((\/usr)?\/bin\/)?su( +|$).*")
env = re.compile("(sudo )? *((\/usr)?\/bin\/)?(chroot |env |exec )(.* )?((\/usr)?\/bin\/)?bash( +|$).*")
bash = re.compile("(sudo )? *((\/usr)?\/bin\/)?bash( +|$).*")
passwd = re.compile("(sudo )? *((\/usr)?\/bin\/)?passwd( +|$).*")
sudo = re.compile("sudo .+")
special_commands = [su, env, bash, passwd, sudo] if os.getenv("BASH_KERNEL_SPECIAL_COMMANDS") is not None else []

class IREPLWrapper(replwrap.REPLWrapper):
"""A subclass of REPLWrapper that gives incremental output
specifically for bash_kernel.

The parameters are the same as for REPLWrapper, except for one
extra parameter:

:param line_output_callback: a callback method to receive each batch
of incremental output. It takes one string parameter.
:param kernel: the kernel object that provides at least the methods
`process_output` to send the response to the frontend, and the
`getpass` to ask the user for secrets .
"""
def __init__(self, cmd_or_spawn, orig_prompt, prompt_change, unique_prompt,
extra_init_cmd=None, line_output_callback=None):
extra_init_cmd=None, kernel=None):
self.unique_prompt = unique_prompt
self.line_output_callback = line_output_callback
self.prompt_change = prompt_change
self.extra_init_cmd = extra_init_cmd
self.kernel = kernel
# The extra regex at the start of PS1 below is designed to catch the
# `(envname) ` which conda/mamba add to the start of PS1 by default.
# Obviously anything else that looks like this, including user output,
Expand All @@ -42,33 +53,54 @@ def __init__(self, cmd_or_spawn, orig_prompt, prompt_change, unique_prompt,
# through a cell.
self.ps1_re = r"(\(\w+\) )?" + re.escape(self.unique_prompt + ">")
self.ps2_re = re.escape(self.unique_prompt + "+")
self.all_prompts = [re.compile(x) for x in [self.ps1_re, self.ps2_re, '\r?\n', '\r',
u"((Retype )?[Nn]ew )?[Pp]assword:",
u"\[sudo\] password for .*:",
u"su: .*\n", u"sudo: .*\n",
u"chroot: .*\n", u"passwd: .*\n",
"\$", "\#"]]
replwrap.REPLWrapper.__init__(self, cmd_or_spawn, orig_prompt,
prompt_change, new_prompt=self.ps1_re,
continuation_prompt=self.ps2_re, extra_init_cmd=extra_init_cmd)

def run_command(self, command, timeout=-1, async_=False):

self.prompts = self.all_prompts if True in [cmd.match(command) is not None for cmd in special_commands] else self.all_prompts[:4]
command = command + " -s /bin/bash" if su.match(command) else command
res = super().run_command(command, timeout=timeout, async_=async_)

# Initialization
if su.match(command) or bash.match(command) or env.match(command):
self.run_command(self.extra_init_cmd)
self.run_command("bind 'set enable-bracketed-paste off' >/dev/null 2>&1 || true")
# self.run_command(build_cmds())
return res

def _expect_prompt(self, timeout=-1):
prompts = [self.ps1_re, self.ps2_re]

if timeout == None:
# "None" means we are executing code from a Jupyter cell by way of the run_command
# in the do_execute() code below, so do incremental output, i.e.
# also look for end of line or carridge return
prompts.extend(['\r?\n', '\r'])
while True:
pos = self.child.expect_list([re.compile(x) for x in prompts], timeout=None)
if pos == 2:
# End of line received.
self.line_output_callback(self.child.before + '\n')
elif pos == 3:
# Carriage return ('\r') received.
self.line_output_callback(self.child.before + '\r')
else:
if len(self.child.before) != 0:
# Prompt received, but partial line precedes it.
self.line_output_callback(self.child.before)
pos = self.child.expect_list(self.prompts, timeout=None)
if pos in [0, 1]:
if len(self.child.before) > 0 and len(self.prompts) == 4:
self.kernel.process_output(self.child.before)
break
elif pos in [2, 3] + [6, 7, 8, 9]:
self.kernel.process_output(self.child.before + self.child.after)
elif pos in [4, 5]:
self.kernel.process_output(self.child.before + self.child.after)
password = self.kernel.getpass()
self.child.sendline(password)
elif pos in [10, 11]:
self.child.sendline(self.prompt_change)
else:
raise Exception("Unexpected prompt")
else:
# Otherwise, wait (with timeout) until the next prompt
prompts = [self.ps1_re, self.ps2_re]
pos = self.child.expect_list([re.compile(x) for x in prompts], timeout=timeout)

# Prompt received, so return normally
Expand Down Expand Up @@ -103,6 +135,8 @@ def __init__(self, **kwargs):
Kernel.__init__(self, **kwargs)
self._start_bash()
self._known_display_ids = set()
# Enable this to allow calling Kernel.getpass() for passwords.
self._allow_stdin = True

def _start_bash(self):
# Signal handlers are inherited by forked processes, and we can't easily
Expand Down Expand Up @@ -132,8 +166,7 @@ def _start_bash(self):
prompt_change = u"PS1='{0}' PS2='{1}' PROMPT_COMMAND=''".format(ps1, ps2)
# Using IREPLWrapper to get incremental output
self.bashwrapper = IREPLWrapper(child, u'\$', prompt_change, self.unique_prompt,
extra_init_cmd="export PAGER=cat",
line_output_callback=self.process_output)
extra_init_cmd="export PAGER=cat", kernel=self)
finally:
signal.signal(signal.SIGINT, old_sigint_handler)
signal.signal(signal.SIGPIPE, old_sigpipe_handler)
Expand All @@ -145,7 +178,7 @@ def _start_bash(self):


def process_output(self, output):
if not self.silent:
if hasattr(self, "silent") and not self.silent:
plain_output, rich_contents = extract_contents(output)

# Send standard output
Expand All @@ -156,7 +189,7 @@ def process_output(self, output):
# Send rich contents, if any:
for content in rich_contents:
if isinstance(content, Exception):
message = {'name': 'stderr', 'text': str(e)}
message = {'name': 'stderr', 'text': str(content)}
self.send_response(self.iopub_socket, 'stream', message)
else:
if 'transient' in content and 'display_id' in content['transient']:
Expand Down Expand Up @@ -190,7 +223,7 @@ def do_execute(self, code, silent, store_history=True,
return {'status': 'ok', 'execution_count': self.execution_count,
'payload': [], 'user_expressions': {}}


if code.strip().endswith("\\"):
error_content = {
'ename': '',
Expand Down