Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

size computations #40

Merged
merged 4 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions protobuf_serialization.nim
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ import sets
import serialization
export serialization

import protobuf_serialization/[internal, types, reader, writer]
export types, reader, writer
import protobuf_serialization/[internal, types, reader, sizer, writer]
export types, reader, sizer, writer

serializationFormat Protobuf

Expand All @@ -27,3 +27,7 @@ func supports*[T](_: type Protobuf, ty: typedesc[T]): bool =
# TODO return false when not supporting, instead of crashing compiler
static: supportsCompileTime(T)
true

func computeSize*[T: object](_: type Protobuf, value: T): int =
## Return the encoded size of the given value
computeObjectSize(value)
23 changes: 21 additions & 2 deletions protobuf_serialization/codec.nim
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,26 @@ template toBytes*(x: pfloat): openArray[byte] =
template toBytes*(header: FieldHeader): openArray[byte] =
toBytes(uint32(header), Leb128).toOpenArray()

proc vsizeof*(x: SomeVarint): int =
func computeSize*(x: SomeVarint): int =
## Returns number of bytes required to encode integer ``x`` as varint.
Leb128.len(toUleb(x))

func computeSize*(x: SomeFixed64 | SomeFixed32): int =
## Returns number of bytes required to encode integer ``x`` as varint.
sizeof(x)

func computeSize*(x: pstring | pbytes): int =
let len = distinctBase(x).len()
computeSize(puint64(len)) + len

func computeSize*(x: FieldHeader): int =
## Returns number of bytes required to encode integer ``x`` as varint.
computeSize(puint32(x))

func computeSize*(field: int, x: SomeScalar): int =
computeSize(FieldHeader.init(field, wireKind(typeof(x)))) +
computeSize(x)

proc writeValue*(output: OutputStream, value: SomeVarint) =
output.write(toBytes(value))

Expand All @@ -177,8 +193,11 @@ proc writeValue*(output: OutputStream, value: pbytes) =
proc writeValue*(output: OutputStream, value: SomeFixed32) =
output.write(toBytes(value))

proc writeValue*(output: OutputStream, value: FieldHeader) =
output.write(toBytes(value))

proc writeField*(output: OutputStream, field: int, value: SomeScalar) =
output.write(toBytes(FieldHeader.init(field, wireKind(typeof(value)))))
output.writeValue(FieldHeader.init(field, wireKind(typeof(value))))
output.writeValue(value)

proc readValue*[T: SomeVarint](input: InputStream, _: type T): T =
Expand Down
28 changes: 26 additions & 2 deletions protobuf_serialization/internal.nim
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@ proc fieldNumberOf*(T: type, fieldName: static string): int {.compileTime.} =
else:
fieldNum

template tableObject*(TableObject, K, V) =
when K is SomePBInt and V is SomePBInt:
type
TableObject {.proto3.} = object
key {.fieldNumber: 1, pint.}: K
value {.fieldNumber: 2, pint.}: V
elif K is SomePBInt:
type
TableObject {.proto3.} = object
key {.fieldNumber: 1, pint.}: K
value {.fieldNumber: 2.}: V
elif V is SomePBInt:
type
TableObject {.proto3.} = object
key {.fieldNumber: 1.}: K
value {.fieldNumber: 2, pint.}: V
else:
type
TableObject {.proto3.} = object
key {.fieldNumber: 1.}: K
value {.fieldNumber: 2.}: V

template protoType*(InnerType, RootType, FieldType: untyped, fieldName: untyped) =
mixin flatType

Expand Down Expand Up @@ -117,8 +139,10 @@ template protoType*(InnerType, RootType, FieldType: untyped, fieldName: untyped)
type InnerType = pbytes
elif FlatType is enum:
type InnerType = penum
elif FlatType is object or FlatType is ref:
type InnerType = FieldType
elif FlatType is object:
type InnerType = pbytes
elif FlatType is ref and defined(ConformanceTest):
arnetheduck marked this conversation as resolved.
Show resolved Hide resolved
type InnerType = pbytes
else:
type InnerType = UnsupportedType[FieldType, RootType, fieldName]

