From f6a05a2fb4cd7519ee7275294c64972a047401ee Mon Sep 17 00:00:00 2001 From: Marcel R Date: Tue, 2 Apr 2024 10:51:24 +0200 Subject: [PATCH 1/2] Adapt to updated aot workflow. --- PhysicsTools/TensorFlowAOT/test/testAOTTools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/PhysicsTools/TensorFlowAOT/test/testAOTTools.py b/PhysicsTools/TensorFlowAOT/test/testAOTTools.py index cd598c0499544..72e43fe65183d 100644 --- a/PhysicsTools/TensorFlowAOT/test/testAOTTools.py +++ b/PhysicsTools/TensorFlowAOT/test/testAOTTools.py @@ -59,7 +59,6 @@ def test_dev_workflow(self, tmp_dir): self.assertTrue(exists("include", "tfaot-model-test")) self.assertTrue(exists("include", "tfaot-model-test", "test_simple_bs1.h")) self.assertTrue(exists("include", "tfaot-model-test", "test_simple_bs2.h")) - self.assertTrue(exists("include", "tfaot-model-test", "test_simple.h")) self.assertTrue(exists("include", "tfaot-model-test", "model.h")) self.assertTrue(exists("lib", "test_simple_bs1.o")) self.assertTrue(exists("lib", "test_simple_bs2.o")) From 8e69cadc226c94034f10afe288146af4560b3a16 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Tue, 2 Apr 2024 10:52:01 +0200 Subject: [PATCH 2/2] Add second BatchRule ctor taking string format. --- .../TensorFlowAOT/interface/Batching.h | 10 ++++ PhysicsTools/TensorFlowAOT/interface/Model.h | 3 + PhysicsTools/TensorFlowAOT/src/Batching.cc | 55 ++++++++++++++++--- .../TensorFlowAOT/test/testInterface.cc | 4 ++ 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/PhysicsTools/TensorFlowAOT/interface/Batching.h b/PhysicsTools/TensorFlowAOT/interface/Batching.h index 0989b2147a4be..ed5016a3fc3af 100644 --- a/PhysicsTools/TensorFlowAOT/interface/Batching.h +++ b/PhysicsTools/TensorFlowAOT/interface/Batching.h @@ -21,6 +21,10 @@ namespace tfaot { // constructor explicit BatchRule(size_t batchSize, const std::vector& sizes, size_t lastPadding = 0); + // constructor taking a string in the format "batchSize:size1,...,sizeN" with lastPadding being + // inferred from the sum of sizes + BatchRule(const std::string& ruleString); + // destructor ~BatchRule() = default; @@ -43,6 +47,9 @@ namespace tfaot { size_t batchSize_; std::vector sizes_; size_t lastPadding_; + + // validation helper + void validate() const; }; // stream operator @@ -60,6 +67,9 @@ namespace tfaot { // registers a new rule for a batch size void setRule(const BatchRule& rule) { rules_.insert_or_assign(rule.getBatchSize(), rule); } + // registers a new rule for a batch size, given a rule string (see BatchRule constructor) + void setRule(const std::string& ruleString) { this->setRule(BatchRule(ruleString)); } + // returns whether a rule was already registered for a certain batch size bool hasRule(size_t batchSize) const { return rules_.find(batchSize) != rules_.end(); } diff --git a/PhysicsTools/TensorFlowAOT/interface/Model.h b/PhysicsTools/TensorFlowAOT/interface/Model.h index e4c1a7e1db387..2cc376195d1e8 100644 --- a/PhysicsTools/TensorFlowAOT/interface/Model.h +++ b/PhysicsTools/TensorFlowAOT/interface/Model.h @@ -38,6 +38,9 @@ namespace tfaot { batchStrategy_.setRule(BatchRule(batchSize, sizes, lastPadding)); } + // adds a new batch rule to the strategy, given a rule string (see BatchRule constructor) + void setBatchRule(const std::string& batchRule) { batchStrategy_.setRule(BatchRule(batchRule)); } + // evaluates the model for multiple inputs and outputs of different types template std::tuple run(size_t batchSize, Inputs&&... inputs); diff --git a/PhysicsTools/TensorFlowAOT/src/Batching.cc b/PhysicsTools/TensorFlowAOT/src/Batching.cc index 1f09205d3e0d0..bc0c619bb4903 100644 --- a/PhysicsTools/TensorFlowAOT/src/Batching.cc +++ b/PhysicsTools/TensorFlowAOT/src/Batching.cc @@ -15,16 +15,53 @@ namespace tfaot { BatchRule::BatchRule(size_t batchSize, const std::vector& sizes, size_t lastPadding) : batchSize_(batchSize), sizes_(sizes), lastPadding_(lastPadding) { + validate(); + } + + BatchRule::BatchRule(const std::string& ruleString) { + // extract the target batch size from the front + std::string rule = ruleString; + auto pos = rule.find(":"); + if (pos == std::string::npos) { + throw cms::Exception("InvalidBatchRule") << "invalid batch rule format: " << ruleString; + } + size_t batchSize = std::stoi(rule.substr(0, pos)); + rule = rule.substr(pos + 1); + + // loop through remaining comma-separated sizes + std::vector sizes; + size_t sumSizes = 0; + while (!rule.empty()) { + pos = rule.find(","); + sizes.push_back(std::stoi(rule.substr(0, pos))); + sumSizes += sizes.back(); + rule = pos == std::string::npos ? "" : rule.substr(pos + 1); + } + + // the sum of composite batch sizes should never be smaller than the target batch size + if (sumSizes < batchSize) { + throw cms::Exception("InvalidBatchRule") + << "sum of composite batch sizes is smaller than target batch size: " << ruleString; + } + + // set members and validate + batchSize_ = batchSize; + sizes_ = sizes; + lastPadding_ = sumSizes - batchSize; + validate(); + } + + void BatchRule::validate() const { // sizes must not be empty - if (sizes.size() == 0) { + if (sizes_.size() == 0) { throw cms::Exception("EmptySizes") << "no batch sizes provided for stitching"; } // the padding must be smaller than the last size - size_t lastSize = sizes[sizes.size() - 1]; - if (lastPadding >= lastSize) { + size_t lastSize = sizes_[sizes_.size() - 1]; + if (lastPadding_ >= lastSize) { throw cms::Exception("WrongPadding") - << "padding " << lastPadding << " must be smaller than last size " << lastSize; + << "padding " << lastPadding_ << " must be smaller than last size " << lastSize; } // compute the covered batch size @@ -32,16 +69,16 @@ namespace tfaot { for (const size_t& s : sizes_) { sizeSum += s; } - if (lastPadding > sizeSum) { + if (lastPadding_ > sizeSum) { throw cms::Exception("WrongPadding") - << "padding " << lastPadding << " must not be larger than sum of sizes " << sizeSum; + << "padding " << lastPadding_ << " must not be larger than sum of sizes " << sizeSum; } - sizeSum -= lastPadding; + sizeSum -= lastPadding_; // compare to given batch size - if (batchSize != sizeSum) { + if (batchSize_ != sizeSum) { throw cms::Exception("WrongBatchSize") - << "batch size " << batchSize << " does not match sum of sizes - padding " << sizeSum; + << "batch size " << batchSize_ << " does not match sum of sizes - padding " << sizeSum; } } diff --git a/PhysicsTools/TensorFlowAOT/test/testInterface.cc b/PhysicsTools/TensorFlowAOT/test/testInterface.cc index 10c19c746a790..70d52248f9c13 100644 --- a/PhysicsTools/TensorFlowAOT/test/testInterface.cc +++ b/PhysicsTools/TensorFlowAOT/test/testInterface.cc @@ -40,6 +40,7 @@ void testInterface::test_simple() { // register (optional) batch rules model.setBatchRule(1, {1}); model.setBatchRule(3, {2, 2}, 1); + model.setBatchRule("5:2,2,2"); // test batching strategies CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(1)); @@ -50,6 +51,9 @@ void testInterface::test_simple() { CPPUNIT_ASSERT(model.getBatchStrategy().getRule(3).nSizes() == 2); CPPUNIT_ASSERT(model.getBatchStrategy().getRule(3).getLastPadding() == 1); CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(4)); + CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(5)); + CPPUNIT_ASSERT(model.getBatchStrategy().getRule(5).nSizes() == 3); + CPPUNIT_ASSERT(model.getBatchStrategy().getRule(5).getLastPadding() == 1); // evaluate batch size 1 tfaot::FloatArrays input_bs1 = {{0, 1, 2, 3}};