diff --git a/HeterogeneousCore/SonicCore/BuildFile.xml b/HeterogeneousCore/SonicCore/BuildFile.xml index b0d5e2a08b98f..5208c91638f37 100644 --- a/HeterogeneousCore/SonicCore/BuildFile.xml +++ b/HeterogeneousCore/SonicCore/BuildFile.xml @@ -2,6 +2,7 @@ + diff --git a/HeterogeneousCore/SonicCore/interface/RetryActionBase.h b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h new file mode 100644 index 0000000000000..e3fc0bbb8af9a --- /dev/null +++ b/HeterogeneousCore/SonicCore/interface/RetryActionBase.h @@ -0,0 +1,35 @@ +#ifndef HeterogeneousCore_SonicCore_RetryActionBase +#define HeterogeneousCore_SonicCore_RetryActionBase + +#include "FWCore/PluginManager/interface/PluginFactory.h" +#include "FWCore/ParameterSet/interface/ParameterSet.h" +#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" +#include +#include + +// Base class for retry actions +class RetryActionBase { +public: + RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client); + virtual ~RetryActionBase() = default; + + bool shouldRetry() const { return shouldRetry_; } // Getter for shouldRetry_ + + virtual void retry() = 0; // Pure virtual function for execution logic + virtual void start() = 0; // Pure virtual function for execution logic for initialization + +protected: + void eval(); // interface for calling evaluate in client + +protected: + SonicClientBase* client_; + bool shouldRetry_; // Flag to track if further retries should happen +}; + +// Define the factory for creating retry actions +using RetryActionFactory = + edmplugin::PluginFactory; + +#endif + +#define DEFINE_RETRY_ACTION(type) DEFINE_EDM_PLUGIN(RetryActionFactory, type, #type); diff --git a/HeterogeneousCore/SonicCore/interface/SonicClientBase.h b/HeterogeneousCore/SonicCore/interface/SonicClientBase.h index 47caaae8b2052..45a089701ed12 100644 --- a/HeterogeneousCore/SonicCore/interface/SonicClientBase.h +++ b/HeterogeneousCore/SonicCore/interface/SonicClientBase.h @@ -9,12 +9,15 @@ #include "HeterogeneousCore/SonicCore/interface/SonicDispatcherPseudoAsync.h" #include +#include #include #include #include enum class SonicMode { Sync = 1, Async = 2, PseudoAsync = 3 }; +class RetryActionBase; + class SonicClientBase { public: //constructor @@ -54,14 +57,23 @@ class SonicClientBase { SonicMode mode_; bool verbose_; std::unique_ptr dispatcher_; - unsigned allowedTries_, tries_; + unsigned totalTries_; std::optional holder_; + // Use a unique_ptr with a custom deleter to avoid incomplete type issues + struct RetryDeleter { + void operator()(RetryActionBase* ptr) const; + }; + + using RetryActionPtr = std::unique_ptr; + std::vector retryActions_; + //for logging/debugging std::string debugName_, clientName_, fullDebugName_; friend class SonicDispatcher; friend class SonicDispatcherPseudoAsync; + friend class RetryActionBase; }; #endif diff --git a/HeterogeneousCore/SonicCore/plugins/BuildFile.xml b/HeterogeneousCore/SonicCore/plugins/BuildFile.xml new file mode 100644 index 0000000000000..eaff0919e46bc --- /dev/null +++ b/HeterogeneousCore/SonicCore/plugins/BuildFile.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc b/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc new file mode 100644 index 0000000000000..9877013b93d5b --- /dev/null +++ b/HeterogeneousCore/SonicCore/plugins/RetrySameServerAction.cc @@ -0,0 +1,30 @@ +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" +#include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" + +class RetrySameServerAction : public RetryActionBase { +public: + RetrySameServerAction(const edm::ParameterSet& pset, SonicClientBase* client) + : RetryActionBase(pset, client), allowedTries_(pset.getUntrackedParameter("allowedTries", 0)) {} + + void start() override { tries_ = 0; }; + +protected: + void retry() override; + +private: + unsigned allowedTries_, tries_; +}; + +void RetrySameServerAction::retry() { + ++tries_; + //if max retries has not been exceeded, call evaluate again + if (tries_ < allowedTries_) { + eval(); + return; + } else { + shouldRetry_ = false; // Flip flag when max retries are reached + edm::LogInfo("RetrySameServerAction") << "Max retry attempts reached. No further retries."; + } +} + +DEFINE_RETRY_ACTION(RetrySameServerAction) diff --git a/HeterogeneousCore/SonicCore/src/RetryActionBase.cc b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc new file mode 100644 index 0000000000000..41b9a6186da2b --- /dev/null +++ b/HeterogeneousCore/SonicCore/src/RetryActionBase.cc @@ -0,0 +1,15 @@ +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" + +// Constructor implementation +RetryActionBase::RetryActionBase(const edm::ParameterSet& conf, SonicClientBase* client) + : client_(client), shouldRetry_(true) {} + +void RetryActionBase::eval() { + if (client_) { + client_->evaluate(); + } else { + edm::LogError("RetryActionBase") << "Client pointer is null, cannot evaluate."; + } +} + +EDM_REGISTER_PLUGINFACTORY(RetryActionFactory, "RetryActionFactory"); diff --git a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc index 745c51f17aaf3..739e7e6fe7913 100644 --- a/HeterogeneousCore/SonicCore/src/SonicClientBase.cc +++ b/HeterogeneousCore/SonicCore/src/SonicClientBase.cc @@ -1,18 +1,34 @@ #include "HeterogeneousCore/SonicCore/interface/SonicClientBase.h" +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" #include "FWCore/Utilities/interface/Exception.h" #include "FWCore/ParameterSet/interface/allowedValues.h" +// Custom deleter implementation +void SonicClientBase::RetryDeleter::operator()(RetryActionBase* ptr) const { delete ptr; } + SonicClientBase::SonicClientBase(const edm::ParameterSet& params, const std::string& debugName, const std::string& clientName) - : allowedTries_(params.getUntrackedParameter("allowedTries", 0)), - debugName_(debugName), - clientName_(clientName), - fullDebugName_(debugName_) { + : debugName_(debugName), clientName_(clientName), fullDebugName_(debugName_) { if (!clientName_.empty()) fullDebugName_ += ":" + clientName_; + const auto& retryPSetList = params.getParameter>("Retry"); std::string modeName(params.getParameter("mode")); + + for (const auto& retryPSet : retryPSetList) { + const std::string& actionType = retryPSet.getParameter("retryType"); + + auto retryAction = RetryActionFactory::get()->create(actionType, retryPSet, this); + if (retryAction) { + //Convert to RetryActionPtr Type from raw pointer of retryAction + retryActions_.emplace_back(RetryActionPtr(retryAction.release())); + } else { + throw cms::Exception("Configuration") + << "Unknown Retry type " << actionType << " for SonicClient: " << fullDebugName_; + } + } + if (modeName == "Sync") setMode(SonicMode::Sync); else if (modeName == "Async") @@ -40,24 +56,30 @@ void SonicClientBase::start(edm::WaitingTaskWithArenaHolder holder) { holder_ = std::move(holder); } -void SonicClientBase::start() { tries_ = 0; } +void SonicClientBase::start() { + totalTries_ = 0; + // initialize all actions + for (auto& action : retryActions_) { + action->start(); + } +} void SonicClientBase::finish(bool success, std::exception_ptr eptr) { //retries are only allowed if no exception was raised if (!success and !eptr) { - ++tries_; - //if max retries has not been exceeded, call evaluate again - if (tries_ < allowedTries_) { - evaluate(); - //avoid calling doneWaiting() twice - return; - } - //prepare an exception if exceeded - else { - edm::Exception ex(edm::errors::ExternalFailure); - ex << "SonicCallFailed: call failed after max " << tries_ << " tries"; - eptr = make_exception_ptr(ex); + ++totalTries_; + for (const auto& action : retryActions_) { + if (action->shouldRetry()) { + action->retry(); // Call retry only if shouldRetry_ is true + return; + } } + //prepare an exception if no more retries left + edm::LogInfo("SonicClientBase") << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ + << " tries."; + edm::Exception ex(edm::errors::ExternalFailure); + ex << "SonicCallFailed: call failed, no retry actions available after " << totalTries_ << " tries."; + eptr = make_exception_ptr(ex); } if (holder_) { holder_->doneWaiting(eptr); @@ -74,7 +96,20 @@ void SonicClientBase::fillBasePSetDescription(edm::ParameterSetDescription& desc //restrict allowed values desc.ifValue(edm::ParameterDescription("mode", "PseudoAsync", true), edm::allowedValues("Sync", "Async", "PseudoAsync")); - if (allowRetry) - desc.addUntracked("allowedTries", 0); + if (allowRetry) { + // Defines the structure of each entry in the VPSet + edm::ParameterSetDescription retryDesc; + retryDesc.add("retryType", "RetrySameServerAction"); + retryDesc.addUntracked("allowedTries", 0); + + // Define a default retry action + edm::ParameterSet defaultRetry; + defaultRetry.addParameter("retryType", "RetrySameServerAction"); + defaultRetry.addUntrackedParameter("allowedTries", 0); + + // Add the VPSet with the default retry action + desc.addVPSet("Retry", retryDesc, {defaultRetry}); + } + desc.add("sonicClientBase", desc); desc.addUntracked("verbose", false); } diff --git a/HeterogeneousCore/SonicCore/test/DummyClient.h b/HeterogeneousCore/SonicCore/test/DummyClient.h index ccef888ad9f7d..6504843926c0a 100644 --- a/HeterogeneousCore/SonicCore/test/DummyClient.h +++ b/HeterogeneousCore/SonicCore/test/DummyClient.h @@ -36,7 +36,7 @@ class DummyClient : public SonicClient { this->output_ = this->input_ * factor_; //simulate a failure - if (this->tries_ < fails_) + if (this->totalTries_ < fails_) this->finish(false); else this->finish(true); diff --git a/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py b/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py index 11c23c6cdfcc9..35cf42fa2b5ae 100644 --- a/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py +++ b/HeterogeneousCore/SonicCore/test/sonicTestAna_cfg.py @@ -16,16 +16,27 @@ mode = cms.string("Sync"), factor = cms.int32(-1), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0), + ) + ) ), ) process.dummySyncAnaRetry = process.dummySyncAna.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(2), + ) + ) + ) ) diff --git a/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py b/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py index 614297d86e3bb..bcbe820030440 100644 --- a/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py +++ b/HeterogeneousCore/SonicCore/test/sonicTest_cfg.py @@ -17,15 +17,19 @@ process.options.numberOfThreads = 2 process.options.numberOfStreams = 0 - process.dummySync = _moduleClass(_moduleName, input = cms.int32(1), Client = cms.PSet( mode = cms.string("Sync"), factor = cms.int32(-1), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0) + ) + ) ), ) @@ -35,8 +39,14 @@ mode = cms.string("PseudoAsync"), factor = cms.int32(2), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0) + ) + ) + ), ) @@ -46,32 +56,53 @@ mode = cms.string("Async"), factor = cms.int32(5), wait = cms.int32(10), - allowedTries = cms.untracked.uint32(0), fails = cms.uint32(0), + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(0) + ) + ) ), ) process.dummySyncRetry = process.dummySync.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(2) + ) + ) + ) ) process.dummyPseudoAsyncRetry = process.dummyPseudoAsync.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(2) + ) + ) ) ) process.dummyAsyncRetry = process.dummyAsync.clone( Client = dict( wait = 2, - allowedTries = 2, fails = 1, + Retry = cms.VPSet( + cms.PSet( + allowedTries = cms.untracked.uint32(2), + retryType = cms.string('RetrySameServerAction') + ) + ) ) ) diff --git a/HeterogeneousCore/SonicTriton/BuildFile.xml b/HeterogeneousCore/SonicTriton/BuildFile.xml index b93d51e711e87..4af38d69d89e9 100644 --- a/HeterogeneousCore/SonicTriton/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/BuildFile.xml @@ -10,6 +10,7 @@ + - + diff --git a/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h new file mode 100644 index 0000000000000..e992e9631f92c --- /dev/null +++ b/HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h @@ -0,0 +1,29 @@ +#ifndef HeterogeneousCore_SonicTriton_RetryActionDiffServer_h +#define HeterogeneousCore_SonicTriton_RetryActionDiffServer_h + +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" + +/** + * @class RetryActionDiffServer + * @brief A concrete implementation of RetryActionBase that attempts to retry an inference + * request on a different Triton server. + * + * This class provides a fallback mechanism. If an initial inference request fails + * (e.g., due to server unavailability or a model-specific error), this action will be + * triggered. It queries the central TritonService to select an alternative server (e.g., + * the fallback server when available) and instructs the TritonClient to reconnect to + * that server for the retry attempt. This action is designed for one-time use per + * inference call; after the retry attempt, it disables itself until the next `start()` + * call. + */ + +class RetryActionDiffServer : public RetryActionBase { +public: + RetryActionDiffServer(const edm::ParameterSet& conf, SonicClientBase* client); + ~RetryActionDiffServer() override = default; + + void retry() override; + void start() override; +}; + +#endif diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h index df8f9b559427c..2dd6205442fe1 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h @@ -1,102 +1,110 @@ -#ifndef HeterogeneousCore_SonicTriton_TritonClient -#define HeterogeneousCore_SonicTriton_TritonClient - -#include "FWCore/ParameterSet/interface/ParameterSet.h" -#include "FWCore/ParameterSet/interface/ParameterSetDescription.h" -#include "FWCore/ServiceRegistry/interface/ServiceToken.h" -#include "HeterogeneousCore/SonicCore/interface/SonicClient.h" -#include "HeterogeneousCore/SonicTriton/interface/TritonData.h" -#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" - -#include -#include -#include -#include -#include - -#include "grpc_client.h" -#include "grpc_service.pb.h" - -enum class TritonBatchMode { Rectangular = 1, Ragged = 2 }; - -class TritonClient : public SonicClient { -public: - struct ServerSideStats { - uint64_t inference_count_; - uint64_t execution_count_; - uint64_t success_count_; - uint64_t cumm_time_ns_; - uint64_t queue_time_ns_; - uint64_t compute_input_time_ns_; - uint64_t compute_infer_time_ns_; - uint64_t compute_output_time_ns_; - }; - - //constructor - TritonClient(const edm::ParameterSet& params, const std::string& debugName); - - //destructor - ~TritonClient() override; - - //accessors - unsigned batchSize() const; - TritonBatchMode batchMode() const { return batchMode_; } - bool verbose() const { return verbose_; } - bool useSharedMemory() const { return useSharedMemory_; } - void setUseSharedMemory(bool useShm) { useSharedMemory_ = useShm; } - bool setBatchSize(unsigned bsize); - void setBatchMode(TritonBatchMode batchMode); - void resetBatchMode(); - void reset() override; - TritonServerType serverType() const { return serverType_; } - bool isLocal() const { return isLocal_; } - - //for fillDescriptions - static void fillPSetDescription(edm::ParameterSetDescription& iDesc); - -protected: - //helpers - bool noOuterDim() const { return noOuterDim_; } - unsigned outerDim() const { return outerDim_; } - unsigned nEntries() const; - void getResults(const std::vector>& results); - void evaluate() override; - template - bool handle_exception(F&& call); - - void reportServerSideStats(const ServerSideStats& stats) const; - ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status, - const inference::ModelStatistics& end_status) const; - - inference::ModelStatistics getServerSideStatus() const; - - //members - unsigned maxOuterDim_; - unsigned outerDim_; - bool noOuterDim_; - unsigned nEntries_; - TritonBatchMode batchMode_; - bool manualBatchMode_; - bool verbose_; - bool useSharedMemory_; - TritonServerType serverType_; - bool isLocal_; - grpc_compression_algorithm compressionAlgo_; - triton::client::Headers headers_; - - std::unique_ptr client_; - //stores timeout, model name and version - std::vector options_; - edm::ServiceToken token_; - -private: - friend TritonInputData; - friend TritonOutputData; - - //private accessors only used by data - auto client() { return client_.get(); } - void addEntry(unsigned entry); - void resizeEntries(unsigned entry); -}; - -#endif +#ifndef HeterogeneousCore_SonicTriton_TritonClient +#define HeterogeneousCore_SonicTriton_TritonClient + +#include "FWCore/ParameterSet/interface/ParameterSet.h" +#include "FWCore/ParameterSet/interface/ParameterSetDescription.h" +#include "FWCore/ServiceRegistry/interface/ServiceToken.h" +#include "HeterogeneousCore/SonicCore/interface/SonicClient.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonData.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" + +#include +#include +#include +#include +#include + +#include "grpc_client.h" +#include "grpc_service.pb.h" + +enum class TritonBatchMode { Rectangular = 1, Ragged = 2 }; + +class TritonClient : public SonicClient { +public: + struct ServerSideStats { + uint64_t inference_count_; + uint64_t execution_count_; + uint64_t success_count_; + uint64_t cumm_time_ns_; + uint64_t queue_time_ns_; + uint64_t compute_input_time_ns_; + uint64_t compute_infer_time_ns_; + uint64_t compute_output_time_ns_; + }; + + //constructor + TritonClient(const edm::ParameterSet& params, const std::string& debugName); + + //destructor + ~TritonClient() override; + + //accessors + unsigned batchSize() const; + TritonBatchMode batchMode() const { return batchMode_; } + bool verbose() const { return verbose_; } + bool useSharedMemory() const { return useSharedMemory_; } + void setUseSharedMemory(bool useShm) { useSharedMemory_ = useShm; } + bool setBatchSize(unsigned bsize); + void setBatchMode(TritonBatchMode batchMode); + void resetBatchMode(); + void reset() override; + TritonServerType serverType() const { return serverType_; } + bool isLocal() const { return isLocal_; } + std::string modelName() const { return options_[0].model_name_; } + std::string serverName() const { return serverName_; } + virtual void connectToServer(const std::string& url); + virtual void updateServer(const std::string& serverName); + + //for fillDescriptions + static void fillPSetDescription(edm::ParameterSetDescription& iDesc); + +protected: + // Protected default constructor for unit testing (no framework services) + TritonClient(); + + //helpers + bool noOuterDim() const { return noOuterDim_; } + unsigned outerDim() const { return outerDim_; } + unsigned nEntries() const; + void getResults(const std::vector>& results); + virtual void evaluate() override; + template + bool handle_exception(F&& call); + + void reportServerSideStats(const ServerSideStats& stats) const; + ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status, + const inference::ModelStatistics& end_status) const; + + inference::ModelStatistics getServerSideStatus() const; + + //members + unsigned maxOuterDim_; + unsigned outerDim_; + bool noOuterDim_; + unsigned nEntries_; + TritonBatchMode batchMode_; + bool manualBatchMode_; + bool verbose_; + bool useSharedMemory_; + TritonServerType serverType_; + bool isLocal_; + std::string serverName_; + grpc_compression_algorithm compressionAlgo_; + triton::client::Headers headers_; + + std::unique_ptr client_; + //stores timeout, model name and version + std::vector options_; + edm::ServiceToken token_; + +private: + friend TritonInputData; + friend TritonOutputData; + + //private accessors only used by data + auto client() { return client_.get(); } + void addEntry(unsigned entry); + void resizeEntries(unsigned entry); +}; + +#endif diff --git a/HeterogeneousCore/SonicTriton/interface/TritonService.h b/HeterogeneousCore/SonicTriton/interface/TritonService.h index 8f36f73566e06..44b3b98a18932 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonService.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonService.h @@ -3,6 +3,7 @@ #include "FWCore/ParameterSet/interface/ParameterSet.h" #include "FWCore/Utilities/interface/GlobalIdentifier.h" +#include "oneapi/tbb/concurrent_hash_map.h" #include #include @@ -11,6 +12,7 @@ #include #include #include +#include #include "grpc_client.h" @@ -90,6 +92,16 @@ class TritonService { static const std::string fallbackAddress; static const std::string siteconfName; }; + //Dynamic quantities of servers + struct ServerHealth { + bool live{false}; + bool ready{false}; + + uint64_t inferenceCount{0}; + uint64_t failureCount{0}; + double avgQueueTimeMs{0.0}; + double avgInferTimeMs{0.0}; + }; struct Model { Model(const std::string& path_ = "") : path(path_) {} @@ -111,7 +123,23 @@ class TritonService { //accessors void addModel(const std::string& modelName, const std::string& path); - Server serverInfo(const std::string& model, const std::string& preferred = "") const; + + const std::pair& serverInfo(const std::string& model, const std::string& preferred = "") const; + + // update health stats of all servers + void updateServerHealth(const std::string& modelName = ""); + + // return the best server for retry, ignore the current server + std::optional getBestServer(const std::string& modelName, const std::string& IgnoreServer = ""); + + // helper functions to get server statistics? + // - getServerSideStatus() + // - updateServerStatus() + // - loop over servers_ get statistics + // - getBestServer(model) + // - call updateServerStatus() + // - loop over servers_ get their statistics, compute metric, return server name + const std::string& pid() const { return pid_; } void notifyCallStatus(bool status) const; @@ -139,6 +167,8 @@ class TritonService { std::unordered_map unservedModels_; //this represents a many:many:many map std::unordered_map servers_; + //server health needs concurrent-safe edits + tbb::concurrent_hash_map serversHealth_; std::unordered_map models_; std::unordered_map modules_; int numberOfThreads_; diff --git a/HeterogeneousCore/SonicTriton/python/customize.py b/HeterogeneousCore/SonicTriton/python/customize.py index b4f9943423133..63618b6c155a5 100644 --- a/HeterogeneousCore/SonicTriton/python/customize.py +++ b/HeterogeneousCore/SonicTriton/python/customize.py @@ -35,6 +35,7 @@ def getParser(): parser.add_argument("--fallbackName", default="", type=str, help="name for fallback server") parser.add_argument("--imageName", default="", type=str, help="container image name for fallback server") parser.add_argument("--tempDir", default="", type=str, help="temp directory for fallback server") + parser.add_argument("--retryAction", default="same", type=str, choices=["same","diff"], help="retry policy: same server or different server") return parser @@ -90,12 +91,18 @@ def applyOptions(process, options, applyToModules=False): return process def getClientOptions(options): + action = cms.PSet( + retryType = cms.string('RetrySameServerAction'), + allowedTries = cms.untracked.uint32(options.tries)) + if options.retryAction != 'same': + action.retryType = cms.string('RetryActionDiffServer') + return dict( compression = cms.untracked.string(options.compression), useSharedMemory = cms.untracked.bool(not options.noShm), timeout = cms.untracked.uint32(options.timeout), timeoutUnit = cms.untracked.string(options.timeoutUnit), - allowedTries = cms.untracked.uint32(options.tries), + Retry = cms.VPSet(action) ) def applyClientOptions(client, options): diff --git a/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc new file mode 100644 index 0000000000000..0abd2da6ca569 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/src/RetryActionDiffServer.cc @@ -0,0 +1,45 @@ +#include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonClient.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" +#include "FWCore/MessageLogger/interface/MessageLogger.h" +#include "FWCore/ServiceRegistry/interface/Service.h" + +RetryActionDiffServer::RetryActionDiffServer(const edm::ParameterSet& conf, SonicClientBase* client) + : RetryActionBase(conf, client) {} + +void RetryActionDiffServer::start() { this->shouldRetry_ = true; } + +void RetryActionDiffServer::retry() { + if (!this->shouldRetry_) { + this->shouldRetry_ = false; + edm::LogInfo("RetryActionDiffServer") << "Retry not armed; skipping."; + return; + } + + try { + auto* tritonClient = static_cast(client_); + edm::LogInfo("RetryActionDiffServer") << "Attempting retry by switching to fallback server"; + // TODO: Get the server name from TritonService, use fallback for testing + edm::Service ts; + + // get best server, ignoring the current server + auto bestServerName = ts->getBestServer(tritonClient->modelName(),tritonClient->serverName()); + + if (bestServerName) { + tritonClient->updateServer(*bestServerName); + eval(); + } else { + edm::LogWarning("RetryActionDiffServer") + << "No alternative server found for model " << tritonClient->modelName(); + } + } catch (TritonException& e) { + e.convertToWarning(); + } catch (std::exception& e) { + edm::LogError("RetryActionDiffServer") << "Failed to retry with alternative server: " << e.what(); + } catch (...) { + edm::LogError("RetryActionDiffServer: UnknownFailure") << "An unknown exception was thrown"; + } + this->shouldRetry_ = false; +} + +DEFINE_RETRY_ACTION(RetryActionDiffServer); diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index ddcdff83448d0..0becdb31758f7 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -28,6 +28,19 @@ namespace tc = triton::client; namespace { + // Minimal ParameterSet to satisfy SonicClientBase requirements during unit tests + edm::ParameterSet makeMinimalSonicParamsForTest() { + edm::ParameterSet params; + params.addParameter("mode", "PseudoAsync"); + + edm::ParameterSet defaultRetry; + defaultRetry.addParameter("retryType", "RetrySameServerAction"); + defaultRetry.addUntrackedParameter("allowedTries", 0u); + std::vector retryVec{defaultRetry}; + params.addParameter>("Retry", retryVec); + + return params; + } grpc_compression_algorithm getCompressionAlgo(const std::string& name) { if (name.empty() or name.compare("none") == 0) return grpc_compression_algorithm::GRPC_COMPRESS_NONE; @@ -61,7 +74,7 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d useSharedMemory_(params.getUntrackedParameter("useSharedMemory")), compressionAlgo_(getCompressionAlgo(params.getUntrackedParameter("compression"))) { options_.emplace_back(params.getParameter("modelName")); - //get appropriate server for this model + edm::Service ts; // We save the token to be able to notify the service in case of an exception in the evaluate method. @@ -70,21 +83,8 @@ TritonClient::TritonClient(const edm::ParameterSet& params, const std::string& d // create the context. token_ = edm::ServiceRegistry::instance().presentToken(); - const auto& server = - ts->serverInfo(options_[0].model_name_, params.getUntrackedParameter("preferredServer")); - serverType_ = server.type; - edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url; - //enforce sync mode for fallback CPU server to avoid contention - //todo: could enforce async mode otherwise (unless mode was specified by user?) - if (serverType_ == TritonServerType::LocalCPU) - setMode(SonicMode::Sync); - isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU; - - //connect to the server - TRITON_THROW_IF_ERROR( - tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions), - "TritonClient(): unable to create inference context", - isLocal_); + //Connect to server + updateServer(params.getUntrackedParameter("preferredServer")); //set options options_[0].model_version_ = params.getParameter("modelVersion"); @@ -369,7 +369,7 @@ void TritonClient::getResults(const std::vector //default case for sync and pseudo async void TritonClient::evaluate() { //undo previous signal from TritonException - if (tries_ > 0) { + if (totalTries_ > 0) { // If we are retrying then the evaluate method is called outside the frameworks TBB thread pool. // So we need to setup the service token for the current thread to access the service registry. edm::ServiceRegistry::Operate op(token_); @@ -574,6 +574,32 @@ inference::ModelStatistics TritonClient::getServerSideStatus() const { return inference::ModelStatistics{}; } +void TritonClient::updateServer(const std::string& serverName) { + //get appropriate server for this model + edm::Service ts; + + const auto& serverMap = ts->serverInfo(options_[0].model_name_, serverName); + + const auto& server = serverMap.second; + + //update server name + serverName_ = serverMap.first; + + serverType_ = server.type; + edm::LogInfo("TritonDiscovery") << debugName_ << " assigned server: " << server.url; + //enforce sync mode for fallback CPU server to avoid contention + //todo: could enforce async mode otherwise (unless mode was specified by user?) + if (serverType_ == TritonServerType::LocalCPU) + setMode(SonicMode::Sync); + isLocal_ = serverType_ == TritonServerType::LocalCPU or serverType_ == TritonServerType::LocalGPU; + + //connect to the server + TRITON_THROW_IF_ERROR( + tc::InferenceServerGrpcClient::Create(&client_, server.url, false, server.useSsl, server.sslOptions), + "TritonClient(): unable to create inference context", + isLocal_); +} + //for fillDescriptions void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) { edm::ParameterSetDescription descClient; @@ -591,3 +617,24 @@ void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) { descClient.addUntracked>("outputs", {}); iDesc.add("Client", descClient); } + +void TritonClient::connectToServer(const std::string& url) { + // Update client state for a generic remote server + serverType_ = TritonServerType::Remote; + isLocal_ = false; + + edm::LogInfo("TritonDiscovery") << debugName_ << " connecting to server: " << url; + + // Use default SSL options + triton::client::SslOptions sslOptions; + bool useSsl = false; // Assuming no SSL for direct URL connection + + // Connect to the server + TRITON_THROW_IF_ERROR(triton::client::InferenceServerGrpcClient::Create(&client_, url, false, useSsl, sslOptions), + "TritonClient::connectToServer(): unable to create inference context", + false // isLocal is false + ); +} + +//constructor for testing +TritonClient::TritonClient() : SonicClient(makeMinimalSonicParamsForTest(), "TritonClient_test", "TritonClient") {} diff --git a/HeterogeneousCore/SonicTriton/src/TritonService.cc b/HeterogeneousCore/SonicTriton/src/TritonService.cc index d0d82bbaa9efc..fd5e0368bc80c 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonService.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonService.cc @@ -119,11 +119,14 @@ TritonService::TritonService(const edm::ParameterSet& pset, edm::ActivityRegistr << "TritonService: Not allowed to specify more than one server with same name (" << serverName << ")"; } - //loop over all servers: check which models they have + //loop over all servers: check which models they have, populate serverHealth std::string msg; if (verbose_) msg = "List of models for each server:\n"; for (auto& [serverName, server] : servers_) { + //populate serverHealth + serversHealth_.emplace(serverName, ServerHealth{}); + std::unique_ptr client; TRITON_THROW_IF_ERROR( tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions), @@ -223,7 +226,7 @@ void TritonService::preModuleDestruction(edm::ModuleDescription const& desc) { } //second return value is only true if fallback CPU server is being used -TritonService::Server TritonService::serverInfo(const std::string& model, const std::string& preferred) const { +const std::pair& TritonService::serverInfo(const std::string& model, const std::string& preferred) const { auto mit = models_.find(model); if (mit == models_.end()) throw cms::Exception("MissingModel") << "TritonService: There are no servers that provide model " << model; @@ -241,8 +244,116 @@ TritonService::Server TritonService::serverInfo(const std::string& model, const const auto& serverName(msit == modelServers.end() ? *modelServers.begin() : preferred); //todo: use some algorithm to select server rather than just picking arbitrarily - const auto& server(servers_.find(serverName)->second); - return server; + const auto serverPair = servers_.find(serverName); + return *serverPair; +} + +void TritonService::updateServerHealth(const std::string& modelName) { + for (auto& [serverName, server] : servers_) { + try { + std::unique_ptr client; + TRITON_THROW_IF_ERROR( + tc::InferenceServerGrpcClient::Create(&client, server.url, false, server.useSsl, server.sslOptions), + "TritonService(): unable to create inference context for " + serverName + " (" + server.url + ")", + false); + + bool live = false, ready = false; + client->IsServerLive(&live); + client->IsServerReady(&ready); + + inference::ModelStatisticsResponse stats; + if (!modelName.empty()) { + client->ModelInferenceStatistics(&stats, modelName); + } else { + for (const auto& m : server.models) { + client->ModelInferenceStatistics(&stats, m); + } + } + + uint64_t infer_count = 0, queue_count = 0, failures = 0; + double avgQueueTimeMs = 0.0; + double avgInferTimeMs = 0.0; + + for (const auto& mstat : stats.model_stats()) { + if (modelName.empty() || mstat.name() == modelName) { + const auto& infer = mstat.inference_stats(); + + infer_count += infer.compute_infer().count(); + avgInferTimeMs += infer.compute_infer().ns() / 1e3; + queue_count += infer.queue().count(); + avgQueueTimeMs += infer.queue().ns() / 1e3; + failures += infer.fail().count(); + } + } + // Update health map safely with accessor + tbb::concurrent_hash_map::accessor acc; + serversHealth_.find(acc, serverName); + + ServerHealth& health = acc->second; + health.live = live; + health.ready = ready; + health.failureCount = failures; + health.avgQueueTimeMs = avgQueueTimeMs / queue_count; + health.avgInferTimeMs = avgInferTimeMs / infer_count; + + } catch (const TritonException& e) { + // mark existing entry unhealthy if present + tbb::concurrent_hash_map::accessor acc; + if (serversHealth_.find(acc, serverName)) { + ServerHealth& health = acc->second; + health.live = false; + health.ready = false; + } + } catch (const std::exception& e) { + // fallback for other exceptions + tbb::concurrent_hash_map::accessor acc; + if (serversHealth_.find(acc, serverName)) { + ServerHealth& health = acc->second; + health.live = false; + health.ready = false; + } + } + } +} + +std::optional TritonService::getBestServer(const std::string& modelName, + const std::string& IgnoreServer) { + std::optional bestServerName; + ServerHealth bestHealth; + + // get fresh ServerHealth statistics + updateServerHealth(modelName); + + for (auto& [serverName, server] : servers_) { + if (serverName == IgnoreServer) + continue; // skip ignored server + if (server.models.find(modelName) == server.models.end()) + continue; // server doesn't have model + + tbb::concurrent_hash_map::const_accessor acc; + if (!serversHealth_.find(acc, serverName)) + continue; // no health info + + const ServerHealth& health = acc->second; + + if (!health.live || !health.ready) + continue; // skip unhealthy + + // Select server according to rules: + // 1) lowest failureCount + // 2) tie-breaker: lowest avgQueueTimeMs + if (!bestServerName || health.failureCount < bestHealth.failureCount || + (health.failureCount == bestHealth.failureCount && health.avgQueueTimeMs < bestHealth.avgQueueTimeMs)) { + bestServerName = serverName; + bestHealth = health; + } + } + if (verbose_ && bestServerName) { + edm::LogInfo("Chosen server for model '" + modelName + "': " + *bestServerName + + " (failures=" + std::to_string(bestHealth.failureCount) + + ", avgQueueTime=" + std::to_string(bestHealth.avgQueueTimeMs) + " ms)"); + } + return bestServerName; } void TritonService::preBeginJob(edm::ProcessContext const&) { diff --git a/HeterogeneousCore/SonicTriton/test/BuildFile.xml b/HeterogeneousCore/SonicTriton/test/BuildFile.xml index e4ff7a0bb56f3..2f6def462c7b0 100644 --- a/HeterogeneousCore/SonicTriton/test/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/test/BuildFile.xml @@ -1,11 +1,22 @@ + + - + + + + + + + + + + diff --git a/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc b/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc new file mode 100644 index 0000000000000..c501e8f9c9e3c --- /dev/null +++ b/HeterogeneousCore/SonicTriton/test/test_RetryActionDiffServer.cc @@ -0,0 +1,71 @@ +#define CATCH_CONFIG_MAIN +#include "catch.hpp" + +#include "HeterogeneousCore/SonicTriton/interface/RetryActionDiffServer.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonClient.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonService.h" +#include "HeterogeneousCore/SonicCore/interface/RetryActionBase.h" + +#include "FWCore/ParameterSet/interface/ParameterSet.h" + +#include + +// Test double for TritonClient to observe updateServer calls without framework/services +class TestTritonClient : public TritonClient { +public: + TestTritonClient() : TritonClient() {} + + void connectToServer(const std::string& url) override { lastConnectedUrl = url; } + + void updateServer(const std::string& serverName) override { lastUpdatedServerName = serverName; } + + const std::string& lastUrl() const { return lastConnectedUrl; } + const std::string& lastServerName() const { return lastUpdatedServerName; } + +protected: + void evaluate() override {} + +private: + std::string lastConnectedUrl; + std::string lastUpdatedServerName; +}; + +TEST_CASE("RetryActionDiffServer switches to fallback via updateServer", "[RetryActionDiffServer]") { + edm::ParameterSet empty; + TestTritonClient client; + + RetryActionDiffServer action(empty, static_cast(&client)); + + // start should arm the action + action.start(); + REQUIRE(action.shouldRetry()); + + // retry should call updateServer with fallback name then disarm + action.retry(); + REQUIRE(client.lastServerName() == TritonService::Server::fallbackName); + + // second retry without re-arming should be a no-op: lastServerName unchanged + std::string afterFirst = client.lastServerName(); + action.retry(); + REQUIRE(client.lastServerName() == afterFirst); +} + +// A client that throws during updateServer to exercise error handling path +class ThrowingTritonClient : public TritonClient { +public: + ThrowingTritonClient() : TritonClient() {} + void updateServer(const std::string&) override { throw TritonException("updateServer failure"); } + +protected: + void evaluate() override {} +}; + +TEST_CASE("RetryActionDiffServer catches exceptions from updateServer", "[RetryActionDiffServer]") { + edm::ParameterSet empty; + ThrowingTritonClient client; + RetryActionDiffServer action(empty, static_cast(&client)); + action.start(); + + // Should not throw despite client throwing internally; action disarms afterward + REQUIRE_NOTHROW(action.retry()); +} diff --git a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py index 33d6a9c60aad4..0bfd04d095128 100644 --- a/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py +++ b/HeterogeneousCore/SonicTriton/test/tritonTest_cfg.py @@ -21,6 +21,7 @@ parser.add_argument("--brief", default=False, action="store_true", help="briefer output for graph modules") parser.add_argument("--unittest", default=False, action="store_true", help="unit test mode: reduce input sizes") parser.add_argument("--testother", default=False, action="store_true", help="also test gRPC communication if shared memory enabled, or vice versa") +options = parser.parse_args() options = getOptions(parser, verbose=True)