Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for more complex logic for subclass Discrimnator field matching #184

Closed
samongoose opened this issue Dec 21, 2023 · 3 comments · Fixed by #185
Closed

Allow for more complex logic for subclass Discrimnator field matching #184

samongoose opened this issue Dec 21, 2023 · 3 comments · Fixed by #185
Assignees
Labels
enhancement New feature or request

Comments

@samongoose
Copy link

Is your feature request related to a problem? Please describe.
In building a deserializer for a large complex configuration file, some subclasses have identical shapes, but it would still be useful to distinguish by a type field. Unfortunately, in our situation there are multiple values for type that map to the same subclass.

Describe the solution you'd like
I haven't dug into the existing code deep enough to know what's feasible, so there are probably a number of possible solutions (or maybe none, I suppose).

The least intrusive I could see would be having variant_trigger_fn return a list[str] instead of str.

A more involved solution might be adding an inverse to the variant_trigger_fn that takes a str and returns a class.

Finally, the workaround below could be improved upon if there was a hook in the base class that could accomplish the same thing.

Describe alternatives you've considered
I have been able to work around this using __pre_deserialize__:

@dataclass
class ClientEvent(DataClassDictMixin):
    client_ip: str
    type: str
    _type = "unknown"

@dataclass
class ClientConnectedEvent(ClientEvent):
    _type = "connected"

@dataclass
class ClientDisconnectedEvent(ClientEvent):
    _type = "disconnected"

def get_type(typ: str):
    if typ in ["disconnected", "connected"]:
        return typ
    if typ == "d/c":
        return "disconnected"
    return "unknown"

@dataclass
class AggregatedEvents(DataClassDictMixin):
    list: List[
        Annotated[
            ClientEvent, Discriminator(field="_type", include_subtypes=True, include_supertypes=True)
        ]
    ]
    @classmethod
    def __pre_deserialize__(cls, d: dict[Any, Any]) -> dict[Any, Any]:
        d["list"] = [dict({"_type": get_type(event["type"])}, **event) for event in d["list"]]
        return d

events = AggregatedEvents.from_dict(
    {
        "list": [
            {"type": "connected", "client_ip": "10.0.0.42"},
            {"type": "disconnected", "client_ip": "10.0.0.42"},
            {"type": "N/A", "client_ip": "10.0.0.42"},
            {"type": "d/c", "client_ip": "10.0.0.42"},
        ]
    }
)      

# Produces:
AggregatedEvents(list=[ClientConnectedEvent(client_ip='10.0.0.42', type='connected'), ClientDisconnectedEvent(client_ip='10.0.0.42', type='disconnected'), ClientEvent(client_ip='10.0.0.42', type='N/A'), ClientDisconnectedEvent(client_ip='10.0.0.42', type='d/c')])

This works reasonably well (and preserves the original type which can be useful). The major downside is that the __pre_deserialize__ method needs to be implemented in each class that makes use of ClientEvent, and in our situation that ends up being several. It would be more convenient if there was a hook in the ClientEvent base class that could accomplish the same thing.

Additional context
I'm just getting started with this library and it's great so far. It's possible I'm missing something and this is already doable. Alternatively, the workaround, well, works, so feel free to close this if this is not a feature you're looking to add. Thanks!

@Fatal1ty Fatal1ty added the enhancement New feature or request label Dec 22, 2023
@Fatal1ty
Copy link
Owner

The least intrusive I could see would be having variant_trigger_fn return a list[str] instead of str.

At first glance, this is the best way to go. The only difficulty here is what kind of code should be generated when variant_tagger_fn is used. Right now it looks like this:

...
for variant in (*iter_all_subclasses(__main__.ClientEvent), __main__.ClientEvent):
    try:
        variants_map[variant_tagger_fn(variant)] = variant
    except KeyError:
        continue
...

Important part here is variants_map[variant_tagger_fn(variant)] = variant. If variant_tagger_fn returns a list, we will have to iterate through all the items and set an appropriate variant for all of them. If we decide to do some introspection at runtime (like isinstance for example), it will lead to a decrease in performance. We could add a new parameter to Discriminator to configure, but it seems to be overcomplicated.

As an alternative, I'd like you to consider creating an explicit ClientUnknownEvent and using a class level discriminator. There is an undocumented attribute __mashumaro_subtype_variants__ that is set for a class with a class level discriminator. You can manually register all "type" aliases in it.

@dataclass
class ClientEvent(DataClassDictMixin):
    client_ip: str
    type: str

    class Config:
        debug = True
        discriminator = Discriminator(field="type", include_subtypes=True)


@dataclass
class ClientUnknownEvent(ClientEvent):
    type = "unknown"


@dataclass
class ClientConnectedEvent(ClientEvent):
    type = "connected"

@dataclass
class ClientDisconnectedEvent(ClientEvent):
    type = "disconnected"


for key, value in (("N/A", ClientUnknownEvent), ('d/c', ClientDisconnectedEvent)):
    ClientEvent.__mashumaro_subtype_variants__[key] = value


@dataclass
class AggregatedEvents(DataClassDictMixin):
    list: List[ClientEvent]

events = AggregatedEvents.from_dict(
    {
        "list": [
            {"type": "connected", "client_ip": "10.0.0.42"},
            {"type": "disconnected", "client_ip": "10.0.0.42"},
            {"type": "N/A", "client_ip": "10.0.0.42"},
            {"type": "d/c", "client_ip": "10.0.0.42"},
        ]
    }
)

# Produces:
AggregatedEvents(list=[ClientConnectedEvent(client_ip='10.0.0.42', type='connected'), ClientDisconnectedEvent(client_ip='10.0.0.42', type='disconnected'), ClientUnknownEvent(client_ip='10.0.0.42', type='N/A'), ClientDisconnectedEvent(client_ip='10.0.0.42', type='d/c')])

@mishamsk
Copy link
Contributor

btw I had the same case where multiple tag values map to the same class (the tag field is a a literal with multiple possible values). I was thinking about opening a PR to allow variant_tagger_fn but for now I thought I have enough PR's open already:-) also, in my use case, the number of classes is so small that I actually build a fully manual deserializer, seemed quicker than doing a PR

@Fatal1ty
Copy link
Owner

If we decide to do some introspection at runtime (like isinstance for example), it will lead to a decrease in performance.

I've been thinking about it and came to the conclusion that this will not have much of an impact. We iterate over variants and register them only when there is no tag in the registry. I'm going to allow variant_tagger_fn return a list, so that the following code will handle it:

variant_tags = variant_tagger_fn(variant)
if type(variant_tags) is list:
    for varint_tag in variant_tags:
        variants_map[varint_tag] = variant
else:
    variants_map[variant_tags] = variant

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants