diff --git a/include/glow/Backends/CompiledFunction.h b/include/glow/Backends/CompiledFunction.h index a5197c8906..27483a03ba 100644 --- a/include/glow/Backends/CompiledFunction.h +++ b/include/glow/Backends/CompiledFunction.h @@ -38,21 +38,21 @@ class CompiledFunction { virtual ~CompiledFunction() = default; /// Execute the network and allocate Placeholder memory with given /// \p ctx providing mapping between Placeholder and populated tensor. - virtual void execute() = 0; + virtual void execute(Context *ctx) = 0; /// Does any needed initialization work for the Backend. /// This includes device init constant memory allocation and copying to - /// device. - virtual void setupRuns() = 0; + /// device. \deprecated + virtual void setupRuns() { runsSetup_ = true; } - /// Per run setup. Copy inputs to device. - virtual void beforeRun(const Context &ctx) = 0; + /// Per run setup. Copy inputs to device. \deprecated + virtual void beforeRun(const Context &ctx) {} - /// Per run cleanup. Copy outputs from device. - virtual void afterRun(const Context &ctx) = 0; + /// Per run cleanup. Copy outputs from device. \deprecated + virtual void afterRun(const Context &ctx) {} - /// Final cleanup. Release memory, reset device. - virtual void tearDownRuns() = 0; + /// Final cleanup. Release memory, reset device. \deprecated + virtual void tearDownRuns() { runsSetup_ = false; } /// Getter for the runtimeBundle. const runtime::RuntimeBundle &getRuntimeBundle() const { diff --git a/lib/Backends/CPU/CPUDeviceManager.cpp b/lib/Backends/CPU/CPUDeviceManager.cpp index a3f63a29e2..37e33514c1 100644 --- a/lib/Backends/CPU/CPUDeviceManager.cpp +++ b/lib/Backends/CPU/CPUDeviceManager.cpp @@ -92,11 +92,7 @@ void CPUDeviceManager::runFunctionImpl(RunIdentifierTy id, std::string function, CompiledFunction *func = funcIt->second; // Run that function. - func->setupRuns(); - func->beforeRun(*ctx); - func->execute(); - func->afterRun(*ctx); - func->tearDownRuns(); + func->execute(ctx.get()); // Fire the resultCB. resultCB(id, ResultCode::Executed, std::move(ctx)); diff --git a/lib/Backends/CPU/CPUFunction.cpp b/lib/Backends/CPU/CPUFunction.cpp index 16d00439a1..ffc36f10f8 100644 --- a/lib/Backends/CPU/CPUFunction.cpp +++ b/lib/Backends/CPU/CPUFunction.cpp @@ -30,42 +30,29 @@ CPUFunction::~CPUFunction() { tearDownRuns(); } -void CPUFunction::setupRuns() { - if (!runsSetup_) { - if (runtimeBundle_.getActivationsSize() != 0) { - baseActivationsAddress_ = (uint8_t *)alignedAlloc( - runtimeBundle_.getActivationsSize(), TensorAlignment); - } - - if (runtimeBundle_.getMutableWeightSize() != 0) { - baseMutableWeightVarsAddress_ = (uint8_t *)alignedAlloc( - runtimeBundle_.getMutableWeightSize(), TensorAlignment); - } - runsSetup_ = true; - } -} - void CPUFunction::collectConstants(IRFunction *F) { runtimeBundle_.collectConstants(F); } -void CPUFunction::beforeRun(const Context &ctx) { +void CPUFunction::loadPlaceholders(Context *ctx, + uint8_t *baseMutableWeightVarsAddress) { // Copy Placeholders into allocated memory. - for (auto PH : ctx.pairs()) { + for (auto PH : ctx->pairs()) { auto payload = PH.second->getUnsafePtr(); auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first); auto addr = symbolInfo.offset; auto numBytes = symbolInfo.size; // copy PH to allocated memory. - memcpy(baseMutableWeightVarsAddress_ + addr, payload, numBytes); + memcpy(baseMutableWeightVarsAddress + addr, payload, numBytes); } } -void CPUFunction::afterRun(const Context &ctx) { +void CPUFunction::updatePlaceholders(Context *ctx, + uint8_t *baseMutableWeightVarsAddress) { // Copy placeholders from device back into context. - for (auto PH : ctx.pairs()) { + for (auto PH : ctx->pairs()) { auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first); - auto payload = baseMutableWeightVarsAddress_ + symbolInfo.offset; + auto payload = baseMutableWeightVarsAddress + symbolInfo.offset; auto numBytes = symbolInfo.size; auto addr = PH.second->getUnsafePtr(); // copy PH from allocated memory. @@ -73,20 +60,25 @@ void CPUFunction::afterRun(const Context &ctx) { } } -void CPUFunction::tearDownRuns() { - if (baseMutableWeightVarsAddress_) { - alignedFree(baseMutableWeightVarsAddress_); - baseMutableWeightVarsAddress_ = nullptr; +void CPUFunction::execute(Context *ctx) { + /// Base address for Activations memory block. + uint8_t *baseActivationsAddress{nullptr}; + + /// Base address for Mutable weights memory block, Inputs and Outputs. + uint8_t *baseMutableWeightVarsAddress{nullptr}; + + if (runtimeBundle_.getActivationsSize() != 0) { + baseActivationsAddress = (uint8_t *)alignedAlloc( + runtimeBundle_.getActivationsSize(), TensorAlignment); } - if (baseActivationsAddress_) { - alignedFree(baseActivationsAddress_); - baseActivationsAddress_ = nullptr; + if (runtimeBundle_.getMutableWeightSize() != 0) { + baseMutableWeightVarsAddress = (uint8_t *)alignedAlloc( + runtimeBundle_.getMutableWeightSize(), TensorAlignment); } - runsSetup_ = false; -} -void CPUFunction::execute() { + loadPlaceholders(ctx, baseMutableWeightVarsAddress); + auto sym = JIT_->findSymbol("jitmain"); assert(sym && "Unable to JIT the code!"); using JitFuncType = @@ -95,9 +87,14 @@ void CPUFunction::execute() { auto address = sym.getAddress(); if (address) { JitFuncType funcPtr = reinterpret_cast(address.get()); - funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress_, - baseActivationsAddress_); + funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress, + baseActivationsAddress); } else { GLOW_ASSERT(false && "Error getting address."); } + + updatePlaceholders(ctx, baseMutableWeightVarsAddress); + + alignedFree(baseMutableWeightVarsAddress); + alignedFree(baseActivationsAddress); } diff --git a/lib/Backends/CPU/CPUFunction.h b/lib/Backends/CPU/CPUFunction.h index c1ca4fdf4b..d35277f39b 100644 --- a/lib/Backends/CPU/CPUFunction.h +++ b/lib/Backends/CPU/CPUFunction.h @@ -28,12 +28,6 @@ class CPUFunction final : public CompiledFunction { /// initializes the LLVM backends. std::unique_ptr JIT_; - /// Base address for Activations memory block. - uint8_t *baseActivationsAddress_{}; - - /// Base address for Mutable weights memory block, Inputs and Outputs. - uint8_t *baseMutableWeightVarsAddress_{}; - public: /// Ctor. CPUFunction(std::unique_ptr JIT, @@ -42,24 +36,19 @@ class CPUFunction final : public CompiledFunction { /// Collects constants for runtime. void collectConstants(IRFunction *F); - /// Allocate Mutable buffers on device this includes Activations and - /// Placeholders. - void setupRuns() override; - - /// Copy Input Placeholder data to position. - void beforeRun(const Context &ctx) override; - - /// Copy Outputs to Placeholders in \p ctx. - void afterRun(const Context &ctx) override; - - /// Final cleanup, free all allocations. - void tearDownRuns() override; - /// \name CompiledFunction interface ///@{ ~CPUFunction() override; - void execute() override; + void execute(Context *ctx) override; ///@} +private: + /// Load constant tensors from \p ctx into \p weightsAddress, as defined by + /// the RuntimeBundle (pre-run). + void loadPlaceholders(Context *ctx, uint8_t *weightsAddress); + + /// Load weights from \p weightsAddress into applicable backing tensors in + /// \p ctx, as defined by the RuntimeBundle (post-run). + void updatePlaceholders(Context *ctx, uint8_t *weightsAddress); }; } // end namespace glow diff --git a/lib/Backends/Interpreter/InterpreterFunction.cpp b/lib/Backends/Interpreter/InterpreterFunction.cpp index c95b7386ad..6ab0f503bc 100644 --- a/lib/Backends/Interpreter/InterpreterFunction.cpp +++ b/lib/Backends/Interpreter/InterpreterFunction.cpp @@ -22,6 +22,7 @@ #include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" using namespace glow; InterpreterFunction::InterpreterFunction(std::unique_ptr F, @@ -29,22 +30,18 @@ InterpreterFunction::InterpreterFunction(std::unique_ptr F, : CompiledFunction(bundle), F_(std::move(F)) {} InterpreterFunction::~InterpreterFunction() { - // Delete the tensors that are owned by this backend. - for (const auto &p : tensors_) { + for (const auto &p : constants_) { delete p.second; } - tensors_.clear(); - externalTensors_.clear(); + constants_.clear(); + alignedFree(runtimeBundle_.getConstants()); tearDownRuns(); } void InterpreterFunction::collectConstants(IRFunction *F) { runtimeBundle_.collectConstants(F); -} - -void InterpreterFunction::setupRuns() { - if (!runsSetup_) { + if (constants_.empty()) { if (runtimeBundle_.getConstantWeightSize()) { for (const auto &v : F_->getGraph()->getParent()->getConstants()) { auto symbolInfo = runtimeBundle_.getSymbolInfo(v); @@ -53,36 +50,27 @@ void InterpreterFunction::setupRuns() { constants_.emplace(std::string(v->getName()), tensor); } } - runsSetup_ = true; - } -} - -void InterpreterFunction::beforeRun(const Context &ctx) { - // Register the concrete tensors that back the placeholder tensors. - for (auto &ph : ctx.pairs()) { - auto *w = F_->getWeightForNode(ph.first); - assert(!externalTensors_.count(w) && "The tensor is already registered"); - externalTensors_[w] = ph.second; } } -void InterpreterFunction::afterRun(const Context &ctx) { - // Remove the concrete tensors that back the placeholder tensors. - for (auto &ph : ctx.pairs()) { - auto *w = F_->getWeightForNode(ph.first); - externalTensors_.erase(w); +void InterpreterFunction::execute(Context *ctx) { + if (constants_.empty()) { + collectConstants(F_.get()); } + BoundInterpreterFunction boundFunc(constants_); + boundFunc.execute(F_.get(), ctx); } -void InterpreterFunction::tearDownRuns() { - for (const auto &p : constants_) { +BoundInterpreterFunction::~BoundInterpreterFunction() { + // Delete the tensors that are owned by this backend. + for (const auto &p : tensors_) { delete p.second; } - constants_.clear(); - runsSetup_ = false; + tensors_.clear(); + externalTensors_.clear(); } -Tensor *InterpreterFunction::getTensor(const Value *v) const { +Tensor *BoundInterpreterFunction::getTensor(const Value *v) const { auto it = tensors_.find(v); if (it != tensors_.end()) { return it->second; @@ -97,7 +85,7 @@ Tensor *InterpreterFunction::getTensor(const Value *v) const { return ie->second; } -Tensor *InterpreterFunction::getOrCreateTensor(const Value *v) { +Tensor *BoundInterpreterFunction::getOrCreateTensor(const Value *v) { auto ie = externalTensors_.find(v); if (ie != externalTensors_.end()) { return ie->second; @@ -117,9 +105,8 @@ Tensor *InterpreterFunction::getOrCreateTensor(const Value *v) { return it->second; } -Tensor * -InterpreterFunction::getOrCreateUnownedTensor(const Value *v, const Value *src, - llvm::ArrayRef offsets) { +Tensor *BoundInterpreterFunction::getOrCreateUnownedTensor( + const Value *v, const Value *src, llvm::ArrayRef offsets) { assert(llvm::isa(v) && "Expected a tensor view"); // Pick the tensor. @@ -136,7 +123,7 @@ InterpreterFunction::getOrCreateUnownedTensor(const Value *v, const Value *src, return T; } -void InterpreterFunction::deleteTensor(const Value *v) { +void BoundInterpreterFunction::deleteTensor(const Value *v) { auto it = tensors_.find(v); if (it == tensors_.end()) { return; @@ -146,7 +133,14 @@ void InterpreterFunction::deleteTensor(const Value *v) { tensors_.erase(it); } -void InterpreterFunction::execute() { +void BoundInterpreterFunction::execute(IRFunction *F, Context *ctx) { + // Register the concrete tensors that back the placeholder tensors. + for (auto &ph : ctx->pairs()) { + auto *w = F->getWeightForNode(ph.first); + assert(!externalTensors_.count(w) && "The tensor is already registered"); + externalTensors_[w] = ph.second; + } + // Do the forward pass. #define DEF_VALUE(CLASS, NAME) #define DEF_INSTR(CLASS, NAME) \ @@ -156,7 +150,7 @@ void InterpreterFunction::execute() { } #define DEF_BACKEND_SPECIFIC_INSTR(CLASS, NAME) // Dispatch the interpreter on each instruction in the program: - for (const auto &I : F_->getInstrs()) { + for (const auto &I : F->getInstrs()) { switch (I.getKind()) { #include "glow/AutoGenInstr.def" @@ -164,4 +158,10 @@ void InterpreterFunction::execute() { llvm_unreachable("Invalid instruction."); } } + + // Remove the concrete tensors that back the placeholder tensors. + for (auto &ph : ctx->pairs()) { + auto *w = F->getWeightForNode(ph.first); + externalTensors_.erase(w); + } } diff --git a/lib/Backends/Interpreter/InterpreterFunction.h b/lib/Backends/Interpreter/InterpreterFunction.h index 9fb49aa187..8ac6832e1a 100644 --- a/lib/Backends/Interpreter/InterpreterFunction.h +++ b/lib/Backends/Interpreter/InterpreterFunction.h @@ -45,10 +45,7 @@ class Constant; class InterpreterFunction final : public CompiledFunction { /// The IR to be executed. std::unique_ptr F_; - /// Maps values to Tensors, that are owned by this class. - std::unordered_map tensors_; - /// Maps values to Tensors, that are *not* owned by this class. - std::unordered_map externalTensors_; + /// Maps Value.name to tensors for constants. std::unordered_map constants_; @@ -60,29 +57,34 @@ class InterpreterFunction final : public CompiledFunction { ///@{ ~InterpreterFunction() override; - /// Does any needed initialization work for the Backend, creates tensors from - /// constants. - /// Collects constants for runtime. void collectConstants(IRFunction *F); - void setupRuns() override; - - /// Per run setup, adds references for tensors from \p ctx to - /// externalTensors_. - void beforeRun(const Context &ctx) override; - - /// Per run cleanup, removes references for tensors from \p ctx from - /// externalTensors_. - void afterRun(const Context &ctx) override; + void execute(Context *ctx) override; - /// Final cleanup, remove created constant Tensors. - void tearDownRuns() override; - - void execute() override; /// Get reference to IR function. IRFunction *getIR() { return F_.get(); } ///@} +}; + +/// An InterpreterFunction bound to a specific invocation. +class BoundInterpreterFunction { + /// Maps values to Tensors, that are owned by this class. + std::unordered_map tensors_; + + /// Maps values to Tensors, that are *not* owned by this class. + std::unordered_map externalTensors_; + + /// A reference to the constant map from the owning InterpreterFunction. + const std::unordered_map &constants_; + +public: + BoundInterpreterFunction( + const std::unordered_map &constants) + : constants_(constants) {} + ~BoundInterpreterFunction(); + + void execute(IRFunction *F, Context *ctx); private: /// \returns a pointer to the tensor that is saved under \p v. @@ -108,8 +110,9 @@ class InterpreterFunction final : public CompiledFunction { return getTensor(v)->getHandle(); } - /// @name Interpreter methods. This is a list of method declerations that are - /// used by the interpreter to dispatch different instructions. + /// @name BoundInterpreterFunction methods. This is a list of method + /// declerations that are used by the interpreter to dispatch different + /// instructions. ///@{ #define DEF_VALUE(CLASS, NAME) diff --git a/lib/Backends/Interpreter/InterpreterNodes.cpp b/lib/Backends/Interpreter/InterpreterNodes.cpp index 03cf45aaea..7f3648b0d5 100644 --- a/lib/Backends/Interpreter/InterpreterNodes.cpp +++ b/lib/Backends/Interpreter/InterpreterNodes.cpp @@ -78,7 +78,7 @@ using namespace glow; /// This is the floating point implementation of Convolution. template -void InterpreterFunction::fwdConvolutionInstFloatImpl( +void BoundInterpreterFunction::fwdConvolutionInstFloatImpl( Value *inV, Value *outV, Value *filterV, Value *biasV, llvm::ArrayRef kernelSizes, llvm::ArrayRef strides, llvm::ArrayRef pads, size_t group) { @@ -148,7 +148,7 @@ void InterpreterFunction::fwdConvolutionInstFloatImpl( /// This is the quantized implementation of Convolution. /// For bias, we support int32 quantization. template -void InterpreterFunction::fwdConvolutionInstQuantizedImpl( +void BoundInterpreterFunction::fwdConvolutionInstQuantizedImpl( Value *inV, Value *outV, Value *filterV, Value *biasV, llvm::ArrayRef kernelSizes, llvm::ArrayRef strides, llvm::ArrayRef pads, size_t group) { @@ -242,7 +242,7 @@ void InterpreterFunction::fwdConvolutionInstQuantizedImpl( } // N } -void InterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) { +void BoundInterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) { auto kernelSizes = I->getKernels(); auto pads = I->getPads(); auto strides = I->getStrides(); @@ -262,7 +262,8 @@ void InterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) { kernelSizes, strides, pads, group); } -void InterpreterFunction::fwdConvolutionGradInst(const ConvolutionGradInst *I) { +void BoundInterpreterFunction::fwdConvolutionGradInst( + const ConvolutionGradInst *I) { auto inW = getWeightHandle(I->getSrc()); auto inG = getWeightHandle(I->getSrcGrad()); auto outG = getWeightHandle(I->getDestGrad()); @@ -400,7 +401,7 @@ static void fwdMaxPool(Tensor *inW, Tensor *outW, Handle *SXY, } // N } -void InterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) { +void BoundInterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) { auto inW = getTensor(I->getSrc()); auto outW = getTensor(I->getDest()); @@ -416,7 +417,8 @@ void InterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) { I->getPads()); } -void InterpreterFunction::fwdMaxPoolWithXYInst(const MaxPoolWithXYInst *I) { +void BoundInterpreterFunction::fwdMaxPoolWithXYInst( + const MaxPoolWithXYInst *I) { auto inW = getTensor(I->getSrc()); auto outW = getTensor(I->getDest()); auto SXY = getWeightHandle(I->getSrcXY()); @@ -433,7 +435,7 @@ void InterpreterFunction::fwdMaxPoolWithXYInst(const MaxPoolWithXYInst *I) { } template -void InterpreterFunction::fwdAvgPoolInstFloatImpl(const AvgPoolInst *I) { +void BoundInterpreterFunction::fwdAvgPoolInstFloatImpl(const AvgPoolInst *I) { staticAssertFloatingPointType(ElemTy); ShapeNHWC odim(I->getDest()->dims()); @@ -481,7 +483,7 @@ void InterpreterFunction::fwdAvgPoolInstFloatImpl(const AvgPoolInst *I) { } // N } -void InterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) { +void BoundInterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) { ShapeNHWC odim(I->getDest()->dims()); ShapeNHWC idim(I->getSrc()->dims()); @@ -534,7 +536,7 @@ void InterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) { } // N } -void InterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) { +void BoundInterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) { if (I->getSrc()->getType()->isQuantizedType()) { fwdAvgPoolInstI8Impl(I); return; @@ -544,7 +546,7 @@ void InterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) { I->getSrc()->getElementType(), I); } -void InterpreterFunction::fwdMaxPoolWithXYGradInst( +void BoundInterpreterFunction::fwdMaxPoolWithXYGradInst( const MaxPoolWithXYGradInst *I) { auto inG = getWeightHandle(I->getSrcGrad()); auto outW = getWeightHandle(I->getDest()); @@ -578,7 +580,7 @@ void InterpreterFunction::fwdMaxPoolWithXYGradInst( } // N } -void InterpreterFunction::fwdAvgPoolGradInst(const AvgPoolGradInst *I) { +void BoundInterpreterFunction::fwdAvgPoolGradInst(const AvgPoolGradInst *I) { auto inG = getWeightHandle(I->getSrcGrad()); auto outW = getWeightHandle(I->getDest()); auto outG = getWeightHandle(I->getDestGrad()); @@ -630,7 +632,7 @@ void InterpreterFunction::fwdAvgPoolGradInst(const AvgPoolGradInst *I) { // Activation functions //===----------------------------------------------------------------------===// template -void InterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) { +void BoundInterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) { staticAssertFloatingPointType(ElemTy); auto inW = getWeightHandle(I->getSrc()); @@ -642,13 +644,13 @@ void InterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) { } } -void InterpreterFunction::fwdSigmoidInst(const SigmoidInst *I) { +void BoundInterpreterFunction::fwdSigmoidInst(const SigmoidInst *I) { dispatchFloatingPointImpl(fwdSigmoidInstFloatImpl, I->getSrc()->getElementType(), I); } template -void InterpreterFunction::fwdTanhInstFloatImpl(const TanhInst *I) { +void BoundInterpreterFunction::fwdTanhInstFloatImpl(const TanhInst *I) { staticAssertFloatingPointType(ElemTy); auto inW = getWeightHandle(I->getSrc()); @@ -660,7 +662,7 @@ void InterpreterFunction::fwdTanhInstFloatImpl(const TanhInst *I) { } } -void InterpreterFunction::fwdTanhInst(const TanhInst *I) { +void BoundInterpreterFunction::fwdTanhInst(const TanhInst *I) { dispatchFloatingPointImpl(fwdTanhInstFloatImpl, I->getSrc()->getElementType(), I); } @@ -670,7 +672,7 @@ void InterpreterFunction::fwdTanhInst(const TanhInst *I) { //===----------------------------------------------------------------------===// template -void InterpreterFunction::fwdSoftMaxInstImpl(const SoftMaxInst *I) { +void BoundInterpreterFunction::fwdSoftMaxInstImpl(const SoftMaxInst *I) { staticAssertFloatingPointType(ElemTy); auto inW = getWeightHandle(I->getSrc()); @@ -699,12 +701,12 @@ void InterpreterFunction::fwdSoftMaxInstImpl(const SoftMaxInst *I) { } // N } -void InterpreterFunction::fwdSoftMaxInst(const SoftMaxInst *I) { +void BoundInterpreterFunction::fwdSoftMaxInst(const SoftMaxInst *I) { dispatchFloatingPointImpl(fwdSoftMaxInstImpl, I->getSrc()->getElementType(), I); } -void InterpreterFunction::fwdSoftMaxGradInst(const SoftMaxGradInst *I) { +void BoundInterpreterFunction::fwdSoftMaxGradInst(const SoftMaxGradInst *I) { auto inG = getWeightHandle(I->getSrcGrad()); auto idim = inG.dims(); auto outW = getWeightHandle(I->getOrigDest()); @@ -723,7 +725,7 @@ void InterpreterFunction::fwdSoftMaxGradInst(const SoftMaxGradInst *I) { } template -void InterpreterFunction::fwdCrossEntropyLossInstFloatImpl( +void BoundInterpreterFunction::fwdCrossEntropyLossInstFloatImpl( const CrossEntropyLossInst *I) { staticAssertFloatingPointType(ElemTy); @@ -739,13 +741,13 @@ void InterpreterFunction::fwdCrossEntropyLossInstFloatImpl( } } -void InterpreterFunction::fwdCrossEntropyLossInst( +void BoundInterpreterFunction::fwdCrossEntropyLossInst( const CrossEntropyLossInst *I) { dispatchFloatingPointImpl(fwdCrossEntropyLossInstFloatImpl, I->getP()->getElementType(), I); } -void InterpreterFunction::fwdCrossEntropyLossGradInst( +void BoundInterpreterFunction::fwdCrossEntropyLossGradInst( const CrossEntropyLossGradInst *I) { auto P = getWeightHandle(I->getP()); auto Labels = getWeightHandle(I->getLabels()); @@ -763,13 +765,13 @@ void InterpreterFunction::fwdCrossEntropyLossGradInst( // Tensor shape (copy/transpose/concat/...) //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdCopyInst(const CopyInst *I) { +void BoundInterpreterFunction::fwdCopyInst(const CopyInst *I) { auto inT = getTensor(I->getSrc()); auto outT = getTensor(I->getDest()); outT->copyRawFrom(inT); } -void InterpreterFunction::fwdTransposeInst(const TransposeInst *I) { +void BoundInterpreterFunction::fwdTransposeInst(const TransposeInst *I) { auto inT = getTensor(I->getSrc()); (void)inT; auto outT = getTensor(I->getDest()); @@ -783,11 +785,11 @@ void InterpreterFunction::fwdTransposeInst(const TransposeInst *I) { } } -void InterpreterFunction::fwdTensorViewInst(const TensorViewInst *I) { +void BoundInterpreterFunction::fwdTensorViewInst(const TensorViewInst *I) { getOrCreateUnownedTensor(I, I->getSrc(), I->getOffsets()); } -void InterpreterFunction::fwdSplatInst(const glow::SplatInst *I) { +void BoundInterpreterFunction::fwdSplatInst(const glow::SplatInst *I) { auto *T = getTensor(I->getDest()); ElemKind k = T->getElementType(); @@ -815,7 +817,8 @@ void InterpreterFunction::fwdSplatInst(const glow::SplatInst *I) { llvm_unreachable("Unsupported tensor type"); } -void InterpreterFunction::fwdInsertTensorInst(const glow::InsertTensorInst *I) { +void BoundInterpreterFunction::fwdInsertTensorInst( + const glow::InsertTensorInst *I) { Tensor *outT = getTensor(I->getDest()); Tensor *inT = getTensor(I->getSrc()); ElemKind k = outT->getElementType(); @@ -835,7 +838,7 @@ void InterpreterFunction::fwdInsertTensorInst(const glow::InsertTensorInst *I) { llvm_unreachable("Unsupported tensor type"); } -void InterpreterFunction::fwdExtractTensorInst( +void BoundInterpreterFunction::fwdExtractTensorInst( const glow::ExtractTensorInst *I) { Tensor *outT = getTensor(I->getDest()); Tensor *inT = getTensor(I->getSrc()); @@ -857,7 +860,7 @@ void InterpreterFunction::fwdExtractTensorInst( } template -void InterpreterFunction::fwdGatherInstImpl(const glow::GatherInst *I) { +void BoundInterpreterFunction::fwdGatherInstImpl(const glow::GatherInst *I) { Tensor *dataT = getTensor(I->getData()); auto &dataTy = dataT->getType(); Tensor *indicesT = getTensor(I->getIndices()); @@ -895,7 +898,7 @@ void InterpreterFunction::fwdGatherInstImpl(const glow::GatherInst *I) { } } -void InterpreterFunction::fwdGatherInst(const glow::GatherInst *I) { +void BoundInterpreterFunction::fwdGatherInst(const glow::GatherInst *I) { switch (I->getIndices()->getElementType()) { case ElemKind::Int64ITy: fwdGatherInstImpl(I); @@ -909,7 +912,7 @@ void InterpreterFunction::fwdGatherInst(const glow::GatherInst *I) { } template -void InterpreterFunction::fwdGatherRangesInstImpl( +void BoundInterpreterFunction::fwdGatherRangesInstImpl( const glow::GatherRangesInst *I) { Tensor *dataT = getTensor(I->getData()); auto &dataTy = dataT->getType(); @@ -977,7 +980,8 @@ void InterpreterFunction::fwdGatherRangesInstImpl( assert(grandTotalLen == (outP / dataElementSize)); } -void InterpreterFunction::fwdGatherRangesInst(const glow::GatherRangesInst *I) { +void BoundInterpreterFunction::fwdGatherRangesInst( + const glow::GatherRangesInst *I) { switch (I->getRanges()->getElementType()) { case ElemKind::Int64ITy: fwdGatherRangesInstImpl(I); @@ -990,7 +994,7 @@ void InterpreterFunction::fwdGatherRangesInst(const glow::GatherRangesInst *I) { } } -void InterpreterFunction::fwdScatterAssignInst( +void BoundInterpreterFunction::fwdScatterAssignInst( const glow::ScatterAssignInst *I) { Tensor *dataT = getTensor(I->getData()); Tensor *indicesT = getTensor(I->getIndices()); @@ -1010,7 +1014,8 @@ void InterpreterFunction::fwdScatterAssignInst( } template -void InterpreterFunction::fwdBatchOneHotImpl(const glow::BatchOneHotInst *I) { +void BoundInterpreterFunction::fwdBatchOneHotImpl( + const glow::BatchOneHotInst *I) { auto dataH = getWeightHandle(I->getData()); auto lengthsH = getWeightHandle(I->getLengths()); auto valuesH = getWeightHandle(I->getValues()); @@ -1034,7 +1039,8 @@ void InterpreterFunction::fwdBatchOneHotImpl(const glow::BatchOneHotInst *I) { } } -void InterpreterFunction::fwdBatchOneHotInst(const glow::BatchOneHotInst *I) { +void BoundInterpreterFunction::fwdBatchOneHotInst( + const glow::BatchOneHotInst *I) { switch (I->getData()->getElementType()) { case ElemKind::Int64ITy: fwdBatchOneHotImpl(I); @@ -1053,7 +1059,7 @@ void InterpreterFunction::fwdBatchOneHotInst(const glow::BatchOneHotInst *I) { //===----------------------------------------------------------------------===// template -void InterpreterFunction::fwdLocalResponseNormalizationInstFloatImpl( +void BoundInterpreterFunction::fwdLocalResponseNormalizationInstFloatImpl( const glow::LocalResponseNormalizationInst *I) { staticAssertFloatingPointType(ElemTy); @@ -1111,13 +1117,13 @@ void InterpreterFunction::fwdLocalResponseNormalizationInstFloatImpl( } } -void InterpreterFunction::fwdLocalResponseNormalizationInst( +void BoundInterpreterFunction::fwdLocalResponseNormalizationInst( const LocalResponseNormalizationInst *I) { dispatchFloatingPointImpl(fwdLocalResponseNormalizationInstFloatImpl, I->getSrc()->getElementType(), I); } -void InterpreterFunction::fwdLocalResponseNormalizationGradInst( +void BoundInterpreterFunction::fwdLocalResponseNormalizationGradInst( const glow::LocalResponseNormalizationGradInst *I) { auto inW = getWeightHandle(I->getSrc()); auto inG = getWeightHandle(I->getSrcGrad()); @@ -1190,7 +1196,8 @@ void InterpreterFunction::fwdLocalResponseNormalizationGradInst( //===----------------------------------------------------------------------===// // Arithmetic operations //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdElementAddInstI8Impl(const ElementAddInst *I) { +void BoundInterpreterFunction::fwdElementAddInstI8Impl( + const ElementAddInst *I) { assert(getTensor(I->getLHS())->getType().isQuantizedType() && "Wrong function"); auto lhsTy = I->getLHS()->getType(); @@ -1224,7 +1231,8 @@ void InterpreterFunction::fwdElementAddInstI8Impl(const ElementAddInst *I) { } template -void InterpreterFunction::fwdElementAddInstFloatImpl(const ElementAddInst *I) { +void BoundInterpreterFunction::fwdElementAddInstFloatImpl( + const ElementAddInst *I) { staticAssertFloatingPointType(ElemTy); auto outW = getWeightHandle(I->getDest()); @@ -1235,7 +1243,7 @@ void InterpreterFunction::fwdElementAddInstFloatImpl(const ElementAddInst *I) { } } -void InterpreterFunction::fwdElementAddInst(const ElementAddInst *I) { +void BoundInterpreterFunction::fwdElementAddInst(const ElementAddInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { fwdElementAddInstI8Impl(I); return; @@ -1246,7 +1254,8 @@ void InterpreterFunction::fwdElementAddInst(const ElementAddInst *I) { } template -void InterpreterFunction::fwdElementSubInstFloatImpl(const ElementSubInst *I) { +void BoundInterpreterFunction::fwdElementSubInstFloatImpl( + const ElementSubInst *I) { staticAssertFloatingPointType(ElemTy); auto outW = getWeightHandle(I->getDest()); @@ -1257,7 +1266,7 @@ void InterpreterFunction::fwdElementSubInstFloatImpl(const ElementSubInst *I) { } } -void InterpreterFunction::fwdElementSubInst(const ElementSubInst *I) { +void BoundInterpreterFunction::fwdElementSubInst(const ElementSubInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto destTy = I->getDest()->getType(); auto lhsTy = I->getLHS()->getType(); @@ -1290,7 +1299,8 @@ void InterpreterFunction::fwdElementSubInst(const ElementSubInst *I) { } template -void InterpreterFunction::fwdElementMulInstFloatImpl(const ElementMulInst *I) { +void BoundInterpreterFunction::fwdElementMulInstFloatImpl( + const ElementMulInst *I) { staticAssertFloatingPointType(ElemTy); auto outW = getWeightHandle(I->getDest()); @@ -1301,7 +1311,7 @@ void InterpreterFunction::fwdElementMulInstFloatImpl(const ElementMulInst *I) { } } -void InterpreterFunction::fwdElementMulInst(const ElementMulInst *I) { +void BoundInterpreterFunction::fwdElementMulInst(const ElementMulInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto lhsTy = I->getLHS()->getType(); auto rhsTy = I->getRHS()->getType(); @@ -1327,7 +1337,7 @@ void InterpreterFunction::fwdElementMulInst(const ElementMulInst *I) { I->getDest()->getElementType(), I); } -void InterpreterFunction::fwdElementDivInst(const ElementDivInst *I) { +void BoundInterpreterFunction::fwdElementDivInst(const ElementDivInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto destTy = I->getDest()->getType(); auto lhsTy = I->getLHS()->getType(); @@ -1382,7 +1392,8 @@ void InterpreterFunction::fwdElementDivInst(const ElementDivInst *I) { } } -void InterpreterFunction::fwdElementMaxInstI8Impl(const ElementMaxInst *I) { +void BoundInterpreterFunction::fwdElementMaxInstI8Impl( + const ElementMaxInst *I) { assert(getTensor(I->getLHS())->getType().isQuantizedType() && "Wrong function"); auto lhsTy = I->getLHS()->getType(); @@ -1408,7 +1419,8 @@ void InterpreterFunction::fwdElementMaxInstI8Impl(const ElementMaxInst *I) { } template -void InterpreterFunction::fwdElementMaxInstFloatImpl(const ElementMaxInst *I) { +void BoundInterpreterFunction::fwdElementMaxInstFloatImpl( + const ElementMaxInst *I) { staticAssertFloatingPointType(ElemTy); auto outW = getWeightHandle(I->getDest()); @@ -1419,7 +1431,7 @@ void InterpreterFunction::fwdElementMaxInstFloatImpl(const ElementMaxInst *I) { } } -void InterpreterFunction::fwdElementMaxInst(const ElementMaxInst *I) { +void BoundInterpreterFunction::fwdElementMaxInst(const ElementMaxInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { fwdElementMaxInstI8Impl(I); return; @@ -1430,7 +1442,8 @@ void InterpreterFunction::fwdElementMaxInst(const ElementMaxInst *I) { } template -void InterpreterFunction::fwdElementMinInstFloatImpl(const ElementMinInst *I) { +void BoundInterpreterFunction::fwdElementMinInstFloatImpl( + const ElementMinInst *I) { staticAssertFloatingPointType(ElemTy); auto outW = getWeightHandle(I->getDest()); @@ -1441,7 +1454,7 @@ void InterpreterFunction::fwdElementMinInstFloatImpl(const ElementMinInst *I) { } } -void InterpreterFunction::fwdElementMinInst(const ElementMinInst *I) { +void BoundInterpreterFunction::fwdElementMinInst(const ElementMinInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto lhsTy = I->getLHS()->getType(); auto rhsTy = I->getRHS()->getType(); @@ -1471,7 +1484,7 @@ void InterpreterFunction::fwdElementMinInst(const ElementMinInst *I) { } template -void InterpreterFunction::fwdElementCmpLTEInstFloatImpl( +void BoundInterpreterFunction::fwdElementCmpLTEInstFloatImpl( const ElementCmpLTEInst *I) { staticAssertFloatingPointType(ElemTy); @@ -1485,7 +1498,8 @@ void InterpreterFunction::fwdElementCmpLTEInstFloatImpl( // For both quantized and non-quantized CmpLTE, we set the result to 1.0/0.0. // In the quantized case, we assume that the scale params are (1.0, 0). -void InterpreterFunction::fwdElementCmpLTEInst(const ElementCmpLTEInst *I) { +void BoundInterpreterFunction::fwdElementCmpLTEInst( + const ElementCmpLTEInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto lhsTy = I->getLHS()->getType(); auto rhsTy = I->getRHS()->getType(); @@ -1513,7 +1527,8 @@ void InterpreterFunction::fwdElementCmpLTEInst(const ElementCmpLTEInst *I) { } template -void InterpreterFunction::fwdElementCmpEQInstImpl(const ElementCmpEQInst *I) { +void BoundInterpreterFunction::fwdElementCmpEQInstImpl( + const ElementCmpEQInst *I) { auto outW = getWeightHandle(I->getDest()); auto lhsW = getWeightHandle(I->getLHS()); auto rhsW = getWeightHandle(I->getRHS()); @@ -1522,7 +1537,7 @@ void InterpreterFunction::fwdElementCmpEQInstImpl(const ElementCmpEQInst *I) { } } -void InterpreterFunction::fwdElementCmpEQInst(const ElementCmpEQInst *I) { +void BoundInterpreterFunction::fwdElementCmpEQInst(const ElementCmpEQInst *I) { auto *T = getTensor(I->getDest()); switch (T->getElementType()) { @@ -1536,7 +1551,8 @@ void InterpreterFunction::fwdElementCmpEQInst(const ElementCmpEQInst *I) { } template -void InterpreterFunction::fwdElementPowInstFloatImpl(const ElementPowInst *I) { +void BoundInterpreterFunction::fwdElementPowInstFloatImpl( + const ElementPowInst *I) { staticAssertFloatingPointType(ElemTy); auto baseW = getWeightHandle(I->getLHS()); @@ -1547,13 +1563,14 @@ void InterpreterFunction::fwdElementPowInstFloatImpl(const ElementPowInst *I) { } } -void InterpreterFunction::fwdElementPowInst(const glow::ElementPowInst *I) { +void BoundInterpreterFunction::fwdElementPowInst( + const glow::ElementPowInst *I) { dispatchFloatingPointImpl(fwdElementPowInstFloatImpl, I->getLHS()->getElementType(), I); } template -void InterpreterFunction::fwdElementIsNaNInstFloatImpl( +void BoundInterpreterFunction::fwdElementIsNaNInstFloatImpl( const ElementIsNaNInst *I) { staticAssertFloatingPointType(ElemTy); @@ -1565,13 +1582,15 @@ void InterpreterFunction::fwdElementIsNaNInstFloatImpl( } } -void InterpreterFunction::fwdElementIsNaNInst(const glow::ElementIsNaNInst *I) { +void BoundInterpreterFunction::fwdElementIsNaNInst( + const glow::ElementIsNaNInst *I) { dispatchFloatingPointImpl(fwdElementIsNaNInstFloatImpl, I->getSrc()->getElementType(), I); } template -void InterpreterFunction::fwdElementLogInstFloatImpl(const ElementLogInst *I) { +void BoundInterpreterFunction::fwdElementLogInstFloatImpl( + const ElementLogInst *I) { staticAssertFloatingPointType(ElemTy); auto inW = getWeightHandle(I->getSrc()); @@ -1582,13 +1601,13 @@ void InterpreterFunction::fwdElementLogInstFloatImpl(const ElementLogInst *I) { } } -void InterpreterFunction::fwdElementLogInst(const ElementLogInst *I) { +void BoundInterpreterFunction::fwdElementLogInst(const ElementLogInst *I) { dispatchFloatingPointImpl(fwdElementLogInstFloatImpl, I->getSrc()->getElementType(), I); } template -void InterpreterFunction::fwdElementSelectInstFloatImpl( +void BoundInterpreterFunction::fwdElementSelectInstFloatImpl( const glow::ElementSelectInst *I) { staticAssertFloatingPointType(ElemTy); @@ -1601,7 +1620,7 @@ void InterpreterFunction::fwdElementSelectInstFloatImpl( } } -void InterpreterFunction::fwdElementSelectInst( +void BoundInterpreterFunction::fwdElementSelectInst( const glow::ElementSelectInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { auto destTy = I->getDest()->getType(); @@ -1637,7 +1656,7 @@ void InterpreterFunction::fwdElementSelectInst( // Mat Mul //===----------------------------------------------------------------------===// template -void InterpreterFunction::fwdMatMulInstQuantizedImpl( +void BoundInterpreterFunction::fwdMatMulInstQuantizedImpl( const glow::MatMulInst *I) { assert(getTensor(I->getLHS())->getType().isQuantizedType()); auto lhs = getWeightHandle(I->getLHS()); @@ -1683,7 +1702,7 @@ void InterpreterFunction::fwdMatMulInstQuantizedImpl( } template -void InterpreterFunction::fwdMatMulInstFloatImpl(const MatMulInst *I) { +void BoundInterpreterFunction::fwdMatMulInstFloatImpl(const MatMulInst *I) { staticAssertFloatingPointType(ElemTy); auto lhs = getWeightHandle(I->getLHS()); @@ -1709,7 +1728,7 @@ void InterpreterFunction::fwdMatMulInstFloatImpl(const MatMulInst *I) { } } -void InterpreterFunction::fwdMatMulInst(const glow::MatMulInst *I) { +void BoundInterpreterFunction::fwdMatMulInst(const glow::MatMulInst *I) { if (getTensor(I->getLHS())->getType().isQuantizedType()) { dispatchQuantizedWithAccumulationImpl(fwdMatMulInstQuantizedImpl, I->getLHS()->getElementType(), I); @@ -1723,7 +1742,7 @@ void InterpreterFunction::fwdMatMulInst(const glow::MatMulInst *I) { //===----------------------------------------------------------------------===// // Row-wise quantized FC //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdRowwiseQuantizedFullyConnectedInst( +void BoundInterpreterFunction::fwdRowwiseQuantizedFullyConnectedInst( const RowwiseQuantizedFullyConnectedInst *I) { auto inW = getWeightHandle(I->getSrc()); auto outW = getWeightHandle(I->getDest()); @@ -1811,7 +1830,7 @@ static void fwdBatchedAdd(Tensor *batch, Tensor *slice, Tensor *dest) { } template -void InterpreterFunction::fwdBatchedAddInstFloatImpl( +void BoundInterpreterFunction::fwdBatchedAddInstFloatImpl( const glow::BatchedAddInst *I) { staticAssertFloatingPointType(ElemTy); @@ -1834,7 +1853,8 @@ void InterpreterFunction::fwdBatchedAddInstFloatImpl( } } -void InterpreterFunction::fwdBatchedAddInst(const glow::BatchedAddInst *I) { +void BoundInterpreterFunction::fwdBatchedAddInst( + const glow::BatchedAddInst *I) { if (getTensor(I->getBatch())->getType().isQuantizedType()) { dispatchQuantizedImpl(fwdBatchedAdd, I->getSlice()->getElementType(), getTensor(I->getBatch()), getTensor(I->getSlice()), @@ -1846,7 +1866,7 @@ void InterpreterFunction::fwdBatchedAddInst(const glow::BatchedAddInst *I) { } template -void InterpreterFunction::fwdBatchedReduceAddInstFloatImpl( +void BoundInterpreterFunction::fwdBatchedReduceAddInstFloatImpl( Value *batch, Value *dest, unsigned_t axis, const ShapeVector &eBatchDims, const ShapeVector &eDestDims) { staticAssertFloatingPointType(ElemTy); @@ -1878,7 +1898,7 @@ void InterpreterFunction::fwdBatchedReduceAddInstFloatImpl( } } -void InterpreterFunction::fwdBatchedReduceAddInst( +void BoundInterpreterFunction::fwdBatchedReduceAddInst( const glow::BatchedReduceAddInst *I) { static_assert(max_tensor_dimensions == 6, "Loops below assume max_tensor_dimensions = 6."); @@ -1956,7 +1976,8 @@ void InterpreterFunction::fwdBatchedReduceAddInst( } template -void InterpreterFunction::fwdLengthsSumInstFloatImpl(const LengthsSumInst *I) { +void BoundInterpreterFunction::fwdLengthsSumInstFloatImpl( + const LengthsSumInst *I) { staticAssertFloatingPointType(ElemTy); auto out = getTensor(I->getDest()); @@ -1989,13 +2010,13 @@ void InterpreterFunction::fwdLengthsSumInstFloatImpl(const LengthsSumInst *I) { assert(offsetOut == out->size() && "All values in Dest should be written to"); } -void InterpreterFunction::fwdLengthsSumInst(const LengthsSumInst *I) { +void BoundInterpreterFunction::fwdLengthsSumInst(const LengthsSumInst *I) { dispatchFloatingPointImpl(fwdLengthsSumInstFloatImpl, I->getData()->getElementType(), I) } template -void InterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl( +void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl( const SparseLengthsWeightedSumInst *I) { staticAssertFloatingPointType(ElemTy); @@ -2036,7 +2057,7 @@ void InterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl( } } -void InterpreterFunction::fwdSparseLengthsWeightedSumInstI8Impl( +void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstI8Impl( const SparseLengthsWeightedSumInst *I) { auto out = getTensor(I->getDest()); @@ -2088,7 +2109,7 @@ void InterpreterFunction::fwdSparseLengthsWeightedSumInstI8Impl( } } -void InterpreterFunction::fwdSparseLengthsWeightedSumInst( +void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInst( const SparseLengthsWeightedSumInst *I) { if (I->getDest()->getType()->isQuantizedType()) { return fwdSparseLengthsWeightedSumInstI8Impl(I); @@ -2097,7 +2118,7 @@ void InterpreterFunction::fwdSparseLengthsWeightedSumInst( I->getData()->getElementType(), I); } -void InterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumInst( +void BoundInterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumInst( const RowwiseQuantizedSparseLengthsWeightedSumInst *I) { auto *out = getTensor(I->getDest()); auto *data = getTensor(I->getData()); @@ -2146,7 +2167,8 @@ void InterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumInst( } } -void InterpreterFunction::fwdLengthsToRangesInst(const LengthsToRangesInst *I) { +void BoundInterpreterFunction::fwdLengthsToRangesInst( + const LengthsToRangesInst *I) { auto ranges = getTensor(I->getDest())->getHandle(); auto lengths = getTensor(I->getLengths())->getHandle(); int32_t offset = 0; @@ -2159,7 +2181,7 @@ void InterpreterFunction::fwdLengthsToRangesInst(const LengthsToRangesInst *I) { } template -void InterpreterFunction::fwdSparseToDenseInstFloatImpl( +void BoundInterpreterFunction::fwdSparseToDenseInstFloatImpl( const SparseToDenseInst *I) { staticAssertFloatingPointType(ElemTy); @@ -2207,7 +2229,8 @@ void InterpreterFunction::fwdSparseToDenseInstFloatImpl( } } -void InterpreterFunction::fwdSparseToDenseInst(const SparseToDenseInst *I) { +void BoundInterpreterFunction::fwdSparseToDenseInst( + const SparseToDenseInst *I) { dispatchFloatingPointImpl(fwdSparseToDenseInstFloatImpl, I->getDest()->getElementType(), I); } @@ -2250,7 +2273,7 @@ static void fwdTopK(Tensor *outW, Tensor *indW, Tensor *inW, size_t k) { // Sorting operators //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdTopKInst(const TopKInst *I) { +void BoundInterpreterFunction::fwdTopKInst(const TopKInst *I) { auto outW = getTensor(I->getValues()); auto indW = getTensor(I->getIndices()); auto inW = getTensor(I->getInput()); @@ -2268,11 +2291,12 @@ void InterpreterFunction::fwdTopKInst(const TopKInst *I) { // Tensor allocation operations //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdAllocActivationInst(const AllocActivationInst *I) { +void BoundInterpreterFunction::fwdAllocActivationInst( + const AllocActivationInst *I) { getOrCreateTensor(I); } -void InterpreterFunction::fwdDeallocActivationInst( +void BoundInterpreterFunction::fwdDeallocActivationInst( const DeallocActivationInst *I) { deleteTensor(I->getSrc()); } @@ -2284,7 +2308,7 @@ void InterpreterFunction::fwdDeallocActivationInst( /// Prints a value of the instruction's operand. /// In most cases it will be the name of the variable and the value of the /// tensor. -void InterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) { +void BoundInterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) { auto *V = I->getSrc(); llvm::outs() << I->getName() << ": "; // Dump the content of a value. @@ -2297,7 +2321,7 @@ void InterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) { //===----------------------------------------------------------------------===// // Instructions used by Quantization //===----------------------------------------------------------------------===// -void InterpreterFunction::fwdQuantizationProfileInst( +void BoundInterpreterFunction::fwdQuantizationProfileInst( const glow::QuantizationProfileInst *I) { auto inputTensor = getWeightHandle(I->getInputTensor()); auto currentHistogram = getWeightHandle(I->getHistogram()); @@ -2313,7 +2337,7 @@ void InterpreterFunction::fwdQuantizationProfileInst( /// Quantize floating point tensor. Scale and Offset are based on return type /// of the instruction \p I. -void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) { +void BoundInterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) { auto *srcTensor = getTensor(I->getSrc()); auto *destTensor = getTensor(I->getDest()); auto destTy = destTensor->getType(); @@ -2325,7 +2349,8 @@ void InterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) { /// Dequantize integer tensor. Scale and Offset are based /// on the source tensor type. -void InterpreterFunction::fwdDequantizeInst(const glow::DequantizeInst *I) { +void BoundInterpreterFunction::fwdDequantizeInst( + const glow::DequantizeInst *I) { auto *srcTensor = getTensor(I->getSrc()); auto *destTensor = getTensor(I->getDest()); auto destTy = destTensor->getType(); @@ -2335,7 +2360,7 @@ void InterpreterFunction::fwdDequantizeInst(const glow::DequantizeInst *I) { } template -void InterpreterFunction::fwdRescaleQuantizedInstImpl( +void BoundInterpreterFunction::fwdRescaleQuantizedInstImpl( Value *src, Value *dest, TensorQuantizationParams &srcQ, TensorQuantizationParams &destQ) { @@ -2348,7 +2373,7 @@ void InterpreterFunction::fwdRescaleQuantizedInstImpl( } } -void InterpreterFunction::fwdRescaleQuantizedInst( +void BoundInterpreterFunction::fwdRescaleQuantizedInst( const glow::RescaleQuantizedInst *I) { auto src = I->getSrc(); auto dest = I->getDest(); @@ -2362,7 +2387,8 @@ void InterpreterFunction::fwdRescaleQuantizedInst( src, dest, srcQ, destQ); } -void InterpreterFunction::fwdIntLookupTableInst(const IntLookupTableInst *I) { +void BoundInterpreterFunction::fwdIntLookupTableInst( + const IntLookupTableInst *I) { auto srcH = getWeightHandle(I->getSrc()); auto destH = getWeightHandle(I->getDest()); auto mappingH = getWeightHandle(I->getMapping()); @@ -2372,7 +2398,7 @@ void InterpreterFunction::fwdIntLookupTableInst(const IntLookupTableInst *I) { } } -void InterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) { +void BoundInterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) { Tensor *source = getTensor(I->getInput()); Tensor *dest = getTensor(I->getResult()); if (source->getType() == dest->getType()) { diff --git a/lib/Backends/OpenCL/OpenCL.cpp b/lib/Backends/OpenCL/OpenCL.cpp index 7038824159..6bfef00d11 100644 --- a/lib/Backends/OpenCL/OpenCL.cpp +++ b/lib/Backends/OpenCL/OpenCL.cpp @@ -621,7 +621,9 @@ static void topK(Tensor &outW, Tensor &indW, Tensor &inW, size_t k) { } } -void OpenCLFunction::execute() { +void OpenCLFunction::execute(Context *ctx) { + (void)ctx; + for (const auto &I : F_->getInstrs()) { // The kernels are named after the name of the instruction, plus the "W" // suffix to prevent name colissions for functions like 'tanh' that are also diff --git a/lib/Backends/OpenCL/OpenCL.h b/lib/Backends/OpenCL/OpenCL.h index 5226dbe271..83d5e31fd7 100644 --- a/lib/Backends/OpenCL/OpenCL.h +++ b/lib/Backends/OpenCL/OpenCL.h @@ -96,7 +96,7 @@ class OpenCLFunction final : public CompiledFunction { ///@{ ~OpenCLFunction() override; - void execute() override; + void execute(Context *ctx) override; ///@} /// Allocates on device buffer and copies Constant weights to device. void setupRuns() override; diff --git a/lib/ExecutionEngine/ExecutionEngine.cpp b/lib/ExecutionEngine/ExecutionEngine.cpp index 96d558ef80..cdf5b9f47b 100644 --- a/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/lib/ExecutionEngine/ExecutionEngine.cpp @@ -82,7 +82,7 @@ void ExecutionEngine::run(Context &ctx) { ctx.allocate(M_.getPlaceholders()); function_->setupRuns(); function_->beforeRun(ctx); - function_->execute(); + function_->execute(&ctx); function_->afterRun(ctx); } diff --git a/tests/unittests/BackendCorrectnessTest.cpp b/tests/unittests/BackendCorrectnessTest.cpp index f96d8f68ff..b06d1dafa4 100644 --- a/tests/unittests/BackendCorrectnessTest.cpp +++ b/tests/unittests/BackendCorrectnessTest.cpp @@ -316,7 +316,7 @@ TEST_P(CPUOnly, dataParallelStackingTest) { auto function = backend.compileIR(std::move(M)); function->setupRuns(); function->beforeRun(ctx); - function->execute(); + function->execute(&ctx); function->afterRun(ctx); function->tearDownRuns(); auto H = outputTensor->getHandle(); diff --git a/tests/unittests/BackendTest.cpp b/tests/unittests/BackendTest.cpp index b96dfc87ec..cc4d0682e8 100644 --- a/tests/unittests/BackendTest.cpp +++ b/tests/unittests/BackendTest.cpp @@ -163,7 +163,7 @@ TEST_P(BackendTest, debugPrint) { auto function = backend->compileIR(std::move(IR)); function->setupRuns(); function->beforeRun(ctx); - function->execute(); + function->execute(&ctx); function->afterRun(ctx); function->tearDownRuns(); } diff --git a/tests/unittests/BackendTestUtils.h b/tests/unittests/BackendTestUtils.h index 9685374bbc..93ad2dcde7 100644 --- a/tests/unittests/BackendTestUtils.h +++ b/tests/unittests/BackendTestUtils.h @@ -23,11 +23,7 @@ namespace glow { /// MockBackend used only for unit testing. class MockBackend : public Backend { class MockFunction : public CompiledFunction { - void execute() override{}; - void setupRuns() override{}; - void beforeRun(const Context &ctx) override{}; - void afterRun(const Context &ctx) override{}; - void tearDownRuns() override{}; + void execute(Context *) override{}; }; std::unique_ptr compile(Function *F) const override { return llvm::make_unique();