Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert string types to Enums in Struct py:enum gen code #2825

Conversation

roshanjrajan-zip
Copy link
Contributor

@roshanjrajan-zip roshanjrajan-zip commented Jun 30, 2023

When serializing and deserializing Enums for Structs, the current code expects Enum member variables to be a string instead of an Enum. What this means is that the thrift file shows that an Enum is passed in the struct, when actually a string is passed in. This is very confusing can arbitrarily lead to KeyError in the deserialization logic.

However, changing this behavior is potentially breaking. To avoid breaking existing users, check if the type is an Enum when setting the member variable and convert the type to an Enum if it is a string. This allows the actual type to be an Enum and simplifies the serialization/deserialization logic to use Enums instead of converting strings.

Given a thrift file like the following:

enum TestEnum {
  TestEnum0 = 0,
  TestEnum1 = 1,
}

struct TestStruct {
    1: optional string param1
    2: optional TestEnum param2
}

this is how you would create a struct before and after the change but both are valid.

# Before
TestStruct(param1="test_string", param2=TestEnum.TestEnum1.name) # or pass in "TestEnum1"

# After
TestStruct(param1="test_string", param2=TestEnum.TestEnum1)

Serialization/Deserialization Code Change Example-

class TestStruct(object):
    """
    Attributes:
     - param1
     - param2

    """


    def __init__(self, param1=None, param2=None,):
        self.param1 = param1
        self.param2 = param2

++    def __setattr__(self, name, value):
++		if name == "param2":
++	        super().__setattr__(name, value if hasattr(value, 'value') else TestEnum.__members__.get(value))
++			return
++        super().__setattr__(name, value)

    def read(self, iprot):
        if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None:
            iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])
            return
        iprot.readStructBegin()
        while True:
            (fname, ftype, fid) = iprot.readFieldBegin()
            if ftype == TType.STOP:
                break
            if fid == 1:
                if ftype == TType.STRING:
                    self.param1 = iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] == 2 else iprot.readString()
                else:
                    iprot.skip(ftype)
            elif fid == 2:
                if ftype == TType.I32:
--                    self.param2 = TestEnum(iprot.readI32()).name
++                    self.param2 = TestEnum(iprot.readI32())
                else:
                    iprot.skip(ftype)
            else:
                iprot.skip(ftype)
            iprot.readFieldEnd()
        iprot.readStructEnd()

    def write(self, oprot):
        if oprot._fast_encode is not None and self.thrift_spec is not None:
            oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
            return
        oprot.writeStructBegin('TestStruct')
        if self.param1 is not None:
            oprot.writeFieldBegin('param1', TType.STRING, 1)
            oprot.writeString(self.param1.encode('utf-8') if sys.version_info[0] == 2 else self.param1)
            oprot.writeFieldEnd()
        if self.param2 is not None:
            oprot.writeFieldBegin('param2', TType.I32, 2)
--            oprot.writeI32(TestEnum[self.param2].value)
++           oprot.writeI32(self.param2.value)
            oprot.writeFieldEnd()
        oprot.writeFieldStop()
        oprot.writeStructEnd()

    def validate(self):
        return

    def __repr__(self):
        L = ['%s=%r' % (key, value)
             for key, value in self.__dict__.items()]
        return '%s(%s)' % (self.__class__.__name__, ', '.join(L))

    def __eq__(self, other):
        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

    def __ne__(self, other):
        return not (self == other)
all_structs.append(TestStruct)
TestStruct.thrift_spec = (
    None,  # 0
    (1, TType.STRING, 'param1', 'UTF8', None, ),  # 1
    (2, TType.I32, 'param2', None, None, ),  # 2
)
fix_spec(all_structs)
del all_structs
  • Did you create an Apache Jira ticket? (Request account here, not required for trivial changes)
  • If a ticket exists: Does your pull request title follow the pattern "THRIFT-NNNN: describe my issue"?
  • Did you squash your changes to a single commit? (not required, but preferred)
  • Did you do your best to avoid breaking changes? If one was needed, did you label the Jira ticket with "Breaking-Change"?
  • If your change does not involve any code, include [skip ci] anywhere in the commit message to free up build resources.

@roshanjrajan-zip roshanjrajan-zip changed the title Use Enum not Enum.name in Struct serialization/deserialization Use Enum or string in Struct serialization/deserialization for py:enum gen code Jun 30, 2023
@roshanjrajan-zip roshanjrajan-zip changed the title Use Enum or string in Struct serialization/deserialization for py:enum gen code Convert string types to Enums in Struct py:enum gen code Jun 30, 2023
@roshanjrajan-zip roshanjrajan-zip force-pushed the roshanrajan-fix_enum_serialize_deserialize2 branch from b3c774d to 51eca6c Compare June 30, 2023 05:59
out << endl;
indent(out) << "def __setattr__(self, name, value):" << endl;
indent_up();
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test (and also paste the generated python code into PR discussion) on when a struct has 2 fields that's enum (and they are different enum types)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the test! Here is the entire generated code file

