diff --git a/infra/stream/BufferingStreamReader.cpp b/infra/stream/BufferingStreamReader.cpp new file mode 100644 index 000000000..a2c2289ce --- /dev/null +++ b/infra/stream/BufferingStreamReader.cpp @@ -0,0 +1,107 @@ +#include "infra/stream/BufferingStreamReader.hpp" + +namespace infra +{ + BufferingStreamReader::BufferingStreamReader(infra::BoundedDeque& buffer, infra::StreamReaderWithRewinding& input) + : buffer(buffer) + , input(input) + {} + + BufferingStreamReader::~BufferingStreamReader() + { + StoreRemainder(); + } + + void BufferingStreamReader::Extract(infra::ByteRange range, infra::StreamErrorPolicy& errorPolicy) + { + if (index != buffer.size()) + { + Read(infra::Head(buffer.contiguous_range(buffer.begin() + index), range.size()), range); + // Perhaps the deque just wrapped around, try once more + Read(infra::Head(buffer.contiguous_range(buffer.begin() + index), range.size()), range); + } + + if (!range.empty()) + Read(input.ExtractContiguousRange(range.size()), range); + + errorPolicy.ReportResult(range.empty()); + } + + uint8_t BufferingStreamReader::Peek(infra::StreamErrorPolicy& errorPolicy) + { + auto range = PeekContiguousRange(0); + + errorPolicy.ReportResult(!range.empty()); + + if (range.empty()) + return 0; + else + return range.front(); + } + + infra::ConstByteRange BufferingStreamReader::ExtractContiguousRange(std::size_t max) + { + if (index < buffer.size()) + { + auto from = infra::Head(buffer.contiguous_range(buffer.begin() + index), max); + index += from.size(); + return from; + } + + return input.ExtractContiguousRange(max); + } + + infra::ConstByteRange BufferingStreamReader::PeekContiguousRange(std::size_t start) + { + if (index + start < buffer.size()) + return buffer.contiguous_range(buffer.begin() + index + start); + + return input.PeekContiguousRange(index + start - buffer.size()); + } + + bool BufferingStreamReader::Empty() const + { + return Available() == 0; + } + + std::size_t BufferingStreamReader::Available() const + { + return buffer.size() + input.Available(); + } + + std::size_t BufferingStreamReader::ConstructSaveMarker() const + { + return index; + } + + void BufferingStreamReader::Rewind(std::size_t marker) + { + if (index > buffer.size()) + { + auto rewindAmount = std::min(index - marker, index - buffer.size()); + input.Rewind(input.ConstructSaveMarker() - rewindAmount); + index -= rewindAmount; + } + + if (marker < buffer.size()) + index = marker; + } + + void BufferingStreamReader::Read(infra::ConstByteRange from, infra::ByteRange& to) + { + infra::Copy(from, infra::Head(to, from.size())); + to.pop_front(from.size()); + index += from.size(); + } + + void BufferingStreamReader::StoreRemainder() + { + std::size_t bufferDecrease = std::min(buffer.size(), index); + buffer.erase(buffer.begin(), buffer.begin() + bufferDecrease); + while (!input.Empty()) + { + auto range = input.ExtractContiguousRange(std::numeric_limits::max()); + buffer.insert(buffer.end(), range.begin(), range.end()); + } + } +} diff --git a/infra/stream/BufferingStreamReader.hpp b/infra/stream/BufferingStreamReader.hpp new file mode 100644 index 000000000..3df25964f --- /dev/null +++ b/infra/stream/BufferingStreamReader.hpp @@ -0,0 +1,39 @@ +#ifndef INFRA_BUFFERING_STREAM_READER_HPP +#define INFRA_BUFFERING_STREAM_READER_HPP + +#include "infra/stream/InputStream.hpp" +#include "infra/util/BoundedDeque.hpp" + +namespace infra +{ + // Usage: Everything that is not read from the inputData is stored into the buffer upon destruction of the BufferingStreamReader + // Any data already present in the buffer is read first from the reader + class BufferingStreamReader + : public infra::StreamReaderWithRewinding + { + public: + BufferingStreamReader(infra::BoundedDeque& buffer, infra::StreamReaderWithRewinding& input); + ~BufferingStreamReader() override; + + // Implementation of StreamReaderWithRewinding + void Extract(infra::ByteRange range, infra::StreamErrorPolicy& errorPolicy) override; + uint8_t Peek(infra::StreamErrorPolicy& errorPolicy) override; + infra::ConstByteRange ExtractContiguousRange(std::size_t max) override; + infra::ConstByteRange PeekContiguousRange(std::size_t start) override; + bool Empty() const override; + std::size_t Available() const override; + std::size_t ConstructSaveMarker() const override; + void Rewind(std::size_t marker) override; + + private: + void Read(infra::ConstByteRange range, infra::ByteRange& to); + void StoreRemainder(); + + private: + infra::BoundedDeque& buffer; + infra::StreamReaderWithRewinding& input; + std::size_t index = 0; + }; +} + +#endif diff --git a/infra/stream/BufferingStreamWriter.cpp b/infra/stream/BufferingStreamWriter.cpp new file mode 100644 index 000000000..c66445491 --- /dev/null +++ b/infra/stream/BufferingStreamWriter.cpp @@ -0,0 +1,65 @@ +#include "infra/stream/BufferingStreamWriter.hpp" + +namespace infra +{ + BufferingStreamWriter::BufferingStreamWriter(infra::BoundedDeque& buffer, infra::StreamWriter& output) + : buffer(buffer) + , output(output) + { + LoadRemainder(); + } + + void BufferingStreamWriter::Insert(infra::ConstByteRange range, infra::StreamErrorPolicy& errorPolicy) + { + auto first = infra::Head(range, output.Available()); + output.Insert(first, errorPolicy); + index += first.size(); + range.pop_front(first.size()); + + buffer.insert(buffer.end(), range.begin(), range.end()); + index += range.size(); + } + + std::size_t BufferingStreamWriter::Available() const + { + return output.Available() + buffer.max_size() - buffer.size(); + } + + std::size_t BufferingStreamWriter::ConstructSaveMarker() const + { + return index; + } + + std::size_t BufferingStreamWriter::GetProcessedBytesSince(std::size_t marker) const + { + return index - marker; + } + + [[noreturn]] infra::ByteRange BufferingStreamWriter::SaveState(std::size_t marker) + { + std::abort(); + } + + [[noreturn]] void BufferingStreamWriter::RestoreState(infra::ByteRange range) + { + std::abort(); + } + + [[noreturn]] infra::ByteRange BufferingStreamWriter::Overwrite(std::size_t marker) + { + std::abort(); + } + + void BufferingStreamWriter::LoadRemainder() + { + infra::StreamErrorPolicy errorPolicy; + auto from = infra::Head(buffer.contiguous_range(buffer.begin()), output.Available()); + output.Insert(from, errorPolicy); + buffer.erase(buffer.begin(), buffer.begin() + from.size()); + from = infra::Head(buffer.contiguous_range(buffer.begin()), output.Available()); + output.Insert(from, errorPolicy); + buffer.erase(buffer.begin(), buffer.begin() + from.size()); + + index = buffer.size(); + } +} diff --git a/infra/stream/BufferingStreamWriter.hpp b/infra/stream/BufferingStreamWriter.hpp new file mode 100644 index 000000000..79f2d4216 --- /dev/null +++ b/infra/stream/BufferingStreamWriter.hpp @@ -0,0 +1,36 @@ +#ifndef INFRA_BUFFERING_STREAM_WRITER_HPP +#define INFRA_BUFFERING_STREAM_WRITER_HPP + +#include "infra/stream/OutputStream.hpp" +#include "infra/util/BoundedDeque.hpp" + +namespace infra +{ + // Usage: Any data that does not fit into the output stream is written to the buffer + // Any data already present in the buffer is written to the output stream upon construction of BufferingStreamWriter + class BufferingStreamWriter + : public infra::StreamWriter + { + public: + BufferingStreamWriter(infra::BoundedDeque& buffer, infra::StreamWriter& output); + + // Implementation of StreamWriter + void Insert(infra::ConstByteRange range, infra::StreamErrorPolicy& errorPolicy) override; + std::size_t Available() const override; + std::size_t ConstructSaveMarker() const override; + std::size_t GetProcessedBytesSince(std::size_t marker) const override; + [[noreturn]] infra::ByteRange SaveState(std::size_t marker) override; + [[noreturn]] void RestoreState(infra::ByteRange range) override; + [[noreturn]] infra::ByteRange Overwrite(std::size_t marker) override; + + private: + void LoadRemainder(); + + private: + infra::BoundedDeque& buffer; + infra::StreamWriter& output; + std::size_t index = 0; + }; +} + +#endif diff --git a/infra/stream/CMakeLists.txt b/infra/stream/CMakeLists.txt index eab26fd3f..4b3537a71 100644 --- a/infra/stream/CMakeLists.txt +++ b/infra/stream/CMakeLists.txt @@ -13,6 +13,10 @@ target_sources(infra.stream PRIVATE BoundedVectorInputStream.hpp BoundedVectorOutputStream.cpp BoundedVectorOutputStream.hpp + BufferingStreamReader.cpp + BufferingStreamReader.hpp + BufferingStreamWriter.cpp + BufferingStreamWriter.hpp ByteInputStream.cpp ByteInputStream.hpp ByteOutputStream.cpp diff --git a/infra/stream/test/CMakeLists.txt b/infra/stream/test/CMakeLists.txt index 957be3ce4..c3bfcaece 100644 --- a/infra/stream/test/CMakeLists.txt +++ b/infra/stream/test/CMakeLists.txt @@ -12,6 +12,8 @@ target_sources(infra.stream_test PRIVATE StreamMock.hpp TestBoundedDequeInputStream.cpp TestBoundedVectorOutputStream.cpp + TestBufferingStreamReader.cpp + TestBufferingStreamWriter.cpp TestByteInputStream.cpp TestByteOutputStream.cpp TestCountingInputStream.cpp diff --git a/infra/stream/test/TestBufferingStreamReader.cpp b/infra/stream/test/TestBufferingStreamReader.cpp new file mode 100644 index 000000000..1d58c3b79 --- /dev/null +++ b/infra/stream/test/TestBufferingStreamReader.cpp @@ -0,0 +1,206 @@ +#include "infra/stream/BufferingStreamReader.hpp" +#include "infra/stream/test/StreamMock.hpp" +#include "gmock/gmock.h" + +class BufferingStreamReaderTest + : public testing::Test +{ +public: + ~BufferingStreamReaderTest() + { + EXPECT_CALL(input, Empty()).WillOnce(testing::Return(true)); + } + + void Extract(std::size_t amount) + { + std::vector data(amount, 0); + std::vector inputData; + EXPECT_CALL(input, ExtractContiguousRange(testing::_)).Times(testing::AnyNumber()).WillRepeatedly(testing::Invoke([&](std::size_t max) + { + inputData.resize(max); + return infra::MakeRange(inputData); + })); + reader.Extract(infra::MakeRange(data), errorPolicy); + } + + void ExpectBuffer(const std::vector contents) + { + EXPECT_EQ(infra::BoundedDeque::WithMaxSize<100>(contents.begin(), contents.end()), buffer); + } + + infra::BoundedDeque::WithMaxSize<4> buffer{ std::initializer_list{ 1, 2 } }; + testing::StrictMock input; + infra::StreamErrorPolicy errorPolicy{ infra::noFail }; + infra::BufferingStreamReader reader{ buffer, input }; +}; + +TEST_F(BufferingStreamReaderTest, Extract_from_empty_buffer) +{ + buffer.clear(); + std::array data; + EXPECT_CALL(input, ExtractContiguousRange(2)).WillOnce(testing::Return(infra::ConstByteRange())); + reader.Extract(data, errorPolicy); + EXPECT_TRUE(errorPolicy.Failed()); +} + +TEST_F(BufferingStreamReaderTest, Extract_from_buffer) +{ + std::array data; + reader.Extract(data, errorPolicy); + EXPECT_EQ((std::array{ 1, 2 }), data); + EXPECT_FALSE(errorPolicy.Failed()); +} + +TEST_F(BufferingStreamReaderTest, Extract_from_wrapped_buffer) +{ + buffer.push_back(3); + buffer.push_back(4); + buffer.pop_front(); + buffer.pop_front(); + buffer.push_back(5); + buffer.push_back(6); + + std::array data; + reader.Extract(data, errorPolicy); + EXPECT_EQ((std::array{ 3, 4, 5, 6 }), data); +} + +TEST_F(BufferingStreamReaderTest, Extract_from_buffer_and_input) +{ + std::array data; + std::array inputData{ 3, 4 }; + EXPECT_CALL(input, ExtractContiguousRange(2)).WillOnce(testing::Invoke([&](std::size_t max) + { + return infra::MakeRange(inputData); + })); + reader.Extract(data, errorPolicy); + EXPECT_EQ((std::array{ 1, 2, 3, 4 }), data); +} + +TEST_F(BufferingStreamReaderTest, PeekContiguous_range_from_buffer_head) +{ + EXPECT_EQ((std::array{ 1, 2 }), reader.PeekContiguousRange(0)); +} + +TEST_F(BufferingStreamReaderTest, PeekContiguous_range_from_buffer_1) +{ + EXPECT_EQ((std::array{ 2 }), reader.PeekContiguousRange(1)); +} + +TEST_F(BufferingStreamReaderTest, PeekContiguous_range_from_input) +{ + std::array inputData{ 3, 4 }; + EXPECT_CALL(input, PeekContiguousRange(0)).WillOnce(testing::Invoke([&](std::size_t start) + { + return infra::MakeRange(inputData); + })); + EXPECT_EQ((std::array{ 3, 4 }), reader.PeekContiguousRange(2)); +} + +TEST_F(BufferingStreamReaderTest, Peek_from_buffer) +{ + EXPECT_EQ(1, reader.Peek(errorPolicy)); + reader.ExtractContiguousRange(2); + + std::array inputData{ 3, 4 }; + EXPECT_CALL(input, PeekContiguousRange(0)).WillOnce(testing::Invoke([&](std::size_t start) + { + return infra::MakeRange(inputData); + })); + EXPECT_EQ(3, reader.Peek(errorPolicy)); +} + +TEST_F(BufferingStreamReaderTest, Available) +{ + EXPECT_CALL(input, Available()).WillOnce(testing::Return(1)); + EXPECT_EQ(3, reader.Available()); + + EXPECT_CALL(input, Available()).WillOnce(testing::Return(1)); + EXPECT_FALSE(reader.Empty()); +} + +TEST_F(BufferingStreamReaderTest, ConstructSaveMarker) +{ + EXPECT_EQ(0, reader.ConstructSaveMarker()); + + reader.ExtractContiguousRange(2); + EXPECT_EQ(2, reader.ConstructSaveMarker()); +} + +TEST_F(BufferingStreamReaderTest, Rewind_in_buffer) +{ + reader.ExtractContiguousRange(2); + reader.Rewind(1); + EXPECT_EQ(2, reader.Peek(errorPolicy)); +} + +TEST_F(BufferingStreamReaderTest, Rewind_from_input_to_input) +{ + Extract(4); + + EXPECT_CALL(input, ConstructSaveMarker()).WillOnce(testing::Return(100)); + EXPECT_CALL(input, Rewind(99)); + reader.Rewind(3); + std::array inputData2{ 4 }; + EXPECT_CALL(input, PeekContiguousRange(1)).WillOnce(testing::Invoke([&](std::size_t start) + { + return infra::MakeRange(inputData2); + })); + EXPECT_EQ(4, reader.Peek(errorPolicy)); +} + +TEST_F(BufferingStreamReaderTest, Rewind_from_input_to_buffer) +{ + Extract(4); + + EXPECT_CALL(input, ConstructSaveMarker()).WillOnce(testing::Return(100)); + EXPECT_CALL(input, Rewind(98)); + reader.Rewind(1); + EXPECT_EQ(2, reader.Peek(errorPolicy)); +} + +TEST_F(BufferingStreamReaderTest, destruction_reduces_buffer) +{ + Extract(1); + + EXPECT_CALL(input, Empty()).WillOnce(testing::Return(true)); + infra::ReConstruct(reader, buffer, input); + + ExpectBuffer({ 2 }); +} + +TEST_F(BufferingStreamReaderTest, destruction_consumes_buffer) +{ + Extract(3); + + EXPECT_CALL(input, Empty()).WillOnce(testing::Return(true)); + infra::ReConstruct(reader, buffer, input); + + ExpectBuffer({}); +} + +TEST_F(BufferingStreamReaderTest, destruction_stores_input) +{ + EXPECT_CALL(input, Empty()).WillOnce(testing::Return(false)).WillOnce(testing::Return(true)); + std::array inputData{ 3, 4 }; + EXPECT_CALL(input, ExtractContiguousRange(std::numeric_limits::max())).WillOnce(testing::Invoke([&](std::size_t max) + { + return infra::MakeRange(inputData); + })); + infra::ReConstruct(reader, buffer, input); + + ExpectBuffer({ 1, 2, 3, 4 }); +} + +TEST_F(BufferingStreamReaderTest, destruction_stores_input_until_empty) +{ + EXPECT_CALL(input, Empty()).WillOnce(testing::Return(false)).WillOnce(testing::Return(false)).WillOnce(testing::Return(true)); + std::array inputData{ 3 }; + EXPECT_CALL(input, ExtractContiguousRange(std::numeric_limits::max())).WillRepeatedly(testing::Invoke([&](std::size_t max) + { + return infra::MakeRange(inputData); + })); + infra::ReConstruct(reader, buffer, input); + + ExpectBuffer({ 1, 2, 3, 3 }); +} diff --git a/infra/stream/test/TestBufferingStreamWriter.cpp b/infra/stream/test/TestBufferingStreamWriter.cpp new file mode 100644 index 000000000..2bb6db290 --- /dev/null +++ b/infra/stream/test/TestBufferingStreamWriter.cpp @@ -0,0 +1,103 @@ +#include "infra/stream/BufferingStreamWriter.hpp" +#include "infra/stream/test/StreamMock.hpp" +#include "infra/util/test_helper/MemoryRangeMatcher.hpp" +#include "gmock/gmock.h" + +class BufferingStreamWriterTest + : public testing::Test +{ +public: + void ExpectBuffer(const std::vector contents) + { + EXPECT_EQ(infra::BoundedDeque::WithMaxSize<100>(contents.begin(), contents.end()), buffer); + } + + infra::BoundedDeque::WithMaxSize<4> buffer; + testing::StrictMock output; + infra::StreamErrorPolicy errorPolicy{ infra::noFail }; + infra::Execute execute{ [this]() + { + EXPECT_CALL(output, Available()).WillRepeatedly(testing::Return(0)); + EXPECT_CALL(output, Insert(testing::_, testing::_)).Times(2); + } }; + infra::BufferingStreamWriter writer{ buffer, output }; +}; + +TEST_F(BufferingStreamWriterTest, Insert_into_output) +{ + std::array data{ 3, 4 }; + EXPECT_CALL(output, Available()).WillOnce(testing::Return(10)); + EXPECT_CALL(output, Insert(infra::ContentsEqual(data), testing::Ref(errorPolicy))); + writer.Insert(data, errorPolicy); + + ExpectBuffer({}); + EXPECT_EQ(2, writer.ConstructSaveMarker()); +} + +TEST_F(BufferingStreamWriterTest, Insert_overflows_output) +{ + std::array data{ 3, 4 }; + EXPECT_CALL(output, Available()).WillOnce(testing::Return(1)); + EXPECT_CALL(output, Insert(infra::ByteRangeContentsEqual(infra::Head(infra::MakeRange(data), 1)), testing::Ref(errorPolicy))); + writer.Insert(data, errorPolicy); + + ExpectBuffer({ 4 }); + EXPECT_EQ(2, writer.ConstructSaveMarker()); +} + +TEST_F(BufferingStreamWriterTest, GetProcessedBytesSince) +{ + EXPECT_EQ(0, writer.GetProcessedBytesSince(0)); + + std::array data{ 3, 4 }; + EXPECT_CALL(output, Available()).WillOnce(testing::Return(10)); + EXPECT_CALL(output, Insert(infra::ContentsEqual(data), testing::Ref(errorPolicy))); + writer.Insert(data, errorPolicy); + + EXPECT_EQ(2, writer.GetProcessedBytesSince(0)); + EXPECT_EQ(1, writer.GetProcessedBytesSince(1)); +} + +TEST_F(BufferingStreamWriterTest, LoadRemainder_loads_from_buffer) +{ + std::array data{ 1, 2 }; + buffer.insert(buffer.end(), data.begin(), data.end()); + + EXPECT_CALL(output, Available()).WillOnce(testing::Return(2)).WillOnce(testing::Return(0)); + EXPECT_CALL(output, Insert(infra::ContentsEqual(data), testing::_)); + EXPECT_CALL(output, Insert(infra::ByteRangeContentsEqual(infra::ConstByteRange()), testing::_)); + infra::ReConstruct(writer, buffer, output); +} + +TEST_F(BufferingStreamWriterTest, LoadRemainder_loads_from_circular_buffer) +{ + std::array data1{ 1, 2 }; + std::array data2{ 3, 4 }; + std::array data3{ 5, 6 }; + buffer.insert(buffer.end(), data1.begin(), data1.end()); + buffer.insert(buffer.end(), data2.begin(), data2.end()); + buffer.pop_front(); + buffer.pop_front(); + buffer.insert(buffer.end(), data3.begin(), data3.end()); + + EXPECT_CALL(output, Available()).WillOnce(testing::Return(4)).WillOnce(testing::Return(2)); + EXPECT_CALL(output, Insert(infra::ContentsEqual(data2), testing::_)); + EXPECT_CALL(output, Insert(infra::ContentsEqual(data3), testing::_)); + infra::ReConstruct(writer, buffer, output); +} + +TEST_F(BufferingStreamWriterTest, LoadRemainder_constrained_by_output_writes_in_chunks) +{ + std::array data{ 1, 2 }; + buffer.insert(buffer.end(), data.begin(), data.end()); + + EXPECT_CALL(output, Available()).WillOnce(testing::Return(1)).WillOnce(testing::Return(0)); + EXPECT_CALL(output, Insert(infra::ByteRangeContentsEqual(infra::Head(infra::MakeRange(data), 1)), testing::_)); + EXPECT_CALL(output, Insert(infra::ByteRangeContentsEqual(infra::ConstByteRange()), testing::_)); + infra::ReConstruct(writer, buffer, output); + + EXPECT_CALL(output, Available()).WillOnce(testing::Return(1)).WillOnce(testing::Return(0)); + EXPECT_CALL(output, Insert(infra::ByteRangeContentsEqual(infra::DiscardHead(infra::MakeRange(data), 1)), testing::_)); + EXPECT_CALL(output, Insert(infra::ByteRangeContentsEqual(infra::ConstByteRange()), testing::_)); + infra::ReConstruct(writer, buffer, output); +} diff --git a/infra/syntax/ProtoFormatter.cpp b/infra/syntax/ProtoFormatter.cpp index efb630d04..7b3e9d99a 100644 --- a/infra/syntax/ProtoFormatter.cpp +++ b/infra/syntax/ProtoFormatter.cpp @@ -130,6 +130,12 @@ namespace infra PutBytes(bytes); } + void ProtoFormatter::PutLengthDelimitedSize(std::size_t size, uint32_t fieldNumber) + { + PutVarInt((fieldNumber << 3) | 2); + PutVarInt(size); + } + ProtoLengthDelimitedFormatter ProtoFormatter::LengthDelimitedFormatter(uint32_t fieldNumber) { return ProtoLengthDelimitedFormatter(*this, fieldNumber); diff --git a/infra/syntax/ProtoFormatter.hpp b/infra/syntax/ProtoFormatter.hpp index 45748075b..7c717ce3c 100644 --- a/infra/syntax/ProtoFormatter.hpp +++ b/infra/syntax/ProtoFormatter.hpp @@ -14,6 +14,7 @@ namespace infra { public: ProtoLengthDelimitedFormatter(ProtoFormatter& formatter, uint32_t fieldNumber); + ProtoLengthDelimitedFormatter(ProtoFormatter& formatter, uint32_t fieldNumber, std::size_t size); ProtoLengthDelimitedFormatter(const ProtoLengthDelimitedFormatter& other) = delete; ProtoLengthDelimitedFormatter(ProtoLengthDelimitedFormatter&& other) noexcept; ProtoLengthDelimitedFormatter& operator=(const ProtoLengthDelimitedFormatter& other) = delete; @@ -43,6 +44,7 @@ namespace infra void PutLengthDelimitedField(infra::ConstByteRange range, uint32_t fieldNumber); void PutStringField(infra::BoundedConstString string, uint32_t fieldNumber); void PutBytesField(infra::ConstByteRange bytes, uint32_t fieldNumber); + void PutLengthDelimitedSize(std::size_t size, uint32_t fieldNumber); ProtoLengthDelimitedFormatter LengthDelimitedFormatter(uint32_t fieldNumber); private: diff --git a/infra/syntax/ProtoParser.cpp b/infra/syntax/ProtoParser.cpp index 513ebf7a6..8b5e09fb5 100644 --- a/infra/syntax/ProtoParser.cpp +++ b/infra/syntax/ProtoParser.cpp @@ -116,7 +116,40 @@ namespace infra return result; } + struct MakeFullField + : infra::StaticVisitor + { + MakeFullField(infra::DataInputStream inputStream, infra::StreamErrorPolicy& formatErrorPolicy, uint32_t fieldNumber) + : inputStream(inputStream) + , formatErrorPolicy(formatErrorPolicy) + , fieldNumber(fieldNumber) + {} + + template + ProtoParser::Field operator()(T value) const + { + return { value, fieldNumber }; + } + + ProtoParser::Field operator()(PartialProtoLengthDelimited value) const + { + return { ProtoLengthDelimited(inputStream, formatErrorPolicy, value.length), fieldNumber }; + } + + private: + infra::DataInputStream inputStream; + infra::StreamErrorPolicy& formatErrorPolicy; + uint32_t fieldNumber; + }; + ProtoParser::Field ProtoParser::GetField() + { + auto [value, fieldNumber] = GetPartialField(); + MakeFullField visitor(input, formatErrorPolicy, fieldNumber); + return infra::ApplyVisitor(visitor, value); + } + + ProtoParser::PartialField ProtoParser::GetPartialField() { uint32_t x = static_cast(GetVarInt()); uint8_t type = x & 7; @@ -129,7 +162,7 @@ namespace infra case 1: return std::make_pair(GetFixed64(), fieldNumber); case 2: - return std::make_pair(ProtoLengthDelimited(input, formatErrorPolicy, static_cast(GetVarInt())), fieldNumber); + return std::make_pair(PartialProtoLengthDelimited{ static_cast(GetVarInt()) }, fieldNumber); case 5: return std::make_pair(GetFixed32(), fieldNumber); default: diff --git a/infra/syntax/ProtoParser.hpp b/infra/syntax/ProtoParser.hpp index 3cb6d8306..652041edc 100644 --- a/infra/syntax/ProtoParser.hpp +++ b/infra/syntax/ProtoParser.hpp @@ -34,10 +34,18 @@ namespace infra infra::StreamErrorPolicy& formatErrorPolicy; }; + struct PartialProtoLengthDelimited + { + uint32_t length; + }; + class ProtoParser { public: - using Field = std::pair, uint32_t>; + using FieldVariant = infra::Variant; + using Field = std::pair; + using PartialFieldVariant = infra::Variant; + using PartialField = std::pair; explicit ProtoParser(infra::DataInputStream inputStream); ProtoParser(infra::DataInputStream inputStream, infra::StreamErrorPolicy& formatErrorPolicy); @@ -48,6 +56,7 @@ namespace infra uint64_t GetFixed64(); Field GetField(); + PartialField GetPartialField(); void ReportFormatResult(bool ok); bool FormatFailed() const; diff --git a/infra/syntax/test/TestProtoFormatter.cpp b/infra/syntax/test/TestProtoFormatter.cpp index 6b34a148a..e42384893 100644 --- a/infra/syntax/test/TestProtoFormatter.cpp +++ b/infra/syntax/test/TestProtoFormatter.cpp @@ -76,6 +76,17 @@ TEST(ProtoFormatterTest, PutBytesField) EXPECT_EQ((std::array{ 4 << 3 | 2, 1, 5 }), stream.Writer().Processed()); } +TEST(ProtoFormatterTest, PutSubObjectOfKnownSize) +{ + infra::ByteOutputStream::WithStorage<20> stream; + infra::ProtoFormatter formatter(stream); + + formatter.PutLengthDelimitedSize(2, 4); + formatter.PutVarIntField(2, 4); + + EXPECT_EQ((std::array{ 4 << 3 | 2, 2, 4 << 3, 2 }), stream.Writer().Processed()); +} + TEST(ProtoFormatterTest, PutSubObject) { infra::ByteOutputStream::WithStorage<20> stream; diff --git a/infra/util/ConstructBin.cpp b/infra/util/ConstructBin.cpp index ce30b8fff..001ae1dce 100644 --- a/infra/util/ConstructBin.cpp +++ b/infra/util/ConstructBin.cpp @@ -32,6 +32,13 @@ namespace infra return *this; } + ConstructBin& ConstructBin::Repeat(std::size_t amount, std::initializer_list v) + { + for (size_t i = 0; i != amount; ++i) + contents.insert(contents.end(), v.begin(), v.end()); + return *this; + } + ConstructBin& ConstructBin::RepeatString(std::size_t amount, const std::string& v) { for (size_t i = 0; i != amount; ++i) diff --git a/infra/util/ConstructBin.hpp b/infra/util/ConstructBin.hpp index ba514fd4e..83631f5c4 100644 --- a/infra/util/ConstructBin.hpp +++ b/infra/util/ConstructBin.hpp @@ -17,6 +17,7 @@ namespace infra ConstructBin& operator()(std::initializer_list v); ConstructBin& Repeat(std::size_t amount, uint8_t v); + ConstructBin& Repeat(std::size_t amount, std::initializer_list v); ConstructBin& RepeatString(std::size_t amount, const std::string& v); template diff --git a/infra/util/WithStorage.hpp b/infra/util/WithStorage.hpp index 52f78eb9e..1b199ef7a 100644 --- a/infra/util/WithStorage.hpp +++ b/infra/util/WithStorage.hpp @@ -26,6 +26,8 @@ namespace infra WithStorage(); template WithStorage(InPlace, StorageArg&& storageArg, Args&&... args); + template + WithStorage(InPlace, std::initializer_list initializerList); template WithStorage(Arg&& arg, std::enable_if_t>>, std::nullptr_t> = nullptr); template @@ -89,6 +91,13 @@ namespace infra , Base(detail::StorageHolder::storage, std::forward(args)...) {} + template + template + WithStorage::WithStorage(InPlace, std::initializer_list initializerList) + : detail::StorageHolder(initializerList) + , Base(detail::StorageHolder::storage) + {} + template template WithStorage::WithStorage(Arg&& arg, std::enable_if_t>>, std::nullptr_t>) diff --git a/protobuf/CMakeLists.txt b/protobuf/CMakeLists.txt index da2d6f17d..e6fe5e065 100644 --- a/protobuf/CMakeLists.txt +++ b/protobuf/CMakeLists.txt @@ -1,5 +1,5 @@ -add_subdirectory(echo) add_subdirectory(echo_attributes) +add_subdirectory(echo) add_subdirectory(protoc_echo_plugin) add_subdirectory(protoc_echo_plugin_csharp) add_subdirectory(protoc_echo_plugin_java) diff --git a/protobuf/echo/CMakeLists.txt b/protobuf/echo/CMakeLists.txt index 168b25331..3f459aaeb 100644 --- a/protobuf/echo/CMakeLists.txt +++ b/protobuf/echo/CMakeLists.txt @@ -8,6 +8,11 @@ target_link_libraries(protobuf.echo PUBLIC target_sources(protobuf.echo PRIVATE Echo.cpp Echo.hpp + Proto.hpp + ProtoMessageReceiver.cpp + ProtoMessageReceiver.hpp + ProtoMessageSender.cpp + ProtoMessageSender.hpp ServiceForwarder.cpp ServiceForwarder.hpp TracingEcho.cpp diff --git a/protobuf/echo/Echo.hpp b/protobuf/echo/Echo.hpp index 81d04ba97..96a5e6a1d 100644 --- a/protobuf/echo/Echo.hpp +++ b/protobuf/echo/Echo.hpp @@ -1,12 +1,11 @@ #ifndef PROTOBUF_ECHO_HPP #define PROTOBUF_ECHO_HPP -#include "infra/syntax/ProtoFormatter.hpp" -#include "infra/syntax/ProtoParser.hpp" #include "infra/util/BoundedDeque.hpp" #include "infra/util/Compatibility.hpp" #include "infra/util/Function.hpp" #include "infra/util/Optional.hpp" +#include "protobuf/echo/Proto.hpp" #include "services/util/MessageCommunication.hpp" namespace services @@ -15,121 +14,6 @@ namespace services class Service; class ServiceProxy; - struct ProtoBool - {}; - - struct ProtoUInt32 - {}; - - struct ProtoInt32 - {}; - - struct ProtoUInt64 - {}; - - struct ProtoInt64 - {}; - - struct ProtoFixed32 - {}; - - struct ProtoFixed64 - {}; - - struct ProtoSFixed32 - {}; - - struct ProtoSFixed64 - {}; - - struct ProtoUnboundedString - {}; - - struct ProtoUnboundedBytes - {}; - - template - struct ProtoMessage - {}; - - template - struct ProtoEnum - {}; - - template - struct ProtoBytes - {}; - - template - struct ProtoString - {}; - - template - struct ProtoRepeated - {}; - - template - struct ProtoUnboundedRepeated - {}; - - void SerializeField(ProtoBool, infra::ProtoFormatter& formatter, bool value, uint32_t fieldNumber); - void SerializeField(ProtoUInt32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber); - void SerializeField(ProtoInt32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber); - void SerializeField(ProtoUInt64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber); - void SerializeField(ProtoInt64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber); - void SerializeField(ProtoFixed32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber); - void SerializeField(ProtoFixed64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber); - void SerializeField(ProtoSFixed32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber); - void SerializeField(ProtoSFixed64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber); - void SerializeField(ProtoUnboundedString, infra::ProtoFormatter& formatter, const std::string& value, uint32_t fieldNumber); - void SerializeField(ProtoUnboundedBytes, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber); - - template - void SerializeField(ProtoRepeated, infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber); - template - void SerializeField(ProtoUnboundedRepeated, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber); - template - void SerializeField(ProtoUnboundedRepeated, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber); - template - void SerializeField(ProtoMessage, infra::ProtoFormatter& formatter, const U& value, uint32_t fieldNumber); - template - void SerializeField(ProtoEnum, infra::ProtoFormatter& formatter, T value, uint32_t fieldNumber); - template - void SerializeField(ProtoBytes, infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber); - template - void SerializeField(ProtoString, infra::ProtoFormatter& formatter, infra::BoundedConstString value, uint32_t fieldNumber); - - void DeserializeField(ProtoBool, infra::ProtoParser& parser, infra::ProtoParser::Field& field, bool& value); - void DeserializeField(ProtoUInt32, infra::ProtoParser& parser, infra::ProtoParser::Field& field, uint32_t& value); - void DeserializeField(ProtoInt32, infra::ProtoParser& parser, infra::ProtoParser::Field& field, int32_t& value); - void DeserializeField(ProtoUInt64, infra::ProtoParser& parser, infra::ProtoParser::Field& field, uint64_t& value); - void DeserializeField(ProtoInt64, infra::ProtoParser& parser, infra::ProtoParser::Field& field, int64_t& value); - void DeserializeField(ProtoFixed32, infra::ProtoParser& parser, infra::ProtoParser::Field& field, uint32_t& value); - void DeserializeField(ProtoFixed64, infra::ProtoParser& parser, infra::ProtoParser::Field& field, uint64_t& value); - void DeserializeField(ProtoSFixed32, infra::ProtoParser& parser, infra::ProtoParser::Field& field, int32_t& value); - void DeserializeField(ProtoSFixed64, infra::ProtoParser& parser, infra::ProtoParser::Field& field, int64_t& value); - void DeserializeField(ProtoUnboundedString, infra::ProtoParser& parser, infra::ProtoParser::Field& field, std::string& value); - void DeserializeField(ProtoUnboundedBytes, infra::ProtoParser& parser, infra::ProtoParser::Field& field, std::vector& value); - - template - void DeserializeField(ProtoRepeated, infra::ProtoParser& parser, infra::ProtoParser::Field& field, infra::BoundedVector& value); - template - void DeserializeField(ProtoUnboundedRepeated, infra::ProtoParser& parser, infra::ProtoParser::Field& field, std::vector& value); - template - void DeserializeField(ProtoUnboundedRepeated, infra::ProtoParser& parser, infra::ProtoParser::Field& field, std::vector& value); - template - void DeserializeField(ProtoMessage, infra::ProtoParser& parser, infra::ProtoParser::Field& field, U& value); - template - void DeserializeField(ProtoEnum, infra::ProtoParser& parser, infra::ProtoParser::Field& field, T& value); - template - void DeserializeField(ProtoBytes, infra::ProtoParser& parser, infra::ProtoParser::Field& field, infra::BoundedVector& value); - template - void DeserializeField(ProtoBytes, infra::ProtoParser& parser, infra::ProtoParser::Field& field, infra::ConstByteRange& value); - template - void DeserializeField(ProtoString, infra::ProtoParser& parser, infra::ProtoParser::Field& field, infra::BoundedString& value); - template - void DeserializeField(ProtoString, infra::ProtoParser& parser, infra::ProtoParser::Field& field, infra::BoundedConstString& value); - class EchoErrorPolicy { protected: @@ -275,261 +159,6 @@ namespace services //// Implementation //// - inline void SerializeField(ProtoBool, infra::ProtoFormatter& formatter, bool value, uint32_t fieldNumber) - { - formatter.PutVarIntField(value, fieldNumber); - } - - inline void SerializeField(ProtoUInt32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber) - { - formatter.PutVarIntField(value, fieldNumber); - } - - inline void SerializeField(ProtoInt32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber) - { - formatter.PutVarIntField(value, fieldNumber); - } - - inline void SerializeField(ProtoUInt64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber) - { - formatter.PutVarIntField(value, fieldNumber); - } - - inline void SerializeField(ProtoInt64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber) - { - formatter.PutVarIntField(value, fieldNumber); - } - - inline void SerializeField(ProtoFixed32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber) - { - formatter.PutFixed32Field(value, fieldNumber); - } - - inline void SerializeField(ProtoFixed64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber) - { - formatter.PutFixed64Field(value, fieldNumber); - } - - inline void SerializeField(ProtoSFixed32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber) - { - formatter.PutFixed32Field(static_cast(value), fieldNumber); - } - - inline void SerializeField(ProtoSFixed64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber) - { - formatter.PutFixed64Field(static_cast(value), fieldNumber); - } - - inline void SerializeField(ProtoUnboundedString, infra::ProtoFormatter& formatter, const std::string& value, uint32_t fieldNumber) - { - formatter.PutStringField(value, fieldNumber); - } - - inline void SerializeField(ProtoUnboundedBytes, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber) - { - formatter.PutBytesField(value, fieldNumber); - } - - template - void SerializeField(ProtoRepeated, infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber) - { - for (auto& v : value) - SerializeField(T(), formatter, v, fieldNumber); - } - - template - void SerializeField(ProtoUnboundedRepeated, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber) - { - for (auto& v : value) - SerializeField(T(), formatter, v, fieldNumber); - } - - template - void SerializeField(ProtoUnboundedRepeated, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber) - { - for (auto v : value) - SerializeField(T(), formatter, v, fieldNumber); - } - - template - void SerializeField(ProtoMessage, infra::ProtoFormatter& formatter, const U& value, uint32_t fieldNumber) - { - infra::ProtoLengthDelimitedFormatter nestedMessage(formatter, fieldNumber); - value.Serialize(formatter); - } - - template - void SerializeField(ProtoEnum, infra::ProtoFormatter& formatter, T value, uint32_t fieldNumber) - { - formatter.PutVarIntField(static_cast(value), fieldNumber); - } - - template - void SerializeField(ProtoBytes, infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber) - { - formatter.PutBytesField(infra::MakeRange(value), fieldNumber); - } - - template - void SerializeField(ProtoString, infra::ProtoFormatter& formatter, infra::BoundedConstString value, uint32_t fieldNumber) - { - formatter.PutStringField(value, fieldNumber); - } - - inline void DeserializeField(ProtoBool, infra::ProtoParser& parser, infra::ProtoParser::Field& field, bool& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = field.first.Get() != 0; - } - - inline void DeserializeField(ProtoUInt32, infra::ProtoParser& parser, infra::ProtoParser::Field& field, uint32_t& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = static_cast(field.first.Get()); - } - - inline void DeserializeField(ProtoInt32, infra::ProtoParser& parser, infra::ProtoParser::Field& field, int32_t& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = static_cast(field.first.Get()); - } - - inline void DeserializeField(ProtoUInt64, infra::ProtoParser& parser, infra::ProtoParser::Field& field, uint64_t& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = field.first.Get(); - } - - inline void DeserializeField(ProtoInt64, infra::ProtoParser& parser, infra::ProtoParser::Field& field, int64_t& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = static_cast(field.first.Get()); - } - - inline void DeserializeField(ProtoFixed32, infra::ProtoParser& parser, infra::ProtoParser::Field& field, uint32_t& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = field.first.Get(); - } - - inline void DeserializeField(ProtoFixed64, infra::ProtoParser& parser, infra::ProtoParser::Field& field, uint64_t& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = field.first.Get(); - } - - inline void DeserializeField(ProtoSFixed32, infra::ProtoParser& parser, infra::ProtoParser::Field& field, int32_t& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = static_cast(field.first.Get()); - } - - inline void DeserializeField(ProtoSFixed64, infra::ProtoParser& parser, infra::ProtoParser::Field& field, int64_t& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = static_cast(field.first.Get()); - } - - inline void DeserializeField(ProtoUnboundedString, infra::ProtoParser& parser, infra::ProtoParser::Field& field, std::string& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = field.first.Get().GetStdString(); - } - - inline void DeserializeField(ProtoUnboundedBytes, infra::ProtoParser& parser, infra::ProtoParser::Field& field, std::vector& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = field.first.Get().GetUnboundedBytes(); - } - - template - void DeserializeField(ProtoRepeated, infra::ProtoParser& parser, infra::ProtoParser::Field& field, infra::BoundedVector& value) - { - parser.ReportFormatResult(!value.full()); - if (!value.full()) - { - value.emplace_back(); - DeserializeField(T(), parser, field, value.back()); - } - } - - template - void DeserializeField(ProtoUnboundedRepeated, infra::ProtoParser& parser, infra::ProtoParser::Field& field, std::vector& value) - { - value.emplace_back(); - DeserializeField(T(), parser, field, value.back()); - } - - template - void DeserializeField(ProtoUnboundedRepeated, infra::ProtoParser& parser, infra::ProtoParser::Field& field, std::vector& value) - { - bool result{}; - DeserializeField(T(), parser, field, result); - value.push_back(result); - } - - template - void DeserializeField(ProtoMessage, infra::ProtoParser& parser, infra::ProtoParser::Field& field, U& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - { - infra::ProtoParser nestedParser = field.first.Get().Parser(); - value.Deserialize(nestedParser); - } - } - - template - void DeserializeField(ProtoEnum, infra::ProtoParser& parser, infra::ProtoParser::Field& field, T& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - value = static_cast(field.first.Get()); - } - - template - void DeserializeField(ProtoBytes, infra::ProtoParser& parser, infra::ProtoParser::Field& field, infra::BoundedVector& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - field.first.Get().GetBytes(value); - } - - template - void DeserializeField(ProtoBytes, infra::ProtoParser& parser, infra::ProtoParser::Field& field, infra::ConstByteRange& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - field.first.Get().GetBytesReference(value); - } - - template - void DeserializeField(ProtoString, infra::ProtoParser& parser, infra::ProtoParser::Field& field, infra::BoundedString& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - field.first.Get().GetString(value); - } - - template - void DeserializeField(ProtoString, infra::ProtoParser& parser, infra::ProtoParser::Field& field, infra::BoundedConstString& value) - { - parser.ReportFormatResult(field.first.Is()); - if (field.first.Is()) - field.first.Get().GetStringReference(value); - } - template template ServiceProxyResponseQueue::ServiceProxyResponseQueue(Container& container, Args&&... args) diff --git a/protobuf/echo/Proto.hpp b/protobuf/echo/Proto.hpp new file mode 100644 index 000000000..3fb0ae479 --- /dev/null +++ b/protobuf/echo/Proto.hpp @@ -0,0 +1,541 @@ +#ifndef PROTOBUF_PROTO_HPP +#define PROTOBUF_PROTO_HPP + +#include "infra/syntax/ProtoFormatter.hpp" +#include "infra/syntax/ProtoParser.hpp" +#include + +namespace services +{ + struct ProtoBool + {}; + + struct ProtoUInt32 + {}; + + struct ProtoInt32 + {}; + + struct ProtoUInt64 + {}; + + struct ProtoInt64 + {}; + + struct ProtoFixed32 + {}; + + struct ProtoFixed64 + {}; + + struct ProtoSFixed32 + {}; + + struct ProtoSFixed64 + {}; + + struct ProtoUnboundedString + {}; + + struct ProtoUnboundedBytes + {}; + + template + struct ProtoMessage + {}; + + template + struct ProtoEnum + {}; + + struct ProtoBytesBase + {}; + + template + struct ProtoBytes + : ProtoBytesBase + {}; + + struct ProtoStringBase + {}; + + template + struct ProtoString + : ProtoStringBase + {}; + + template + struct ProtoRepeatedBase + {}; + + template + struct ProtoRepeated + : ProtoRepeatedBase + {}; + + template + struct ProtoUnboundedRepeated + {}; + + void SerializeField(ProtoBool, infra::ProtoFormatter& formatter, bool value, uint32_t fieldNumber); + void SerializeField(ProtoUInt32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber); + void SerializeField(ProtoInt32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber); + void SerializeField(ProtoUInt64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber); + void SerializeField(ProtoInt64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber); + void SerializeField(ProtoFixed32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber); + void SerializeField(ProtoFixed64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber); + void SerializeField(ProtoSFixed32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber); + void SerializeField(ProtoSFixed64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber); + void SerializeField(ProtoUnboundedString, infra::ProtoFormatter& formatter, const std::string& value, uint32_t fieldNumber); + void SerializeField(ProtoUnboundedBytes, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber); + + template + void SerializeField(ProtoRepeated, infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber); + template + void SerializeField(ProtoUnboundedRepeated, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber); + template + void SerializeField(ProtoUnboundedRepeated, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber); + template + void SerializeField(ProtoMessage, infra::ProtoFormatter& formatter, const U& value, uint32_t fieldNumber); + template + void SerializeField(ProtoEnum, infra::ProtoFormatter& formatter, T value, uint32_t fieldNumber); + template + void SerializeField(ProtoBytes, infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber); + template + void SerializeField(ProtoString, infra::ProtoFormatter& formatter, infra::BoundedConstString value, uint32_t fieldNumber); + + void DeserializeField(ProtoBool, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, bool& value); + void DeserializeField(ProtoUInt32, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, uint32_t& value); + void DeserializeField(ProtoInt32, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, int32_t& value); + void DeserializeField(ProtoUInt64, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, uint64_t& value); + void DeserializeField(ProtoInt64, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, int64_t& value); + void DeserializeField(ProtoFixed32, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, uint32_t& value); + void DeserializeField(ProtoFixed64, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, uint64_t& value); + void DeserializeField(ProtoSFixed32, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, int32_t& value); + void DeserializeField(ProtoSFixed64, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, int64_t& value); + void DeserializeField(ProtoUnboundedString, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, std::string& value); + void DeserializeField(ProtoUnboundedBytes, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, std::vector& value); + + template + void DeserializeField(ProtoRepeated, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, infra::BoundedVector& value); + template + void DeserializeField(ProtoUnboundedRepeated, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, std::vector& value); + template + void DeserializeField(ProtoUnboundedRepeated, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, std::vector& value); + template + void DeserializeField(ProtoMessage, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, U& value); + template + void DeserializeField(ProtoEnum, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, T& value); + template + void DeserializeField(ProtoBytes, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, infra::BoundedVector& value); + template + void DeserializeField(ProtoBytes, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, infra::ConstByteRange& value); + template + void DeserializeField(ProtoString, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, infra::BoundedString& value); + template + void DeserializeField(ProtoString, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, infra::BoundedConstString& value); + + void DeserializeField(ProtoBool, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, bool& value); + void DeserializeField(ProtoUInt32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint32_t& value); + void DeserializeField(ProtoInt32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int32_t& value); + void DeserializeField(ProtoUInt64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint64_t& value); + void DeserializeField(ProtoInt64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int64_t& value); + void DeserializeField(ProtoFixed32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint32_t& value); + void DeserializeField(ProtoFixed64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint64_t& value); + void DeserializeField(ProtoSFixed32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int32_t& value); + void DeserializeField(ProtoSFixed64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int64_t& value); + + template + struct MessageDepth + { + static constexpr uint32_t value = 0; + }; + + template<> + struct MessageDepth + { + static constexpr uint32_t value = 1; + }; + + template<> + struct MessageDepth + { + static constexpr uint32_t value = 1; + }; + + template + struct MessageDepth> + { + static constexpr uint32_t value = 1; + }; + + template + struct MessageDepth> + { + static constexpr uint32_t value = 1; + }; + + template + struct MessageDepth> + { + static constexpr uint32_t value = 1; + }; + + template + struct MessageDepth> + { + static constexpr uint32_t value = 1; + }; + + template + struct Max; + + template + struct Max + { + static constexpr uint32_t value = V; + }; + + template + struct Max + { + static constexpr uint32_t value = std::max(V, Max::value); + }; + + template + struct MaxFieldsDepth; + + template + struct MaxFieldsDepth> + { + static constexpr uint32_t value = Max>::value...>::value; + }; + + template + struct MessageDepth> + { + static constexpr uint32_t value = MaxFieldsDepth>::value + 1; + }; + + //// Implementation //// + + inline void SerializeField(ProtoBool, infra::ProtoFormatter& formatter, bool value, uint32_t fieldNumber) + { + formatter.PutVarIntField(value, fieldNumber); + } + + inline void SerializeField(ProtoUInt32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber) + { + formatter.PutVarIntField(value, fieldNumber); + } + + inline void SerializeField(ProtoInt32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber) + { + formatter.PutVarIntField(value, fieldNumber); + } + + inline void SerializeField(ProtoUInt64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber) + { + formatter.PutVarIntField(value, fieldNumber); + } + + inline void SerializeField(ProtoInt64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber) + { + formatter.PutVarIntField(value, fieldNumber); + } + + inline void SerializeField(ProtoFixed32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber) + { + formatter.PutFixed32Field(value, fieldNumber); + } + + inline void SerializeField(ProtoFixed64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber) + { + formatter.PutFixed64Field(value, fieldNumber); + } + + inline void SerializeField(ProtoSFixed32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber) + { + formatter.PutFixed32Field(static_cast(value), fieldNumber); + } + + inline void SerializeField(ProtoSFixed64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber) + { + formatter.PutFixed64Field(static_cast(value), fieldNumber); + } + + inline void SerializeField(ProtoUnboundedString, infra::ProtoFormatter& formatter, const std::string& value, uint32_t fieldNumber) + { + formatter.PutStringField(value, fieldNumber); + } + + inline void SerializeField(ProtoUnboundedBytes, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber) + { + formatter.PutBytesField(value, fieldNumber); + } + + template + void SerializeField(ProtoRepeated, infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber) + { + for (auto& v : value) + SerializeField(T(), formatter, v, fieldNumber); + } + + template + void SerializeField(ProtoUnboundedRepeated, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber) + { + for (auto& v : value) + SerializeField(T(), formatter, v, fieldNumber); + } + + template + void SerializeField(ProtoUnboundedRepeated, infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber) + { + for (auto v : value) + SerializeField(T(), formatter, v, fieldNumber); + } + + template + void SerializeField(ProtoMessage, infra::ProtoFormatter& formatter, const U& value, uint32_t fieldNumber) + { + infra::ProtoLengthDelimitedFormatter nestedMessage(formatter, fieldNumber); + value.Serialize(formatter); + } + + template + void SerializeField(ProtoEnum, infra::ProtoFormatter& formatter, T value, uint32_t fieldNumber) + { + formatter.PutVarIntField(static_cast(value), fieldNumber); + } + + template + void SerializeField(ProtoBytes, infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber) + { + formatter.PutBytesField(infra::MakeRange(value), fieldNumber); + } + + template + void SerializeField(ProtoString, infra::ProtoFormatter& formatter, infra::BoundedConstString value, uint32_t fieldNumber) + { + formatter.PutStringField(value, fieldNumber); + } + + inline void DeserializeField(ProtoBool, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, bool& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = field.Get() != 0; + } + + inline void DeserializeField(ProtoUInt32, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, uint32_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + inline void DeserializeField(ProtoInt32, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, int32_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + inline void DeserializeField(ProtoUInt64, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, uint64_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = field.Get(); + } + + inline void DeserializeField(ProtoInt64, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, int64_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + inline void DeserializeField(ProtoFixed32, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, uint32_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = field.Get(); + } + + inline void DeserializeField(ProtoFixed64, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, uint64_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = field.Get(); + } + + inline void DeserializeField(ProtoSFixed32, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, int32_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + inline void DeserializeField(ProtoSFixed64, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, int64_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + inline void DeserializeField(ProtoUnboundedString, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, std::string& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = field.Get().GetStdString(); + } + + inline void DeserializeField(ProtoUnboundedBytes, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, std::vector& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = field.Get().GetUnboundedBytes(); + } + + template + void DeserializeField(ProtoRepeated, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, infra::BoundedVector& value) + { + parser.ReportFormatResult(!value.full()); + if (!value.full()) + { + value.emplace_back(); + DeserializeField(T(), parser, field, value.back()); + } + } + + template + void DeserializeField(ProtoUnboundedRepeated, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, std::vector& value) + { + value.emplace_back(); + DeserializeField(T(), parser, field, value.back()); + } + + template + void DeserializeField(ProtoUnboundedRepeated, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, std::vector& value) + { + bool result{}; + DeserializeField(T(), parser, field, result); + value.push_back(result); + } + + template + void DeserializeField(ProtoMessage, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, U& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + { + infra::ProtoParser nestedParser = field.Get().Parser(); + value.Deserialize(nestedParser); + } + } + + template + void DeserializeField(ProtoEnum, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, T& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + template + void DeserializeField(ProtoBytes, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, infra::BoundedVector& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + field.Get().GetBytes(value); + } + + template + void DeserializeField(ProtoBytes, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, infra::ConstByteRange& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + field.Get().GetBytesReference(value); + } + + template + void DeserializeField(ProtoString, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, infra::BoundedString& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + field.Get().GetString(value); + } + + template + void DeserializeField(ProtoString, infra::ProtoParser& parser, infra::ProtoParser::FieldVariant& field, infra::BoundedConstString& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + field.Get().GetStringReference(value); + } + + inline void DeserializeField(ProtoBool, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, bool& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = field.Get() != 0; + } + + inline void DeserializeField(ProtoUInt32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint32_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + inline void DeserializeField(ProtoInt32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int32_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + inline void DeserializeField(ProtoUInt64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint64_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = field.Get(); + } + + inline void DeserializeField(ProtoInt64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int64_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + inline void DeserializeField(ProtoFixed32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint32_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = field.Get(); + } + + inline void DeserializeField(ProtoFixed64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint64_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = field.Get(); + } + + inline void DeserializeField(ProtoSFixed32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int32_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + inline void DeserializeField(ProtoSFixed64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int64_t& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } +} + +#endif diff --git a/protobuf/echo/ProtoMessageReceiver.cpp b/protobuf/echo/ProtoMessageReceiver.cpp new file mode 100644 index 000000000..1c2e14a9e --- /dev/null +++ b/protobuf/echo/ProtoMessageReceiver.cpp @@ -0,0 +1,142 @@ +#include "protobuf/echo/ProtoMessageReceiver.hpp" +#include "infra/stream/BufferingStreamReader.hpp" + +namespace services +{ + ProtoMessageReceiverBase::ProtoMessageReceiverBase(infra::BoundedVector>>& stack) + : stack(stack) + {} + + void ProtoMessageReceiverBase::Feed(infra::StreamReaderWithRewinding& data) + { + infra::BufferingStreamReader reader{ buffer, data }; + + while (true) + { + infra::LimitedStreamReaderWithRewinding limitedReader(reader, stack.back().first); + infra::DataInputStream::WithErrorPolicy stream{ limitedReader, infra::softFail }; + + auto available = limitedReader.Available(); + + auto& current = stack.back(); + current.second(stream); + + if (stream.Failed() || reader.Empty()) + break; + + if (¤t != &stack.front()) + { + current.first -= available - limitedReader.Available(); + + if (current.first == 0) + stack.pop_back(); + } + } + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoBool, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, bool& value) const + { + services::DeserializeField(ProtoBool(), parser, field, value); + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoUInt32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint32_t& value) const + { + services::DeserializeField(ProtoUInt32(), parser, field, value); + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoInt32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int32_t& value) const + { + services::DeserializeField(ProtoInt32(), parser, field, value); + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoUInt64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint64_t& value) const + { + services::DeserializeField(ProtoUInt64(), parser, field, value); + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoInt64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int64_t& value) const + { + services::DeserializeField(ProtoInt64(), parser, field, value); + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoFixed32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint32_t& value) const + { + services::DeserializeField(ProtoFixed32(), parser, field, value); + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoFixed64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint64_t& value) const + { + services::DeserializeField(ProtoFixed64(), parser, field, value); + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoSFixed32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int32_t& value) const + { + services::DeserializeField(ProtoSFixed32(), parser, field, value); + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoSFixed64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int64_t& value) const + { + services::DeserializeField(ProtoSFixed64(), parser, field, value); + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoStringBase, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, infra::BoundedString& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + { + auto stringSize = field.Get().length; + stack.emplace_back(stringSize, [&value](const infra::DataInputStream& stream) + { + while (!stream.Empty()) + value.append(infra::ByteRangeAsString(stream.ContiguousRange())); + }); + value.clear(); + } + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoUnboundedString, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, std::string& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + { + auto stringSize = field.Get().length; + stack.emplace_back(stringSize, [&value](const infra::DataInputStream& stream) + { + while (!stream.Empty()) + value.append(infra::ByteRangeAsStdString(stream.ContiguousRange())); + }); + value.clear(); + } + } + + void ProtoMessageReceiverBase::DeserializeField(ProtoBytesBase, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, infra::BoundedVector& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + { + auto bytesSize = field.Get().length; + stack.emplace_back(bytesSize, [&value](const infra::DataInputStream& stream) + { + while (!stream.Empty()) + { + auto range = stream.ContiguousRange(); + value.insert(value.end(), range.begin(), range.end()); + } + }); + value.clear(); + } + } + + void ProtoMessageReceiverBase::ConsumeUnknownField(infra::ProtoParser::PartialField& field) + { + if (field.first.Is()) + { + auto size = field.first.Get().length; + stack.emplace_back(size, [](const infra::DataInputStream& stream) + { + while (!stream.Empty()) + stream.ContiguousRange(); + }); + } + } +} diff --git a/protobuf/echo/ProtoMessageReceiver.hpp b/protobuf/echo/ProtoMessageReceiver.hpp new file mode 100644 index 000000000..b127695b4 --- /dev/null +++ b/protobuf/echo/ProtoMessageReceiver.hpp @@ -0,0 +1,173 @@ +#ifndef PROTOBUF_PROTO_MESSAGE_RECEIVER_HPP +#define PROTOBUF_PROTO_MESSAGE_RECEIVER_HPP + +#include "infra/syntax/ProtoParser.hpp" +#include "infra/util/BoundedDeque.hpp" +#include "infra/util/BoundedVector.hpp" +#include "protobuf/echo/Proto.hpp" + +namespace services +{ + class ProtoMessageReceiverBase + { + public: + explicit ProtoMessageReceiverBase(infra::BoundedVector>>& stack); + + void Feed(infra::StreamReaderWithRewinding& data); + + protected: + template + void FeedForMessage(const infra::DataInputStream& stream, Message& message); + + private: + template + bool DeserializeFields(infra::ProtoParser::PartialField& field, infra::ProtoParser& parser, Message& message, std::index_sequence); + + template + bool DeserializeSingleField(infra::ProtoParser::PartialField& field, infra::ProtoParser& parser, Message& message); + + void DeserializeField(ProtoBool, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, bool& value) const; + void DeserializeField(ProtoUInt32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint32_t& value) const; + void DeserializeField(ProtoInt32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int32_t& value) const; + void DeserializeField(ProtoUInt64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint64_t& value) const; + void DeserializeField(ProtoInt64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int64_t& value) const; + void DeserializeField(ProtoFixed32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint32_t& value) const; + void DeserializeField(ProtoFixed64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, uint64_t& value) const; + void DeserializeField(ProtoSFixed32, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int32_t& value) const; + void DeserializeField(ProtoSFixed64, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, int64_t& value) const; + + void DeserializeField(ProtoStringBase, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, infra::BoundedString& value); + void DeserializeField(ProtoUnboundedString, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, std::string& value); + void DeserializeField(ProtoBytesBase, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, infra::BoundedVector& value); + + template + void DeserializeField(ProtoEnum, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, Enum& value) const; + template + void DeserializeField(ProtoMessage, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, Message& value); + template + void DeserializeField(ProtoRepeatedBase, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, Type& value) const; + template + void DeserializeField(ProtoUnboundedRepeated, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, Type& value) const; + + void ConsumeUnknownField(infra::ProtoParser::PartialField& field); + + private: + infra::BoundedDeque::WithMaxSize<32> buffer; + infra::BoundedVector>>& stack; + }; + + template + class ProtoMessageReceiver + : public ProtoMessageReceiverBase + { + public: + ProtoMessageReceiver(); + + Message message; + + private: + infra::BoundedVector>>::WithMaxSize>::value + 1> stack{ { std::pair>{ std::numeric_limits::max(), [this](const infra::DataInputStream& stream) + { + FeedForMessage(stream, message); + } } } }; + }; +} + +//// Implementation //// + +namespace services +{ + template + void ProtoMessageReceiverBase::FeedForMessage(const infra::DataInputStream& stream, Message& message) + { + infra::ProtoParser parser{ stream }; + + while (!stream.Empty()) + { + auto marker = static_cast(stream.Reader()).ConstructSaveMarker(); + auto field = parser.GetPartialField(); + + if (stream.Failed()) + { + static_cast(stream.Reader()).Rewind(marker); + break; + } + + auto stackSize = stack.size(); + if (!DeserializeFields(field, parser, message, std::make_index_sequence{})) + ConsumeUnknownField(field); + + if (stackSize != stack.size()) + { + stream.Failed(); + return; + } + } + } + + template + bool ProtoMessageReceiverBase::DeserializeFields(infra::ProtoParser::PartialField& field, infra::ProtoParser& parser, Message& message, std::index_sequence) + { + return (DeserializeSingleField(field, parser, message) || ...); + } + + template + bool ProtoMessageReceiverBase::DeserializeSingleField(infra::ProtoParser::PartialField& field, infra::ProtoParser& parser, Message& message) + { + if (field.second == Message::template fieldNumber) + { + DeserializeField(typename Message::template ProtoType(), parser, field.first, message.Get(std::integral_constant())); + return true; + } + + return false; + } + + template + void ProtoMessageReceiverBase::DeserializeField(ProtoEnum, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, Enum& value) const + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + value = static_cast(field.Get()); + } + + template + void ProtoMessageReceiverBase::DeserializeField(ProtoMessage, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, Message& value) + { + parser.ReportFormatResult(field.Is()); + if (field.Is()) + { + auto messageSize = field.Get().length; + stack.emplace_back(messageSize, [this, &value](const infra::DataInputStream& stream) + { + FeedForMessage(stream, value); + }); + infra::ReConstruct(value); + } + } + + template + void ProtoMessageReceiverBase::DeserializeField(ProtoRepeatedBase, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, Type& value) const + { + parser.ReportFormatResult(!value.full()); + if (!value.full()) + { + value.emplace_back(); + DeserializeField(ProtoType(), parser, field, value.back()); + } + } + + template + void ProtoMessageReceiverBase::DeserializeField(ProtoUnboundedRepeated, infra::ProtoParser& parser, infra::ProtoParser::PartialFieldVariant& field, Type& value) const + { + value.emplace_back(); + DeserializeField(ProtoType(), parser, field, value.back()); + } + + template + ProtoMessageReceiver::ProtoMessageReceiver() + : ProtoMessageReceiverBase(stack) + {} +} + +#endif diff --git a/protobuf/echo/ProtoMessageSender.cpp b/protobuf/echo/ProtoMessageSender.cpp new file mode 100644 index 000000000..5f013d486 --- /dev/null +++ b/protobuf/echo/ProtoMessageSender.cpp @@ -0,0 +1,132 @@ +#include "protobuf/echo/ProtoMessageSender.hpp" +#include "infra/stream/BufferingStreamWriter.hpp" + +namespace services +{ + ProtoMessageSenderBase::ProtoMessageSenderBase(infra::BoundedVector>>& stack) + : stack(stack) + {} + + void ProtoMessageSenderBase::Fill(infra::DataOutputStream output) + { + infra::BufferingStreamWriter writer{ buffer, output.Writer() }; + + while (!stack.empty()) + { + infra::DataOutputStream::WithErrorPolicy stream{ writer }; + + auto& [index, callback] = stack.back(); + bool retry = false; + auto result = callback(stream, index, retry, output.Writer()); + if (result) + stack.pop_back(); + else if (!retry) + break; + } + } + + bool ProtoMessageSenderBase::SerializeField(ProtoBool, infra::ProtoFormatter& formatter, bool value, uint32_t fieldNumber, [[maybe_unused]] const bool& retry) const + { + services::SerializeField(ProtoBool(), formatter, value, fieldNumber); + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoUInt32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber, [[maybe_unused]] const bool& retry) const + { + services::SerializeField(ProtoUInt32(), formatter, value, fieldNumber); + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoInt32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber, [[maybe_unused]] const bool& retry) const + { + services::SerializeField(ProtoInt32(), formatter, value, fieldNumber); + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoUInt64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber, [[maybe_unused]] const bool& retry) const + { + services::SerializeField(ProtoUInt64(), formatter, value, fieldNumber); + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoInt64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber, [[maybe_unused]] const bool& retry) const + { + services::SerializeField(ProtoInt64(), formatter, value, fieldNumber); + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoFixed32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber, [[maybe_unused]] const bool& retry) const + { + services::SerializeField(ProtoFixed32(), formatter, value, fieldNumber); + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoSFixed32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber, [[maybe_unused]] const bool& retry) const + { + services::SerializeField(ProtoSFixed32(), formatter, value, fieldNumber); + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoFixed64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber, [[maybe_unused]] const bool& retry) const + { + services::SerializeField(ProtoFixed64(), formatter, value, fieldNumber); + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoSFixed64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber, [[maybe_unused]] const bool& retry) const + { + services::SerializeField(ProtoSFixed64(), formatter, value, fieldNumber); + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoStringBase, infra::ProtoFormatter& formatter, const infra::BoundedString& value, uint32_t fieldNumber, bool& retry) const + { + formatter.PutVarInt((fieldNumber << 3) | 2); + formatter.PutVarInt(value.size()); + + stack.emplace_back(0, [&value](infra::DataOutputStream& stream, uint32_t& index, const bool&, const infra::StreamWriter&) + { + auto range = infra::Head(infra::DiscardHead(infra::StringAsByteRange(value), index), stream.Available()); + stream << range; + index += range.size(); + return index == value.size(); + }); + + retry = true; + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoUnboundedString, infra::ProtoFormatter& formatter, const std::string& value, uint32_t fieldNumber, bool& retry) const + { + formatter.PutVarInt((fieldNumber << 3) | 2); + formatter.PutVarInt(value.size()); + + stack.emplace_back(0, [&value](infra::DataOutputStream& stream, uint32_t& index, const bool&, const infra::StreamWriter&) + { + auto range = infra::Head(infra::DiscardHead(infra::StdStringAsByteRange(value), index), stream.Available()); + stream << range; + index += range.size(); + return index == value.size(); + }); + + retry = true; + return true; + } + + bool ProtoMessageSenderBase::SerializeField(ProtoBytesBase, infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber, bool& retry) const + { + formatter.PutVarInt((fieldNumber << 3) | 2); + formatter.PutVarInt(value.size()); + + stack.emplace_back(0, [&value](infra::DataOutputStream& stream, uint32_t& index, const bool&, const infra::StreamWriter&) + { + auto range = infra::Head(infra::DiscardHead(infra::MakeRange(value), index), stream.Available()); + stream << range; + index += range.size(); + return index == value.size(); + }); + + retry = true; + return true; + } +} diff --git a/protobuf/echo/ProtoMessageSender.hpp b/protobuf/echo/ProtoMessageSender.hpp new file mode 100644 index 000000000..930c3df9f --- /dev/null +++ b/protobuf/echo/ProtoMessageSender.hpp @@ -0,0 +1,180 @@ +#ifndef PROTOBUF_PROTO_MESSAGE_SENDER_HPP +#define PROTOBUF_PROTO_MESSAGE_SENDER_HPP + +#include "infra/stream/CountingOutputStream.hpp" +#include "infra/syntax/ProtoFormatter.hpp" +#include "infra/util/BoundedDeque.hpp" +#include "infra/util/BoundedVector.hpp" +#include "protobuf/echo/Proto.hpp" + +namespace services +{ + class ProtoMessageSenderBase + { + public: + explicit ProtoMessageSenderBase(infra::BoundedVector>>& stack); + + void Fill(infra::DataOutputStream output); + + protected: + template + bool FillForMessage(infra::DataOutputStream& stream, const Message& message, uint32_t& index, bool& retry, const infra::StreamWriter& finalWriter) const; + + private: + template + bool SerializeFields(infra::ProtoFormatter& formatter, const Message& message, uint32_t& index, bool& retry, const infra::StreamWriter& finalWriter, std::index_sequence) const; + + template + bool SerializeSingleField(infra::ProtoFormatter& formatter, const Message& message, uint32_t& index, bool& retry, const infra::StreamWriter& finalWriter) const; + + bool SerializeField(ProtoBool, infra::ProtoFormatter& formatter, bool value, uint32_t fieldNumber, const bool& retry) const; + bool SerializeField(ProtoUInt32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber, const bool& retry) const; + bool SerializeField(ProtoInt32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber, const bool& retry) const; + bool SerializeField(ProtoUInt64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber, const bool& retry) const; + bool SerializeField(ProtoInt64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber, const bool& retry) const; + bool SerializeField(ProtoFixed32, infra::ProtoFormatter& formatter, uint32_t value, uint32_t fieldNumber, const bool& retry) const; + bool SerializeField(ProtoSFixed32, infra::ProtoFormatter& formatter, int32_t value, uint32_t fieldNumber, const bool& retry) const; + bool SerializeField(ProtoFixed64, infra::ProtoFormatter& formatter, uint64_t value, uint32_t fieldNumber, const bool& retry) const; + bool SerializeField(ProtoSFixed64, infra::ProtoFormatter& formatter, int64_t value, uint32_t fieldNumber, const bool& retry) const; + + bool SerializeField(ProtoStringBase, infra::ProtoFormatter& formatter, const infra::BoundedString& value, uint32_t fieldNumber, bool& retry) const; + bool SerializeField(ProtoUnboundedString, infra::ProtoFormatter& formatter, const std::string& value, uint32_t fieldNumber, bool& retry) const; + bool SerializeField(ProtoBytesBase, infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber, bool& retry) const; + + template + bool SerializeField(ProtoEnum, infra::ProtoFormatter& formatter, const Enum& value, uint32_t fieldNumber, const bool& retry) const; + template + bool SerializeField(ProtoMessage, infra::ProtoFormatter& formatter, const Message& value, uint32_t fieldNumber, bool& retry) const; + template + bool SerializeField(ProtoRepeatedBase, const infra::ProtoFormatter& formatter, const infra::BoundedVector& value, uint32_t fieldNumber, bool& retry) const; + template + bool SerializeField(ProtoUnboundedRepeated, const infra::ProtoFormatter& formatter, const std::vector& value, uint32_t fieldNumber, bool& retry) const; + + private: + infra::BoundedDeque::WithMaxSize<32> buffer; + infra::BoundedVector>>& stack; + }; + + template + class ProtoMessageSender + : public ProtoMessageSenderBase + { + public: + explicit ProtoMessageSender(const Message& message); + + private: + const Message& message; + infra::BoundedVector>>::WithMaxSize>::value + 1> stack{ { std::pair>{ 0, [this](infra::DataOutputStream& stream, uint32_t& index, bool& retry, const infra::StreamWriter& finalWriter) + { + return FillForMessage(stream, message, index, retry, finalWriter); + } } } }; + }; + + //// Implementation //// + + template + bool ProtoMessageSenderBase::FillForMessage(infra::DataOutputStream& stream, const Message& message, uint32_t& index, bool& retry, const infra::StreamWriter& finalWriter) const + { + infra::ProtoFormatter formatter{ stream }; + + return SerializeFields(formatter, message, index, retry, finalWriter, std::make_index_sequence{}); + } + + template + bool ProtoMessageSenderBase::SerializeFields(infra::ProtoFormatter& formatter, const Message& message, uint32_t& index, bool& retry, const infra::StreamWriter& finalWriter, std::index_sequence) const + { + return (SerializeSingleField(formatter, message, index, retry, finalWriter) && ...); + } + + template + bool ProtoMessageSenderBase::SerializeSingleField(infra::ProtoFormatter& formatter, const Message& message, uint32_t& index, bool& retry, const infra::StreamWriter& finalWriter) const + { + if (finalWriter.Available() == 0) + return false; + + if (index == I) + { + if (!SerializeField(typename Message::template ProtoType(), formatter, message.Get(std::integral_constant()), Message::template fieldNumber, retry)) + return false; + + index = I + 1; + return !retry; + } + + return true; + } + + template + bool ProtoMessageSenderBase::SerializeField(ProtoEnum, infra::ProtoFormatter& formatter, const Enum& value, uint32_t fieldNumber, [[maybe_unused]] const bool& retry) const + { + services::SerializeField(ProtoEnum(), formatter, value, fieldNumber); + return true; + } + + template + bool ProtoMessageSenderBase::SerializeField(ProtoMessage, infra::ProtoFormatter& formatter, const Message& value, uint32_t fieldNumber, bool& retry) const + { + infra::DataOutputStream::WithWriter countingStream; + infra::ProtoFormatter countingFormatter{ countingStream }; + value.Serialize(countingFormatter); + formatter.PutLengthDelimitedSize(countingStream.Writer().Processed(), fieldNumber); + + stack.emplace_back(0, [this, &value](infra::DataOutputStream& stream, uint32_t& index, bool& retry2, const infra::StreamWriter& finalWriter) + { + return FillForMessage(stream, value, index, retry2, finalWriter); + }); + + retry = true; + return true; + } + + template + bool ProtoMessageSenderBase::SerializeField(ProtoRepeatedBase, const infra::ProtoFormatter&, const infra::BoundedVector& value, uint32_t fieldNumber, bool& retry) const + { + stack.emplace_back(0, [this, &value, fieldNumber](infra::DataOutputStream& stream, uint32_t& index, const bool&, const infra::StreamWriter& finalWriter) + { + infra::ProtoFormatter formatter{ stream }; + + for (; index != value.size(); ++index) + { + if (finalWriter.Available() == 0) + return false; + services::SerializeField(ProtoType(), formatter, value[index], fieldNumber); + } + + return true; + }); + + retry = true; + return true; + } + + template + bool ProtoMessageSenderBase::SerializeField(ProtoUnboundedRepeated, const infra::ProtoFormatter&, const std::vector& value, uint32_t fieldNumber, bool& retry) const + { + stack.emplace_back(0, [this, &value, fieldNumber](infra::DataOutputStream& stream, uint32_t& index, const bool&, const infra::StreamWriter& finalWriter) + { + infra::ProtoFormatter formatter{ stream }; + + for (; index != value.size(); ++index) + { + if (finalWriter.Available() == 0) + return false; + services::SerializeField(ProtoType(), formatter, value[index], fieldNumber); + } + + return true; + }); + + retry = true; + return true; + } + + template + ProtoMessageSender::ProtoMessageSender(const Message& message) + : ProtoMessageSenderBase(stack) + , message(message) + {} +} + +#endif diff --git a/protobuf/echo/test/CMakeLists.txt b/protobuf/echo/test/CMakeLists.txt index 3176af5dc..9dd11414e 100644 --- a/protobuf/echo/test/CMakeLists.txt +++ b/protobuf/echo/test/CMakeLists.txt @@ -2,8 +2,12 @@ add_executable(protobuf.echo_test) emil_build_for(protobuf.echo_test BOOL EMIL_BUILD_TESTS) emil_add_test(protobuf.echo_test) +protocol_buffer_echo_cpp(protobuf.echo_test TestMessages.proto) + target_sources(protobuf.echo_test PRIVATE TestEchoServiceResponseQueue.cpp + TestProtoMessageReceiver.cpp + TestProtoMessageSender.cpp TestServiceForwarder.cpp ) diff --git a/protobuf/protoc_echo_plugin/test/TestMessages.proto b/protobuf/echo/test/TestMessages.proto similarity index 92% rename from protobuf/protoc_echo_plugin/test/TestMessages.proto rename to protobuf/echo/test/TestMessages.proto index 8f931978e..5588d0a11 100644 --- a/protobuf/protoc_echo_plugin/test/TestMessages.proto +++ b/protobuf/echo/test/TestMessages.proto @@ -64,7 +64,7 @@ message TestRepeatedString { } message TestBytes { - bytes value = 1 [(bytes_size) = 10]; + bytes value = 1 [(bytes_size) = 50]; } message TestUnboundedBytes { @@ -72,7 +72,11 @@ message TestUnboundedBytes { } message TestRepeatedUInt32 { - repeated uint32 value = 1 [(array_size) = 10]; + repeated uint32 value = 1 [(array_size) = 50]; +} + +message TestUnboundedRepeatedUInt32 { + repeated uint32 value = 1; } message TestMessageWithMessageField { @@ -168,3 +172,8 @@ service TestService2 rpc Search (TestString) returns (Nothing) { option (method_id) = 1; } } + +message TestBoolWithBytes { + bytes b = 1 [(bytes_size) = 30]; + bool value = 2; +} diff --git a/protobuf/echo/test/TestProtoMessageReceiver.cpp b/protobuf/echo/test/TestProtoMessageReceiver.cpp new file mode 100644 index 000000000..fe7a83fc0 --- /dev/null +++ b/protobuf/echo/test/TestProtoMessageReceiver.cpp @@ -0,0 +1,222 @@ +#include "generated/echo/TestMessages.pb.hpp" +#include "infra/stream/StdVectorInputStream.hpp" +#include "protobuf/echo/ProtoMessageReceiver.hpp" +#include + +TEST(ProtoMessageReceiverTest, parse_bool) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 1 << 3, 1 }); + receiver.Feed(data); + + EXPECT_EQ(true, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_unknown_simple_field) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 2 << 3, 1, 1 << 3, 1 }); + receiver.Feed(data); + + EXPECT_EQ(true, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_unknown_length_delimited) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 1 << 3, 1, (2 << 3) | 2, 2, 1 << 3, 0 }); + receiver.Feed(data); + + EXPECT_EQ(true, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_incomplete_bool) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 8 }); + receiver.Feed(data); + + EXPECT_EQ(false, receiver.message.value); + + infra::StdVectorInputStreamReader::WithStorage data2(infra::inPlace, std::initializer_list{ 1 }); + receiver.Feed(data2); + + EXPECT_EQ(true, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_uint32) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 1 << 3, 5 }); + receiver.Feed(data); + + EXPECT_EQ(5, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_int32) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 8, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 1 }); + receiver.Feed(data); + + EXPECT_EQ(-1, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_uint64) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 1 << 3, 5 }); + receiver.Feed(data); + + EXPECT_EQ(5, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_int64) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 8, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 1 }); + receiver.Feed(data); + + EXPECT_EQ(-1, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_fixed32) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 13, 0x90, 0x91, 0x0f, 0 }); + receiver.Feed(data); + + EXPECT_EQ(1020304, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_fixed64) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 9, 0x64, 0xc8, 0x0c, 0xce, 0xcb, 0x5c, 0x00, 0x00 }); + receiver.Feed(data); + + EXPECT_EQ(102030405060708, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_sfixed32) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 13, 0xff, 0xff, 0xff, 0xff }); + receiver.Feed(data); + + EXPECT_EQ(-1, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_sfixed64) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }); + receiver.Feed(data); + + EXPECT_EQ(-1, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_enum) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 1 << 3, 1 }); + receiver.Feed(data); + + EXPECT_EQ(test_messages::Enumeration::val1, receiver.message.e); +} + +TEST(ProtoMessageReceiverTest, parse_repeated_uint32) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 1 << 3, 5, 1 << 3, 6 }); + receiver.Feed(data); + + EXPECT_EQ((infra::BoundedVector::WithMaxSize<10>{ { 5, 6 } }), receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_unbounded_repeated_uint32) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 1 << 3, 5, 1 << 3, 6 }); + receiver.Feed(data); + + EXPECT_EQ((std::vector{ { 5, 6 } }), receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_incomplete_int) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 8, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }); + receiver.Feed(data); + EXPECT_EQ(0, receiver.message.value); + + infra::StdVectorInputStreamReader::WithStorage data2(infra::inPlace, std::initializer_list{ 0xff, 1 }); + receiver.Feed(data2); + EXPECT_EQ(-1, receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_string) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 10, 4, 'a', 'b', 'c', 'd' }); + receiver.Feed(data); + + EXPECT_EQ("abcd", receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_std_string) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 10, 4, 'a', 'b', 'c', 'd' }); + receiver.Feed(data); + + EXPECT_EQ("abcd", receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_bytes) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ 10, 2, 5, 6 }); + receiver.Feed(data); + + EXPECT_EQ((infra::BoundedVector::WithMaxSize<10>{ { 5, 6 } }), receiver.message.value); +} + +TEST(ProtoMessageReceiverTest, parse_message) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ (1 << 3) | 2, 2, 1 << 3, 5 }); + receiver.Feed(data); + + EXPECT_EQ((test_messages::TestMessageWithMessageField(test_messages::TestUInt32(5))), receiver.message.message); +} + +TEST(ProtoMessageReceiverTest, parse_more_nested_message) +{ + services::ProtoMessageReceiver receiver; + + infra::StdVectorInputStreamReader::WithStorage data(infra::inPlace, std::initializer_list{ (1 << 3) | 2, 2, 1 << 3, 5, (2 << 3) | 2, 2, 2 << 3, 10 }); + receiver.Feed(data); + + EXPECT_EQ((test_messages::TestMoreNestedMessage({ 5 }, { 10 })), receiver.message); +} diff --git a/protobuf/echo/test/TestProtoMessageSender.cpp b/protobuf/echo/test/TestProtoMessageSender.cpp new file mode 100644 index 000000000..3cefc4282 --- /dev/null +++ b/protobuf/echo/test/TestProtoMessageSender.cpp @@ -0,0 +1,222 @@ +#include "generated/echo/TestMessages.pb.hpp" +#include "infra/stream/BoundedVectorOutputStream.hpp" +#include "infra/stream/ByteOutputStream.hpp" +#include "infra/stream/StdVectorOutputStream.hpp" +#include "infra/util/ConstructBin.hpp" +#include "protobuf/echo/ProtoMessageSender.hpp" +#include + +class ProtoMessageSenderTest + : public testing::Test +{ +public: + template + void ExpectFill(const std::vector& data, services::ProtoMessageSender& sender) + { + infra::StdVectorOutputStream::WithStorage stream; + sender.Fill(stream); + + EXPECT_EQ(data, stream.Storage()); + } +}; + +TEST_F(ProtoMessageSenderTest, format_bool) +{ + test_messages::TestBool message{ true }; + services::ProtoMessageSender sender(message); + + ExpectFill({ 1 << 3, 1 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_uint32) +{ + test_messages::TestUInt32 message(5); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 1 << 3, 5 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_int32) +{ + test_messages::TestInt32 message(-1); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 8, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 1 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_uint64) +{ + test_messages::TestUInt64 message(5); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 1 << 3, 5 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_int64) +{ + test_messages::TestInt64 message(-1); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 8, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 1 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_fixed32) +{ + test_messages::TestFixed32 message(1020304); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 13, 0x90, 0x91, 0x0f, 0 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_fixed64) +{ + test_messages::TestFixed64 message(102030405060708); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 9, 0x64, 0xc8, 0x0c, 0xce, 0xcb, 0x5c, 0x00, 0x00 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_sfixed32) +{ + test_messages::TestSFixed32 message(-1); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 13, 0xff, 0xff, 0xff, 0xff }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_sfixed64) +{ + test_messages::TestSFixed64 message(-1); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }, sender); +} + +TEST_F(ProtoMessageSenderTest, partially_format_bool) +{ + test_messages::TestBool message{ true }; + services::ProtoMessageSender sender(message); + + infra::ByteOutputStream::WithStorage<1> partialStream; + sender.Fill(partialStream); + infra::StdVectorOutputStream::WithStorage stream; + sender.Fill(stream); + + EXPECT_EQ((std::array{ 1 << 3 }), partialStream.Storage()); + EXPECT_EQ((std::vector{ 1 }), stream.Storage()); +} + +TEST_F(ProtoMessageSenderTest, format_enum) +{ + test_messages::TestEnum message{ test_messages::Enumeration::val1 }; + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 1 << 3, 1 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_repeated_uint32) +{ + test_messages::TestRepeatedUInt32 message; + message.value.push_back(5); + message.value.push_back(6); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 1 << 3, 5, 1 << 3, 6 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_unbounded_repeated_uint32) +{ + test_messages::TestUnboundedRepeatedUInt32 message; + message.value.push_back(5); + message.value.push_back(6); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 1 << 3, 5, 1 << 3, 6 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_string) +{ + test_messages::TestString message{ "abcd" }; + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 10, 4, 'a', 'b', 'c', 'd' }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_std_string) +{ + test_messages::TestStdString message{ "abcd" }; + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 10, 4, 'a', 'b', 'c', 'd' }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_bytes) +{ + test_messages::TestBytes message; + message.value.push_back(5); + message.value.push_back(6); + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 10, 2, 5, 6 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_message) +{ + test_messages::TestMessageWithMessageField message{ 5 }; + services::ProtoMessageSender sender{ message }; + + ExpectFill({ 1 << 3 | 2, 2, 1 << 3, 5 }, sender); +} + +TEST_F(ProtoMessageSenderTest, format_more_nested_message) +{ + test_messages::TestMoreNestedMessage message{ { 5 }, { 10 } }; + services::ProtoMessageSender sender{ message }; + + ExpectFill({ (1 << 3) | 2, 2, 1 << 3, 5, (2 << 3) | 2, 2, 2 << 3, 10 }, sender); +} + +TEST_F(ProtoMessageSenderTest, dont_format_on_buffer_full) +{ + test_messages::TestBoolWithBytes message; + message.b.insert(message.b.begin(), 30, static_cast(5)); + message.value = true; + services::ProtoMessageSender sender{ message }; + + infra::ByteOutputStream::WithStorage<1> partialStream; + sender.Fill(partialStream); + + infra::StdVectorOutputStream::WithStorage stream; + sender.Fill(stream); + EXPECT_EQ((std::vector{ 30, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 2 << 3, 1 }), stream.Storage()); +} + +TEST_F(ProtoMessageSenderTest, format_many_repeated_uint32) +{ + test_messages::TestRepeatedUInt32 message; + message.value.insert(message.value.end(), 50, static_cast(5)); + services::ProtoMessageSender sender{ message }; + + infra::ByteOutputStream::WithStorage<10> partialStream; + sender.Fill(partialStream); + EXPECT_EQ((std::array{ 1 << 3, 5, 1 << 3, 5, 1 << 3, 5, 1 << 3, 5, 1 << 3, 5 }), partialStream.Storage()); + + infra::StdVectorOutputStream::WithStorage stream; + sender.Fill(stream); + EXPECT_EQ(infra::ConstructBin().Repeat(45, { 1 << 3, 5 }).Vector(), stream.Storage()); +} + +TEST_F(ProtoMessageSenderTest, format_many_bytes) +{ + test_messages::TestBytes message; + message.value.insert(message.value.end(), 50, static_cast(5)); + services::ProtoMessageSender sender{ message }; + + infra::ByteOutputStream::WithStorage<10> partialStream; + sender.Fill(partialStream); + EXPECT_EQ((std::array{ 10, 50, 5, 5, 5, 5, 5, 5, 5, 5 }), partialStream.Storage()); + + infra::StdVectorOutputStream::WithStorage stream; + sender.Fill(stream); + EXPECT_EQ((std::vector{ 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5 }), stream.Storage()); +} diff --git a/protobuf/echo/test_doubles/ServiceStub.cpp b/protobuf/echo/test_doubles/ServiceStub.cpp index a9216d6ee..90a16384e 100644 --- a/protobuf/echo/test_doubles/ServiceStub.cpp +++ b/protobuf/echo/test_doubles/ServiceStub.cpp @@ -19,7 +19,7 @@ namespace services switch (field.second) { case 1: - DeserializeField(services::ProtoUInt32(), parser, field, value); + DeserializeField(services::ProtoUInt32(), parser, field.first, value); break; default: if (field.first.Is()) diff --git a/protobuf/protoc_echo_plugin/ProtoCEchoPlugin.cpp b/protobuf/protoc_echo_plugin/ProtoCEchoPlugin.cpp index 493f2b7bc..3a35ff3a0 100644 --- a/protobuf/protoc_echo_plugin/ProtoCEchoPlugin.cpp +++ b/protobuf/protoc_echo_plugin/ProtoCEchoPlugin.cpp @@ -319,7 +319,7 @@ namespace application , prefix(prefix) {} - void MessageTypeMapGenerator::Run(Entities& formatter) + void MessageTypeMapGenerator::Run(Entities& formatter) const { auto typeMapNamespace = std::make_shared("detail"); @@ -331,15 +331,21 @@ namespace application { auto typeMapSpecialization = std::make_shared(MessageName() + "TypeMap"); typeMapSpecialization->TemplateSpecialization(google::protobuf::SimpleItoa(std::distance(message->fields.data(), &field))); - typeMapSpecialization->Add(std::make_shared("ProtoType", field->protoType)); + AddTypeMapProtoType(*field, *typeMapSpecialization); AddTypeMapType(*field, *typeMapSpecialization); + AddTypeMapFieldNumber(*field, *typeMapSpecialization); typeMapNamespace->Add(typeMapSpecialization); } formatter.Add(typeMapNamespace); } - void MessageTypeMapGenerator::AddTypeMapType(EchoField& field, Entities& entities) + void MessageTypeMapGenerator::AddTypeMapProtoType(const EchoField& field, Entities& entities) const + { + entities.Add(std::make_shared("ProtoType", field.protoType)); + } + + void MessageTypeMapGenerator::AddTypeMapType(const EchoField& field, Entities& entities) const { std::string result; StorageTypeVisitor visitor(result); @@ -347,6 +353,11 @@ namespace application entities.Add(std::make_shared("Type", result)); } + void MessageTypeMapGenerator::AddTypeMapFieldNumber(const EchoField& field, Entities& entities) const + { + entities.Add(std::make_shared("fieldNumber", "static const uint32_t", google::protobuf::SimpleItoa(field.number))); + } + std::string MessageTypeMapGenerator::MessageName() const { return prefix + message->name + MessageSuffix(); @@ -357,27 +368,12 @@ namespace application return ""; } - void MessageReferenceTypeMapGenerator::Run(Entities& formatter) + void MessageReferenceTypeMapGenerator::AddTypeMapProtoType(const EchoField& field, Entities& entities) const { - auto typeMapNamespace = std::make_shared("detail"); - - auto typeMapDeclaration = std::make_shared(MessageName() + "TypeMap"); - typeMapDeclaration->TemplateParameter("std::size_t fieldIndex"); - typeMapNamespace->Add(typeMapDeclaration); - - for (auto& field : message->fields) - { - auto typeMapSpecialization = std::make_shared(MessageName() + "TypeMap"); - typeMapSpecialization->TemplateSpecialization(google::protobuf::SimpleItoa(std::distance(message->fields.data(), &field))); - typeMapSpecialization->Add(std::make_shared("ProtoType", field->protoReferenceType)); - AddTypeMapType(*field, *typeMapSpecialization); - typeMapNamespace->Add(typeMapSpecialization); - } - - formatter.Add(typeMapNamespace); + entities.Add(std::make_shared("ProtoType", field.protoReferenceType)); } - void MessageReferenceTypeMapGenerator::AddTypeMapType(EchoField& field, Entities& entities) + void MessageReferenceTypeMapGenerator::AddTypeMapType(const EchoField& field, Entities& entities) const { std::string result; ParameterReferenceTypeVisitor visitor(result); @@ -476,12 +472,16 @@ namespace application { auto typeMap = std::make_shared("public"); + auto numberOfFields = std::make_shared("numberOfFields", "static const uint32_t", google::protobuf::SimpleItoa(message->fields.size())); + typeMap->Add(numberOfFields); auto protoTypeUsing = std::make_shared("ProtoType", "typename " + TypeMapName() + "::ProtoType"); protoTypeUsing->TemplateParameter("std::size_t fieldIndex"); typeMap->Add(protoTypeUsing); auto typeUsing = std::make_shared("Type", "typename " + TypeMapName() + "::Type"); typeUsing->TemplateParameter("std::size_t fieldIndex"); typeMap->Add(typeUsing); + auto fieldNumber = std::make_shared("fieldNumber", "template static const uint32_t", TypeMapName() + "::fieldNumber"); + typeMap->Add(fieldNumber); classFormatter->Add(typeMap); } @@ -493,9 +493,12 @@ namespace application for (auto& field : message->fields) { auto index = std::distance(message->fields.data(), &field); - auto function = std::make_shared("Get", "return " + field->name + ";\n", ClassName() + "::Type<" + google::protobuf::SimpleItoa(index) + ">&", 0); - function->Parameter("std::integral_constant"); - getters->Add(function); + auto functionGet = std::make_shared("Get", "return " + field->name + ";\n", ClassName() + "::Type<" + google::protobuf::SimpleItoa(index) + ">&", 0); + functionGet->Parameter("std::integral_constant"); + getters->Add(functionGet); + auto functionConstGet = std::make_shared("Get", "return " + field->name + ";\n", "const " + ClassName() + "::Type<" + google::protobuf::SimpleItoa(index) + ">&", Function::fConst); + functionConstGet->Parameter("std::integral_constant"); + getters->Add(functionConstGet); } classFormatter->Add(getters); @@ -618,7 +621,7 @@ namespace application for (auto& field : message->fields) printer.Print(R"(case $constant$: - DeserializeField($type$(), parser, field, $name$); + DeserializeField($type$(), parser, field.first, $name$); break; )", "constant", field->constantName, "type", field->protoType, "name", field->name); @@ -740,9 +743,12 @@ namespace application for (auto& field : message->fields) { auto index = std::distance(message->fields.data(), &field); - auto function = std::make_shared("Get", "return " + field->name + ";\n", ClassName() + "::Type<" + google::protobuf::SimpleItoa(index) + ">&", 0); - function->Parameter("std::integral_constant"); - getters->Add(function); + auto functionGet = std::make_shared("Get", "return " + field->name + ";\n", ClassName() + "::Type<" + google::protobuf::SimpleItoa(index) + ">&", 0); + functionGet->Parameter("std::integral_constant"); + getters->Add(functionGet); + auto functionConstGet = std::make_shared("Get", "return " + field->name + ";\n", ClassName() + "::Type<" + google::protobuf::SimpleItoa(index) + ">", Function::fConst); + functionConstGet->Parameter("std::integral_constant"); + getters->Add(functionConstGet); } classFormatter->Add(getters); diff --git a/protobuf/protoc_echo_plugin/ProtoCEchoPlugin.hpp b/protobuf/protoc_echo_plugin/ProtoCEchoPlugin.hpp index 81c3676f6..c611e71c2 100644 --- a/protobuf/protoc_echo_plugin/ProtoCEchoPlugin.hpp +++ b/protobuf/protoc_echo_plugin/ProtoCEchoPlugin.hpp @@ -55,10 +55,12 @@ namespace application MessageTypeMapGenerator& operator=(const MessageTypeMapGenerator& other) = delete; ~MessageTypeMapGenerator() = default; - void Run(Entities& formatter); + void Run(Entities& formatter) const; protected: - virtual void AddTypeMapType(EchoField& field, Entities& entities); + virtual void AddTypeMapProtoType(const EchoField& field, Entities& entities) const; + virtual void AddTypeMapType(const EchoField& field, Entities& entities) const; + void AddTypeMapFieldNumber(const EchoField& field, Entities& entities) const; std::string MessageName() const; virtual std::string MessageSuffix() const; @@ -73,10 +75,9 @@ namespace application public: using MessageTypeMapGenerator::MessageTypeMapGenerator; - void Run(Entities& formatter); - protected: - void AddTypeMapType(EchoField& field, Entities& entities) override; + void AddTypeMapProtoType(const EchoField& field, Entities& entities) const override; + void AddTypeMapType(const EchoField& field, Entities& entities) const override; std::string MessageSuffix() const override; }; diff --git a/protobuf/protoc_echo_plugin/test/CMakeLists.txt b/protobuf/protoc_echo_plugin/test/CMakeLists.txt index e301e87a6..9e2eed725 100644 --- a/protobuf/protoc_echo_plugin/test/CMakeLists.txt +++ b/protobuf/protoc_echo_plugin/test/CMakeLists.txt @@ -2,7 +2,7 @@ add_executable(protobuf.protoc_echo_plugin_test) emil_build_for(protobuf.protoc_echo_plugin_test BOOL EMIL_BUILD_TESTS) emil_add_test(protobuf.protoc_echo_plugin_test) -protocol_buffer_echo_cpp(protobuf.protoc_echo_plugin_test TestMessages.proto) +protocol_buffer_echo_cpp(protobuf.protoc_echo_plugin_test ../../echo/test/TestMessages.proto) target_sources(protobuf.protoc_echo_plugin_test PRIVATE TestCppFormatter.cpp diff --git a/protobuf/protoc_echo_plugin/test/TestProtoCEchoPlugin.cpp b/protobuf/protoc_echo_plugin/test/TestProtoCEchoPlugin.cpp index 83aa9272e..5bf5bbf8f 100644 --- a/protobuf/protoc_echo_plugin/test/TestProtoCEchoPlugin.cpp +++ b/protobuf/protoc_echo_plugin/test/TestProtoCEchoPlugin.cpp @@ -159,6 +159,28 @@ TEST(ProtoCEchoPluginTest, deserialize_bool) EXPECT_EQ(true, message.value); } +TEST(ProtoCEchoPluginTest, serialize_enum) +{ + test_messages::TestEnum message; + message.e = test_messages::Enumeration::val1; + + infra::ByteOutputStream::WithStorage<100> stream; + infra::ProtoFormatter formatter(stream); + message.Serialize(formatter); + + EXPECT_EQ((std::array{ 8, 1 }), stream.Writer().Processed()); +} + +TEST(ProtoCEchoPluginTest, deserialize_enum) +{ + std::array data{ 8, 1 }; + infra::ByteInputStream stream(data); + infra::ProtoParser parser(stream); + + test_messages::TestEnum message(parser); + EXPECT_EQ(test_messages::Enumeration::val1, message.e); +} + TEST(ProtoCEchoPluginTest, serialize_string) { test_messages::TestString message; @@ -265,6 +287,32 @@ TEST(ProtoCEchoPluginTest, deserialize_bytes) EXPECT_EQ(value, message.value); } +TEST(ProtoCEchoPluginTest, serialize_unbounded_bytes) +{ + test_messages::TestUnboundedBytes message; + message.value.push_back(5); + message.value.push_back(6); + + infra::ByteOutputStream::WithStorage<100> stream; + infra::ProtoFormatter formatter(stream); + message.Serialize(formatter); + + EXPECT_EQ((std::array{ 10, 2, 5, 6 }), stream.Writer().Processed()); +} + +TEST(ProtoCEchoPluginTest, deserialize_unbounded_bytes) +{ + std::array data{ 10, 2, 5, 6 }; + infra::ByteInputStream stream(data); + infra::ProtoParser parser(stream); + + test_messages::TestUnboundedBytes message(parser); + std::vector value; + value.push_back(5); + value.push_back(6); + EXPECT_EQ(value, message.value); +} + TEST(ProtoCEchoPluginTest, serialize_uint32) { test_messages::TestUInt32 message; @@ -334,6 +382,31 @@ TEST(ProtoCEchoPluginTest, deserialize_repeated_uint32) EXPECT_EQ(6, message.value[1]); } +TEST(ProtoCEchoPluginTest, serialize_unbounded_repeated_uint32) +{ + test_messages::TestUnboundedRepeatedUInt32 message; + message.value.push_back(5); + message.value.push_back(6); + + infra::ByteOutputStream::WithStorage<100> stream; + infra::ProtoFormatter formatter(stream); + message.Serialize(formatter); + + EXPECT_EQ((std::array{ 1 << 3, 5, 1 << 3, 6 }), stream.Writer().Processed()); +} + +TEST(ProtoCEchoPluginTest, deserialize_unbounded_repeated_uint32) +{ + std::array data{ 1 << 3, 5, 1 << 3, 6 }; + infra::ByteInputStream stream(data); + infra::ProtoParser parser(stream); + + test_messages::TestUnboundedRepeatedUInt32 message(parser); + EXPECT_EQ(2, message.value.size()); + EXPECT_EQ(5, message.value[0]); + EXPECT_EQ(6, message.value[1]); +} + TEST(ProtoCEchoPluginTest, serialize_message) { test_messages::TestMessageWithMessageField message; diff --git a/services/network/test_doubles/Certificates.cpp b/services/network/test_doubles/Certificates.cpp index 304f07a74..bffa0ff77 100644 --- a/services/network/test_doubles/Certificates.cpp +++ b/services/network/test_doubles/Certificates.cpp @@ -53,7 +53,7 @@ namespace services "-----END CERTIFICATE-----\r\n"; const char testServerKeyData[] = - "-----BEGIN RSA PRIVATE KEY-----\r\n" + "-----BEGIN RSA PRIVATE KEY-----\r\n" //NOSONAR "MIIEogIBAAKCAQEArcYgmszNxJu5pfLpploIw12iIfAyZWK6PZ1/6zlifaR1v0KL\r\n" "x08aLSQAOAvnb8jmfVzkBwpt3hYVm69gWdTJcFpRXMuZ1nIur7HEDQXtzrgvmV6v\r\n" "d8Eu7ngoGfHijP9eiuwxoYNiMmbyvTM2hYnRn5WllOnbK5OBldnh0Vawh/XErMO/\r\n" @@ -105,7 +105,7 @@ namespace services "-----END CERTIFICATE-----\r\n"; const char testClientKeyData[] = - "-----BEGIN RSA PRIVATE KEY-----\r\n" + "-----BEGIN RSA PRIVATE KEY-----\r\n" //NOSONAR "MIIEowIBAAKCAQEAqNdp/w5eM4io5Ki/amQC8FAuhhdk1H3ELm11MmkYTLXJBsSc\r\n" "RCLIdiyM3Ten/I3bZtI0Ak7yvubQnozeBZig3usBBapfOvx5adCSsZfwGjgsqLF+\r\n" "zQ2maEd15pHSRPPXBAJ++4ZpxdVkfr0r/ZRIxWJwvyxtK3N0BJfUcop1J2fa3cZx\r\n"