Skip to content

Commit

Permalink
feat(protobuf): added error wrappers for invalid length validation
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenzo Delgado committed Feb 16, 2023
1 parent e12b7cb commit 7c958a8
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 49 deletions.
1 change: 1 addition & 0 deletions tests/all_tests_common.nim
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
import
./common/test_envvar_serialization,
./common/test_confutils_envvar,
./common/test_protobuf_validation,
./common/test_sqlite_migrations
103 changes: 103 additions & 0 deletions tests/common/test_protobuf_validation.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@

{.used.}

import
testutils/unittests
import
../../waku/common/protobuf


## Fixtures

const MaxTestRpcFieldLen = 5

type TestRpc = object
testField*: string

proc init(T: type TestRpc, field: string): T =
T(testField: field)

proc encode(rpc: TestRpc): ProtoBuffer =
var pb = initProtoBuffer()
pb.write3(1, rpc.testField)
pb.finish3()
pb

proc encodeWithBadFieldId(rpc: TestRpc): ProtoBuffer =
var pb = initProtoBuffer()
pb.write3(666, rpc.testField)
pb.finish3()
pb

proc decode(T: type TestRpc, buf: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buf)

var field: string
if not ?pb.getField(1, field):
return err(ProtobufError.missingRequiredField("test_field"))
if field.len > MaxTestRpcFieldLen:
return err(ProtobufError.invalidLengthField("test_field"))

ok(TestRpc.init(field))


## Tests

suite "Waku Common - libp2p minprotobuf wrapper":

test "serialize and deserialize - valid length field":
## Given
let field = "12345"

let rpc = TestRpc.init(field)

## When
let encodedRpc = rpc.encode()
let decodedRpcRes = TestRpc.decode(encodedRpc.buffer)

## Then
check:
decodedRpcRes.isOk()

let decodedRpc = decodedRpcRes.tryGet()
check:
decodedRpc.testField == field

test "serialize and deserialize - missing required field":
## Given
let field = "12345"

let rpc = TestRpc.init(field)

## When
let encodedRpc = rpc.encodeWithBadFieldId()
let decodedRpcRes = TestRpc.decode(encodedRpc.buffer)

## Then
check:
decodedRpcRes.isErr()

let error = decodedRpcRes.tryError()
check:
error.kind == ProtobufErrorKind.MissingRequiredField
error.field == "test_field"


test "serialize and deserialize - invalid length field":
## Given
let field = "123456" # field.len = MaxTestRpcFieldLen + 1

let rpc = TestRpc.init(field)

## When
let encodedRpc = rpc.encode()
let decodedRpcRes = TestRpc.decode(encodedRpc.buffer)

## Then
check:
decodedRpcRes.isErr()

let error = decodedRpcRes.tryError()
check:
error.kind == ProtobufErrorKind.InvalidLengthField
error.field == "test_field"
6 changes: 3 additions & 3 deletions tests/v2/test_waku_noise_sessions.nim
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import
std/tables,
stew/[results, byteutils],
testutils/unittests,
libp2p/protobuf/minprotobuf
testutils/unittests
import
../../waku/common/protobuf,
../../waku/v2/utils/noise as waku_message_utils,
../../waku/v2/protocol/waku_noise/noise_types,
../../waku/v2/protocol/waku_noise/noise_utils,
Expand Down Expand Up @@ -82,7 +82,7 @@ procSuite "Waku Noise Sessions":
var
sentTransportMessage: seq[byte]
aliceStep, bobStep: HandshakeStepResult
msgFromPb: ProtoResult[WakuMessage]
msgFromPb: ProtobufResult[WakuMessage]
wakuMsg: Result[WakuMessage, cstring]
pb: ProtoBuffer
readPayloadV2: PayloadV2
Expand Down
37 changes: 36 additions & 1 deletion waku/common/protobuf.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,47 @@ import
std/options,
libp2p/protobuf/minprotobuf,
libp2p/varint

export
minprotobuf,
varint


## Custom errors

type
ProtobufErrorKind* {.pure.} = enum
DecodeFailure
MissingRequiredField
InvalidLengthField

ProtobufError* = object
case kind*: ProtobufErrorKind
of DecodeFailure:
error*: minprotobuf.ProtoError
of MissingRequiredField, InvalidLengthField:
field*: string

ProtobufResult*[T] = Result[T, ProtobufError]


converter toProtobufError*(err: minprotobuf.ProtoError): ProtobufError =
case err:
of minprotobuf.ProtoError.RequiredFieldMissing:
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: "unknown")
else:
ProtobufError(kind: ProtobufErrorKind.DecodeFailure, error: err)


proc missingRequiredField*(T: type ProtobufError, field: string): T =
ProtobufError(kind: ProtobufErrorKind.MissingRequiredField, field: field)

proc invalidLengthField*(T: type ProtobufError, field: string): T =
ProtobufError(kind: ProtobufErrorKind.InvalidLengthField, field: field)


