Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slight update to TensorFlowAOT interface. #44586

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions PhysicsTools/TensorFlowAOT/interface/Batching.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ namespace tfaot {
// constructor
explicit BatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding = 0);

// constructor taking a string in the format "batchSize:size1,...,sizeN" with lastPadding being
// inferred from the sum of sizes
BatchRule(const std::string& ruleString);

// destructor
~BatchRule() = default;

Expand All @@ -43,6 +47,9 @@ namespace tfaot {
size_t batchSize_;
std::vector<size_t> sizes_;
size_t lastPadding_;

// validation helper
void validate() const;
};

// stream operator
Expand All @@ -60,6 +67,9 @@ namespace tfaot {
// registers a new rule for a batch size
void setRule(const BatchRule& rule) { rules_.insert_or_assign(rule.getBatchSize(), rule); }

// registers a new rule for a batch size, given a rule string (see BatchRule constructor)
void setRule(const std::string& ruleString) { this->setRule(BatchRule(ruleString)); }

// returns whether a rule was already registered for a certain batch size
bool hasRule(size_t batchSize) const { return rules_.find(batchSize) != rules_.end(); }

Expand Down
3 changes: 3 additions & 0 deletions PhysicsTools/TensorFlowAOT/interface/Model.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ namespace tfaot {
batchStrategy_.setRule(BatchRule(batchSize, sizes, lastPadding));
}

// adds a new batch rule to the strategy, given a rule string (see BatchRule constructor)
void setBatchRule(const std::string& batchRule) { batchStrategy_.setRule(BatchRule(batchRule)); }

// evaluates the model for multiple inputs and outputs of different types
template <typename... Outputs, typename... Inputs>
std::tuple<Outputs...> run(size_t batchSize, Inputs&&... inputs);
Expand Down
55 changes: 46 additions & 9 deletions PhysicsTools/TensorFlowAOT/src/Batching.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,70 @@ namespace tfaot {

BatchRule::BatchRule(size_t batchSize, const std::vector<size_t>& sizes, size_t lastPadding)
: batchSize_(batchSize), sizes_(sizes), lastPadding_(lastPadding) {
validate();
}

BatchRule::BatchRule(const std::string& ruleString) {
// extract the target batch size from the front
std::string rule = ruleString;
auto pos = rule.find(":");
if (pos == std::string::npos) {
throw cms::Exception("InvalidBatchRule") << "invalid batch rule format: " << ruleString;
}
size_t batchSize = std::stoi(rule.substr(0, pos));
rule = rule.substr(pos + 1);

// loop through remaining comma-separated sizes
std::vector<size_t> sizes;
size_t sumSizes = 0;
while (!rule.empty()) {
pos = rule.find(",");
sizes.push_back(std::stoi(rule.substr(0, pos)));
sumSizes += sizes.back();
rule = pos == std::string::npos ? "" : rule.substr(pos + 1);
}

// the sum of composite batch sizes should never be smaller than the target batch size
if (sumSizes < batchSize) {
throw cms::Exception("InvalidBatchRule")
<< "sum of composite batch sizes is smaller than target batch size: " << ruleString;
}

// set members and validate
batchSize_ = batchSize;
sizes_ = sizes;
lastPadding_ = sumSizes - batchSize;
validate();
}

void BatchRule::validate() const {
// sizes must not be empty
if (sizes.size() == 0) {
if (sizes_.size() == 0) {
throw cms::Exception("EmptySizes") << "no batch sizes provided for stitching";
}

// the padding must be smaller than the last size
size_t lastSize = sizes[sizes.size() - 1];
if (lastPadding >= lastSize) {
size_t lastSize = sizes_[sizes_.size() - 1];
if (lastPadding_ >= lastSize) {
throw cms::Exception("WrongPadding")
<< "padding " << lastPadding << " must be smaller than last size " << lastSize;
<< "padding " << lastPadding_ << " must be smaller than last size " << lastSize;
}

// compute the covered batch size
size_t sizeSum = 0;
for (const size_t& s : sizes_) {
sizeSum += s;
}
if (lastPadding > sizeSum) {
if (lastPadding_ > sizeSum) {
throw cms::Exception("WrongPadding")
<< "padding " << lastPadding << " must not be larger than sum of sizes " << sizeSum;
<< "padding " << lastPadding_ << " must not be larger than sum of sizes " << sizeSum;
}
sizeSum -= lastPadding;
sizeSum -= lastPadding_;

// compare to given batch size
if (batchSize != sizeSum) {
if (batchSize_ != sizeSum) {
throw cms::Exception("WrongBatchSize")
<< "batch size " << batchSize << " does not match sum of sizes - padding " << sizeSum;
<< "batch size " << batchSize_ << " does not match sum of sizes - padding " << sizeSum;
}
}

Expand Down
1 change: 0 additions & 1 deletion PhysicsTools/TensorFlowAOT/test/testAOTTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def test_dev_workflow(self, tmp_dir):
self.assertTrue(exists("include", "tfaot-model-test"))
self.assertTrue(exists("include", "tfaot-model-test", "test_simple_bs1.h"))
self.assertTrue(exists("include", "tfaot-model-test", "test_simple_bs2.h"))
self.assertTrue(exists("include", "tfaot-model-test", "test_simple.h"))
self.assertTrue(exists("include", "tfaot-model-test", "model.h"))
self.assertTrue(exists("lib", "test_simple_bs1.o"))
self.assertTrue(exists("lib", "test_simple_bs2.o"))
Expand Down
4 changes: 4 additions & 0 deletions PhysicsTools/TensorFlowAOT/test/testInterface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ void testInterface::test_simple() {
// register (optional) batch rules
model.setBatchRule(1, {1});
model.setBatchRule(3, {2, 2}, 1);
model.setBatchRule("5:2,2,2");

// test batching strategies
CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(1));
Expand All @@ -50,6 +51,9 @@ void testInterface::test_simple() {
CPPUNIT_ASSERT(model.getBatchStrategy().getRule(3).nSizes() == 2);
CPPUNIT_ASSERT(model.getBatchStrategy().getRule(3).getLastPadding() == 1);
CPPUNIT_ASSERT(!model.getBatchStrategy().hasRule(4));
CPPUNIT_ASSERT(model.getBatchStrategy().hasRule(5));
CPPUNIT_ASSERT(model.getBatchStrategy().getRule(5).nSizes() == 3);
CPPUNIT_ASSERT(model.getBatchStrategy().getRule(5).getLastPadding() == 1);

// evaluate batch size 1
tfaot::FloatArrays input_bs1 = {{0, 1, 2, 3}};
Expand Down