From f972afd2d38d4d17d2899cd48dae7b3a4cd7647c Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Fri, 29 Nov 2024 12:23:24 -0800 Subject: [PATCH] Support registering new types with classes. (#1167) * Support registering new types with classes. Previously, dns.rdata.register_type() required passing a module which contained the implementation of the new type, and it would extract the class from the module. This change allows passing the class directly. --- dns/rdata.py | 12 +++++++----- tests/test_rdata.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/dns/rdata.py b/dns/rdata.py index bcdac094..1913dd6c 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -891,8 +891,8 @@ def register_type( ) -> None: """Dynamically register a module to handle an rdatatype. - *implementation*, a module implementing the type in the usual dnspython - way. + *implementation*, a subclass of ``dns.rdata.Rdata`` implementing the type, + or a module containing such a class named by its text form. *rdtype*, an ``int``, the rdatatype to register. @@ -909,7 +909,9 @@ def register_type( existing_cls = get_rdata_class(rdclass, rdtype) if existing_cls != GenericRdata or dns.rdatatype.is_metatype(rdtype): raise RdatatypeExists(rdclass=rdclass, rdtype=rdtype) - _rdata_classes[(rdclass, rdtype)] = getattr( - implementation, rdtype_text.replace("-", "_") - ) + if isinstance(implementation, type) and issubclass(implementation, Rdata): + impclass = implementation + else: + impclass = getattr(implementation, rdtype_text.replace("-", "_")) + _rdata_classes[(rdclass, rdtype)] = impclass dns.rdatatype.register_type(rdtype, rdtype_text, is_singleton) diff --git a/tests/test_rdata.py b/tests/test_rdata.py index 4c62aa1d..c1d3416c 100644 --- a/tests/test_rdata.py +++ b/tests/test_rdata.py @@ -63,6 +63,20 @@ def test_module_registration(self): self.assertEqual(dns.rdatatype.from_text("ttxt"), TTXT) self.assertEqual(dns.rdatatype.RdataType.make("ttxt"), TTXT) + def test_class_registration(self): + CTXT = 64003 + class CTXTImp(dns.rdtypes.txtbase.TXTBase): + """Test TXT-like record""" + + dns.rdata.register_type(CTXTImp, CTXT, "CTXT") + rdata = dns.rdata.from_text(dns.rdataclass.IN, CTXT, "hello world") + self.assertEqual(rdata.strings, (b"hello", b"world")) + self.assertEqual(dns.rdatatype.to_text(CTXT), "CTXT") + self.assertEqual(dns.rdatatype.from_text("CTXT"), CTXT) + self.assertEqual(dns.rdatatype.RdataType.make("CTXT"), CTXT) + self.assertEqual(dns.rdatatype.from_text("ctxt"), CTXT) + self.assertEqual(dns.rdatatype.RdataType.make("ctxt"), CTXT) + def test_module_reregistration(self): def bad(): TTXTTWO = dns.rdatatype.TXT