Skip to content

Commit

Permalink
fix segfaults on tests in debug builds for PyPy
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jul 5, 2023
1 parent d08d269 commit ba938b7
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 41 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,22 @@ jobs:

# test with a debug build as it picks up errors which optimised release builds do not
test-debug:
name: test-debug ${{ matrix.python-version }}
runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
python-version:
- '3.11'
- 'pypy3.9'

steps:
- uses: actions/checkout@v3
- name: set up python
uses: actions/setup-python@v4
with:
python-version: '3.11'
python-version: ${{ matrix.python-version }}

- run: pip install -r tests/requirements.txt
- run: pip install -e . --config-settings=build-args='--profile dev'
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub fn get_version() -> String {
fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> {
m.add("__version__", get_version())?;
m.add("build_profile", env!("PROFILE"))?;
m.add("_recursion_limit", recursion_guard::RECURSION_GUARD_LIMIT)?;
m.add("PydanticUndefined", PydanticUndefinedType::new(py))?;
m.add_class::<PydanticUndefinedType>()?;
m.add_class::<PySome>()?;
Expand Down
9 changes: 6 additions & 3 deletions src/recursion_guard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ type RecursionKey = (
#[derive(Debug, Clone, Default)]
pub struct RecursionGuard {
ids: Option<AHashSet<RecursionKey>>,
// see validators/definition::BACKUP_GUARD_LIMIT for details
// depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just
// use one number for all validators
depth: u16,
}

// A hard limit to avoid stack overflows when rampant recursion occurs
pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(target_family = "wasm") { 50 } else { 255 };

impl RecursionGuard {
// insert a new id into the set, return whether the set already had the id in it
pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool {
Expand All @@ -37,9 +39,10 @@ impl RecursionGuard {
}

// see #143 this is used as a backup in case the identity check recursion guard fails
pub fn incr_depth(&mut self) -> u16 {
#[must_use]
pub fn incr_depth(&mut self) -> bool {
self.depth += 1;
self.depth
self.depth >= RECURSION_GUARD_LIMIT
}

pub fn decr_depth(&mut self) {
Expand Down
30 changes: 9 additions & 21 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::{intern, AsPyPointer};

use ahash::AHashSet;
use serde::ser::Error;

use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use super::shared::CombinedSerializer;
use crate::definitions::Definitions;
use crate::recursion_guard::RecursionGuard;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
/// then `SerializationState` got a `serialize_infer` method, but I couldn't get it to work
Expand Down Expand Up @@ -347,43 +347,31 @@ impl CollectWarnings {
}
}

/// we have `RecursionInfo` then a `RefCell` since `SerializeInfer.serialize` can't take a `&mut self`
#[derive(Default, Clone)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct RecursionInfo {
ids: AHashSet<(usize, usize)>, // first element is the object's id, the second is the serializer's id
/// as with `src/recursion_guard.rs` this is used as a backup in case the identity check recursion guard fails
/// see #143
depth: u16,
}

#[derive(Default, Clone)]
#[cfg_attr(debug_assertions, derive(Debug))]
pub struct SerRecursionGuard {
info: RefCell<RecursionInfo>,
guard: RefCell<RecursionGuard>,
}

impl SerRecursionGuard {
const MAX_DEPTH: u16 = 200;

pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<usize> {
// https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert
// "If the set did not have this value present, `true` is returned."
let id = value.as_ptr() as usize;
let mut info = self.info.borrow_mut();
if !info.ids.insert((id, def_ref_id)) {
let mut guard = self.guard.borrow_mut();

if guard.contains_or_insert(id, def_ref_id) {
Err(PyValueError::new_err("Circular reference detected (id repeated)"))
} else if info.depth > Self::MAX_DEPTH {
} else if guard.incr_depth() {
Err(PyValueError::new_err("Circular reference detected (depth exceeded)"))
} else {
info.depth += 1;
Ok(id)
}
}

pub fn pop(&self, id: usize, def_ref_id: usize) {
let mut info = self.info.borrow_mut();
info.depth -= 1;
info.ids.remove(&(id, def_ref_id));
let mut guard = self.guard.borrow_mut();
guard.decr_depth();
guard.remove(id, def_ref_id);
}
}
13 changes: 2 additions & 11 deletions src/validators/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl Validator for DefinitionRefValidator {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorType::RecursionLoop, input))
} else {
if recursion_guard.incr_depth() > BACKUP_GUARD_LIMIT {
if recursion_guard.incr_depth() {
return Err(ValError::new(ErrorType::RecursionLoop, input));
}
let output = validate(self.validator_id, py, input, extra, definitions, recursion_guard);
Expand Down Expand Up @@ -112,7 +112,7 @@ impl Validator for DefinitionRefValidator {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorType::RecursionLoop, obj))
} else {
if recursion_guard.incr_depth() > BACKUP_GUARD_LIMIT {
if recursion_guard.incr_depth() {
return Err(ValError::new(ErrorType::RecursionLoop, obj));
}
let output = validate_assignment(
Expand Down Expand Up @@ -169,15 +169,6 @@ impl Validator for DefinitionRefValidator {
}
}

// see #143 this is a backup in case the identity check recursion guard fails
// if a single validator "depth" (how many times it's called inside itself) exceeds the limit,
// we raise a recursion error.
const BACKUP_GUARD_LIMIT: u16 = if cfg!(PyPy) || cfg!(target_family = "wasm") {
123
} else {
255
};

fn validate<'s, 'data>(
validator_id: usize,
py: Python<'data>,
Expand Down
17 changes: 15 additions & 2 deletions tests/benchmarks/test_micro_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import platform
import sys
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from enum import Enum
Expand All @@ -12,6 +13,7 @@
import pytest
from dirty_equals import IsStr

import pydantic_core
from pydantic_core import ArgsKwargs, PydanticCustomError, SchemaValidator, ValidationError, core_schema
from pydantic_core import ValidationError as CoreValidationError

Expand All @@ -26,6 +28,15 @@

skip_pydantic = pytest.mark.skipif(BaseModel is None, reason='skipping benchmarks vs. pydantic')

skip_pypy_deep_stack = pytest.mark.skipif(
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
)

skip_wasm_deep_stack = pytest.mark.skipif(
sys.platform == 'emscriptem', reason='wasm does not have enough stack space to recurse very deep'
)


class TestBenchmarkSimpleModel:
@pytest.fixture(scope='class')
Expand Down Expand Up @@ -328,7 +339,6 @@ def definition_model_data():
return data


@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason='crashes on pypy due to recursion depth')
@skip_pydantic
@pytest.mark.benchmark(group='recursive model')
def test_definition_model_pyd(definition_model_data, benchmark):
Expand All @@ -339,7 +349,8 @@ class PydanticBranch(BaseModel):
benchmark(PydanticBranch.parse_obj, definition_model_data)


@pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason='crashes on pypy due to recursion depth')
@skip_pypy_deep_stack
@skip_wasm_deep_stack
@pytest.mark.benchmark(group='recursive model')
def test_definition_model_core(definition_model_data, benchmark):
class CoreBranch:
Expand Down Expand Up @@ -1452,6 +1463,8 @@ def test_tagged_union_int_keys_json(benchmark):
benchmark(v.validate_json, payload)


@skip_pypy_deep_stack
@skip_wasm_deep_stack
@pytest.mark.benchmark(group='field_function_validator')
def test_field_function_validator(benchmark) -> None:
def f(v: int, info: core_schema.FieldValidationInfo) -> int:
Expand Down
11 changes: 9 additions & 2 deletions tests/serializers/test_any.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import json
import platform
import sys
from collections import namedtuple
from datetime import date, datetime, time, timedelta, timezone
Expand All @@ -11,6 +12,7 @@
import pytest
from dirty_equals import HasRepr, IsList

import pydantic_core
from pydantic_core import PydanticSerializationError, SchemaSerializer, SchemaValidator, core_schema, to_json

from ..conftest import plain_repr
Expand Down Expand Up @@ -343,6 +345,10 @@ def __repr__(self):
return f'<FoobarCount {self.v} repr>'


@pytest.mark.skipif(
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
)
def test_fallback_cycle_change(any_serializer: SchemaSerializer):
v = 1

Expand All @@ -359,8 +365,9 @@ def fallback_func(obj):

f = FoobarCount(0)
v = 0
# because when recursion is detected and we're in mode python, we just return the value
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr('<FoobarCount 201 repr>')
# when recursion is detected and we're in mode python, we just return the value
expected_visits = pydantic_core._pydantic_core._recursion_limit - 1
assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'<FoobarCount {expected_visits} repr>')

with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'):
any_serializer.to_json(f, fallback=fallback_func)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import platform
import re
from typing import List

import pytest
from dirty_equals import IsList

import pydantic_core
from pydantic_core import (
PydanticSerializationError,
SchemaSerializer,
Expand Down Expand Up @@ -285,6 +287,10 @@ def fallback_func_passthrough(obj):
to_json(f, fallback=fallback_func_passthrough)


@pytest.mark.skipif(
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
)
def test_cycle_change():
def fallback_func_change_id(obj):
return Foobar()
Expand Down
8 changes: 7 additions & 1 deletion tests/validators/test_definitions_recursive.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import platform
from dataclasses import dataclass
from typing import List, Optional

import pytest
from dirty_equals import AnyThing, HasAttributes, IsList, IsPartialDict, IsStr, IsTuple

import pydantic_core
from pydantic_core import SchemaError, SchemaValidator, ValidationError, __version__, core_schema

from ..conftest import Err, plain_repr
Expand Down Expand Up @@ -714,6 +716,10 @@ def f(input_value, info):
]


@pytest.mark.skipif(
platform.python_implementation() == 'PyPy' and pydantic_core._pydantic_core.build_profile == 'debug',
reason='PyPy does not have enough stack space for Rust debug builds to recurse very deep',
)
@pytest.mark.parametrize('strict', [True, False], ids=lambda s: f'strict={s}')
def test_function_change_id(strict: bool):
def f(input_value, info):
Expand Down Expand Up @@ -750,7 +756,7 @@ def f(input_value, info):


def test_many_uses_of_ref():
# check we can safely exceed BACKUP_GUARD_LIMIT without upsetting the backup recursion guard
# check we can safely exceed RECURSION_GUARD_LIMIT without upsetting the recursion guard
v = SchemaValidator(
{
'type': 'typed-dict',
Expand Down

0 comments on commit ba938b7

Please sign in to comment.