Skip to content

Commit

Permalink
rebased ssh branch into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Acesif committed Sep 9, 2024
1 parent 2f970f7 commit f30f9fa
Showing 1 changed file with 98 additions and 25 deletions.
123 changes: 98 additions & 25 deletions probe_src/python/probe_py/manual/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from probe_py.manual import util

rich.traceback.install(show_locals=False)
from typing import List
from . import parse_probe_log
from . import analysis
from . import util


project_root = pathlib.Path(__file__).resolve().parent.parent.parent.parent
Expand Down Expand Up @@ -181,7 +185,58 @@ def ssh(
"""
Wrap SSH and record provenance of the remote command.
"""
libprobe = project_root / "libprobe/build" / ("libprobe-dbg.so" if debug else "libprobe.so")

one_arg_options = set("BbcDEeFIiJLlmOoPpRSWw")
no_arg_options = set("46AaCfGgKkMNnqsTtVvXxYy")

# fsm to figure out the flags, destination and remote cmds
state = 'start'
i = 0
flags = []
destination = None
remote_host = []

while i < len(ssh_args):
curr_arg = ssh_args[i]

if state == 'start':
if curr_arg.startswith("-"):
state = 'flag'
elif destination != None:
state = 'cmd'
else:
state = 'destination'

elif state == 'flag':
opt = curr_arg[-1]
if opt in one_arg_options:
state = 'one_arg'
elif opt in no_arg_options:
flags.append(curr_arg)
state = 'start'
i+=1

elif state == 'one_arg':
flags.extend([ssh_args[i-1],curr_arg])
state = 'start'
i+=1

elif state == 'destination':
if destination == None:
destination = curr_arg
state = 'start'
else:
state = 'cmd'
continue
i+=1

elif state == 'cmd':
remote_host.extend(ssh_args[i:])
break

ssh_cmd = ["ssh"] + flags

libprobe = pathlib.Path(os.environ["__PROBE_LIB"]) / ("libprobe-dbg.so" if debug else "libprobe.so")
if not libprobe.exists():
typer.secho(f"Libprobe not found at {libprobe}", fg=typer.colors.RED)
raise typer.Abort()
Expand All @@ -190,46 +245,64 @@ def ssh(
local_temp_dir = pathlib.Path(tempfile.mkdtemp(prefix=f"probe_log_{os.getpid()}"))

# Check if remote platform matches local platform
remote_gcc_machine_cmd = ["ssh"] + ssh_args + ["gcc", "-dumpmachine"]
remote_gcc_machine_cmd = ssh_cmd + ["gcc", "-dumpmachine"]
local_gcc_machine_cmd = ["gcc", "-dumpmachine"]
remote_gcc_machine = subprocess.check_output(remote_gcc_machine_cmd).decode().strip()
local_gcc_machine = subprocess.check_output(local_gcc_machine_cmd).decode().strip()

remote_gcc_machine = subprocess.check_output(remote_gcc_machine_cmd)
local_gcc_machine = subprocess.check_output(local_gcc_machine_cmd)

if remote_gcc_machine != local_gcc_machine:
raise NotImplementedError("Remote platform is different from local platform")

# Upload libprobe.so to the remote temporary directory
remote_temp_dir_cmd = ["ssh"] + ssh_args + ["mktemp", "-d", "/tmp/probe_log_XXXXXX"]
remote_temp_dir_cmd = ssh_cmd + [destination] + ["mktemp", "-d", "/tmp/probe_log_XXXXXX"]
remote_temp_dir = subprocess.check_output(remote_temp_dir_cmd).decode().strip()
remote_probe_dir = f"{remote_temp_dir}/probe_dir"

scp_cmd = ["scp", str(libprobe), f"{ssh_args[0]}:{remote_temp_dir}/"]
subprocess.run(scp_cmd, check=True)


ssh_g = subprocess.run(ssh_cmd + [destination] + ['-G'],stdout=subprocess.PIPE)
ssh_g_op = ssh_g.stdout.decode().strip().splitlines()

ssh_pair = []
for pair in ssh_g_op:
ssh_pair.append(pair.split())

scp_cmd = ["scp"]
for option in ssh_g_op:
key_value = option.split(' ', 1)
if len(key_value) == 2:
key, value = key_value
scp_cmd.append(f"-o {key}={value}")

scp_args =[str(libprobe),f"{destination}:{remote_temp_dir}"]
scp_cmd.extend(scp_args)

subprocess.run(scp_cmd,check=True)

# Prepare the remote command with LD_PRELOAD and __PROBE_DIR
ld_preload = f"{remote_temp_dir}/{libprobe.name}"
remote_cmd = f"env LD_PRELOAD={ld_preload} __PROBE_DIR={remote_probe_dir} {' '.join(ssh_args[1:])}"

if debug:
typer.secho(f"Running remote command: {remote_cmd}", fg=typer.colors.GREEN)

# Run the remote command
ssh_cmd = ["ssh"] + ssh_args + [remote_cmd]
proc = subprocess.run(ssh_cmd)
env = ["env", f"LD_PRELOAD={ld_preload}", f"__PROBE_DIR={remote_probe_dir}"]
proc = subprocess.run(ssh_cmd + [destination] + env + remote_host)

# Download the provenance log from the remote machine
remote_tar_cmd = ["ssh"] + ssh_args + [f"tar -czf - -C {remote_temp_dir} probe_dir"]
with open(local_temp_dir / "probe_log.tar.gz", "wb") as f:
subprocess.run(remote_tar_cmd, stdout=f, check=True)


remote_tar_file = f"{remote_temp_dir}.tar.gz"
tar_cmd = ssh_cmd + [destination] + ["tar", "-czf", remote_tar_file, "-C", remote_temp_dir, "."]
subprocess.run(tar_cmd, check=True)

# Download the tarball to the local machine
local_tar_file = local_temp_dir / f"{remote_temp_dir.split('/')[-1]}.tar.gz"
scp_download_cmd = ["scp"] + scp_cmd[1:-2] + [f"{destination}:{remote_tar_file}", str(local_tar_file)]
typer.secho(f"PROBE log downloaded at: {scp_download_cmd[-1]}",fg=typer.colors.GREEN)
subprocess.run(scp_download_cmd, check=True)

# Clean up the remote temporary directory
remote_cleanup_cmd = ["ssh"] + ssh_args + [f"rm -rf {remote_temp_dir}"]
remote_cleanup_cmd = ssh_cmd + [destination] + [f"rm -rf {remote_temp_dir}"]
subprocess.run(remote_cleanup_cmd, check=True)

# Clean up the local temporary directory
shutil.rmtree(local_temp_dir)

raise typer.Exit(proc.returncode)

if __name__ == "__main__":
Expand Down

0 comments on commit f30f9fa

Please sign in to comment.