-
Notifications
You must be signed in to change notification settings - Fork 699
[RFC] Refactor CPUFunction and InterpreterFunction to remove per-run state #2274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,63 +30,55 @@ CPUFunction::~CPUFunction() { | |
tearDownRuns(); | ||
} | ||
|
||
void CPUFunction::setupRuns() { | ||
if (!runsSetup_) { | ||
if (runtimeBundle_.getActivationsSize() != 0) { | ||
baseActivationsAddress_ = (uint8_t *)alignedAlloc( | ||
runtimeBundle_.getActivationsSize(), TensorAlignment); | ||
} | ||
|
||
if (runtimeBundle_.getMutableWeightSize() != 0) { | ||
baseMutableWeightVarsAddress_ = (uint8_t *)alignedAlloc( | ||
runtimeBundle_.getMutableWeightSize(), TensorAlignment); | ||
} | ||
runsSetup_ = true; | ||
} | ||
} | ||
|
||
void CPUFunction::collectConstants(IRFunction *F) { | ||
runtimeBundle_.collectConstants(F); | ||
} | ||
|
||
void CPUFunction::beforeRun(const Context &ctx) { | ||
void CPUFunction::loadPlaceholders(Context *ctx, | ||
uint8_t *baseMutableWeightVarsAddress) { | ||
// Copy Placeholders into allocated memory. | ||
for (auto PH : ctx.pairs()) { | ||
for (auto PH : ctx->pairs()) { | ||
auto payload = PH.second->getUnsafePtr(); | ||
auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first); | ||
auto addr = symbolInfo.offset; | ||
auto numBytes = symbolInfo.size; | ||
// copy PH to allocated memory. | ||
memcpy(baseMutableWeightVarsAddress_ + addr, payload, numBytes); | ||
memcpy(baseMutableWeightVarsAddress + addr, payload, numBytes); | ||
} | ||
} | ||
|
||
void CPUFunction::afterRun(const Context &ctx) { | ||
void CPUFunction::updatePlaceholders(Context *ctx, | ||
uint8_t *baseMutableWeightVarsAddress) { | ||
// Copy placeholders from device back into context. | ||
for (auto PH : ctx.pairs()) { | ||
for (auto PH : ctx->pairs()) { | ||
auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first); | ||
auto payload = baseMutableWeightVarsAddress_ + symbolInfo.offset; | ||
auto payload = baseMutableWeightVarsAddress + symbolInfo.offset; | ||
auto numBytes = symbolInfo.size; | ||
auto addr = PH.second->getUnsafePtr(); | ||
// copy PH from allocated memory. | ||
memcpy(addr, payload, numBytes); | ||
} | ||
} | ||
|
||
void CPUFunction::tearDownRuns() { | ||
if (baseMutableWeightVarsAddress_) { | ||
alignedFree(baseMutableWeightVarsAddress_); | ||
baseMutableWeightVarsAddress_ = nullptr; | ||
void CPUFunction::execute(Context *ctx) { | ||
/// Base address for Activations memory block. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'd remove comment, does not provide any additional info on top of the var name. |
||
uint8_t *baseActivationsAddress{nullptr}; | ||
|
||
/// Base address for Mutable weights memory block, Inputs and Outputs. | ||
uint8_t *baseMutableWeightVarsAddress{nullptr}; | ||
|
||
if (runtimeBundle_.getActivationsSize() != 0) { | ||
baseActivationsAddress = (uint8_t *)alignedAlloc( | ||
runtimeBundle_.getActivationsSize(), TensorAlignment); | ||
} | ||
|
||
if (baseActivationsAddress_) { | ||
alignedFree(baseActivationsAddress_); | ||
baseActivationsAddress_ = nullptr; | ||
if (runtimeBundle_.getMutableWeightSize() != 0) { | ||
baseMutableWeightVarsAddress = (uint8_t *)alignedAlloc( | ||
runtimeBundle_.getMutableWeightSize(), TensorAlignment); | ||
} | ||
runsSetup_ = false; | ||
} | ||
|
||
void CPUFunction::execute() { | ||
loadPlaceholders(ctx, baseMutableWeightVarsAddress); | ||
|
||
auto sym = JIT_->findSymbol("jitmain"); | ||
assert(sym && "Unable to JIT the code!"); | ||
using JitFuncType = | ||
|
@@ -95,9 +87,14 @@ void CPUFunction::execute() { | |
auto address = sym.getAddress(); | ||
if (address) { | ||
JitFuncType funcPtr = reinterpret_cast<JitFuncType>(address.get()); | ||
funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress_, | ||
baseActivationsAddress_); | ||
funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress, | ||
baseActivationsAddress); | ||
} else { | ||
GLOW_ASSERT(false && "Error getting address."); | ||
} | ||
|
||
updatePlaceholders(ctx, baseMutableWeightVarsAddress); | ||
|
||
alignedFree(baseMutableWeightVarsAddress); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be fine for now (alloc and dealloc every inference request). But technically we could store this in thread local and reuse buffers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, if we run into CPU backend perf concerns we should add a memory pool here. |
||
alignedFree(baseActivationsAddress); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,12 +28,6 @@ class CPUFunction final : public CompiledFunction { | |
/// initializes the LLVM backends. | ||
std::unique_ptr<llvm::orc::GlowJIT> JIT_; | ||
|
||
/// Base address for Activations memory block. | ||
uint8_t *baseActivationsAddress_{}; | ||
|
||
/// Base address for Mutable weights memory block, Inputs and Outputs. | ||
uint8_t *baseMutableWeightVarsAddress_{}; | ||
|
||
public: | ||
/// Ctor. | ||
CPUFunction(std::unique_ptr<llvm::orc::GlowJIT> JIT, | ||
|
@@ -42,24 +36,19 @@ class CPUFunction final : public CompiledFunction { | |
/// Collects constants for runtime. | ||
void collectConstants(IRFunction *F); | ||
|
||
/// Allocate Mutable buffers on device this includes Activations and | ||
/// Placeholders. | ||
void setupRuns() override; | ||
|
||
/// Copy Input Placeholder data to position. | ||
void beforeRun(const Context &ctx) override; | ||
|
||
/// Copy Outputs to Placeholders in \p ctx. | ||
void afterRun(const Context &ctx) override; | ||
|
||
/// Final cleanup, free all allocations. | ||
void tearDownRuns() override; | ||
|
||
/// \name CompiledFunction interface | ||
///@{ | ||
~CPUFunction() override; | ||
void execute() override; | ||
void execute(Context *ctx) override; | ||
///@} | ||
private: | ||
/// Load constant tensors from \p ctx into \p weightsAddress, as defined by | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm, I'm a bit confused. Placeholders are for inputs/outputs but not for constant tensors. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm maintaining existing behaviour of CPUFunction in this diff, which is that all constants & placeholders have their space allocated in the RuntimeBundle and then we copy them into the per-run memory block for execution. The memory should be uninitialized so we don't need to memcpy it, but figured we could fix that when we get to it. It is a known issue with the RuntimeBundle. |
||
/// the RuntimeBundle (pre-run). | ||
void loadPlaceholders(Context *ctx, uint8_t *weightsAddress); | ||
|
||
/// Load weights from \p weightsAddress into applicable backing tensors in | ||
/// \p ctx, as defined by the RuntimeBundle (post-run). | ||
void updatePlaceholders(Context *ctx, uint8_t *weightsAddress); | ||
}; | ||
} // end namespace glow | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,29 +22,26 @@ | |
|
||
#include "llvm/Support/Casting.h" | ||
|
||
#include "llvm/Support/raw_ostream.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not used. |
||
using namespace glow; | ||
|
||
InterpreterFunction::InterpreterFunction(std::unique_ptr<IRFunction> F, | ||
const runtime::RuntimeBundle &bundle) | ||
: CompiledFunction(bundle), F_(std::move(F)) {} | ||
|
||
InterpreterFunction::~InterpreterFunction() { | ||
// Delete the tensors that are owned by this backend. | ||
for (const auto &p : tensors_) { | ||
for (const auto &p : constants_) { | ||
delete p.second; | ||
} | ||
tensors_.clear(); | ||
externalTensors_.clear(); | ||
constants_.clear(); | ||
|
||
alignedFree(runtimeBundle_.getConstants()); | ||
tearDownRuns(); | ||
} | ||
|
||
void InterpreterFunction::collectConstants(IRFunction *F) { | ||
runtimeBundle_.collectConstants(F); | ||
} | ||
|
||
void InterpreterFunction::setupRuns() { | ||
if (!runsSetup_) { | ||
if (constants_.empty()) { | ||
if (runtimeBundle_.getConstantWeightSize()) { | ||
for (const auto &v : F_->getGraph()->getParent()->getConstants()) { | ||
auto symbolInfo = runtimeBundle_.getSymbolInfo(v); | ||
|
@@ -53,36 +50,27 @@ void InterpreterFunction::setupRuns() { | |
constants_.emplace(std::string(v->getName()), tensor); | ||
} | ||
} | ||
runsSetup_ = true; | ||
} | ||
} | ||
|
||
void InterpreterFunction::beforeRun(const Context &ctx) { | ||
// Register the concrete tensors that back the placeholder tensors. | ||
for (auto &ph : ctx.pairs()) { | ||
auto *w = F_->getWeightForNode(ph.first); | ||
assert(!externalTensors_.count(w) && "The tensor is already registered"); | ||
externalTensors_[w] = ph.second; | ||
} | ||
} | ||
|
||
void InterpreterFunction::afterRun(const Context &ctx) { | ||
// Remove the concrete tensors that back the placeholder tensors. | ||
for (auto &ph : ctx.pairs()) { | ||
auto *w = F_->getWeightForNode(ph.first); | ||
externalTensors_.erase(w); | ||
void InterpreterFunction::execute(Context *ctx) { | ||
if (constants_.empty()) { | ||
collectConstants(F_.get()); | ||
} | ||
BoundInterpreterFunction boundFunc(constants_); | ||
boundFunc.execute(F_.get(), ctx); | ||
} | ||
|
||
void InterpreterFunction::tearDownRuns() { | ||
for (const auto &p : constants_) { | ||
BoundInterpreterFunction::~BoundInterpreterFunction() { | ||
// Delete the tensors that are owned by this backend. | ||
for (const auto &p : tensors_) { | ||
delete p.second; | ||
} | ||
constants_.clear(); | ||
runsSetup_ = false; | ||
tensors_.clear(); | ||
externalTensors_.clear(); | ||
} | ||
|
||
Tensor *InterpreterFunction::getTensor(const Value *v) const { | ||
Tensor *BoundInterpreterFunction::getTensor(const Value *v) const { | ||
auto it = tensors_.find(v); | ||
if (it != tensors_.end()) { | ||
return it->second; | ||
|
@@ -97,7 +85,7 @@ Tensor *InterpreterFunction::getTensor(const Value *v) const { | |
return ie->second; | ||
} | ||
|
||
Tensor *InterpreterFunction::getOrCreateTensor(const Value *v) { | ||
Tensor *BoundInterpreterFunction::getOrCreateTensor(const Value *v) { | ||
auto ie = externalTensors_.find(v); | ||
if (ie != externalTensors_.end()) { | ||
return ie->second; | ||
|
@@ -117,9 +105,8 @@ Tensor *InterpreterFunction::getOrCreateTensor(const Value *v) { | |
return it->second; | ||
} | ||
|
||
Tensor * | ||
InterpreterFunction::getOrCreateUnownedTensor(const Value *v, const Value *src, | ||
llvm::ArrayRef<size_t> offsets) { | ||
Tensor *BoundInterpreterFunction::getOrCreateUnownedTensor( | ||
const Value *v, const Value *src, llvm::ArrayRef<size_t> offsets) { | ||
assert(llvm::isa<TensorViewInst>(v) && "Expected a tensor view"); | ||
|
||
// Pick the tensor. | ||
|
@@ -136,7 +123,7 @@ InterpreterFunction::getOrCreateUnownedTensor(const Value *v, const Value *src, | |
return T; | ||
} | ||
|
||
void InterpreterFunction::deleteTensor(const Value *v) { | ||
void BoundInterpreterFunction::deleteTensor(const Value *v) { | ||
auto it = tensors_.find(v); | ||
if (it == tensors_.end()) { | ||
return; | ||
|
@@ -146,7 +133,14 @@ void InterpreterFunction::deleteTensor(const Value *v) { | |
tensors_.erase(it); | ||
} | ||
|
||
void InterpreterFunction::execute() { | ||
void BoundInterpreterFunction::execute(IRFunction *F, Context *ctx) { | ||
// Register the concrete tensors that back the placeholder tensors. | ||
for (auto &ph : ctx->pairs()) { | ||
auto *w = F->getWeightForNode(ph.first); | ||
assert(!externalTensors_.count(w) && "The tensor is already registered"); | ||
externalTensors_[w] = ph.second; | ||
} | ||
|
||
// Do the forward pass. | ||
#define DEF_VALUE(CLASS, NAME) | ||
#define DEF_INSTR(CLASS, NAME) \ | ||
|
@@ -156,12 +150,18 @@ void InterpreterFunction::execute() { | |
} | ||
#define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) | ||
// Dispatch the interpreter on each instruction in the program: | ||
for (const auto &I : F_->getInstrs()) { | ||
for (const auto &I : F->getInstrs()) { | ||
switch (I.getKind()) { | ||
#include "glow/AutoGenInstr.def" | ||
|
||
default: | ||
llvm_unreachable("Invalid instruction."); | ||
} | ||
} | ||
|
||
// Remove the concrete tensors that back the placeholder tensors. | ||
for (auto &ph : ctx->pairs()) { | ||
auto *w = F->getWeightForNode(ph.first); | ||
externalTensors_.erase(w); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
magic :) comment about ctx was already in place