Skip to content

Commit

Permalink
Support parallel fuzzing and crash limit
Browse files Browse the repository at this point in the history
Closes #11
Closes #12
  • Loading branch information
senier committed Jan 28, 2024
1 parent 4e1bfbd commit 36366db
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 99 deletions.
249 changes: 184 additions & 65 deletions cobrafuzz/fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import io
import logging
import multiprocessing as mp
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Optional

Expand All @@ -23,9 +23,59 @@
SAMPLING_WINDOW = 5 # IN SECONDS


class Coverage:
def __init__(self) -> None:
self._covered: set[tuple[Optional[str], Optional[int], str, int]] = set()

def store_and_check_improvement(
self,
data: set[tuple[Optional[str], Optional[int], str, int]],
) -> bool:
covered = len(self._covered)
self._covered |= data
if len(self._covered) > covered:
return True
return False

@property
def total(self) -> int:
return len(self._covered)


@dataclass
class Job:
jid: int
data: bytes


@dataclass
class JobData:
data: bytes
timestamp: Optional[float] = None
wid: Optional[int] = None


@dataclass
class Status:
wid: int
jid: int


@dataclass
class Report(Status):
covered: set[tuple[Optional[str], Optional[int], str, int]]


@dataclass
class Error(Status):
error: str


def worker(
wid: int,
target: Callable[[bytes], None],
child_conn: mp.connection.Connection,
request_queue: mp.Queue[Job],
result_queue: mp.Queue[Status],
close_fd_mask: int,
) -> None:
# Silence the fuzzee's noise
Expand All @@ -43,31 +93,36 @@ def write(self, arg: str) -> int:
sys.stderr = DummyFile()

tracer.initialize()

while True:
buf = child_conn.recv_bytes()
job = request_queue.get()

result_queue.put(Status(wid=wid, jid=job.jid))
tracer.reset()

try:
target(buf)
except Exception as e:
logging.exception(buf)
child_conn.send(e)
target(job.data)
except Exception as e: # noqa: BLE001
result_queue.put(Error(wid=wid, jid=job.jid, error=str(e)))
break
else:
child_conn.send_bytes(b"%d" % tracer.get_coverage())
result_queue.put(Report(wid=wid, jid=job.jid, covered=tracer.get_covered()))