Expand Down
28 changes: 4 additions & 24 deletions protobuf_serialization/reader.nim
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,7 @@ when defined(ConformanceTest):
header: FieldHeader,
ProtoType: type
) =
# I know it's ugly, but I cannot find a clean way to do it
# ... And nobody cares about map
when K is SomePBInt and V is SomePBInt:
type
TableObject {.proto3.} = object
key {.fieldNumber: 1, pint.}: K
value {.fieldNumber: 2, pint.}: V
elif K is SomePBInt:
type
TableObject {.proto3.} = object
key {.fieldNumber: 1, pint.}: K
value {.fieldNumber: 2.}: V
elif V is SomePBInt:
type
TableObject {.proto3.} = object
key {.fieldNumber: 1.}: K
value {.fieldNumber: 2, pint.}: V
else:
type
TableObject {.proto3.} = object
key {.fieldNumber: 1.}: K
value {.fieldNumber: 2.}: V
tableObject(TableObject, K, V)
var tmp = default(TableObject)
stream.readFieldInto(tmp, header, ProtoType)
value[tmp.key] = tmp.value
Expand Down Expand Up @@ -146,6 +125,7 @@ proc readFieldPackedInto[T](
elif ProtoType is SomeFixed32:
WireKind.Fixed32
else:
static: doAssert ProtoType is SomeFixed64
WireKind.Fixed64

inner.readFieldInto(value[^1], FieldHeader.init(header.number, kind), ProtoType)
Expand Down Expand Up @@ -184,8 +164,8 @@ proc readValueInternal[T: object](stream: InputStream, value: var T, silent: boo
stream.readFieldPackedInto(fieldVar, header, ProtoType)
else:
stream.readFieldInto(fieldVar, header, ProtoType)
elif ProtoType is ref and defined(ConformanceTest):
fieldVar = new ProtoType
elif typeof(fieldVar) is ref and defined(ConformanceTest):
fieldVar = new typeof(fieldVar)
stream.readFieldInto(fieldVar[], header, ProtoType)
else:
stream.readFieldInto(fieldVar, header, ProtoType)
Expand Down
127 changes: 127 additions & 0 deletions protobuf_serialization/sizer.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import
std/[typetraits, tables],
stew/shims/macros,
serialization,
"."/[codec, internal, types]

func computeObjectSize*[T: object](value: T): int

func computeFieldSize(
fieldNum: int, fieldVal: auto, ProtoType: type UnsupportedType,
_: static bool) =
# TODO turn this into an extension point
unsupportedProtoType ProtoType.FieldType, ProtoType.RootType, ProtoType.fieldName

func computeFieldSize[T: object and not PBOption](
fieldNum: int, fieldVal: T, ProtoType: type pbytes,
skipDefault: static bool): int =
let
size = computeObjectSize(fieldVal)

when skipDefault:
if size == 0:
return 0

computeSize(FieldHeader.init(fieldNum, ProtoType.wireKind())) +
computeSize(puint64(size)) +
size

proc computeFieldSize*[T: not object](
fieldNum: int, fieldVal: T,
ProtoType: type SomeScalar, skipDefault: static bool): int =
when skipDefault:
const def = default(typeof(fieldVal))
if fieldVal == def:
return

computeSize(fieldNum, ProtoType(fieldVal))

proc computeFieldSize*(
fieldNum: int, fieldVal: PBOption, ProtoType: type,
skipDefault: static bool): int =
if fieldVal.isSome(): # TODO required field checking
computeFieldSize(fieldNum, fieldVal.get(), ProtoType, skipDefault)
else:
0

when defined(ConformanceTest):
proc computeFieldSize*[T](
fieldNum: int, fieldVal: ref T,
ProtoType: type pbytes, skipDefault: static bool): int =
if not fieldVal.isNil():
computeFieldSize(fieldNum, fieldVal[], ProtoType, skipDefault)
else:
0

proc writeField[T: enum](
stream: OutputStream, fieldNum: int, fieldVal: T, ProtoType: type) =
when 0 notin T:
{.fatal: $T & " definition must contain a constant that maps to zero".}
stream.writeField(fieldNum, pint32(fieldVal.ord()))

proc computeFieldSize*[K, V](
fieldNum: int, fieldVal: Table[K, V], ProtoType: type pbytes,
skipDefault: static bool): int =
tableObject(TableObject, K, V)
for k, v in fieldVal.pairs():
let tmp = TableObject(key: k, value: v)
result += computeFieldSize(fieldNum, tmp, ProtoType, false)

proc computeSizePacked*[T: not byte, ProtoType: SomePrimitive](
values: openArray[T], _: type ProtoType): int =
const canCopyMem =
ProtoType is SomeFixed32 or ProtoType is SomeFixed64 or ProtoType is pbool
when canCopyMem:
values.len() * sizeof(T)
else:
var total = 0
for item in values:
total += computeSize(ProtoType(item))
total

proc computeFieldSizePacked*[ProtoType: SomePrimitive](
field: int, values: openArray, _: type ProtoType): int =
# Packed encoding uses a length-delimited field byte length of the sum of the
# byte lengths of each field followed by the header-free contents
let
dataSize = computeSizePacked(values, ProtoType)

computeSize(FieldHeader.init(field, WireKind.LengthDelim)) +
computeSize(puint64(dataSize)) +
dataSize

func computeObjectSize*[T: object](value: T): int =
const
isProto2: bool = T.isProto2()
isProto3: bool = T.isProto3()
static:
doAssert isProto2 xor isProto3

var total = 0
enumInstanceSerializedFields(value, fieldName, fieldVal):
const
fieldNum = T.fieldNumberOf(fieldName)

type
FlatType = flatType(fieldVal)

protoType(ProtoType, T, FlatType, fieldName)

let fieldSize = when FlatType is seq and FlatType isnot seq[byte]:
const
isPacked = T.isPacked(fieldName).get(isProto3)
when isPacked and ProtoType is SomePrimitive:
computeFieldSizePacked(fieldNum, fieldVal, ProtoType)
else:
var dataSize = 0
for i in 0..<fieldVal.len:
# don't skip defaults so as to preserve length
dataSize += computeFieldSize(fieldNum, fieldVal[i], ProtoType, false)
dataSize

else:
computeFieldSize(fieldNum, fieldVal, ProtoType, isProto3)

total += fieldSize

total
4 changes: 2 additions & 2 deletions protobuf_serialization/types.nim
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ type
ProtobufEOFError* = object of ProtobufReadError
ProtobufMessageError* = object of ProtobufReadError

ProtobufFlags* = uint8 # enum
# VarIntLengthPrefix, # TODO needs fixing
ProtobufFlags* = enum
VarIntLengthPrefix

ProtobufWriter* = object
stream*: OutputStream
Expand Down
Loading