#
# Autogenerated by Thrift Compiler (0.19.0)
#
# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
#
#  options string: py:enum
#

from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException
from thrift.protocol.TProtocol import TProtocolException
from thrift.TRecursive import fix_spec
from enum import IntEnum

import sys
import shared_types.ttypes

from thrift.transport import TTransport
all_structs = []


class TestEnum(IntEnum):
    TestEnum0 = 0
    TestEnum1 = 1



class TestStruct(object):
    """
    Attributes:
     - param1
     - param2
     - param3

    """


    def __init__(self, param1=None, param2=None, param3=None,):
        self.param1 = param1
        self.param2 = param2
        self.param3 = param3

    def __setattr__(self, name, value):
        if name == "param2":
            super().__setattr__(name, value if hasattr(value, 'value') else TestEnum.__members__.get(value))
            return
        if name == "param3":
            super().__setattr__(name, value if hasattr(value, 'value') else shared_types.ttypes.SharedEnum.__members__.get(value))
            return
        super().__setattr__(name, value)


    def read(self, iprot):
        if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None:
            iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])
            return
        iprot.readStructBegin()
        while True:
            (fname, ftype, fid) = iprot.readFieldBegin()
            if ftype == TType.STOP:
                break
            if fid == 1:
                if ftype == TType.STRING:
                    self.param1 = iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] == 2 else iprot.readString()
                else:
                    iprot.skip(ftype)
            elif fid == 2:
                if ftype == TType.I32:
                    self.param2 = TestEnum(iprot.readI32())
                else:
                    iprot.skip(ftype)
            elif fid == 3:
                if ftype == TType.I32:
                    self.param3 = shared_types.ttypes.SharedEnum(iprot.readI32())
                else:
                    iprot.skip(ftype)
            else:
                iprot.skip(ftype)
            iprot.readFieldEnd()
        iprot.readStructEnd()

    def write(self, oprot):
        if oprot._fast_encode is not None and self.thrift_spec is not None:
            oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
            return
        oprot.writeStructBegin('TestStruct')
        if self.param1 is not None:
            oprot.writeFieldBegin('param1', TType.STRING, 1)
            oprot.writeString(self.param1.encode('utf-8') if sys.version_info[0] == 2 else self.param1)
            oprot.writeFieldEnd()
        if self.param2 is not None:
            oprot.writeFieldBegin('param2', TType.I32, 2)
            oprot.writeI32(self.param2.value)
            oprot.writeFieldEnd()
        if self.param3 is not None:
            oprot.writeFieldBegin('param3', TType.I32, 3)
            oprot.writeI32(self.param3.value)
            oprot.writeFieldEnd()
        oprot.writeFieldStop()
        oprot.writeStructEnd()

    def validate(self):
        return

    def __repr__(self):
        L = ['%s=%r' % (key, value)
             for key, value in self.__dict__.items()]
        return '%s(%s)' % (self.__class__.__name__, ', '.join(L))

    def __eq__(self, other):
        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

    def __ne__(self, other):
        return not (self == other)
all_structs.append(TestStruct)
TestStruct.thrift_spec = (
    None,  # 0
    (1, TType.STRING, 'param1', 'UTF8', None, ),  # 1
    (2, TType.I32, 'param2', None, None, ),  # 2
    (3, TType.I32, 'param3', None, None, ),  # 3
)
fix_spec(all_structs)
del all_structs

@fishy fishy added the python label Jun 30, 2023
@roshanjrajan-zip roshanjrajan-zip force-pushed the roshanrajan-fix_enum_serialize_deserialize2 branch from 51eca6c to d689e54 Compare June 30, 2023 16:02
@fishy
Copy link
Member

fishy commented Jun 30, 2023

actually I just remembered another corner case, we also override __setattr__ on all exceptions, so we need to make sure that an exception with an enum field still works. please also add a test for that.

