Skip to content

Commit

Permalink
Update RTA common.py for py3 (#2287)
Browse files Browse the repository at this point in the history
* add run-all argument and initial p2 conversion

* remove unicode

* format with black

(cherry picked from commit 0fc8006)
  • Loading branch information
brokensound77 authored and github-actions[bot] committed Sep 1, 2022
1 parent 03de32b commit 7b59f8e
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 94 deletions.
18 changes: 9 additions & 9 deletions rta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,38 @@
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.

import glob
import importlib
import os
from pathlib import Path
from typing import List, Optional

from . import common

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
CURRENT_DIR = Path(__file__).resolve().parent


def get_ttp_list(os_types=None):
def get_ttp_list(os_types: Optional[List[str]] = None) -> List[str]:
scripts = []
if os_types and not isinstance(os_types, (list, tuple)):
os_types = [os_types]

for script in sorted(glob.glob(os.path.join(CURRENT_DIR, "*.py"))):
base_name, _ = os.path.splitext(os.path.basename(script))
for script in CURRENT_DIR.glob("*.py"):
base_name = script.stem
if base_name not in ("common", "main") and not base_name.startswith("_"):
if os_types:
# Import it and skip it if it's not supported
importlib.import_module(__name__ + "." + base_name)
if not any(base_name in common.OS_MAPPING[os_type] for os_type in os_types):
continue

scripts.append(script)
scripts.append(str(script))

return scripts


def get_ttp_names(os_types=None):
def get_ttp_names(os_types: Optional[List[str]] = None) -> List[str]:
names = []
for script in get_ttp_list(os_types):
basename, ext = os.path.splitext(os.path.basename(script))
basename = Path(script).stem
names.append(basename)
return names

Expand Down
59 changes: 49 additions & 10 deletions rta/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,57 @@

import argparse
import importlib
import os
import subprocess
import sys
import time
from pathlib import Path

from . import get_ttp_names
from . import get_ttp_list, get_ttp_names
from .common import CURRENT_OS

parser = argparse.ArgumentParser("rta")
parser.add_argument("ttp_name")

parsed_args, remaining = parser.parse_known_args()
ttp_name, _ = os.path.splitext(os.path.basename(parsed_args.ttp_name))
DELAY = 1

if ttp_name not in get_ttp_names():
raise ValueError("Unknown RTA {}".format(ttp_name))

module = importlib.import_module("rta." + ttp_name)
exit(module.main(*remaining))
def run_all():
"""Run a single RTA."""
errors = []
for ttp_file in get_ttp_list(CURRENT_OS):
print(f"---- {Path(ttp_file).name} ----")
p = subprocess.Popen([sys.executable, ttp_file])
p.wait()
code = p.returncode

if p.returncode:
errors.append((ttp_file, code))

time.sleep(DELAY)
print("")

return len(errors)


def run(ttp_name: str, *args):
"""Run all RTAs compatible with OS."""
if ttp_name not in get_ttp_names():
raise ValueError(f"Unknown RTA {ttp_name}")

module = importlib.import_module("rta." + ttp_name)
return module.main(*args)


if __name__ == '__main__':
parser = argparse.ArgumentParser("rta")
parser.add_argument("--ttp-name")
parser.add_argument("--run-all", action="store_true")
parser.add_argument("--delay", type=int, help="For run-all, the delay between executions")
parsed_args, remaining = parser.parse_known_args()

if parsed_args.ttp_name and parsed_args.run_all:
raise ValueError(f"Pass --ttp-name or --run-all, not both")

if parsed_args.run_all:
exit(run_all())
else:
rta_name = Path(parsed_args.run).stem
exit(run(rta_name, *remaining))
Loading

0 comments on commit 7b59f8e

Please sign in to comment.