Skip to content

Commit

Permalink
binexport: add typing where applicable (#2106)
Browse files Browse the repository at this point in the history
  • Loading branch information
mike-hunhoff authored May 31, 2024
1 parent 1d25c45 commit cbe83dd
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 150 deletions.
94 changes: 47 additions & 47 deletions capa/features/extractors/binexport2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
import capa.features.extractors.common
import capa.features.extractors.binexport2.helpers
from capa.features.extractors.binexport2.binexport2_pb2 import BinExport2
from capa.features.extractors.binexport2.binexport2_pb2.BinExport2 import CallGraph, FlowGraph

logger = logging.getLogger(__name__)


def get_binexport2(sample: Path) -> BinExport2:
be2 = BinExport2()
be2: BinExport2 = BinExport2()
be2.ParseFromString(sample.read_bytes())
return be2

Expand All @@ -54,15 +55,15 @@ def get_sample_from_binexport2(input_file: Path, be2: BinExport2, search_paths:
searches in the same directory as the BinExport2 file, and then in search_paths.
"""

def filename_similarity_key(p: Path):
def filename_similarity_key(p: Path) -> Tuple[int, str]:
# note closure over input_file.
# sort first by length of common prefix, then by name (for stability)
return (compute_common_prefix_length(p.name, input_file.name), p.name)

wanted_sha256 = be2.meta_information.executable_id.lower()
wanted_sha256: str = be2.meta_information.executable_id.lower()

input_directory = input_file.parent
siblings = [p for p in input_directory.iterdir() if p.is_file()]
input_directory: Path = input_file.parent
siblings: List[Path] = [p for p in input_directory.iterdir() if p.is_file()]
siblings.sort(key=filename_similarity_key, reverse=True)
for sibling in siblings:
# e.g. with open IDA files in the same directory on Windows
Expand All @@ -71,7 +72,7 @@ def filename_similarity_key(p: Path):
return sibling

for search_path in search_paths:
candidates = [p for p in search_path.iterdir() if p.is_file()]
candidates: List[Path] = [p for p in search_path.iterdir() if p.is_file()]
candidates.sort(key=filename_similarity_key, reverse=True)
for candidate in candidates:
with contextlib.suppress(PermissionError):
Expand All @@ -83,7 +84,7 @@ def filename_similarity_key(p: Path):

class BinExport2Index:
def __init__(self, be2: BinExport2):
self.be2 = be2
self.be2: BinExport2 = be2

self.callers_by_vertex_index: Dict[int, List[int]] = defaultdict(list)
self.callees_by_vertex_index: Dict[int, List[int]] = defaultdict(list)
Expand All @@ -93,9 +94,9 @@ def __init__(self, be2: BinExport2):
self.flow_graph_address_by_index: Dict[int, int] = {}

# edges that come from the given basic block
self.source_edges_by_basic_block_index: Dict[int, List[BinExport2.FlowGraph.Edge]] = defaultdict(list)
self.source_edges_by_basic_block_index: Dict[int, List[FlowGraph.Edge]] = defaultdict(list)
# edges that end up at the given basic block
self.target_edges_by_basic_block_index: Dict[int, List[BinExport2.FlowGraph.Edge]] = defaultdict(list)
self.target_edges_by_basic_block_index: Dict[int, List[FlowGraph.Edge]] = defaultdict(list)

self.vertex_index_by_address: Dict[int, int] = {}

Expand All @@ -119,9 +120,8 @@ def get_insn_address(self, insn_index: int) -> int:
return self.insn_address_by_index[insn_index]

def get_basic_block_address(self, basic_block_index: int) -> int:
basic_block = self.be2.basic_block[basic_block_index]
first_instruction_index = next(self.instruction_indices(basic_block))

basic_block: BinExport2.BasicBlock = self.be2.basic_block[basic_block_index]
first_instruction_index: int = next(self.instruction_indices(basic_block))
return self.get_insn_address(first_instruction_index)

def _index_vertex_edges(self):
Expand All @@ -136,7 +136,7 @@ def _index_vertex_edges(self):

def _index_flow_graph_nodes(self):
for flow_graph_index, flow_graph in enumerate(self.be2.flow_graph):
function_address = self.get_basic_block_address(flow_graph.entry_basic_block_index)
function_address: int = self.get_basic_block_address(flow_graph.entry_basic_block_index)
self.flow_graph_index_by_address[function_address] = flow_graph_index
self.flow_graph_address_by_index[flow_graph_index] = function_address

Expand All @@ -154,7 +154,7 @@ def _index_call_graph_vertices(self):
if not vertex.HasField("address"):
continue

vertex_address = vertex.address
vertex_address: int = vertex.address
self.vertex_index_by_address[vertex_address] = vertex_index

def _index_data_references(self):
Expand All @@ -177,8 +177,8 @@ def _index_insn_addresses(self):

assert self.be2.instruction[0].HasField("address"), "first insn must have explicit address"

addr = 0
next_addr = 0
addr: int = 0
next_addr: int = 0
for idx, insn in enumerate(self.be2.instruction):
if insn.HasField("address"):
addr = insn.address
Expand Down Expand Up @@ -208,22 +208,22 @@ def basic_block_instructions(
the instruction instances, and their addresses.
"""
for instruction_index in self.instruction_indices(basic_block):
instruction = self.be2.instruction[instruction_index]
instruction_address = self.get_insn_address(instruction_index)
instruction: BinExport2.Instruction = self.be2.instruction[instruction_index]
instruction_address: int = self.get_insn_address(instruction_index)

yield instruction_index, instruction, instruction_address

def get_function_name_by_vertex(self, vertex_index: int) -> str:
vertex = self.be2.call_graph.vertex[vertex_index]
name = f"sub_{vertex.address:x}"
vertex: CallGraph.Vertex = self.be2.call_graph.vertex[vertex_index]
name: str = f"sub_{vertex.address:x}"
if vertex.HasField("mangled_name"):
name = vertex.mangled_name

if vertex.HasField("demangled_name"):
name = vertex.demangled_name

if vertex.HasField("library_index"):
library = self.be2.library[vertex.library_index]
library: BinExport2.Library = self.be2.library[vertex.library_index]
if library.HasField("name"):
name = f"{library.name}!{name}"

Expand All @@ -233,23 +233,25 @@ def get_function_name_by_address(self, address: int) -> str:
if address not in self.vertex_index_by_address:
return ""

vertex_index = self.vertex_index_by_address[address]
vertex_index: int = self.vertex_index_by_address[address]
return self.get_function_name_by_vertex(vertex_index)


class BinExport2Analysis:
def __init__(self, be2: BinExport2, idx: BinExport2Index, buf: bytes):
self.be2 = be2
self.idx = idx
self.buf = buf
self.be2: BinExport2 = be2
self.idx: BinExport2Index = idx
self.buf: bytes = buf
self.base_address: int = 0
self.thunks: Dict[int, int] = {}

self._find_base_address()
self._compute_thunks()

def _find_base_address(self):
sections_with_perms = filter(lambda s: s.flag_r or s.flag_w or s.flag_x, self.be2.section)
sections_with_perms: Iterator[BinExport2.Section] = filter(
lambda s: s.flag_r or s.flag_w or s.flag_x, self.be2.section
)
# assume the lowest address is the base address.
# this works as long as BinExport doesn't record other
# libraries mapped into memory.
Expand All @@ -259,15 +261,13 @@ def _find_base_address(self):

def _compute_thunks(self):
for addr, idx in self.idx.vertex_index_by_address.items():
vertex = self.be2.call_graph.vertex[idx]
if not capa.features.extractors.binexport2.helpers.is_vertex_type(
vertex, BinExport2.CallGraph.Vertex.Type.THUNK
):
vertex: CallGraph.Vertex = self.be2.call_graph.vertex[idx]
if not capa.features.extractors.binexport2.helpers.is_vertex_type(vertex, CallGraph.Vertex.Type.THUNK):
continue

curr_idx = idx
curr_idx: int = idx
for _ in range(capa.features.common.THUNK_CHAIN_DEPTH_DELTA):
thunk_callees = self.idx.callees_by_vertex_index[curr_idx]
thunk_callees: List[int] = self.idx.callees_by_vertex_index[curr_idx]
# if this doesn't hold, then it doesn't seem like this is a thunk,
# because either, len is:
# 0 and the thunk doesn't point to anything, or
Expand All @@ -280,11 +280,11 @@ def _compute_thunks(self):

assert len(thunk_callees) == 1

thunked_idx = thunk_callees[0]
thunked_vertex = self.be2.call_graph.vertex[thunked_idx]
thunked_idx: int = thunk_callees[0]
thunked_vertex: CallGraph.Vertex = self.be2.call_graph.vertex[thunked_idx]

if not capa.features.extractors.binexport2.helpers.is_vertex_type(
thunked_vertex, BinExport2.CallGraph.Vertex.Type.THUNK
thunked_vertex, CallGraph.Vertex.Type.THUNK
):
assert thunked_vertex.HasField("address")

Expand Down Expand Up @@ -321,21 +321,21 @@ class AddressSpace:
memory_regions: Tuple[MemoryRegion, ...]

def read_memory(self, address: int, length: int) -> bytes:
rva = address - self.base_address
rva: int = address - self.base_address
for region in self.memory_regions:
if region.contains(rva):
offset = rva - region.address
offset: int = rva - region.address
return region.buf[offset : offset + length]

raise AddressNotMappedError(address)

@classmethod
def from_pe(cls, pe: PE, base_address: int):
regions = []
regions: List[MemoryRegion] = []
for section in pe.sections:
address = section.VirtualAddress
size = section.Misc_VirtualSize
buf = section.get_data()
address: int = section.VirtualAddress
size: int = section.Misc_VirtualSize
buf: bytes = section.get_data()

if len(buf) != size:
# pad the section with NULLs
Expand All @@ -349,16 +349,16 @@ def from_pe(cls, pe: PE, base_address: int):

@classmethod
def from_elf(cls, elf: ELFFile, base_address: int):
regions = []
regions: List[MemoryRegion] = []

# ELF segments are for runtime data,
# ELF sections are for link-time data.
for segment in elf.iter_segments():
# assume p_align is consistent with addresses here.
# otherwise, should harden this loader.
segment_rva = segment.header.p_vaddr
segment_size = segment.header.p_memsz
segment_data = segment.data()
segment_rva: int = segment.header.p_vaddr
segment_size: int = segment.header.p_memsz
segment_data: bytes = segment.data()

if len(segment_data) < segment_size:
# pad the section with NULLs
Expand All @@ -373,10 +373,10 @@ def from_elf(cls, elf: ELFFile, base_address: int):
@classmethod
def from_buf(cls, buf: bytes, base_address: int):
if buf.startswith(capa.features.extractors.common.MATCH_PE):
pe = PE(data=buf)
pe: PE = PE(data=buf)
return cls.from_pe(pe, base_address)
elif buf.startswith(capa.features.extractors.common.MATCH_ELF):
elf = ELFFile(io.BytesIO(buf))
elf: ELFFile = ELFFile(io.BytesIO(buf))
return cls.from_elf(elf, base_address)
else:
raise NotImplementedError("file format address space")
Expand Down
9 changes: 5 additions & 4 deletions capa/features/extractors/binexport2/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.

from typing import Tuple, Iterator
from typing import List, Tuple, Iterator

from capa.features.common import Feature, Characteristic
from capa.features.address import Address, AbsoluteVirtualAddress
from capa.features.basicblock import BasicBlock
from capa.features.extractors.binexport2 import FunctionContext, BasicBlockContext
from capa.features.extractors.base_extractor import BBHandle, FunctionHandle
from capa.features.extractors.binexport2.binexport2_pb2.BinExport2 import FlowGraph


def extract_bb_tight_loop(fh: FunctionHandle, bbh: BBHandle) -> Iterator[Tuple[Feature, Address]]:
Expand All @@ -21,10 +22,10 @@ def extract_bb_tight_loop(fh: FunctionHandle, bbh: BBHandle) -> Iterator[Tuple[F

idx = fhi.ctx.idx

basic_block_index = bbi.basic_block_index
target_edges = idx.target_edges_by_basic_block_index[basic_block_index]
basic_block_index: int = bbi.basic_block_index
target_edges: List[FlowGraph.Edge] = idx.target_edges_by_basic_block_index[basic_block_index]
if basic_block_index in (e.source_basic_block_index for e in target_edges):
basic_block_address = idx.get_basic_block_address(basic_block_index)
basic_block_address: int = idx.get_basic_block_address(basic_block_index)
yield Characteristic("tight loop"), AbsoluteVirtualAddress(basic_block_address)


Expand Down
41 changes: 21 additions & 20 deletions capa/features/extractors/binexport2/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,20 @@
StaticFeatureExtractor,
)
from capa.features.extractors.binexport2.binexport2_pb2 import BinExport2
from capa.features.extractors.binexport2.binexport2_pb2.BinExport2 import CallGraph

logger = logging.getLogger(__name__)


class BinExport2FeatureExtractor(StaticFeatureExtractor):
def __init__(self, be2: BinExport2, buf: bytes):
super().__init__(hashes=SampleHashes.from_bytes(buf))
self.be2 = be2
self.buf = buf
self.idx = BinExport2Index(self.be2)
self.analysis = BinExport2Analysis(self.be2, self.idx, self.buf)
address_space = AddressSpace.from_buf(buf, self.analysis.base_address)
self.ctx = AnalysisContext(self.buf, self.be2, self.idx, self.analysis, address_space)
self.be2: BinExport2 = be2
self.buf: bytes = buf
self.idx: BinExport2Index = BinExport2Index(self.be2)
self.analysis: BinExport2Analysis = BinExport2Analysis(self.be2, self.idx, self.buf)
address_space: AddressSpace = AddressSpace.from_buf(buf, self.analysis.base_address)
self.ctx: AnalysisContext = AnalysisContext(self.buf, self.be2, self.idx, self.analysis, address_space)

self.global_features: List[Tuple[Feature, Address]] = []
self.global_features.extend(list(capa.features.extractors.common.extract_format(self.buf)))
Expand All @@ -57,27 +58,25 @@ def __init__(self, be2: BinExport2, buf: bytes):
# and gradually relax restrictions as they're tested.
# https://github.com/mandiant/capa/issues/1755

def get_base_address(self):
def get_base_address(self) -> AbsoluteVirtualAddress:
return AbsoluteVirtualAddress(self.analysis.base_address)

def extract_global_features(self):
def extract_global_features(self) -> Iterator[Tuple[Feature, Address]]:
yield from self.global_features

def extract_file_features(self):
def extract_file_features(self) -> Iterator[Tuple[Feature, Address]]:
yield from capa.features.extractors.binexport2.file.extract_features(self.be2, self.buf)

def get_functions(self) -> Iterator[FunctionHandle]:
for flow_graph_index, flow_graph in enumerate(self.be2.flow_graph):
entry_basic_block_index = flow_graph.entry_basic_block_index
flow_graph_address = self.idx.get_basic_block_address(entry_basic_block_index)
entry_basic_block_index: int = flow_graph.entry_basic_block_index
flow_graph_address: int = self.idx.get_basic_block_address(entry_basic_block_index)

vertex_idx = self.idx.vertex_index_by_address[flow_graph_address]
be2_vertex = self.be2.call_graph.vertex[vertex_idx]
vertex_idx: int = self.idx.vertex_index_by_address[flow_graph_address]
be2_vertex: CallGraph.Vertex = self.be2.call_graph.vertex[vertex_idx]

# skip thunks
if capa.features.extractors.binexport2.helpers.is_vertex_type(
be2_vertex, BinExport2.CallGraph.Vertex.Type.THUNK
):
if capa.features.extractors.binexport2.helpers.is_vertex_type(be2_vertex, CallGraph.Vertex.Type.THUNK):
continue

yield FunctionHandle(
Expand All @@ -90,11 +89,11 @@ def extract_function_features(self, fh: FunctionHandle) -> Iterator[Tuple[Featur

def get_basic_blocks(self, fh: FunctionHandle) -> Iterator[BBHandle]:
fhi: FunctionContext = fh.inner
flow_graph_index = fhi.flow_graph_index
flow_graph = self.be2.flow_graph[flow_graph_index]
flow_graph_index: int = fhi.flow_graph_index
flow_graph: BinExport2.FlowGraph = self.be2.flow_graph[flow_graph_index]

for basic_block_index in flow_graph.basic_block_index:
basic_block_address = self.idx.get_basic_block_address(basic_block_index)
basic_block_address: int = self.idx.get_basic_block_address(basic_block_index)
yield BBHandle(
address=AbsoluteVirtualAddress(basic_block_address),
inner=BasicBlockContext(basic_block_index),
Expand All @@ -112,5 +111,7 @@ def get_instructions(self, fh: FunctionHandle, bbh: BBHandle) -> Iterator[InsnHa
inner=InstructionContext(instruction_index),
)

def extract_insn_features(self, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle):
def extract_insn_features(
self, fh: FunctionHandle, bbh: BBHandle, ih: InsnHandle
) -> Iterator[Tuple[Feature, Address]]:
yield from capa.features.extractors.binexport2.insn.extract_features(fh, bbh, ih)
Loading

0 comments on commit cbe83dd

Please sign in to comment.