diff --git a/Pcap++/header/PcapLiveDevice.h b/Pcap++/header/PcapLiveDevice.h index 5c26aa4baf..b68290b8b0 100644 --- a/Pcap++/header/PcapLiveDevice.h +++ b/Pcap++/header/PcapLiveDevice.h @@ -103,8 +103,6 @@ namespace pcpp // Should be set to true by the Callee for the Caller std::atomic m_CaptureThreadStarted; - OnPacketArrivesStopBlocking m_cbOnPacketArrivesBlockingMode; - void* m_cbOnPacketArrivesBlockingModeUserCookie; LinkLayerType m_LinkType; bool m_UsePoll; @@ -120,8 +118,6 @@ namespace pcpp void setDeviceMacAddress(); void setDefaultGateway(); - static void onPacketArrivesBlockingMode(uint8_t* user, const struct pcap_pkthdr* pkthdr, const uint8_t* packet); - public: /// The type of the live device enum LiveDeviceType diff --git a/Pcap++/src/PcapLiveDevice.cpp b/Pcap++/src/PcapLiveDevice.cpp index 89f437cf4e..4929f3151b 100644 --- a/Pcap++/src/PcapLiveDevice.cpp +++ b/Pcap++/src/PcapLiveDevice.cpp @@ -286,6 +286,14 @@ namespace pcpp RawPacketVector* capturedPackets = nullptr; }; + struct CaptureContextWithCancellation + { + PcapLiveDevice* device = nullptr; + OnPacketArrivesStopBlocking callback; + void* userCookie = nullptr; + bool requestStop = false; + }; + // A noop function to be used when no callback is set void onPacketArrivesNoop(uint8_t* user, const pcap_pkthdr* pkthdr, const uint8_t* packet) {} @@ -310,6 +318,46 @@ namespace pcpp context->callback(&rawPacket, context->device, context->userCookie); } + // @brief Wraps the raw packet data into a RawPacket instance and calls the user callback with stop indication + void onPacketArrivesCallbackWithCancellation(uint8_t* user, const pcap_pkthdr* pkthdr, const uint8_t* packet) + { + auto* context = reinterpret_cast(user); + if (context == nullptr || context->device == nullptr || context->callback == nullptr) + { + PCPP_LOG_ERROR("Unable to extract PcapLiveDevice instance or callback"); + return; + } + + if (context->requestStop) + { + // If requestStop is true, there is no need to process the packet + PCPP_LOG_DEBUG("Capture request stop is set, skipping packet processing"); + return; + } + + RawPacket rawPacket(packet, pkthdr->caplen, pkthdr->ts, false, context->device->getLinkType()); + + try + { + if (context->callback(&rawPacket, context->device, context->userCookie)) + { + // If the callback returns true, it means that the user wants to stop the capture + PCPP_LOG_DEBUG("Capture callback requested to stop capturing"); + context->requestStop = true; + } + } + catch (const std::exception& ex) + { + PCPP_LOG_ERROR("Exception occurred while invoking packet arrival callback: " << ex.what()); + context->requestStop = true; // Stop capture on exception + } + catch (...) + { + PCPP_LOG_ERROR("Unknown exception occurred while invoking packet arrival callback"); + context->requestStop = true; // Stop capture on unknown exception + } + } + /// @brief Wraps the raw packet data into a RawPacket instance and adds it to the captured packets vector /// @param user A pointer to an AccumulatorCaptureContext instance /// @param pkthdr A pointer to the pcap_pkthdr struct @@ -429,8 +477,6 @@ namespace pcpp m_CaptureThreadStarted = false; m_StopThread = false; m_CaptureThread = {}; - m_cbOnPacketArrivesBlockingMode = nullptr; - m_cbOnPacketArrivesBlockingModeUserCookie = nullptr; if (calculateMacAddress) { setDeviceMacAddress(); @@ -438,24 +484,6 @@ namespace pcpp } } - void PcapLiveDevice::onPacketArrivesBlockingMode(uint8_t* user, const struct pcap_pkthdr* pkthdr, - const uint8_t* packet) - { - PcapLiveDevice* pThis = reinterpret_cast(user); - if (pThis == nullptr) - { - PCPP_LOG_ERROR("Unable to extract PcapLiveDevice instance"); - return; - } - - RawPacket rawPacket(packet, pkthdr->caplen, pkthdr->ts, false, pThis->getLinkType()); - - if (pThis->m_cbOnPacketArrivesBlockingMode != nullptr) - if (pThis->m_cbOnPacketArrivesBlockingMode(&rawPacket, pThis, - pThis->m_cbOnPacketArrivesBlockingModeUserCookie)) - pThis->m_StopThread = true; - } - internal::PcapHandle PcapLiveDevice::doOpen(const DeviceConfiguration& config) { char errbuf[PCAP_ERRBUF_SIZE] = { '\0' }; @@ -813,14 +841,15 @@ namespace pcpp return 0; } - m_cbOnPacketArrivesBlockingMode = std::move(onPacketArrives); - m_cbOnPacketArrivesBlockingModeUserCookie = userCookie; - m_CaptureThreadStarted = true; m_StopThread = false; - const int64_t timeoutMs = timeout * 1000; // timeout unit is seconds, let's change it to milliseconds + // A valid timeout is only generated when timeout is positive. + // This means that the timeout timepoint should be after the start time. + const bool hasTimeout = timeout > 0; auto startTime = std::chrono::steady_clock::now(); + // Calculate the timeout timepoint, cast the double timeout (in seconds) to milliseconds for greater precision + auto timeoutTime = startTime + std::chrono::milliseconds(static_cast(timeout * 1000)); auto currentTime = startTime; #if !defined(_WIN32) @@ -832,30 +861,40 @@ namespace pcpp bool shouldReturnError = false; - if (timeoutMs <= 0) + CaptureContextWithCancellation context; + context.device = this; + context.callback = std::move(onPacketArrives); + context.userCookie = userCookie; + context.requestStop = false; + + // No timeout specified, run until stopped + if (!hasTimeout) { while (!m_StopThread) { - if (pcap_dispatch(m_PcapDescriptor.get(), -1, onPacketArrivesBlockingMode, - reinterpret_cast(this)) == -1) + if (pcap_dispatch(m_PcapDescriptor.get(), -1, onPacketArrivesCallbackWithCancellation, + reinterpret_cast(&context)) == -1) { PCPP_LOG_ERROR("pcap_dispatch returned an error: " << m_PcapDescriptor.getLastError()); shouldReturnError = true; m_StopThread = true; } + else if (context.requestStop) + { + // If the callback requested to stop the capture, we break the loop + m_StopThread = true; + } } } else { - auto const timeoutTimepoint = startTime + std::chrono::milliseconds(timeoutMs); - - while (!m_StopThread && currentTime < timeoutTimepoint) + while (!m_StopThread && currentTime < timeoutTime) { if (m_UsePoll) { #if !defined(_WIN32) int64_t pollTimeoutMs = - std::chrono::duration_cast(timeoutTimepoint - currentTime).count(); + std::chrono::duration_cast(timeoutTime - currentTime).count(); // poll will be in blocking mode if negative value pollTimeoutMs = std::max(pollTimeoutMs, static_cast(0)); @@ -864,13 +903,18 @@ namespace pcpp if (ready > 0) { - if (pcap_dispatch(m_PcapDescriptor.get(), -1, onPacketArrivesBlockingMode, - reinterpret_cast(this)) == -1) + if (pcap_dispatch(m_PcapDescriptor.get(), -1, onPacketArrivesCallbackWithCancellation, + reinterpret_cast(&context)) == -1) { PCPP_LOG_ERROR("pcap_dispatch returned an error: " << m_PcapDescriptor.getLastError()); shouldReturnError = true; m_StopThread = true; } + else if (context.requestStop) + { + // If the callback requested to stop the capture, we break the loop + m_StopThread = true; + } } else if (ready < 0) { @@ -886,13 +930,18 @@ namespace pcpp } else { - if (pcap_dispatch(m_PcapDescriptor.get(), -1, onPacketArrivesBlockingMode, - reinterpret_cast(this)) == -1) + if (pcap_dispatch(m_PcapDescriptor.get(), -1, onPacketArrivesCallbackWithCancellation, + reinterpret_cast(&context)) == -1) { PCPP_LOG_ERROR("pcap_dispatch returned an error: " << m_PcapDescriptor.getLastError()); shouldReturnError = true; m_StopThread = true; } + else if (context.requestStop) + { + // If the callback requested to stop the capture, we break the loop + m_StopThread = true; + } } currentTime = std::chrono::steady_clock::now(); } @@ -900,26 +949,27 @@ namespace pcpp m_CaptureThreadStarted = false; m_StopThread = false; - m_cbOnPacketArrivesBlockingMode = nullptr; - m_cbOnPacketArrivesBlockingModeUserCookie = nullptr; if (shouldReturnError) { return 0; } - if (std::chrono::duration_cast(currentTime - startTime).count() >= timeoutMs) + // Check the time only if a valid timeout was specified. Otherwise it would always be true. + if (hasTimeout && currentTime >= timeoutTime) { - return -1; + return -1; // If we are past the timeout time, return -1 } return 1; } void PcapLiveDevice::stopCapture() { - // in blocking mode stop capture isn't relevant - if (m_cbOnPacketArrivesBlockingMode != nullptr) + // In blocking mode, there is no capture thread, so we don't need to stop it + if (!m_CaptureThread.joinable()) + { return; + } if (m_CaptureThread.get_id() != std::thread::id{} && m_CaptureThread.get_id() == std::this_thread::get_id()) {