Skip to content

Commit

Permalink
Refactor coverage and corpus into common state
Browse files Browse the repository at this point in the history
Closes #15
  • Loading branch information
senier committed Feb 10, 2024
1 parent fbf16c5 commit 0be8cba
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 230 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- Saving fuzzer state to file (#15)

## [2.1.0] - 2024-02-10

### Added
Expand Down Expand Up @@ -66,6 +72,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Rename to cobrafuzz
- Enable GitHub CI

[Unreleased]: https://github.com/senier/cobrafuzz/compare/v2.1.0...main
[2.1.0]: https://github.com/senier/cobrafuzz/compare/v2.0.0...v2.1.0
[2.0.0]: https://github.com/senier/cobrafuzz/compare/v1.0.12...v2.0.0
[1.0.12]: https://github.com/senier/cobrafuzz/compare/v1.0.11...v1.0.12
Expand Down
48 changes: 0 additions & 48 deletions cobrafuzz/corpus.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from __future__ import annotations

import hashlib
import secrets
import struct
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from . import util
Expand Down Expand Up @@ -207,49 +205,3 @@ def mutate(buf: bytearray, max_input_size: Optional[int] = None) -> bytearray:
if max_input_size and len(res) > max_input_size:
res = res[:max_input_size]
return res


class Corpus:
def __init__(
self,
seeds: Optional[list[Path]] = None,
max_input_size: int = 4096,
save_dir: Optional[Path] = None,
):
self._max_input_size = max_input_size
self._save_dir = save_dir
self._seeds = seeds or []

self._inputs: list[bytearray] = []
for path in [p for p in self._seeds if p.is_file()] + [
f for p in self._seeds if not p.is_file() for f in p.glob("*") if f.is_file()
]:
with path.open("rb") as f:
self._inputs.append(bytearray(f.read()))
if not self._inputs:
self._inputs.append(bytearray(0))

@property
def length(self) -> int:
return len(self._inputs)

def put(self, buf: bytearray) -> None:
self._inputs.append(buf)

def save(self) -> None:
if not self._save_dir:
return

if not self._save_dir.exists():
self._save_dir.mkdir()

for buf in self._inputs:
fname = self._save_dir / hashlib.sha256(buf).hexdigest()
with fname.open("wb") as f:
f.write(buf)

def generate_input(self) -> bytearray:
return mutate(
buf=self._inputs[util.rand(len(self._inputs))],
max_input_size=self._max_input_size,
)
71 changes: 29 additions & 42 deletions cobrafuzz/fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import dill as pickle # type: ignore[import-untyped]

from cobrafuzz import corpus, tracer
from cobrafuzz import state as st, tracer

logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.getLogger().setLevel(logging.DEBUG)
Expand All @@ -21,25 +21,6 @@
MPProcess = Union[mp.context.ForkProcess, mp.context.ForkServerProcess, mp.context.SpawnProcess]


class Coverage:
def __init__(self) -> None:
self._covered: set[tuple[Optional[str], Optional[int], str, int]] = set()

def store_and_check_improvement(
self,
data: set[tuple[Optional[str], Optional[int], str, int]],
) -> bool:
covered = len(self._covered)
self._covered |= data
if len(self._covered) > covered:
return True
return False

@property
def total(self) -> int:
return len(self._covered)


@dataclass
class Update:
data: bytes
Expand Down Expand Up @@ -89,9 +70,8 @@ def worker( # noqa: PLR0913
result_queue: mp.Queue[Status],
close_stdout: bool,
close_stderr: bool,
max_input_size: int,
stat_frequency: int,
seeds: list[Path],
state: st.State,
) -> None:
class NullFile(io.StringIO):
"""No-op to trash stdout away."""
Expand All @@ -109,9 +89,6 @@ def write(self, arg: str) -> int:

runs = 0
last_status = time.time()
corp = corpus.Corpus(seeds=seeds, max_input_size=max_input_size)
cov = Coverage()

target = cast(Callable[[bytes], None], pickle.loads(target_bytes)) # noqa: S301

tracer.initialize()
Expand All @@ -121,11 +98,11 @@ def write(self, arg: str) -> int:

while not update_queue.empty():
update = update_queue.get()
cov.store_and_check_improvement(update.covered)
corp.put(bytearray(update.data))
state.store_coverage(update.covered)
state.put_input(bytearray(update.data))

runs += 1
data = corp.generate_input()
data = state.get_input()

try:
target(data)
Expand All @@ -134,7 +111,7 @@ def write(self, arg: str) -> int:
runs = 0
last_status = time.time()
else:
new_path = cov.store_and_check_improvement(data=tracer.get_covered())
new_path = state.store_coverage(data=tracer.get_covered())
if new_path:
result_queue.put(
Report(wid=wid, runs=runs, data=data, covered=tracer.get_covered()),
Expand Down Expand Up @@ -163,6 +140,7 @@ def __init__( # noqa: PLR0913, ref:#2
regression: bool = False,
seeds: Optional[list[Path]] = None,
start_method: Optional[str] = None,
state_file: Optional[Path] = None,
):
"""
Fuzz-test target and store crash artifacts into crash_dir.
Expand All @@ -186,6 +164,8 @@ def __init__( # noqa: PLR0913, ref:#2
start_method: Multiprocessing start method to use (spawn, forkserver or fork).
Defaults to "spawn". Do not use "fork" as it is unreliable and may lead
to deadlocks.
state_file: File to load state from. Will be updated periodically. If no file is
specified, the state will be held in memory and discarded on exit.
"""

self._current_crashes = 0
Expand Down Expand Up @@ -216,6 +196,7 @@ def __init__( # noqa: PLR0913, ref:#2
self._max_time = max_time
self._num_workers: int = num_workers or self._mp_ctx.cpu_count() - 1
self._seeds = seeds or []
self._state_file = state_file

if regression:
for error_file in crash_dir.glob("*"):
Expand Down Expand Up @@ -269,7 +250,7 @@ def _write_sample(self, buf: bytes, prefix: str = "crash-") -> None:
if len(buf) < 200:
logging.info("sample = %s", buf.hex())

def _initialize_process(self, wid: int) -> tuple[MPProcess, mp.Queue[Update]]:
def _initialize_process(self, wid: int, state: st.State) -> tuple[MPProcess, mp.Queue[Update]]:
queue: mp.Queue[Update] = self._mp_ctx.Queue()
result = self._mp_ctx.Process(
target=worker,
Expand All @@ -280,26 +261,29 @@ def _initialize_process(self, wid: int) -> tuple[MPProcess, mp.Queue[Update]]:
self._result_queue,
self._close_stdout,
self._close_stderr,
self._max_input_size,
self._stat_frequency,
self._seeds,
state,
),
)
result.start()
return result, queue

def start(self) -> None: # noqa: PLR0912
start_time = time.time()
coverage = Coverage()
corp = corpus.Corpus(self._seeds, self._max_input_size)
state = st.State(self._seeds, self._max_input_size)

self._workers = [self._initialize_process(wid) for wid in range(self._num_workers)]
if self._state_file:
state.load(self._state_file)

self._workers = [
self._initialize_process(wid=wid, state=state) for wid in range(self._num_workers)
]

logging.info(
"#0 READ units: %d workers: %d seeds: %d",
corp.length,
state.size,
self._num_workers,
corp.length,
len(self._seeds),
)

while True:
Expand Down Expand Up @@ -327,16 +311,19 @@ def start(self) -> None: # noqa: PLR0912
self._current_runs += result.runs

if isinstance(result, Error):
improvement = coverage.store_and_check_improvement(result.covered)
improvement = state.store_coverage(result.covered)
if improvement:
self._current_crashes += 1
self._write_sample(result.data)

elif isinstance(result, Report):
improvement = coverage.store_and_check_improvement(result.covered)
improvement = state.store_coverage(result.covered)
if improvement:
self._log_stats("NEW", coverage.total, corp.length)
corp.put(bytearray(result.data))
self._log_stats("NEW", state.total_coverage, state.size)
state.put_input(bytearray(result.data))

if self._state_file:
state.save(self._state_file)

for wid, (_, queue) in enumerate(self._workers):
if wid != result.wid:
Expand All @@ -349,7 +336,7 @@ def start(self) -> None: # noqa: PLR0912
assert False, f"Unhandled result type: {type(result)}"

if (time.time() - self._last_stats_time) > self._stat_frequency:
self._log_stats("PULSE", coverage.total, corp.length)
self._log_stats("PULSE", state.total_coverage, state.size)

for _, queue in self._workers:
queue.cancel_join_thread()
Expand Down
100 changes: 100 additions & 0 deletions cobrafuzz/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

import ast
import json
import logging
from pathlib import Path
from typing import Optional

from cobrafuzz import corpus, util


class LoadError(Exception):
pass


class State:
def __init__(
self,
seeds: Optional[list[Path]] = None,
max_input_size: int = 4096,
):
self._VERSION = 1
self._max_input_size = max_input_size
self._covered: set[tuple[Optional[str], Optional[int], str, int]] = set()
self._inputs: list[bytearray] = []

for path in [p for p in seeds or [] if p.is_file()] + [
f for p in seeds or [] if not p.is_file() for f in p.glob("*") if f.is_file()
]:
with path.open("rb") as f:
self._inputs.append(bytearray(f.read()))
if not self._inputs:
self._inputs.append(bytearray(0))

def save(self, filename: Path) -> None:
with filename.open(mode="w+") as sf:
json.dump(
obj={
"version": self._VERSION,
"coverage": list(self._covered),
"population": [str(bytes(i))[2:-1] for i in self._inputs],
},
fp=sf,
ensure_ascii=True,
)

def load(self, filename: Path) -> None:
try:
with filename.open() as sf:
data = json.load(sf)
if "version" not in data or data["version"] != self._VERSION:
raise LoadError(
f"Invalid version in state file {filename} (expected {self._VERSION})",
)
self._covered |= {tuple(e) for e in data["coverage"]}
self._inputs.extend(
bytearray(ast.literal_eval(f"b'{i}'")) for i in data["population"]
)
except FileNotFoundError:
pass
except (json.JSONDecodeError, TypeError):
filename.unlink()
logging.info("Malformed state file: %s", filename)
except OSError as e:
logging.info("Error opening state file: %s", e)

def store_coverage(
self,
data: set[tuple[Optional[str], Optional[int], str, int]],
) -> bool:
"""
Store coverage information. Return true if coverage has increased.
Arguments:
---------
data: coverage information to store.
"""

covered = len(self._covered)
self._covered |= data
if len(self._covered) > covered:
return True
return False

@property
def total_coverage(self) -> int:
return len(self._covered)

@property
def size(self) -> int:
return len(self._inputs)

def put_input(self, buf: bytearray) -> None:
self._inputs.append(buf)

def get_input(self) -> bytearray:
return corpus.mutate(
buf=list(self._inputs)[util.rand(len(self._inputs))],
max_input_size=self._max_input_size,
)
Loading

0 comments on commit 0be8cba

Please sign in to comment.