diff --git a/jupyter_forward/core.py b/jupyter_forward/core.py index db286a2..3bc813d 100644 --- a/jupyter_forward/core.py +++ b/jupyter_forward/core.py @@ -6,6 +6,7 @@ import pathlib import socket import sys +import textwrap import time from typing import Callable @@ -118,6 +119,12 @@ def _check_shell(self): self.shell = self.run_command(f'which {self.shell}').stdout.strip() console.print(f'[bold cyan]:white_check_mark: Using shell: {self.shell}') + def put_file(self, remote_path, content): + client = self.session.client + with client.get_transport().open_channel(kind='session') as channel: + channel.exec_command(f'cat > {remote_path}') + channel.sendall(content.encode()) + def run_command( self, command, @@ -182,7 +189,7 @@ def start(self): def _get_hostname(self): if self.launch_command: - return r'\$(hostname -f)' + return '$(hostname -f)' else: return self.session.run('hostname -f').stdout.strip() @@ -256,18 +263,23 @@ def _parse_log_file(self): return parse_stdout(stdout) def _prepare_batch_job_script(self, command): + from rich.syntax import Syntax + console.rule('[bold green]Preparing Batch Job script', characters='*') script_file = f'{self.log_dir}/batch_job_script_{timestamp}' shell = self.shell if 'csh' not in shell: shell = f'{shell} -l' - for command in [ - f"echo -n '#!' > {script_file}", - f'echo {shell} >> {script_file}', - f"echo '{command}' >> {script_file}", - f'chmod +x {script_file}', - ]: - self.run_command(command=command, exit=True) + + script = textwrap.dedent( + f"""\ + #!{shell} + {command} + """ + ) + console.print(Syntax(script, 'bash', line_numbers=True)) + self.put_file(script_file, script) + self.run_command(f'chmod +x {script_file}', exit=True) console.print(f'[bold cyan]:white_check_mark: Batch Job script resides in {script_file}') return script_file diff --git a/tests/test_core.py b/tests/test_core.py index 0177354..86e38c4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,6 +1,7 @@ import datetime import json import os +from contextlib import contextmanager import pytest @@ -14,13 +15,31 @@ ON_GITHUB_ACTIONS = os.environ.get('GITHUB_ACTIONS') is not None +@contextmanager +def tempfile(session): + out = session.run('mktemp') + path = out.stdout.strip() + try: + yield path + finally: + session.run(f'rm {path}') + + +def dummy_auth_handler(title, instructions, prompt_list): + return ['Loremipsumdolorsitamet'] * len(prompt_list) + + +def dummy_fallback_auth_handler(): + return 'Loremipsumdolorsitamet' + + @pytest.fixture(scope='package') def runner(request): remote = jupyter_forward.RemoteRunner( f"{os.environ['JUPYTER_FORWARD_SSH_TEST_USER']}@{os.environ['JUPYTER_FORWARD_SSH_TEST_HOSTNAME']}", shell=request.param, - auth_handler=lambda t, i, p: ['Loremipsumdolorsitamet'] * len(p), - fallback_auth_handler=lambda: 'Loremipsumdolorsitamet', + auth_handler=dummy_auth_handler, + fallback_auth_handler=dummy_fallback_auth_handler, ) try: yield remote @@ -48,8 +67,8 @@ def test_runner_init(port, conda_env, notebook, notebook_dir, port_forwarding, i identity=identity, port_forwarding=port_forwarding, shell=shell, - auth_handler=lambda t, i, p: ['Loremipsumdolorsitamet'] * len(p), - fallback_auth_handler=lambda: 'Loremipsumdolorsitamet', + auth_handler=dummy_auth_handler, + fallback_auth_handler=dummy_fallback_auth_handler, ) assert remote_runner.port == port @@ -80,6 +99,8 @@ def test_runner_authentication_error(): with pytest.raises(SystemExit): jupyter_forward.RemoteRunner( f"foobar@{os.environ['JUPYTER_FORWARD_SSH_TEST_HOSTNAME']}", + auth_handler=dummy_auth_handler, + fallback_auth_handler=dummy_fallback_auth_handler, ) @@ -120,6 +141,17 @@ def test_run_command_failure(runner, command): runner.run_command(command) +@requires_ssh +@pytest.mark.parametrize('content', ['echo $HOME', 'echo $(hostname -f)']) +@pytest.mark.parametrize('runner', [None], indirect=True) +def test_put_file(runner, content): + with tempfile(runner.session) as path: + runner.put_file(path, content) + + out = runner.session.run(f'cat {path}') + assert content == out.stdout + + @requires_ssh @pytest.mark.parametrize('runner', SHELLS, indirect=True) def test_set_logs(runner):