Skip to content

Commit

Permalink
[Python] Fixed the issue with nested unions relying on InitFromBuf. (g…
Browse files Browse the repository at this point in the history
…oogle#7576)

* feat: Fixed the issue with nested unions relying on InitFromBuf.
Problem: Issue google#7569
Nested Unions were broken with the introduction of parsing buffers with an initial encoding offset.

Fix:
Revert the InitFromBuf method to the previous version and introduction of InitFromPackedBuf that allows
users to read types from packed buffers applying the offset automatically.

Test:
Added in TestNestedUnionTables to test the encoding and decoding ability using a nested table with a
union field.

* fix: Uncommented generate code command
  • Loading branch information
joshua-smith8 authored and Jochen Parmentier committed Oct 29, 2024
1 parent 7758c89 commit b7f34d1
Show file tree
Hide file tree
Showing 30 changed files with 766 additions and 65 deletions.
16 changes: 11 additions & 5 deletions scripts/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def glob(path, pattern):
LOBSTER_OPTS = ["--lobster"]
SWIFT_OPTS = ["--swift", "--gen-json-emit", "--bfbs-filenames", str(tests_path)]
SWIFT_OPTS_CODE_GEN = [
"--swift",
"--swift",
"--gen-json-emit",
"--bfbs-filenames",
swift_code_gen
Expand Down Expand Up @@ -305,14 +305,14 @@ def glob(path, pattern):

# Generate the annotated binary of the monster_test binary schema.
flatc_annotate(
schema="../reflection/reflection.fbs",
file="monster_test.bfbs",
schema="../reflection/reflection.fbs",
file="monster_test.bfbs",
include="include_test"
)

flatc_annotate(
schema="monster_test.fbs",
file="monsterdata_test.mon",
schema="monster_test.fbs",
file="monsterdata_test.mon",
include="include_test"
)

Expand Down Expand Up @@ -371,6 +371,12 @@ def glob(path, pattern):
)


flatc(
BASE_OPTS + PYTHON_OPTS,
schema="nested_union_test.fbs",
)


# Optional Scalars
optional_scalars_schema = "optional_scalars.fbs"
flatc(["--java", "--kotlin", "--lobster"], schema=optional_scalars_schema)
Expand Down
18 changes: 16 additions & 2 deletions src/idl_gen_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1137,13 +1137,25 @@ class PythonGenerator : public BaseGenerator {

code += GenIndents(1) + "@classmethod";
code += GenIndents(1) + "def InitFromBuf(cls, buf, pos):";
code += GenIndents(2) + "n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, 0)";
code += GenIndents(2) + struct_var + " = " + struct_type + "()";
code += GenIndents(2) + struct_var + ".Init(buf, pos+n)";
code += GenIndents(2) + struct_var + ".Init(buf, pos)";
code += GenIndents(2) + "return cls.InitFromObj(" + struct_var + ")";
code += "\n";
}

void InitializeFromPackedBuf(const StructDef &struct_def,
std::string *code_ptr) const {
auto &code = *code_ptr;
const auto struct_var = namer_.Variable(struct_def);
const auto struct_type = namer_.Type(struct_def);

code += GenIndents(1) + "@classmethod";
code += GenIndents(1) + "def InitFromPackedBuf(cls, buf, pos=0):";
code += GenIndents(2) + "n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)";
code += GenIndents(2) + "return cls.InitFromBuf(buf, pos+n)";
code += "\n";
}

void InitializeFromObjForObject(const StructDef &struct_def,
std::string *code_ptr) const {
auto &code = *code_ptr;
Expand Down Expand Up @@ -1636,6 +1648,8 @@ class PythonGenerator : public BaseGenerator {

InitializeFromBuf(struct_def, &code);

InitializeFromPackedBuf(struct_def, &code);

InitializeFromObjForObject(struct_def, &code);

GenUnPack(struct_def, &code);
Expand Down
8 changes: 6 additions & 2 deletions tests/MyGame/Example/Ability.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ def __init__(self):

@classmethod
def InitFromBuf(cls, buf, pos):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, 0)
ability = Ability()
ability.Init(buf, pos+n)
ability.Init(buf, pos)
return cls.InitFromObj(ability)

@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)

@classmethod
def InitFromObj(cls, ability):
x = AbilityT()
Expand Down
8 changes: 6 additions & 2 deletions tests/MyGame/Example/ArrayStruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,15 @@ def __init__(self):

@classmethod
def InitFromBuf(cls, buf, pos):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, 0)
arrayStruct = ArrayStruct()
arrayStruct.Init(buf, pos+n)
arrayStruct.Init(buf, pos)
return cls.InitFromObj(arrayStruct)

@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)

@classmethod
def InitFromObj(cls, arrayStruct):
x = ArrayStructT()
Expand Down
8 changes: 6 additions & 2 deletions tests/MyGame/Example/ArrayTable.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,15 @@ def __init__(self):

@classmethod
def InitFromBuf(cls, buf, pos):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, 0)
arrayTable = ArrayTable()
arrayTable.Init(buf, pos+n)
arrayTable.Init(buf, pos)
return cls.InitFromObj(arrayTable)

@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)

@classmethod
def InitFromObj(cls, arrayTable):
x = ArrayTableT()
Expand Down
8 changes: 6 additions & 2 deletions tests/MyGame/Example/Monster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,11 +1131,15 @@ def __init__(self):