class Fuzzer:
def __init__( # noqa: PLR0913, ref:#2
self,
target: Callable[[bytes], None],
crash_dir: Path,
num_workers: int = 1,
dirs: Optional[list[Path]] = None,
artifact_name: Optional[str] = None,
rss_limit_mb: int = 2048,
timeout: int = 120,
regression: bool = False,
max_input_size: int = 4096,
close_fd_mask: int = 0,
runs: int = -1,
runs: Optional[int] = None,
max_crashes: Optional[int] = None,
):
self._target = target
self._crash_dir = crash_dir
Expand All @@ -78,36 +133,42 @@ def __init__( # noqa: PLR0913, ref:#2
self._regression = regression
self._close_fd_mask = close_fd_mask
self._corpus = corpus.Corpus(self._dirs, max_input_size)
self._total_executions = 0
self._executions_in_sample = 0
self._jid = 0
self._last_jid = 0
self._last_sample_time = time.time()
self._total_coverage = 0
self._p: Optional[mp.Process] = None
self._workers: list[mp.Process] = []
self._num_workers: int = num_workers

# TODO(senier): make queue sizes configurable
self._request_queue: mp.Queue[Job] = mp.Queue()
self._result_queue: mp.Queue[Status] = mp.Queue()

self.runs = runs

def log_stats(self, log_type: str) -> int:
assert self._p is not None
self._max_crashes = max_crashes
self.crashes = 0

def log_stats(self, log_type: str, total_coverage: int) -> int:
rss: int = (
(
psutil.Process(self._p.pid).memory_info().rss
+ psutil.Process(os.getpid()).memory_info().rss
)
/ 1024
/ 1024
sum(psutil.Process(p.pid).memory_info().rss for p in self._workers) // 1024 // 1024
)

end_time = time.time()
execs_per_second = int(self._executions_in_sample / (end_time - self._last_sample_time))
execs_per_second = int(
(self._jid - self._last_jid) / (end_time - self._last_sample_time),
)

self._last_sample_time = time.time()
self._executions_in_sample = 0
self._last_jid = self._jid
logging.info(
"#%d %s cov: %d corp: %d exec/s: %d rss: %d MB",
self._total_executions,
"#%d %s cov: %d corp: %d exec/s: %d rss: %d MB crashes: %d",
self._jid,
log_type,
self._total_coverage,
total_coverage,
self._corpus.length,
execs_per_second,
rss,
self.crashes,
)
return rss

Expand All @@ -127,51 +188,109 @@ def write_sample(self, buf: bytes, prefix: str = "crash-") -> None:
if len(buf) < 200:
logging.info("sample = %s", buf.hex())

def start(self) -> None:
logging.info("#0 READ units: %d", self._corpus.length)
def initialize_process(self, wid: int) -> mp.Process:
result = mp.Process(
target=worker,
args=(
wid,
self._target,
self._request_queue,
self._result_queue,
self._close_fd_mask,
),
)
result.start()
return result

def rss(self, wid: int) -> int:
return int(psutil.Process(self._workers[wid].pid).memory_info().rss) // (1024 * 1024)

def start(self) -> None: # noqa: PLR0912, PLR0915
logging.info("#0 READ units: %d workers: %d", self._corpus.length, self._num_workers)
exit_code = 0
parent_conn, child_conn = mp.Pipe()
self._p = mp.Process(target=worker, args=(self._target, child_conn, self._close_fd_mask))
self._p.start()

coverage = Coverage()
jobs: dict[int, JobData] = {}
self._workers = [self.initialize_process(wid) for wid in range(self._num_workers)]

while True:
if self.runs != -1 and self._total_executions >= self.runs:
self._p.terminate()
logging.info("did %d runs, stopping now.", self.runs)
if self.runs is not None and self._jid >= self.runs:
for p in self._workers:
p.terminate()
logging.info("Performed %d runs, stopping.", self.runs)
break

buf = self._corpus.generate_input()
parent_conn.send_bytes(buf)
if not parent_conn.poll(self._timeout):
self._p.kill()
logging.info("=================================================================")
logging.info("timeout reached. testcase took: %d", self._timeout)
self.write_sample(buf, prefix="timeout-")
if self._max_crashes is not None and self.crashes >= self._max_crashes:
for p in self._workers:
p.terminate()
logging.info("Found %d crashes, stopping.", self.crashes)
break

try:
total_coverage = int(parent_conn.recv_bytes())
except ValueError:
self.write_sample(buf)
exit_code = 76
break
self._jid += 1
request = Job(jid=self._jid, data=self._corpus.generate_input())
jobs[self._jid] = JobData(data=request.data)

self._total_executions += 1
self._executions_in_sample += 1
rss = 0
if total_coverage > self._total_coverage:
rss = self.log_stats("NEW")
self._total_coverage = total_coverage
self._corpus.put(buf)
else:
if (time.time() - self._last_sample_time) > SAMPLING_WINDOW:
rss = self.log_stats("PULSE")

if rss > self._rss_limit_mb:
logging.info("MEMORY OOM: exceeded %d MB. Killing worker", self._rss_limit_mb)
self.write_sample(buf)
self._p.kill()
break
if self._request_queue.empty():
self._request_queue.put(request)

if not self._result_queue.empty():
result = self._result_queue.get()

if isinstance(result, Report):
improvement = coverage.store_and_check_improvement(result.covered)
if improvement:
self.log_stats("NEW", coverage.total)
self._corpus.put(bytearray(jobs[result.jid].data))
del jobs[result.jid]

elif isinstance(result, Error):
# TODO(senier): Extend to write error message
self.write_sample(jobs[result.jid].data)
self.crashes += 1
del jobs[result.jid]

elif isinstance(result, Status):
assert result.jid in jobs
jobs[result.jid].wid = result.wid
jobs[result.jid].timestamp = time.time()

else:
assert False, f"Unhandled result type: {type(result)}"

if (time.time() - self._last_sample_time) > SAMPLING_WINDOW:
self.log_stats("PULSE", coverage.total)

for job in jobs.values():
if job.wid is None:
continue

timeout = job.timestamp and (time.time() - job.timestamp) > self._timeout
oom = self.rss(job.wid) > self._rss_limit_mb

if not timeout and not oom:
continue

self.crashes += 1

if timeout:
logging.info("Timeout reached. Testcase took: %d s", self._timeout)
prefix = "timeout-"

if oom:
logging.info(
"OOM: Worker %d exceeded %d MB. Killing.",
job.wid,
self._rss_limit_mb,
)
prefix = "oom-"

self._workers[job.wid].kill()
self._workers[job.wid] = self.initialize_process(job.wid)

self.write_sample(job.data, prefix=prefix)
self._workers[job.wid].kill()
self._workers[job.wid] = self.initialize_process(job.wid)

self._p.join()
for p in self._workers:
p.join()
sys.exit(exit_code)
17 changes: 15 additions & 2 deletions cobrafuzz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ def __call__(self) -> None:
required=True,
help="crash output directory",
)
parser.add_argument(
"-j",
"--num-workers",
type=int,
default=1,
help="number of parallel workers",
)
parser.add_argument(
"--max-crashes",
type=int,
help="maximum number crashes before exiting",
)
parser.add_argument(
"--artifact-name",
type=str,
Expand Down Expand Up @@ -53,8 +65,7 @@ def __call__(self) -> None:
parser.add_argument(
"--runs",
type=int,
default=-1,
help="Number of individual test runs, -1 (the default) to run indefinitely.",
help="Number of individual test runs",
)
parser.add_argument(
"--timeout",
Expand All @@ -66,6 +77,7 @@ def __call__(self) -> None:
f = fuzzer.Fuzzer(
target=self.function,
crash_dir=args.crash_dir,
num_workers=args.num_workers,
dirs=args.dirs,
artifact_name=args.artifact_name,
rss_limit_mb=args.rss_limit_mb,
Expand All @@ -74,5 +86,6 @@ def __call__(self) -> None:
max_input_size=args.max_input_size,
close_fd_mask=args.close_fd_mask,
runs=args.runs,
max_crashes=args.max_crashes,
)
f.start()
Loading

0 comments on commit 36366db

Please sign in to comment.