(you can check https://issues.apache.org/jira/browse/THRIFT-5715 & https://issues.apache.org/jira/browse/THRIFT-4002 for more context)

@roshanjrajan-zip roshanjrajan-zip force-pushed the roshanrajan-fix_enum_serialize_deserialize2 branch from d689e54 to 443e243 Compare June 30, 2023 17:41
@roshanjrajan-zip
Copy link
Contributor Author

roshanjrajan-zip commented Jun 30, 2023

@fishy Updated the PR with handling the Immutable struct/exception case and added tests.

Here is the code generated from the test case.

class TestException(TException):
    """
    Structs can also be exceptions, if they are nasty.

    Attributes:
     - whatOp
     - why
     - who

    """


    def __init__(self, whatOp=None, why=None, who=None,):
        super(TestException, self).__setattr__('whatOp', whatOp)
        super(TestException, self).__setattr__('why', why if hasattr(why, 'value') else shared_types.ttypes.SharedEnum.__members__.get(why))
        super(TestException, self).__setattr__('who', who if hasattr(who, 'value') else TestEnum.__members__.get(who))

    def __setattr__(self, *args):
        raise TypeError("can't modify immutable instance")

    def __delattr__(self, *args):
        raise TypeError("can't modify immutable instance")

    def __hash__(self):
        return hash(self.__class__) ^ hash((self.whatOp, self.why, self.who, ))

    @classmethod
    def read(cls, iprot):
        if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and cls.thrift_spec is not None:
            return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])
        iprot.readStructBegin()
        whatOp = None
        why = None
        who = None
        while True:
            (fname, ftype, fid) = iprot.readFieldBegin()
            if ftype == TType.STOP:
                break
            if fid == 1:
                if ftype == TType.I32:
                    whatOp = iprot.readI32()
                else:
                    iprot.skip(ftype)
            elif fid == 2:
                if ftype == TType.I32:
                    why = shared_types.ttypes.SharedEnum(iprot.readI32())
                else:
                    iprot.skip(ftype)
            elif fid == 3:
                if ftype == TType.I32:
                    who = TestEnum(iprot.readI32())
                else:
                    iprot.skip(ftype)
            else:
                iprot.skip(ftype)
            iprot.readFieldEnd()
        iprot.readStructEnd()
        return cls(
            whatOp=whatOp,
            why=why,
            who=who,
        )

    def write(self, oprot):
        if oprot._fast_encode is not None and self.thrift_spec is not None:
            oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
            return
        oprot.writeStructBegin('TestException')
        if self.whatOp is not None:
            oprot.writeFieldBegin('whatOp', TType.I32, 1)
            oprot.writeI32(self.whatOp)
            oprot.writeFieldEnd()
        if self.why is not None:
            oprot.writeFieldBegin('why', TType.I32, 2)
            oprot.writeI32(self.why.value)
            oprot.writeFieldEnd()
        if self.who is not None:
            oprot.writeFieldBegin('who', TType.I32, 3)
            oprot.writeI32(self.who.value)
            oprot.writeFieldEnd()
        oprot.writeFieldStop()
        oprot.writeStructEnd()

    def validate(self):
        return

    def __str__(self):
        return repr(self)

    def __repr__(self):
        L = ['%s=%r' % (key, value)
             for key, value in self.__dict__.items()]
        return '%s(%s)' % (self.__class__.__name__, ', '.join(L))

    def __eq__(self, other):
        return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

    def __ne__(self, other):
        return not (self == other)

@roshanjrajan-zip roshanjrajan-zip force-pushed the roshanrajan-fix_enum_serialize_deserialize2 branch from 443e243 to 10475b4 Compare June 30, 2023 17:50
@@ -871,7 +871,12 @@ void t_py_generator::generate_py_struct_definition(ostream& out,
}

if (is_immutable(tstruct)) {
if (gen_newstyle_ || gen_dynamic_) {
if (gen_enum_ && type->is_enum()) {
indent(out) << "super(" << tstruct->get_name() << ", self).__setattr__('"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably too much magic to my preference, makes me start to prefer the approach in #2824 again 😅

Copy link
Contributor Author

@roshanjrajan-zip roshanjrajan-zip Jun 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was thinking about this and then I realized that the fix was actually broken because I wasn't running one of the tests that I added.... 😓 . So actually this fix or something else is needed to make sure existing users aren't broken.

if __name__ == "__main__":
    serialization_deserialization_enum_test()
    serialization_deserialization_string_test

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this code:

    def __init__(self, whatOp=None, why=None, who=None,):
        super(TestException, self).__setattr__('whatOp', whatOp)
        super(TestException, self).__setattr__('why', why if hasattr(why, 'value') else shared_types.ttypes.SharedEnum.__members__.get(why))
        super(TestException, self).__setattr__('who', who if hasattr(who, 'value') else TestEnum.__members__.get(who))

calls the parent's __setattr__, doesn't that essentially set those fields on the parent instead? I guess it still "works" as when accessing them you just inherit from the parent, but this feels wrong to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code that generates super(TestException, self).__setattr__('whatOp', whatOp) already exists before my change and I am extending this pattern. In python, super is implemented as part of the binding process for explicit dotted attribute lookups such as setattr and such. So I don't think this is a concern because this is how Python does things under the hood.

test/py/explicit_module/runtest.sh Show resolved Hide resolved
@roshanjrajan-zip roshanjrajan-zip force-pushed the roshanrajan-fix_enum_serialize_deserialize2 branch from 10475b4 to 10d0fd2 Compare July 2, 2023 02:46
@roshanjrajan-zip roshanjrajan-zip force-pushed the roshanrajan-fix_enum_serialize_deserialize2 branch from 10d0fd2 to 0819a29 Compare July 2, 2023 03:04
Copy link
Member

@fishy fishy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's holiday/long weekend here, I'll probably merge this later next week if no one objects.

@fishy fishy merged commit 284e6b3 into apache:master Jul 6, 2023
10 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants