Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use @-prefixed keys in object codec for link properties #384

Merged
merged 2 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion edgedb/codegen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _generate_code(
print(f"{INDENT}@typing.overload", file=buf)
print(
f'{INDENT}def __getitem__'
f'(self, key: {typing_literal}["@{el_name}"]) '
f'(self, key: {typing_literal}["{el_name}"]) '
f'-> {el_code}:',
file=buf,
)
Expand Down
16 changes: 10 additions & 6 deletions edgedb/datatypes/datatypes.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -54,29 +54,33 @@ def create_object_factory(**pointers):
names = ()
fields = {}
for pname, ptype in pointers.items():
names += (pname,)

if not isinstance(ptype, set):
ptype = {ptype}

flag = 0
is_linkprop = False
for pt in ptype:
if pt == 'link':
flag |= EDGE_POINTER_IS_LINK
elif pt == 'property':
pass
elif pt == 'link-property':
flag |= EDGE_POINTER_IS_LINKPROP
is_linkprop = True
elif pt == 'implicit':
flag |= EDGE_POINTER_IS_IMPLICIT
else:
raise ValueError(f'unknown pointer type {pt}')
if is_linkprop:
names += ("@" + pname,)
else:
names += (pname,)
field = dataclasses.field()
field.name = pname
field._field_type = dataclasses._FIELD
fields[pname] = field

flags += (flag,)
field = dataclasses.field()
field.name = pname
field._field_type = dataclasses._FIELD
fields[pname] = field

desc = EdgeRecordDesc_New(names, flags, <object>NULL)
size = len(pointers)
Expand Down
20 changes: 19 additions & 1 deletion edgedb/datatypes/link.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

static int init_type_called = 0;
static Py_hash_t base_hash = -1;
extern PyObject* at_sign_ptr;


PyObject *
Expand Down Expand Up @@ -276,7 +277,12 @@ link_getattr(EdgeLinkObject *o, PyObject *name)
assert(EdgeRecordDesc_Check(desc));

Py_ssize_t pos;
edge_attr_lookup_t ret = EdgeRecordDesc_Lookup(desc, name, &pos);
PyObject *prefixed_name = PyUnicode_Concat(at_sign_ptr, name);
if (prefixed_name == NULL) {
return NULL;
}
edge_attr_lookup_t ret = EdgeRecordDesc_Lookup(desc, prefixed_name, &pos);
Py_DECREF(prefixed_name);
switch (ret) {
case L_ERROR:
return NULL;
Expand Down Expand Up @@ -313,6 +319,18 @@ link_dir(EdgeLinkObject *o, PyObject *args)
return NULL;
}

PyObject *name, *stripped;
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(ret); i++) {
name = PyList_GET_ITEM(ret, i);
stripped = PyUnicode_Substring(name, 1, PyUnicode_GET_LENGTH(name));
if (stripped == NULL) {
Py_DECREF(ret);
return NULL;
}
PyList_SET_ITEM(ret, i, stripped);
Py_DECREF(name);
}

PyObject *str = PyUnicode_FromString("source");
if (str == NULL) {
Py_DECREF(ret);
Expand Down
114 changes: 24 additions & 90 deletions edgedb/datatypes/object.c
Original file line number Diff line number Diff line change
Expand Up @@ -196,50 +196,10 @@ object_getattr(EdgeObject *o, PyObject *name)
) {
return EdgeRecordDesc_GetDataclassFields((PyObject *)o->desc);
}

// getattr(obj, "@...") for link property
int prefixed = PyUnicode_Tailmatch(
name, at_sign_ptr, 0, PY_SSIZE_T_MAX, -1
);
if (prefixed == -1) {
return NULL;
}
if (prefixed) {
PyObject *stripped = PyUnicode_Substring(
name, 1, PyUnicode_GET_LENGTH(name)
);
if (stripped == NULL) {
return NULL;
}
ret = EdgeRecordDesc_Lookup(
(PyObject *)o->desc, stripped, &pos);
Py_DECREF(stripped);
switch (ret) {
case L_ERROR:
return NULL;

case L_NOT_FOUND:
case L_LINK:
case L_PROPERTY:
return PyObject_GenericGetAttr((PyObject *)o, name);

case L_LINKPROP: {
PyObject *val = EdgeObject_GET_ITEM(o, pos);
Py_INCREF(val);
return val;
}

default:
abort();
}
}

