From c6102615907715249d02390a51bc20132afc4192 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Luka=C5=A1=C3=ADk?= Date: Sun, 17 Aug 2025 12:08:46 +0200 Subject: [PATCH 1/2] perf: do not search and dlopen on every sql cmd This extracts a handle of zen-internal FFI to a thread safe singleton. Ensuring the DLL is searched for and loaded only once. Previously, we would run multiple syscalls to get the location of the dll/so and then built the python datastructures (CDLL object) that has to be later GCed. I haven't microbenchmarked the improvements rigorously, as I feel GC savings may contribute considerably in real-world scenario. Although, poetry run pytest -rP ./aikido_zen/vulnerabilities/sql_injection/ improves by over 4.3% on my set-up. --- .../vulnerabilities/sql_injection/__init__.py | 32 +----------- .../sql_injection/zen_internal_ffi.py | 49 +++++++++++++++++++ 2 files changed, 51 insertions(+), 30 deletions(-) create mode 100644 aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py diff --git a/aikido_zen/vulnerabilities/sql_injection/__init__.py b/aikido_zen/vulnerabilities/sql_injection/__init__.py index 745c8b5b6..745241fc3 100644 --- a/aikido_zen/vulnerabilities/sql_injection/__init__.py +++ b/aikido_zen/vulnerabilities/sql_injection/__init__.py @@ -3,11 +3,8 @@ """ import re -import ctypes from aikido_zen.helpers.logging import logger -from .map_dialect_to_rust_int import map_dialect_to_rust_int -from .get_lib_path import get_binary_path -from ...helpers.encode_safely import encode_safely +from .zen_internal_ffi import ZenInternal def detect_sql_injection(query, user_input, dialect): @@ -20,32 +17,7 @@ def detect_sql_injection(query, user_input, dialect): if should_return_early(query_l, userinput_l): return False - internals_lib = ctypes.CDLL(get_binary_path()) - internals_lib.detect_sql_injection.argtypes = [ - ctypes.POINTER(ctypes.c_uint8), - ctypes.c_size_t, - ctypes.POINTER(ctypes.c_uint8), - ctypes.c_size_t, - ctypes.c_int, - ] - internals_lib.detect_sql_injection.restype = ctypes.c_int - - # Parse input variables for rust function - query_bytes = encode_safely(query_l) - userinput_bytes = encode_safely(userinput_l) - query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes) - userinput_buffer = (ctypes.c_uint8 * len(userinput_bytes)).from_buffer_copy( - userinput_bytes - ) - dialect_int = map_dialect_to_rust_int(dialect) - - c_int_res = internals_lib.detect_sql_injection( - query_buffer, - len(query_bytes), - userinput_buffer, - len(userinput_bytes), - dialect_int, - ) + c_int_res = ZenInternal().detect_sql_injection(query_l, userinput_l, dialect) # This means that an error occurred in the library if c_int_res == 2: diff --git a/aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py b/aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py new file mode 100644 index 000000000..985de95fa --- /dev/null +++ b/aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py @@ -0,0 +1,49 @@ +""" +Interface for calling zen-internal shared library +""" + +import ctypes +import threading +from .get_lib_path import get_binary_path +from .map_dialect_to_rust_int import map_dialect_to_rust_int +from ...helpers.encode_safely import encode_safely + + +class __Singleton(type): + _instances = {} + _lock = threading.Lock() # Ensures thread safety + + def __call__(cls, *args, **kwargs): + with cls._lock: # Lock to make the check-and-create operation atomic + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +class ZenInternal(metaclass=__Singleton): + def __init__(self): + self._lib = ctypes.CDLL(get_binary_path()) + self._lib.detect_sql_injection.argtypes = [ + ctypes.POINTER(ctypes.c_uint8), + ctypes.c_size_t, + ctypes.POINTER(ctypes.c_uint8), + ctypes.c_size_t, + ctypes.c_int, + ] + self._lib.detect_sql_injection.restype = ctypes.c_int + + def detect_sql_injection(self, query, user_input, dialect): + query_bytes = encode_safely(query) + userinput_bytes = encode_safely(user_input) + query_buffer = (ctypes.c_uint8 * len(query_bytes)).from_buffer_copy(query_bytes) + userinput_buffer = (ctypes.c_uint8 * len(userinput_bytes)).from_buffer_copy( + userinput_bytes + ) + dialect_int = map_dialect_to_rust_int(dialect) + return self._lib.detect_sql_injection( + query_buffer, + len(query_bytes), + userinput_buffer, + len(userinput_bytes), + dialect_int, + ) From 171054081b036b165e6d834d80e1001f6b5fef03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=A0imon=20Luka=C5=A1=C3=ADk?= Date: Sun, 17 Aug 2025 12:19:10 +0200 Subject: [PATCH 2/2] refactor: expand map_dialect_to_rust_int fn to its natural habitat --- .../sql_injection/map_dialect_to_rust_int.py | 19 ------------------- .../sql_injection/zen_internal_ffi.py | 12 ++++++++++-- 2 files changed, 10 insertions(+), 21 deletions(-) delete mode 100644 aikido_zen/vulnerabilities/sql_injection/map_dialect_to_rust_int.py diff --git a/aikido_zen/vulnerabilities/sql_injection/map_dialect_to_rust_int.py b/aikido_zen/vulnerabilities/sql_injection/map_dialect_to_rust_int.py deleted file mode 100644 index b18b96b6a..000000000 --- a/aikido_zen/vulnerabilities/sql_injection/map_dialect_to_rust_int.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Exports map_dialect_to_rust_int -""" - -DIALECTS = { - "generic": 0, - "clickhouse": 3, - "mysql": 8, - "postgres": 9, - "sqlite": 12, -} - - -def map_dialect_to_rust_int(dialect): - """ - This takes the string dialect as input and maps it to a rust integer - Reference : [rust lib]/src/sql_injection/helpers/select_dialect_based_on_enum.rs - """ - return DIALECTS[dialect] diff --git a/aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py b/aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py index 985de95fa..b046c2994 100644 --- a/aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py +++ b/aikido_zen/vulnerabilities/sql_injection/zen_internal_ffi.py @@ -5,7 +5,6 @@ import ctypes import threading from .get_lib_path import get_binary_path -from .map_dialect_to_rust_int import map_dialect_to_rust_int from ...helpers.encode_safely import encode_safely @@ -21,6 +20,15 @@ def __call__(cls, *args, **kwargs): class ZenInternal(metaclass=__Singleton): + # Reference : [rust lib]/src/sql_injection/helpers/select_dialect_based_on_enum.rs + SQL_DIALECTS = { + "generic": 0, + "clickhouse": 3, + "mysql": 8, + "postgres": 9, + "sqlite": 12, + } + def __init__(self): self._lib = ctypes.CDLL(get_binary_path()) self._lib.detect_sql_injection.argtypes = [ @@ -39,7 +47,7 @@ def detect_sql_injection(self, query, user_input, dialect): userinput_buffer = (ctypes.c_uint8 * len(userinput_bytes)).from_buffer_copy( userinput_bytes ) - dialect_int = map_dialect_to_rust_int(dialect) + dialect_int = self.SQL_DIALECTS[dialect] return self._lib.detect_sql_injection( query_buffer, len(query_bytes),