Skip to content

Commit dfb8c8b

Browse files
fantix1st1
andauthored
Implement dataclass for EdgeObject (#359)
Co-authored-by: Yury Selivanov <yury@edgedb.com>
1 parent 241c80d commit dfb8c8b

File tree

5 files changed

+78
-2
lines changed

5 files changed

+78
-2
lines changed

edgedb/datatypes/datatypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ typedef struct {
5959
EdgeRecordFieldDesc *descs;
6060
Py_ssize_t idpos;
6161
Py_ssize_t size;
62+
PyObject *get_dataclass_fields_func;
6263
} EdgeRecordDescObject;
6364

6465
typedef enum {
@@ -82,6 +83,7 @@ EdgeFieldCardinality EdgeRecordDesc_PointerCardinality(PyObject *, Py_ssize_t);
8283
Py_ssize_t EdgeRecordDesc_GetSize(PyObject *);
8384
edge_attr_lookup_t EdgeRecordDesc_Lookup(PyObject *, PyObject *, Py_ssize_t *);
8485
PyObject * EdgeRecordDesc_List(PyObject *, uint8_t, uint8_t);
86+
PyObject * EdgeRecordDesc_GetDataclassFields(PyObject *);
8587

8688

8789

edgedb/datatypes/object.c

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,18 @@ object_getattr(EdgeObject *o, PyObject *name)
196196
case L_ERROR:
197197
return NULL;
198198

199-
case L_LINKPROP:
200199
case L_NOT_FOUND:
200+
// Used in `dataclasses.as_dict()`
201+
if (
202+
PyUnicode_CompareWithASCIIString(
203+
name, "__dataclass_fields__"
204+
) == 0
205+
) {
206+
return EdgeRecordDesc_GetDataclassFields((PyObject *)o->desc);
207+
}
208+
return PyObject_GenericGetAttr((PyObject *)o, name);
209+
210+
case L_LINKPROP:
201211
return PyObject_GenericGetAttr((PyObject *)o, name);
202212

203213
case L_LINK:
@@ -365,6 +375,16 @@ EdgeObject_InitType(void)
365375
return NULL;
366376
}
367377

378+
// Pass the `dataclasses.is_dataclass(obj)` check - which then checks
379+
// `hasattr(type(obj), "__dataclass_fields__")`, the dict is always empty
380+
PyObject *default_fields = PyDict_New();
381+
if (default_fields == NULL) {
382+
return NULL;
383+
}
384+
PyDict_SetItemString(
385+
EdgeObject_Type.tp_dict, "__dataclass_fields__", default_fields
386+
);
387+
368388
base_hash = _EdgeGeneric_HashString("edgedb.Object");
369389
if (base_hash == -1) {
370390
return NULL;

edgedb/datatypes/record_desc.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ record_desc_dealloc(EdgeRecordDescObject *o)
3030
PyObject_GC_UnTrack(o);
3131
Py_CLEAR(o->index);
3232
Py_CLEAR(o->names);
33+
Py_CLEAR(o->get_dataclass_fields_func);
3334
PyMem_RawFree(o->descs);
3435
PyObject_GC_Del(o);
3536
}
@@ -177,12 +178,24 @@ record_desc_dir(EdgeRecordDescObject *o, PyObject *args)
177178
}
178179

179180

181+
static PyObject *
182+
record_set_dataclass_fields_func(EdgeRecordDescObject *o, PyObject *arg)
183+
{
184+
Py_CLEAR(o->get_dataclass_fields_func);
185+
o->get_dataclass_fields_func = arg;
186+
Py_INCREF(arg);
187+
Py_RETURN_NONE;
188+
}
189+
190+
180191
static PyMethodDef record_desc_methods[] = {
181192
{"is_linkprop", (PyCFunction)record_desc_is_linkprop, METH_O, NULL},
182193
{"is_link", (PyCFunction)record_desc_is_link, METH_O, NULL},
183194
{"is_implicit", (PyCFunction)record_desc_is_implicit, METH_O, NULL},
184195
{"get_pos", (PyCFunction)record_desc_get_pos, METH_O, NULL},
185196
{"__dir__", (PyCFunction)record_desc_dir, METH_NOARGS, NULL},
197+
{"set_dataclass_fields_func",
198+
(PyCFunction)record_set_dataclass_fields_func, METH_O, NULL},
186199
{NULL, NULL}
187200
};
188201

@@ -349,6 +362,7 @@ EdgeRecordDesc_New(PyObject *names, PyObject *flags, PyObject *cards)
349362

350363
o->size = size;
351364
o->idpos = idpos;
365+
o->get_dataclass_fields_func = NULL;
352366

353367
PyObject_GC_Track(o);
354368
return (PyObject *)o;
@@ -537,6 +551,25 @@ EdgeRecordDesc_List(PyObject *ob, uint8_t include_mask, uint8_t exclude_mask)
537551
}
538552

539553

554+
PyObject *
555+
EdgeRecordDesc_GetDataclassFields(PyObject *ob)
556+
{
557+
if (!EdgeRecordDesc_Check(ob)) {
558+
PyErr_BadInternalCall();
559+
return NULL;
560+
}
561+
562+
EdgeRecordDescObject *o = (EdgeRecordDescObject *)ob;
563+
564+
// bpo-37194 added PyObject_CallNoArgs() to Python 3.9.0a1
565+
#if PY_VERSION_HEX < 0x030900A1
566+
return PyObject_CallFunctionObjArgs(o->get_dataclass_fields_func, NULL);
567+
#else
568+
return PyObject_CallNoArgs(o->get_dataclass_fields_func);
569+
#endif
570+
}
571+
572+
540573
PyObject *
541574
EdgeRecordDesc_InitType(void)
542575
{

edgedb/protocol/codecs/object.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
@cython.final
2121
cdef class ObjectCodec(BaseNamedRecordCodec):
22-
cdef bint is_sparse
22+
cdef:
23+
bint is_sparse
24+
object cached_dataclass_fields
2325

2426
cdef encode_args(self, WriteBuffer buf, dict obj)
2527

edgedb/protocol/codecs/object.pyx

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
# limitations under the License.
1717
#
1818

19+
import dataclasses
20+
1921

2022
@cython.final
2123
cdef class ObjectCodec(BaseNamedRecordCodec):
@@ -180,6 +182,22 @@ cdef class ObjectCodec(BaseNamedRecordCodec):
180182

181183
return result
182184

185+
def get_dataclass_fields(self):
186+
cdef descriptor = (<BaseNamedRecordCodec>self).descriptor
187+
188+
rv = self.cached_dataclass_fields
189+
if rv is None:
190+
rv = {}
191+
192+
for i in range(len(self.fields_codecs)):
193+
name = datatypes.record_desc_pointer_name(descriptor, i)
194+
field = rv[name] = dataclasses.field()
195+
field.name = name
196+
field._field_type = dataclasses._FIELD
197+
198+
self.cached_dataclass_fields = rv
199+
return rv
200+
183201
@staticmethod
184202
cdef BaseCodec new(bytes tid, tuple names, tuple flags, tuple cards,
185203
tuple codecs, bint is_sparse):
@@ -195,6 +213,7 @@ cdef class ObjectCodec(BaseNamedRecordCodec):
195213
codec.name = 'Object'
196214
codec.is_sparse = is_sparse
197215
codec.descriptor = datatypes.record_desc_new(names, flags, cards)
216+
codec.descriptor.set_dataclass_fields_func(codec.get_dataclass_fields)
198217
codec.fields_codecs = codecs
199218

200219
return codec

0 commit comments

Comments
 (0)