diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 92a7b22784fb..f22afca10bfa 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -170,7 +170,8 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM void* src = args[1]; int size = args[2]; - hexagon_user_dma_1d_sync(dst, src, size); + int error_code = hexagon_user_dma_1d_sync(dst, src, size); + CHECK_EQ(error_code, 0); *rv = static_cast(0); }); diff --git a/src/runtime/hexagon/hexagon_user_dma.cc b/src/runtime/hexagon/hexagon_user_dma.cc index 0e3fbd2048f6..8d45b7590bc4 100644 --- a/src/runtime/hexagon/hexagon_user_dma.cc +++ b/src/runtime/hexagon/hexagon_user_dma.cc @@ -17,66 +17,47 @@ * under the License. */ -#include +#include "hexagon_user_dma.h" -#include "hexagon_common.h" -#include "hexagon_user_dma_descriptors.h" -#include "hexagon_user_dma_instructions.h" -#include "hexagon_user_dma_registers.h" +#include namespace tvm { namespace runtime { namespace hexagon { -int init_hexagon_user_dma() { -#if __HEXAGON_ARCH__ >= 68 - // reset DMA engine +unsigned int HexagonUserDMA::Init() { unsigned int status = dmpause() & DM0_STATUS_MASK; - if (status != DM0_STATUS_IDLE) { - return DMA_FAILURE; - } -#endif - return DMA_SUCCESS; + return status; } -int hexagon_user_dma_1d_sync_helper(void* dst, void* src, uint32_t length) { -#if __HEXAGON_ARCH__ >= 68 - static int config_dma = init_hexagon_user_dma(); - if (config_dma != DMA_SUCCESS) { +int HexagonUserDMA::Copy(void* dst, void* src, uint32_t length) { + // length limited to 24 bits + if (length > DESC_LENGTH_MASK) { return DMA_FAILURE; } - uint64_t src64 = reinterpret_cast(src); // source address limited to 32 bits - if (src64 > DESC_SRC_MASK) { + uint64_t src64 = reinterpret_cast(src); + if (!src64 || src64 > DESC_SRC_MASK) { return DMA_FAILURE; } - uint64_t dst64 = reinterpret_cast(dst); // destination address limited to 32 bits - if (dst64 > DESC_DST_MASK) { - return DMA_FAILURE; - } - - // length limited to 24 bits - if (length > DESC_LENGTH_MASK) { + uint64_t dst64 = reinterpret_cast(dst); + if (!dst64 || dst64 > DESC_DST_MASK) { return DMA_FAILURE; } - uint32_t src32 = src64 & DESC_SRC_MASK; - uint32_t dst32 = dst64 & DESC_DST_MASK; - - void* dma_desc = nullptr; - - int ret = posix_memalign(&dma_desc, DMA_DESC_2D_SIZE, DMA_DESC_2D_SIZE); - if (ret) { - return DMA_FAILURE; - } + uint32_t src32 = static_cast(src64); + uint32_t dst32 = static_cast(dst64); + // get pointer to next descriptor + dma_desc_2d_t* dma_desc = descriptors_->Next(); if (!dma_desc) { - return DMA_FAILURE; + return DMA_RETRY; } + // populate descriptor fields dma_desc_set_state(dma_desc, DESC_STATE_READY); dma_desc_set_next(dma_desc, DMA_NULL_PTR); dma_desc_set_length(dma_desc, length); @@ -90,23 +71,60 @@ int hexagon_user_dma_1d_sync_helper(void* dst, void* src, uint32_t length) { dma_desc_set_src(dma_desc, src32); dma_desc_set_dst(dma_desc, dst32); - dmstart(dma_desc); - unsigned int status = dmwait() & DM0_STATUS_MASK; - unsigned int done = dma_desc_get_done(dma_desc); + if (first_dma_) { + // `dmstart` first descriptor + dmstart(dma_desc); + first_dma_ = false; + } else { + // `dmlink` descriptor to tail descriptor + dmlink(tail_dma_desc_, dma_desc); + } - free(dma_desc); + // update tail + tail_dma_desc_ = dma_desc; + return DMA_SUCCESS; +} - if (status == DM0_STATUS_IDLE && done == DESC_DONE_COMPLETE) { - return DMA_SUCCESS; +void HexagonUserDMA::Wait(uint32_t max_dmas_in_flight) { + // wait (forever) until max DMAs in flight <= actual DMAs in flight + while (DMAsInFlight() > max_dmas_in_flight) { } -#endif - return DMA_FAILURE; +} + +uint32_t HexagonUserDMA::Poll() { return DMAsInFlight(); } + +uint32_t HexagonUserDMA::DMAsInFlight() { + dmpoll(); // update DMA engine status + return descriptors_->InFlight(); +} + +HexagonUserDMA::HexagonUserDMA() { + // reset DMA engine + unsigned int status = Init(); + CHECK_EQ(status, DM0_STATUS_IDLE); + + auto desc_in_flight = [](dma_desc_2d_t* dma_desc) { + unsigned int done = dma_desc_get_done(dma_desc); + return (done != DESC_DONE_COMPLETE); + }; + descriptors_ = new RingBuffer(MAX_DMA_DESCRIPTORS, desc_in_flight); +} + +HexagonUserDMA::~HexagonUserDMA() { + Init(); // stop DMA engine + delete descriptors_; } int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length) { // One DMA transfer can copy at most DESC_LENGTH_MASK bytes. // Make the common case quick. - if (length <= DESC_LENGTH_MASK) return hexagon_user_dma_1d_sync_helper(dst, src, length); + if (length <= DESC_LENGTH_MASK) { + // sync DMA -> `Copy` and then `Wait(0)` + int ret_val = HexagonUserDMA::Get().Copy(dst, src, length); + if (ret_val != DMA_SUCCESS) return ret_val; + HexagonUserDMA::Get().Wait(0); + return DMA_SUCCESS; + } // Split big transfers into smaller transfers. char* cast_src = static_cast(src); @@ -114,8 +132,10 @@ int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length) { for (uint32_t i = 0; i < length;) { // Ensure there is no overflow while updating i uint32_t cur_len = std::min(length - i, DESC_LENGTH_MASK); - int ret_val = hexagon_user_dma_1d_sync_helper(&cast_dst[i], &cast_src[i], cur_len); + // sync DMA -> `Copy` and then `Wait(0)` + int ret_val = HexagonUserDMA::Get().Copy(&cast_dst[i], &cast_src[i], cur_len); if (ret_val != DMA_SUCCESS) return ret_val; + HexagonUserDMA::Get().Wait(0); // 2 cases for new val for i: // 1. length - i <= DESC_LENGTH_MASK (<= MAX_UINT) // new_i = i + (length - i) = length, no more iter diff --git a/src/runtime/hexagon/hexagon_user_dma.h b/src/runtime/hexagon/hexagon_user_dma.h new file mode 100644 index 000000000000..aa00df79c4d0 --- /dev/null +++ b/src/runtime/hexagon/hexagon_user_dma.h @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_H_ +#define TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_H_ + +#include "hexagon_common.h" +#include "hexagon_user_dma_descriptors.h" +#include "hexagon_user_dma_instructions.h" +#include "hexagon_user_dma_registers.h" +#include "ring_buffer.h" + +namespace tvm { +namespace runtime { +namespace hexagon { + +#define DMA_SUCCESS 0 +#define DMA_FAILURE -1 +#define DMA_RETRY 1 +#define MAX_DMA_DESCRIPTORS 100 + +class HexagonUserDMA { + public: + /*! + * \brief Initiate DMA to copy memory from source to destination address + * \param dst Destination address + * \param src Source address + * \param length Length in bytes to copy + * \returns Status: DMA_SUCCESS or DMA_FAILURE + */ + int Copy(void* dst, void* src, uint32_t length); + + /*! + * \brief Wait until the number of DMAs in flight is less than or equal to some maximum + * \param max_dmas_in_flight Maximum number of DMAs allowed to be in flight + * to satisfy the `Wait` e.g. use `Wait(0)` to wait on "all" outstanding DMAs to complete + */ + void Wait(uint32_t max_dmas_in_flight); + + /*! + * \brief Poll the number of DMAs in flight + * \returns Number of DMAs in flight + */ + uint32_t Poll(); + + //! \brief HexagonUserDMA uses the singleton pattern + static HexagonUserDMA& Get() { + static HexagonUserDMA* hud = new HexagonUserDMA(); + return *hud; + } + + private: + // HexagonUserDMA uses the singleton pattern + HexagonUserDMA(); + ~HexagonUserDMA(); + HexagonUserDMA(const HexagonUserDMA&) = delete; + HexagonUserDMA& operator=(const HexagonUserDMA&) = delete; + HexagonUserDMA(HexagonUserDMA&&) = delete; + HexagonUserDMA& operator=(HexagonUserDMA&&) = delete; + + //! \brief Initializes the Hexagon User DMA engine + unsigned int Init(); + + //! \brief Calculates and returns the number of DMAs in flight + uint32_t DMAsInFlight(); + + //! \brief Tracks whether the very first DMA has been executed + bool first_dma_{true}; + + //! \brief Tracks the tail DMA descriptor + void* tail_dma_desc_{nullptr}; + + //! \brief Storage for all DMA descriptors + RingBuffer* descriptors_{nullptr}; +}; + +} // namespace hexagon +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_HEXAGON_HEXAGON_USER_DMA_H_ diff --git a/src/runtime/hexagon/hexagon_user_dma_descriptors.h b/src/runtime/hexagon/hexagon_user_dma_descriptors.h index 643dbc5e8bf5..913b025df138 100644 --- a/src/runtime/hexagon/hexagon_user_dma_descriptors.h +++ b/src/runtime/hexagon/hexagon_user_dma_descriptors.h @@ -126,8 +126,6 @@ namespace hexagon { #define DESC_DSTWIDTHOFFSET_MASK 0xFFFF0000 #define DESC_DSTWIDTHOFFSET_SHIFT 16 -#define DMA_SUCCESS 0 -#define DMA_FAILURE -1 #define DMA_NULL_PTR 0 /**************************/ diff --git a/src/runtime/hexagon/hexagon_user_dma_instructions.h b/src/runtime/hexagon/hexagon_user_dma_instructions.h index e160b7395658..2345d4daaf21 100644 --- a/src/runtime/hexagon/hexagon_user_dma_instructions.h +++ b/src/runtime/hexagon/hexagon_user_dma_instructions.h @@ -24,8 +24,6 @@ namespace tvm { namespace runtime { namespace hexagon { -#if __HEXAGON_ARCH__ >= 68 - inline unsigned int dmpause() { unsigned int dm0 = 0; asm volatile(" %0 = dmpause" : "=r"(dm0)); @@ -34,6 +32,10 @@ inline unsigned int dmpause() { inline void dmstart(void* next) { asm volatile(" dmstart(%0)" : : "r"(next)); } +inline void dmlink(void* tail, void* next) { + asm volatile(" dmlink(%0, %1)" : : "r"(tail), "r"(next)); +} + inline unsigned int dmpoll() { unsigned int dm0 = 0; asm volatile(" %0 = dmpoll" : "=r"(dm0)); @@ -70,8 +72,6 @@ inline void dmcfgwr(unsigned int dmindex, unsigned int data) { asm volatile(" dmcfgwr(%0, %1)" : : "r"(dmindex), "r"(data)); } -#endif - } // namespace hexagon } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/ring_buffer.h b/src/runtime/hexagon/ring_buffer.h new file mode 100644 index 000000000000..d21b2b9953c2 --- /dev/null +++ b/src/runtime/hexagon/ring_buffer.h @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_HEXAGON_RING_BUFFER_H_ +#define TVM_RUNTIME_HEXAGON_RING_BUFFER_H_ + +#include + +#include "hexagon_common.h" + +namespace tvm { +namespace runtime { +namespace hexagon { + +template +class RingBuffer { + public: + //! \brief Returns the number of Ts in flight + uint32_t InFlight() { + while (id_oldest_ < id_next_ && !in_flight_(GetAddr(id_oldest_))) { + id_oldest_++; + } + return id_next_ - id_oldest_; + } + + //! \brief Returns pointer to next T; null if ring buffer is full + T* Next() { + if (InFlight() == ring_buff_size_) { + return nullptr; + } + T* next = GetAddr(id_next_); + id_next_++; + return next; + } + + /*! \brief Creates a ring buffer for storage items of type T + * \param ring_buff_size Size of the ring buffer in number of Ts + * \param in_flight Function that determines whether a T is in flight + */ + RingBuffer(uint32_t ring_buff_size, std::function in_flight) + : ring_buff_size_(ring_buff_size), in_flight_(in_flight) { + CHECK_NE(ring_buff_size, 0); + int ret = posix_memalign(reinterpret_cast(&ring_buff_ptr_), sizeof(T), + sizeof(T) * ring_buff_size_); + CHECK_EQ(ret, 0); + CHECK_NE(ring_buff_ptr_, nullptr); + } + + ~RingBuffer() { free(ring_buff_ptr_); } + + private: + //! \brief Returns the address of a T given its index + T* GetAddr(uint32_t id) const { + uint32_t ring_buff_index = id % ring_buff_size_; + return ring_buff_ptr_ + ring_buff_index; + } + + //! \brief Pointer to the ring buffer + T* ring_buff_ptr_{nullptr}; + + //! \brief Size of the ring buffer in number of Ts + const uint32_t ring_buff_size_; + + //! \brief Function that determines whether a T is in flight + const std::function in_flight_; + + //! \brief Tracks the ID of the next T to be added to the ring buffer + uint32_t id_next_{0}; + + //! \brief Tracks the ID of the oldest T in flight + uint32_t id_oldest_{0}; +}; + +} // namespace hexagon +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_HEXAGON_RING_BUFFER_H_ diff --git a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc new file mode 100644 index 000000000000..bf7a23712d7d --- /dev/null +++ b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../src/runtime/hexagon/hexagon_user_dma.h" + +using namespace tvm::runtime; +using namespace tvm::runtime::hexagon; + +class HexagonUserDMATest : public ::testing::Test { + void SetUp() override { + src = malloc(length); + dst = malloc(length); + ASSERT_NE(src, nullptr); + ASSERT_NE(dst, nullptr); + + src_char = static_cast(src); + dst_char = static_cast(dst); + for (uint32_t i = 0; i < length; ++i) { + src_char[i] = 1; + dst_char[i] = 0; + } + } + void TearDown() override { + free(src); + free(dst); + } + + public: + int ret{0}; + void* src{nullptr}; + void* dst{nullptr}; + char* src_char{nullptr}; + char* dst_char{nullptr}; + uint32_t length{0x4000}; // 16KB +}; + +TEST_F(HexagonUserDMATest, wait) { + HexagonUserDMA::Get().Wait(0); + HexagonUserDMA::Get().Wait(10); +} + +TEST_F(HexagonUserDMATest, poll) { ASSERT_EQ(HexagonUserDMA::Get().Poll(), 0); } + +TEST_F(HexagonUserDMATest, bad_copy) { + uint64_t bigaddr = 0x100000000; + void* src64 = reinterpret_cast(bigaddr); + void* dst64 = reinterpret_cast(bigaddr); + uint32_t biglength = 0x1000000; + ASSERT_NE(HexagonUserDMA::Get().Copy(dst64, src, length), DMA_SUCCESS); + ASSERT_NE(HexagonUserDMA::Get().Copy(dst, src64, length), DMA_SUCCESS); + ASSERT_NE(HexagonUserDMA::Get().Copy(dst, src, biglength), DMA_SUCCESS); +} + +TEST_F(HexagonUserDMATest, sync_dma) { + // kick off 1 DMA + ret = HexagonUserDMA::Get().Copy(dst, src, length); + ASSERT_EQ(ret, DMA_SUCCESS); + + // wait for DMA to complete + HexagonUserDMA::Get().Wait(0); + + // verify + for (uint32_t i = 0; i < length; ++i) { + ASSERT_EQ(src_char[i], dst_char[i]); + } +} + +TEST_F(HexagonUserDMATest, async_dma_wait) { + // kick off 10x duplicate DMAs + for (uint32_t i = 0; i < 10; ++i) { + ret = HexagonUserDMA::Get().Copy(dst, src, length); + ASSERT_EQ(ret, DMA_SUCCESS); + } + + // wait for at least 1 DMA to complete + HexagonUserDMA::Get().Wait(9); + + // verify + for (uint32_t i = 0; i < length; ++i) { + ASSERT_EQ(src_char[i], dst_char[i]); + } + + // empty the DMA queue + HexagonUserDMA::Get().Wait(0); +} + +TEST_F(HexagonUserDMATest, async_dma_poll) { + // kick off 10x duplicate DMAs + for (uint32_t i = 0; i < 10; ++i) { + ret = HexagonUserDMA::Get().Copy(dst, src, length); + ASSERT_EQ(ret, DMA_SUCCESS); + } + + // poll until at least 1 DMA is complete + while (HexagonUserDMA::Get().Poll() == 10) { + }; + + // verify + for (uint32_t i = 0; i < length; ++i) { + ASSERT_EQ(src_char[i], dst_char[i]); + } + + // empty the DMA queue + HexagonUserDMA::Get().Wait(0); +} + +// TODO: Run non-pipelined case with sync DMA and execution time vs. pipelined case +TEST_F(HexagonUserDMATest, pipeline) { + uint32_t pipeline_depth = 4; + uint32_t pipeline_length = length / pipeline_depth; + + for (uint32_t i = 0; i < pipeline_depth; ++i) { + ret |= HexagonUserDMA::Get().Copy(dst_char + i * pipeline_length, + src_char + i * pipeline_length, pipeline_length); + } + + HexagonUserDMA::Get().Wait(3); + for (uint32_t i = 0; i < pipeline_length; ++i) { + dst_char[i]++; + } + + HexagonUserDMA::Get().Wait(2); + for (uint32_t i = pipeline_length; i < 2 * pipeline_length; ++i) { + dst_char[i]++; + } + + HexagonUserDMA::Get().Wait(1); + for (uint32_t i = 2 * pipeline_length; i < 3 * pipeline_length; ++i) { + dst_char[i]++; + } + + HexagonUserDMA::Get().Wait(0); + for (uint32_t i = 3 * pipeline_length; i < 4 * pipeline_length; ++i) { + dst_char[i]++; + } + + // verify + ASSERT_EQ(ret, DMA_SUCCESS); + for (uint32_t i = 0; i < length; ++i) { + ASSERT_EQ(2, dst_char[i]); + } +} + +TEST_F(HexagonUserDMATest, overflow_ring_buffer) { + uint32_t number_of_dmas = 0x400; // 1k + uint32_t length_of_each_dma = length / number_of_dmas; + + for (uint32_t i = 0; i < number_of_dmas; ++i) { + do { + ret = HexagonUserDMA::Get().Copy(dst_char + i * length_of_each_dma, + src_char + i * length_of_each_dma, length_of_each_dma); + } while (ret == DMA_RETRY); + ASSERT_EQ(ret, DMA_SUCCESS); + } + + // verify + for (uint32_t i = 0; i < length; ++i) { + ASSERT_EQ(src_char[i], dst_char[i]); + } +} \ No newline at end of file diff --git a/tests/cpp-runtime/hexagon/ring_buffer_tests.cc b/tests/cpp-runtime/hexagon/ring_buffer_tests.cc new file mode 100644 index 000000000000..cd40dca87b02 --- /dev/null +++ b/tests/cpp-runtime/hexagon/ring_buffer_tests.cc @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../src/runtime/hexagon/ring_buffer.h" + +using namespace tvm::runtime; +using namespace tvm::runtime::hexagon; + +class RingBufferTest : public ::testing::Test { + void SetUp() override { ring_buff = new RingBuffer(size, in_flight); } + void TearDown() override { delete ring_buff; } + + public: + std::function in_flight = [](int* ptr) { + if (*ptr == 42) { + // finished + return false; + } + // in flight + return true; + }; + + int finished = 42; + int inflight = 43; + uint32_t size = 4; + uint32_t half = size / 2; + RingBuffer* ring_buff; +}; + +TEST_F(RingBufferTest, zero_size_ring_buffer) { + ASSERT_THROW(RingBuffer(0, in_flight), InternalError); +} + +TEST_F(RingBufferTest, in_flight) { ASSERT_EQ(ring_buff->InFlight(), 0); } + +TEST_F(RingBufferTest, next) { + // get pointer to first item + int* ptr = ring_buff->Next(); + ASSERT_NE(ptr, nullptr); + + // mark it in flight and check + *ptr = inflight; + ASSERT_EQ(ring_buff->InFlight(), 1); + + // mark it finished and check + *ptr = finished; + ASSERT_EQ(ring_buff->InFlight(), 0); +} + +TEST_F(RingBufferTest, full) { + // fill the ring buffer + for (int i = 0; i < size; ++i) { + int* ptr = ring_buff->Next(); + ASSERT_NE(ptr, nullptr); + + // mark in flight and check + *ptr = inflight; + ASSERT_EQ(ring_buff->InFlight(), i + 1); + } + + // check that the ring buffer is full + ASSERT_EQ(ring_buff->Next(), nullptr); + ASSERT_EQ(ring_buff->InFlight(), size); +} + +TEST_F(RingBufferTest, wrap) { + // fill the ring buffer, but mark each finished + bool first = true; + int* firstptr = nullptr; + for (int i = 0; i < size; ++i) { + int* ptr = ring_buff->Next(); + ASSERT_NE(ptr, nullptr); + + // save first ptr for later comparison + if (first) { + firstptr = ptr; + first = false; + } + + // mark finished and check + *ptr = finished; + ASSERT_EQ(ring_buff->InFlight(), 0); + } + + // reuse the first ring buffer entry + int* ptr = ring_buff->Next(); + ASSERT_EQ(ptr, firstptr); + + // mark it in flight and check + *ptr = inflight; + ASSERT_EQ(ring_buff->InFlight(), 1); + + // mark it finished and check + *ptr = finished; + ASSERT_EQ(ring_buff->InFlight(), 0); +} + +TEST_F(RingBufferTest, wrap_corner) { + for (int i = 0; i < size; ++i) { + int* ptr = ring_buff->Next(); + *ptr = finished; + } + + // reuse the first ring buffer entry + int* ptr = ring_buff->Next(); + ASSERT_NE(ptr, nullptr); + + // user must mark the item "inflight" before checking in flight count + // here the "finished" status is inherited from the reused ring buffer entry + // thus the in flight count is zero instead one; which the user might expect + ASSERT_EQ(ring_buff->InFlight(), 0); + + // marking the item "inflight" after checking the in flight count + // will not change the outcome; the ring buffer considers the item "finished" + *ptr = inflight; + ASSERT_EQ(ring_buff->InFlight(), 0); +} + +TEST_F(RingBufferTest, half_in_flight) { + // these will complete + for (int i = 0; i < half; ++i) { + int* ptr = ring_buff->Next(); + ASSERT_NE(ptr, nullptr); + *ptr = finished; + ASSERT_EQ(ring_buff->InFlight(), 0); + } + + // these will not complete + for (int i = 0; i < half; ++i) { + int* ptr = ring_buff->Next(); + ASSERT_NE(ptr, nullptr); + *ptr = inflight; + ASSERT_EQ(ring_buff->InFlight(), i + 1); + } + + // check half in flight + ASSERT_EQ(ring_buff->InFlight(), half); + + // get pointer to next item + int* ptr = ring_buff->Next(); + ASSERT_NE(ptr, nullptr); + + // mark it inflight and check + *ptr = inflight; + ASSERT_EQ(ring_buff->InFlight(), 3); + + // mark it finished and check also blocked + *ptr = finished; + ASSERT_EQ(ring_buff->InFlight(), 3); +} + +TEST_F(RingBufferTest, half_in_flight_blocked) { + // these will not complete + for (int i = 0; i < half; ++i) { + int* ptr = ring_buff->Next(); + ASSERT_NE(ptr, nullptr); + *ptr = inflight; + ASSERT_EQ(ring_buff->InFlight(), i + 1); + } + + // these would complete, but they are blocked + for (int i = half; i < size; ++i) { + int* ptr = ring_buff->Next(); + ASSERT_NE(ptr, nullptr); + *ptr = finished; + ASSERT_EQ(ring_buff->InFlight(), i + 1); + } + + // check that the ring buffer is full + ASSERT_EQ(ring_buff->Next(), nullptr); + ASSERT_EQ(ring_buff->InFlight(), size); +}