@classmethod
def InitFromBuf(cls, buf, pos):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, 0)
monster = Monster()
monster.Init(buf, pos+n)
monster.Init(buf, pos)
return cls.InitFromObj(monster)

@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)

@classmethod
def InitFromObj(cls, monster):
x = MonsterT()
Expand Down
8 changes: 6 additions & 2 deletions tests/MyGame/Example/NestedStruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,15 @@ def __init__(self):

@classmethod
def InitFromBuf(cls, buf, pos):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, 0)
nestedStruct = NestedStruct()
nestedStruct.Init(buf, pos+n)
nestedStruct.Init(buf, pos)
return cls.InitFromObj(nestedStruct)

@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)

@classmethod
def InitFromObj(cls, nestedStruct):
x = NestedStructT()
Expand Down
20 changes: 20 additions & 0 deletions tests/MyGame/Example/NestedUnion/Any.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: NestedUnion

class Any(object):
NONE = 0
Vec3 = 1
TestSimpleTableWithEnum = 2

def AnyCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == Any().Vec3:
import MyGame.Example.NestedUnion.Vec3
return MyGame.Example.NestedUnion.Vec3.Vec3T.InitFromBuf(table.Bytes, table.Pos)
if unionType == Any().TestSimpleTableWithEnum:
import MyGame.Example.NestedUnion.TestSimpleTableWithEnum
return MyGame.Example.NestedUnion.TestSimpleTableWithEnum.TestSimpleTableWithEnumT.InitFromBuf(table.Bytes, table.Pos)
return None
12 changes: 12 additions & 0 deletions tests/MyGame/Example/NestedUnion/Color.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: NestedUnion

# Composite components of Monster color.
class Color(object):
Red = 1
# \brief color Green
# Green is bit_flag with value (1u << 1)
Green = 2
# \brief color Blue (1u << 3)
Blue = 8
133 changes: 133 additions & 0 deletions tests/MyGame/Example/NestedUnion/NestedUnionTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: NestedUnion

import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()

class NestedUnionTest(object):
__slots__ = ['_tab']

@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = NestedUnionTest()
x.Init(buf, n + offset)
return x

@classmethod
def GetRootAsNestedUnionTest(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# NestedUnionTest
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# NestedUnionTest
def Name(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None

# NestedUnionTest
def DataType(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
return 0

# NestedUnionTest
def Data(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
from flatbuffers.table import Table
obj = Table(bytearray(), 0)
self._tab.Union(obj, o)
return obj
return None

# NestedUnionTest
def Id(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int16Flags, o + self._tab.Pos)
return 0

def NestedUnionTestStart(builder): builder.StartObject(4)
def Start(builder):
return NestedUnionTestStart(builder)
def NestedUnionTestAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
def AddName(builder, name):
return NestedUnionTestAddName(builder, name)
def NestedUnionTestAddDataType(builder, dataType): builder.PrependUint8Slot(1, dataType, 0)
def AddDataType(builder, dataType):
return NestedUnionTestAddDataType(builder, dataType)
def NestedUnionTestAddData(builder, data): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0)
def AddData(builder, data):
return NestedUnionTestAddData(builder, data)
def NestedUnionTestAddId(builder, id): builder.PrependInt16Slot(3, id, 0)
def AddId(builder, id):
return NestedUnionTestAddId(builder, id)
def NestedUnionTestEnd(builder): return builder.EndObject()
def End(builder):
return NestedUnionTestEnd(builder)
import MyGame.Example.NestedUnion.Any
import MyGame.Example.NestedUnion.TestSimpleTableWithEnum
import MyGame.Example.NestedUnion.Vec3
try:
from typing import Union
except:
pass

class NestedUnionTestT(object):

# NestedUnionTestT
def __init__(self):
self.name = None # type: str
self.dataType = 0 # type: int
self.data = None # type: Union[None, MyGame.Example.NestedUnion.Vec3.Vec3T, MyGame.Example.NestedUnion.TestSimpleTableWithEnum.TestSimpleTableWithEnumT]
self.id = 0 # type: int

@classmethod
def InitFromBuf(cls, buf, pos):
nestedUnionTest = NestedUnionTest()
nestedUnionTest.Init(buf, pos)
return cls.InitFromObj(nestedUnionTest)

@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)

@classmethod
def InitFromObj(cls, nestedUnionTest):
x = NestedUnionTestT()
x._UnPack(nestedUnionTest)
return x

# NestedUnionTestT
def _UnPack(self, nestedUnionTest):
if nestedUnionTest is None:
return
self.name = nestedUnionTest.Name()
self.dataType = nestedUnionTest.DataType()
self.data = MyGame.Example.NestedUnion.Any.AnyCreator(self.dataType, nestedUnionTest.Data())
self.id = nestedUnionTest.Id()

# NestedUnionTestT
def Pack(self, builder):
if self.name is not None:
name = builder.CreateString(self.name)
if self.data is not None:
data = self.data.Pack(builder)
NestedUnionTestStart(builder)
if self.name is not None:
NestedUnionTestAddName(builder, name)
NestedUnionTestAddDataType(builder, self.dataType)
if self.data is not None:
NestedUnionTestAddData(builder, data)
NestedUnionTestAddId(builder, self.id)
nestedUnionTest = NestedUnionTestEnd(builder)
return nestedUnionTest
Loading

0 comments on commit b7f34d1

Please sign in to comment.