## Extension methods

proc write3*(proto: var ProtoBuffer, field: int, value: auto) =
when value is Option:
if value.isSome():
Expand Down
40 changes: 23 additions & 17 deletions waku/v2/protocol/waku_filter/rpc_codec.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,18 @@ proc encode*(filter: ContentFilter): ProtoBuffer =

pb

proc decode*(T: type ContentFilter, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type ContentFilter, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = ContentFilter()

var contentTopic: string
if not ?pb.getField(1, contentTopic):
return err(ProtoError.RequiredFieldMissing)
var topic: string
if not ?pb.getField(1, topic):
return err(ProtobufError.missingRequiredField("content_topic"))
else:
rpc.contentTopic = contentTopic
if topic.len > MaxContentTopicLen:
return err(ProtobufError.invalidLengthField("content_topic"))

rpc.contentTopic = topic

ok(rpc)

Expand All @@ -50,25 +53,28 @@ proc encode*(rpc: FilterRequest): ProtoBuffer =

pb

proc decode*(T: type FilterRequest, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type FilterRequest, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = FilterRequest()

var subflag: uint64
if not ?pb.getField(1, subflag):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("subscribe"))
else:
rpc.subscribe = bool(subflag)

var pubsubTopic: string
if not ?pb.getField(2, pubsubTopic):
return err(ProtoError.RequiredFieldMissing)
var topic: string
if not ?pb.getField(2, topic):
return err(ProtobufError.missingRequiredField("topic"))
else:
rpc.pubsubTopic = pubsubTopic
if topic.len > MaxPubsubTopicLen:
return err(ProtobufError.invalidLengthField("topic"))

rpc.pubsubTopic = topic

var buffs: seq[seq[byte]]
if not ?pb.getRepeatedField(3, buffs):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("content_filters"))
else:
for buf in buffs:
let filter = ?ContentFilter.decode(buf)
Expand All @@ -87,13 +93,13 @@ proc encode*(push: MessagePush): ProtoBuffer =

pb

proc decode*(T: type MessagePush, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type MessagePush, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = MessagePush()

var messages: seq[seq[byte]]
if not ?pb.getRepeatedField(1, messages):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("messages"))
else:
for buf in messages:
let msg = ?WakuMessage.decode(buf)
Expand All @@ -112,13 +118,13 @@ proc encode*(rpc: FilterRPC): ProtoBuffer =

pb

proc decode*(T: type FilterRPC, buffer: seq[byte]): ProtoResult[T] =
let pb = initProtoBuffer(buffer)
proc decode*(T: type FilterRPC, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = FilterRPC()

var requestId: string
if not ?pb.getField(1, requestId):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("request_id"))
else:
rpc.requestId = requestId

Expand Down
20 changes: 10 additions & 10 deletions waku/v2/protocol/waku_lightpush/rpc_codec.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ proc encode*(rpc: PushRequest): ProtoBuffer =

pb

proc decode*(T: type PushRequest, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type PushRequest, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = PushRequest()

var pubSubTopic: PubsubTopic
if not ?pb.getField(1, pubSubTopic):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("pubsub_topic"))
else:
rpc.pubSubTopic = pubSubTopic

var messageBuf: seq[byte]
if not ?pb.getField(2, messageBuf):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("message"))
else:
rpc.message = ?WakuMessage.decode(messageBuf)

Expand All @@ -52,42 +52,42 @@ proc encode*(rpc: PushResponse): ProtoBuffer =

pb

proc decode*(T: type PushResponse, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type PushResponse, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = PushResponse()

var isSuccess: uint64
if not ?pb.getField(1, isSuccess):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("is_success"))
else:
rpc.isSuccess = bool(isSuccess)

var info: string
if not ?pb.getField(2, info):
rpc.info = none(string)
else:
else:
rpc.info = some(info)

ok(rpc)


proc encode*(rpc: PushRPC): ProtoBuffer =
var pb = initProtoBuffer()

pb.write3(1, rpc.requestId)
pb.write3(2, rpc.request.map(encode))
pb.write3(3, rpc.response.map(encode))
pb.finish3()

pb

proc decode*(T: type PushRPC, buffer: seq[byte]): ProtoResult[T] =
proc decode*(T: type PushRPC, buffer: seq[byte]): ProtobufResult[T] =
let pb = initProtoBuffer(buffer)
var rpc = PushRPC()

var requestId: string
if not ?pb.getField(1, requestId):
return err(ProtoError.RequiredFieldMissing)
return err(ProtobufError.missingRequiredField("request_id"))
else:
rpc.requestId = requestId

Expand All @@ -105,4 +105,4 @@ proc decode*(T: type PushRPC, buffer: seq[byte]): ProtoResult[T] =
let response = ?PushResponse.decode(responseBuffer)
rpc.response = some(response)

ok(rpc)
ok(rpc)
Loading

0 comments on commit 7c958a8

Please sign in to comment.