Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Eryk Sun <eryksun@gmail.com>
  • Loading branch information
Dobatymo and eryksun authored May 7, 2024
1 parent 6aeaa32 commit 9737f22
Showing 1 changed file with 56 additions and 32 deletions.
88 changes: 56 additions & 32 deletions Modules/mmapmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -256,28 +256,31 @@ do { \
} while (0)
#endif /* UNIX */

#if defined(MS_WIN32) && !defined(DONT_USE_SEH)
#if defined(MS_WINDOWS) && !defined(DONT_USE_SEH)
static DWORD
filter_page_exception(EXCEPTION_POINTERS *ptrs, EXCEPTION_RECORD *record)
{
*record = *ptrs->ExceptionRecord;
if (record->ExceptionCode == EXCEPTION_IN_PAGE_ERROR
|| record->ExceptionCode == EXCEPTION_ACCESS_VIOLATION) {
if (record->ExceptionCode == EXCEPTION_IN_PAGE_ERROR ||
record->ExceptionCode == EXCEPTION_ACCESS_VIOLATION)
{
return EXCEPTION_EXECUTE_HANDLER;
}
return EXCEPTION_CONTINUE_SEARCH;
}

static DWORD
filter_page_exception_method(mmap_object *self, EXCEPTION_POINTERS *ptrs, EXCEPTION_RECORD *record)
filter_page_exception_method(mmap_object *self, EXCEPTION_POINTERS *ptrs,
EXCEPTION_RECORD *record)
{
*record = *ptrs->ExceptionRecord;
if (record->ExceptionCode == EXCEPTION_IN_PAGE_ERROR
|| record->ExceptionCode == EXCEPTION_ACCESS_VIOLATION) {
if (record->ExceptionCode == EXCEPTION_IN_PAGE_ERROR ||
record->ExceptionCode == EXCEPTION_ACCESS_VIOLATION)
{

ULONG_PTR address = record->ExceptionInformation[1];
if (address >= (ULONG_PTR) self->data
&& address < (ULONG_PTR) self->data + (ULONG_PTR) self->size)
if (address >= (ULONG_PTR) self->data &&
address < (ULONG_PTR) self->data + (ULONG_PTR) self->size)
{
return EXCEPTION_EXECUTE_HANDLER;
}
Expand All @@ -286,16 +289,16 @@ filter_page_exception_method(mmap_object *self, EXCEPTION_POINTERS *ptrs, EXCEPT
}
#endif

#if defined(MS_WIN32) && !defined(DONT_USE_SEH)
#if defined(MS_WINDOWS) && !defined(DONT_USE_SEH)
#define HANDLE_INVALID_MEM(sourcecode) \
do { \
EXCEPTION_RECORD record; \
__try { \
sourcecode \
} \
__except (filter_page_exception(GetExceptionInformation(), &record)) { \
assert(record.ExceptionCode == EXCEPTION_IN_PAGE_ERROR \
|| record.ExceptionCode == EXCEPTION_ACCESS_VIOLATION); \
assert(record.ExceptionCode == EXCEPTION_IN_PAGE_ERROR || \
record.ExceptionCode == EXCEPTION_ACCESS_VIOLATION); \
if (record.ExceptionCode == EXCEPTION_IN_PAGE_ERROR) { \
NTSTATUS status = (NTSTATUS) record.ExceptionInformation[2]; \
ULONG code = LsaNtStatusToWinError(status); \
Expand All @@ -314,16 +317,16 @@ do { \
} while (0)
#endif

#if defined(MS_WIN32) && !defined(DONT_USE_SEH)
#if defined(MS_WINDOWS) && !defined(DONT_USE_SEH)
#define HANDLE_INVALID_MEM_METHOD(self, sourcecode) \
do { \
EXCEPTION_RECORD record; \
__try { \
sourcecode \
} \
__except (filter_page_exception_method(self, GetExceptionInformation(), &record)) { \
assert(record.ExceptionCode == EXCEPTION_IN_PAGE_ERROR \
|| record.ExceptionCode == EXCEPTION_ACCESS_VIOLATION); \
assert(record.ExceptionCode == EXCEPTION_IN_PAGE_ERROR || \
record.ExceptionCode == EXCEPTION_ACCESS_VIOLATION); \
if (record.ExceptionCode == EXCEPTION_IN_PAGE_ERROR) { \
NTSTATUS status = (NTSTATUS) record.ExceptionInformation[2]; \
ULONG code = LsaNtStatusToWinError(status); \
Expand All @@ -343,39 +346,45 @@ do {
#endif

int
safe_memcpy(void *restrict dest, const void *restrict src, size_t count) {
safe_memcpy(void *restrict dest, const void *restrict src, size_t count)
{
HANDLE_INVALID_MEM(
memcpy(dest, src, count);
);
return 0;
}

int
safe_byte_copy(char *dest, const char *src) {
safe_byte_copy(char *dest, const char *src)
{
HANDLE_INVALID_MEM(
*dest = *src;
);
return 0;
}

int
safe_memchr(char **out, const void *ptr, int ch, size_t count) {
safe_memchr(char **out, const void *ptr, int ch, size_t count)
{
HANDLE_INVALID_MEM(
*out = (char *) memchr(ptr, ch, count);
);
return 0;
}

int
safe_memmove(void *dest, const void *src, size_t count) {
safe_memmove(void *dest, const void *src, size_t count)
{
HANDLE_INVALID_MEM(
memmove(dest, src, count);
);
return 0;
}

int
safe_copy_from_slice(char *dest, const char *src, Py_ssize_t start, Py_ssize_t step, Py_ssize_t slicelen) {
safe_copy_from_slice(char *dest, const char *src, Py_ssize_t start,
Py_ssize_t step, Py_ssize_t slicelen)
{
HANDLE_INVALID_MEM(
size_t cur;
Py_ssize_t i;
Expand All @@ -387,7 +396,9 @@ safe_copy_from_slice(char *dest, const char *src, Py_ssize_t start, Py_ssize_t s
}

int
safe_copy_to_slice(char *dest, const char *src, Py_ssize_t start, Py_ssize_t step, Py_ssize_t slicelen) {
safe_copy_to_slice(char *dest, const char *src, Py_ssize_t start,
Py_ssize_t step, Py_ssize_t slicelen)
{
HANDLE_INVALID_MEM(
size_t cur;
Py_ssize_t i;
Expand All @@ -401,20 +412,24 @@ safe_copy_to_slice(char *dest, const char *src, Py_ssize_t start, Py_ssize_t ste

int
_safe_PyBytes_Find(Py_ssize_t *out, mmap_object *self, const char *haystack,
Py_ssize_t len_haystack, const char *needle, Py_ssize_t len_needle,
Py_ssize_t offset) {
Py_ssize_t len_haystack, const char *needle,
Py_ssize_t len_needle, Py_ssize_t offset)
{
HANDLE_INVALID_MEM_METHOD(self,
*out = _PyBytes_Find(haystack, len_haystack, needle, len_needle, offset);
);
return 0;
}

int
_safe_PyBytes_ReverseFind(Py_ssize_t *out, mmap_object *self, const char *haystack,
Py_ssize_t len_haystack, const char *needle, Py_ssize_t len_needle,
Py_ssize_t offset) {
_safe_PyBytes_ReverseFind(Py_ssize_t *out, mmap_object *self,
const char *haystack, Py_ssize_t len_haystack,
const char *needle, Py_ssize_t len_needle,
Py_ssize_t offset)
{
HANDLE_INVALID_MEM_METHOD(self,
*out = _PyBytes_ReverseFind(haystack, len_haystack, needle, len_needle, offset);
*out = _PyBytes_ReverseFind(haystack, len_haystack, needle, len_needle,
offset);
);
return 0;
}
Expand Down Expand Up @@ -505,7 +520,8 @@ mmap_read_method(mmap_object *self,
if (num_bytes < 0 || num_bytes > remaining)
num_bytes = remaining;

PyObject *result = _safe_PyBytes_FromStringAndSize(self->data + self->pos, num_bytes);
PyObject *result = _safe_PyBytes_FromStringAndSize(self->data + self->pos,
num_bytes);
if (result != NULL) {
self->pos += num_bytes;
}
Expand Down Expand Up @@ -551,7 +567,8 @@ mmap_gfind(mmap_object *self,
assert(0 <= start && start <= end && end <= self->size);
if (_safe_PyBytes_ReverseFind(&index, self,
self->data + start, end - start,
view.buf, view.len, start) < 0) {
view.buf, view.len, start) < 0)
{
result = NULL;
}
else {
Expand All @@ -562,7 +579,8 @@ mmap_gfind(mmap_object *self,
assert(0 <= start && start <= end && end <= self->size);
if (_safe_PyBytes_Find(&index, self,
self->data + start, end - start,
view.buf, view.len, start) < 0) {
view.buf, view.len, start) < 0)
{
result = NULL;
}
else {
Expand Down Expand Up @@ -1087,7 +1105,9 @@ mmap_protect_method(mmap_object *self, PyObject *args) {
return NULL;
}

if (VirtualProtect((void *) (self->data + start), length, flNewProtect, &flOldProtect) == 0) {
if (!VirtualProtect((void *) (self->data + start), length, flNewProtect,
&flOldProtect))
{
PyErr_SetFromWindowsErr(GetLastError());
return NULL;
}
Expand Down Expand Up @@ -1263,7 +1283,9 @@ mmap_subscript(mmap_object *self, PyObject *item)
if (result_buf == NULL)
return PyErr_NoMemory();

if (safe_copy_to_slice(result_buf, self->data, start, step, slicelen) < 0) {
if (safe_copy_to_slice(result_buf, self->data, start, step,
slicelen) < 0)
{
result = NULL;
}
else {
Expand Down Expand Up @@ -1390,7 +1412,9 @@ mmap_ass_subscript(mmap_object *self, PyObject *item, PyObject *value)
}
}
else {
if (safe_copy_from_slice(self->data, (char *)vbuf.buf, start, step, slicelen) < 0) {
if (safe_copy_from_slice(self->data, (char *)vbuf.buf, start, step,
slicelen) < 0)
{
result = -1;
}
}
Expand Down

0 comments on commit 9737f22

Please sign in to comment.