return PyObject_GenericGetAttr((PyObject *)o, name);
}

case L_LINKPROP:
return PyObject_GenericGetAttr((PyObject *)o, name);

case L_LINK:
case L_PROPERTY: {
PyObject *val = EdgeObject_GET_ITEM(o, pos);
Expand All @@ -256,77 +216,51 @@ static PyObject *
object_getitem(EdgeObject *o, PyObject *name)
{
Py_ssize_t pos;
int prefixed = 0;
PyObject *stripped = name;
if (PyUnicode_Check(name)) {
prefixed = PyUnicode_Tailmatch(
name, at_sign_ptr, 0, PY_SSIZE_T_MAX, -1
);
if (prefixed == -1) {
return NULL;
}
if (prefixed) {
stripped = PyUnicode_Substring(
name, 1, PyUnicode_GET_LENGTH(name)
);
if (stripped == NULL) {
return NULL;
}
}
}

edge_attr_lookup_t ret = EdgeRecordDesc_Lookup(
(PyObject *)o->desc, stripped, &pos
(PyObject *)o->desc, name, &pos
);
if (prefixed) {
Py_DECREF(stripped);
}
switch (ret) {
case L_ERROR:
return NULL;

case L_PROPERTY:
PyErr_Format(
PyExc_TypeError,
"property %R should be accessed via dot notation",
name);
return NULL;

case L_LINKPROP: {
PyObject *val = EdgeObject_GET_ITEM(o, pos);
Py_INCREF(val);
return val;
}

case L_NOT_FOUND: {
int prefixed = 0;
if (PyUnicode_Check(name)) {
prefixed = PyUnicode_Tailmatch(
name, at_sign_ptr, 0, PY_SSIZE_T_MAX, -1
);
if (prefixed == -1) {
return NULL;
}
}
if (prefixed) {
PyErr_Format(
PyExc_KeyError,
"link property %R does not exist",
name);
} else {
PyErr_Format(
PyExc_TypeError,
"property %R should be accessed via dot notation",
name);
}
return NULL;

case L_LINKPROP:
if (prefixed) {
PyObject *val = EdgeObject_GET_ITEM(o, pos);
Py_INCREF(val);
return val;
} else {
PyErr_Format(
PyExc_TypeError,
"link property %R should be accessed with '@' prefix",
name);
return NULL;
}

case L_NOT_FOUND:
PyErr_Format(
PyExc_KeyError,
"link property %R does not exist",
name);
return NULL;
}

case L_LINK: {
if (prefixed) {
PyErr_Format(
PyExc_KeyError,
"link property %R does not exist",
name);
return NULL;
}
int res = PyErr_WarnEx(
PyExc_DeprecationWarning,
"getting link on object is deprecated since 1.0, "
Expand Down
7 changes: 1 addition & 6 deletions edgedb/datatypes/repr.c
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,7 @@ _EdgeGeneric_RenderItems(_PyUnicodeWriter *writer,
}

if (is_linkprop) {
if (include_link_props) {
if (_PyUnicodeWriter_WriteChar(writer, '@') < 0) {
goto error;
}
}
else {
if (!include_link_props) {
continue;
}
}
Expand Down
2 changes: 2 additions & 0 deletions edgedb/protocol/codecs/codecs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ cdef class CodecsRegistry:
frb_read(spec, str_len), str_len)
pos = <uint16_t>hton.unpack_int16(frb_read(spec, 2))

if flag & datatypes._EDGE_POINTER_IS_LINKPROP:
name = "@" + name
cpython.Py_INCREF(name)
cpython.PyTuple_SetItem(names, i, name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class LinkPropResult:
class LinkPropResultFriendsItem:
id: uuid.UUID
name: str
created_at: datetime.datetime | None

@typing.overload
def __getitem__(self, key: typing.Literal["@created_at"]) -> datetime.datetime | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class LinkPropResult(NoPydanticValidation):
class LinkPropResultFriendsItem(NoPydanticValidation):
id: uuid.UUID
name: str
created_at: typing.Optional[datetime.datetime]

@typing.overload
def __getitem__(self, key: typing_extensions.Literal["@created_at"]) -> typing.Optional[datetime.datetime]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class LinkPropResult:
class LinkPropResultFriendsItem:
id: uuid.UUID
name: str
created_at: typing.Optional[datetime.datetime]

@typing.overload
def __getitem__(self, key: typing.Literal["@created_at"]) -> typing.Optional[datetime.datetime]:
Expand Down
18 changes: 10 additions & 8 deletions tests/datatypes/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,24 +99,24 @@ def test_recorddesc_3(self):
o = f(1, 2, 3, 4)

desc = private.get_object_descriptor(o)
self.assertEqual(set(dir(desc)), set(('id', 'lb', 'c', 'd')))
self.assertEqual(set(dir(desc)), set(('id', '@lb', 'c', 'd')))

self.assertTrue(desc.is_linkprop('lb'))
self.assertTrue(desc.is_linkprop('@lb'))
self.assertFalse(desc.is_linkprop('id'))
self.assertFalse(desc.is_linkprop('c'))
self.assertFalse(desc.is_linkprop('d'))

self.assertFalse(desc.is_link('lb'))
self.assertFalse(desc.is_link('@lb'))
self.assertFalse(desc.is_link('id'))
self.assertFalse(desc.is_link('c'))
self.assertTrue(desc.is_link('d'))

self.assertFalse(desc.is_implicit('lb'))
self.assertFalse(desc.is_implicit('@lb'))
self.assertTrue(desc.is_implicit('id'))
self.assertFalse(desc.is_implicit('c'))
self.assertFalse(desc.is_implicit('d'))

self.assertEqual(desc.get_pos('lb'), 1)
self.assertEqual(desc.get_pos('@lb'), 1)
self.assertEqual(desc.get_pos('id'), 0)
self.assertEqual(desc.get_pos('c'), 2)
self.assertEqual(desc.get_pos('d'), 3)
Expand Down Expand Up @@ -509,7 +509,7 @@ def test_object_1(self):
with self.assertRaises(TypeError):
len(o)

with self.assertRaises(KeyError):
with self.assertRaises(TypeError):
o[0]

with self.assertRaises(TypeError):
Expand Down Expand Up @@ -681,9 +681,9 @@ def test_object_links_4(self):
u = User(1, None)

with self.assertRaisesRegex(
KeyError, "link property 'error_key' does not exist"
KeyError, "link property '@error_key' does not exist"
):
u['error_key']
u['@error_key']

def test_object_link_property_1(self):
O2 = private.create_object_factory(
Expand Down Expand Up @@ -743,13 +743,15 @@ def test_object_dataclass_1(self):
name='property',
tuple='property',
namedtuple='property',
linkprop="link-property",
)

u = User(
1,
'Bob',
edgedb.Tuple((1, 2.0, '3')),
edgedb.NamedTuple(a=1, b="Y"),
123,
)
self.assertTrue(dataclasses.is_dataclass(u))
self.assertEqual(
Expand Down
32 changes: 32 additions & 0 deletions tests/test_async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,3 +992,35 @@ async def test_async_banned_transaction(self):
edgedb.CapabilityError,
r'cannot execute transaction control commands'):
await self.client.execute('start transaction')

async def test_dup_link_prop_name(self):
obj = await self.client.query_single('''
CREATE TYPE test::dup_link_prop_name {
CREATE PROPERTY val -> str;
};
CREATE TYPE test::dup_link_prop_name_p {
CREATE LINK l -> test::dup_link_prop_name {
CREATE PROPERTY val -> int32;
}
};
INSERT test::dup_link_prop_name_p {
l := (INSERT test::dup_link_prop_name {
val := "hello",
@val := 42,
})
};
SELECT test::dup_link_prop_name_p {
l: {
val,
@val
}
} LIMIT 1;
''')

self.assertEqual(obj.l.val, "hello")
self.assertEqual(obj.l["@val"], 42)

await self.client.execute('''
DROP TYPE test::dup_link_prop_name_p;
DROP TYPE test::dup_link_prop_name;
''')