From 36fb14b68bf6c2b95d875cb272c814f7cc059721 Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Thu, 23 Jul 2020 18:35:24 -0700 Subject: [PATCH] [quant] Add Graph Mode Passes to quantize EmbeddingBag operators (#41612) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41612 This change adds preliminary support to quantize the EmbeddingBag operators. We currently support 4-bit and 8-bit quantization+packing of the weights. To quantize these operators, specify the operator name in the `custom_op_name` field of the NoopObserver. Based on the op name (4bit or 8bit) we call the corresponding quantization functions. Refer to the testplan for how to invoke the qconfig for the embedding_bag ops. Future versions of this will support 4-bit and 2-bit qtensors with native support to observe and quantize it. NB - This version assumes that the weights in the EmbeddingBag Module reside on the same device. Test Plan: python test/test_quantization.py TestQuantizeDynamicJitOps.test_embedding_bag Imported from OSS Reviewed By: vkuzo, jerryzh168 Differential Revision: D22609342 fbshipit-source-id: 23e33f44a451c26719e6e283e87fbf09b584c0e6 --- test/quantization/test_quantize_jit.py | 37 +++ torch/csrc/jit/passes/quantization/helper.cpp | 8 +- torch/csrc/jit/passes/quantization/helper.h | 7 + .../quantization/insert_quant_dequant.cpp | 230 ++++++++++++++---- torch/quantization/observer.py | 8 +- 5 files changed, 238 insertions(+), 52 deletions(-) diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py index 9527689274..d2132f2e51 100644 --- a/test/quantization/test_quantize_jit.py +++ b/test/quantization/test_quantize_jit.py @@ -2813,6 +2813,43 @@ def forward(self, x): FunctionalLinear(weight, bias), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True) + def test_embedding_bag(self): + class M(torch.nn.Module): + def __init__(self, weights): + super(M, self).__init__() + self.embedding1 = torch.nn.EmbeddingBag(num_embeddings=10, + embedding_dim=12, + include_last_offset=True, + _weight=weights, + mode='sum') + + self.embedding2 = torch.nn.EmbeddingBag(num_embeddings=10, + embedding_dim=12, + include_last_offset=True, + _weight=weights, + mode='sum') + + def forward(self, indices1, offsets1, indices2, offsets2): + e1 = self.embedding1(indices1, offsets1) + e2 = self.embedding2(indices2, offsets2) + return e1, e2 + + weights = torch.randn(10, 12, dtype=torch.float32) + module = M(weights) + m = torch.jit.script(module) + indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) + offsets = torch.tensor([0, 19, 20, 28, 28, 32]) + + from torch.quantization import QConfigDynamic, NoopObserver + int4_dynamic_qconfig = QConfigDynamic(activation=NoopObserver.with_args(custom_op_name="embedding_bag_4bit"), + weight=NoopObserver.with_args(custom_op_name="embedding_bag_4bit")) + int8_dynamic_qconfig = QConfigDynamic(activation=NoopObserver.with_args(custom_op_name="embedding_bag_byte"), + weight=NoopObserver.with_args(custom_op_name="embedding_bag_byte")) + m = quantize_dynamic_jit(m, {'embedding1' : int4_dynamic_qconfig, 'embedding2' : int8_dynamic_qconfig}) + FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \ + .check_next("quantized::embedding_bag_byte_rowwise_offsets") \ + .run(m.graph) + class TestQuantizeJit(QuantizationTestCase): @override_qengines def test_single_linear(self): diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp index ef7556c501..a89981f945 100644 --- a/torch/csrc/jit/passes/quantization/helper.cpp +++ b/torch/csrc/jit/passes/quantization/helper.cpp @@ -46,6 +46,7 @@ std::vector _static_quantizable_aten_funcs = { std::vector _dynamic_quantizable_call_funcs = { "linear", + "embedding_bag", }; std::vector _dynamic_quantizable_aten_funcs = { @@ -221,8 +222,6 @@ bool matchAtenFuncToUse( (!n.has_value() || n.value() == use.offset); } -// Check if `use` is a CallFunction of name `func_name` and if value -// `v` is the nth argument (if provided) of the function bool matchCallFuncToUse( const Use& use, const std::string& func_name, @@ -255,12 +254,15 @@ bool matchArgPattern( return false; } +// TODO add other op signatures. bool isWeight(Value* v) { bool result = matchArgPattern( v, AtenFuncArgs( {{"conv1d", 1}, {"conv2d", 1}, {"conv3d", 1}, {"linear", 1}}), - CallFuncArgs({{"linear", 2}})); + // embedding_bag - prim::CallFunction(%func, %input.1, %weight, + // %offsets.1, %7, %8, %9, %10, %9, %per_sample_weights.1, %13) + CallFuncArgs({{"linear", 2}, {"embedding_bag", 2}})); return result; } diff --git a/torch/csrc/jit/passes/quantization/helper.h b/torch/csrc/jit/passes/quantization/helper.h index 8b933e4c89..b7cd56115a 100644 --- a/torch/csrc/jit/passes/quantization/helper.h +++ b/torch/csrc/jit/passes/quantization/helper.h @@ -99,6 +99,13 @@ TORCH_API bool useQuantizable(const Use& use, QuantType quant_type); // Given a CallFunction node, extract the graph of the called function TORCH_API std::shared_ptr getCallFunctionGraph(Node* n); +// Check if `use` is a CallFunction of name `func_name` and if value +// `v` is the nth argument (if provided) of the function +bool matchCallFuncToUse( + const Use& use, + const std::string& func_name, + c10::optional nth_arg); + // =========== helper functions for Block ========= // checks if a block will always raise an Exception TORCH_API bool alwaysRaisesException(Block* block); diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 7cc4efa2c2..b2c6e06a1a 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -218,6 +218,23 @@ Node* insertFP16CastOps(Graph* graph, Value* observer_out) { return cast_to_fp32; } +// find the observer for Value `v` and return the name of the observer +c10::optional findObserverName(Value* v) { + // Note that here we just check for the name of observer, but the ideally + // we should be comparing the type of observer, this is a temporary + // work around until data only clone of module.clone is supported. + Node* n = v->node(); + if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward") { + auto module_instance = n->inputs().at(0); + if (module_instance->node()->kind() == prim::GetAttr && + module_instance->node()->s(attr::name).find("_observer_") != + std::string::npos) { + return module_instance->node()->s(attr::name); + } + } + return c10::nullopt; +} + bool isNoopObserver(Value* observer) { if (getModuleName(observer).has_value()) { auto name = getModuleName(observer).value(); @@ -228,6 +245,126 @@ bool isNoopObserver(Value* observer) { return false; } +bool isFP16NoopObserver(script::Module& module, Node* n) { + Value* v = n->output(); + auto observer = n->input(0); + auto observer_module = module.attr(findObserverName(v).value()).toModule(); + return (observer_module.attr("dtype") == at::ScalarType::Half) && + isNoopObserver(observer); +} + +c10::optional getEmbeddingBagObsName( + script::Module& module, + Node* n) { + Value* v = n->output(); + auto observer = n->input(0); + auto observer_module = module.attr(findObserverName(v).value()).toModule(); + if (observer_module.hasattr("custom_op")) { + auto op_name = observer_module.attr("custom_op").toStringRef(); + return isNoopObserver(observer) ? op_name : ""; + } + return c10::nullopt; +} + +bool isEmbeddingBagOp( + Node* observer, + c10::optional embedding_bag_name) { + return embedding_bag_name && + embedding_bag_name.value().find("embedding_bag_") != std::string::npos; +} + +// Insert quant and dequant nodes into the graph for both static and dynamic +// quant. +Node* insertQuantDequantNodes( + Value* self, + Node* observer, + const std::vector& qparam_names, + const std::string& quantize_func) { + Graph* g = observer->owningGraph(); + Value* observer_out = observer->output(); + Value* original_val = observer->input(1); + std::vector inputs = {observer_out}; + // Insert GetAttr nodes for quantization parameters + for (const auto& qparam_name : qparam_names) { + inputs.push_back(g->insertGetAttr(self, qparam_name)); + } + Node* quant = insertQuant( + g, + inputs, + at::Symbol::aten(quantize_func), + original_val->debugName() + ".quant"); + Node* dequant = insertDeQuant(g, quant->output(), original_val); + return dequant; +} + +Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { + Graph* g = observer->owningGraph(); + auto observer_out = observer->output(); + + std::string prepack_fn, quant_fn; + if (op_name == "embedding_bag_4bit") { + prepack_fn = "quantized::embedding_bag_4bit_prepack"; + quant_fn = "quantized::embedding_bag_4bit_rowwise_offsets"; + } else if (op_name == "embedding_bag_byte") { + prepack_fn = "quantized::embedding_bag_byte_prepack"; + quant_fn = "quantized::embedding_bag_byte_rowwise_offsets"; + } else { + TORCH_INTERNAL_ASSERT( + "Graph Mode Quantization currently supports 4-bit and 8-bit embedding bag quantization."); + } + + std::vector prepack_inputs = {observer_out}; + std::vector uses = observer_out->uses(); + Node* embedding_bag_float_op; + // We expect that the output of the weight observer will be consumed by the + // embedding_bag operator. + for (const Use& use : uses) { + if (matchCallFuncToUse(use, "embedding_bag", 2)) { + embedding_bag_float_op = use.user; + } + } + TORCH_CHECK( + embedding_bag_float_op->inputs().size() == 11, + "Expecting FP EmbeddingBag operator to have 11 inputs"); + // Insert prepack op + Node* prepack = g->create(Symbol::fromQualString(prepack_fn), prepack_inputs); + g->insertNode(prepack); + + std::vector embedding_bag_inputs = + embedding_bag_float_op->inputs().vec(); + + // Create and insert quantized embedding op. + Value* none = g->insertConstant(IValue()); + Value* zero = g->insertConstant(IValue(0)); + embedding_bag_inputs[3]->setType(TensorType::get()); + + std::vector qembedding_bag_inputs = { + /* weight */ prepack->output(), + /* indices */ embedding_bag_inputs[1], + /* offsets */ embedding_bag_inputs[3], + /* scale_grad_by_freq */ embedding_bag_inputs[6], + /* mode */ zero, + /* sparse */ embedding_bag_inputs[8], + /* per_sample_weights_ */ embedding_bag_inputs[9]}; + + if (op_name == "embedding_bag_4bit") { + // 4-bit op has an extra input compressed_indices_mapping + qembedding_bag_inputs.push_back(none); + } + qembedding_bag_inputs.push_back(embedding_bag_inputs[10]); + + Node* qembedding_bag = + g->create(Symbol::fromQualString(quant_fn), qembedding_bag_inputs); + g->insertNode(qembedding_bag); + + embedding_bag_float_op->output()->replaceAllUsesWith( + qembedding_bag->output()); + embedding_bag_float_op->removeAllInputs(); + embedding_bag_float_op->destroy(); + g->lint(); + return qembedding_bag; +} + void insertQuantizationOps( Module& module, Value* self, @@ -249,50 +386,47 @@ void insertQuantizationOps( } Value* original_val = observer->input(1); Node *quant, *choose_qparams, *dequant; - if (quant_type == QuantType::DYNAMIC && isNoopObserver(observer->input(0))) { - dequant = insertFP16CastOps(g, observer_out); - } else if ( - quant_type == QuantType::DYNAMIC && !isWeight(module, observer_out)) { - Value* dtype = g->insertGetAttr(self, qparam_names.back()); - std::tie(choose_qparams, quant, dequant) = insertChooseQParamQuantDequant( - g, observer_out, dtype, at::Symbol::aten(quantize_func)); - } else { - // Else branch is executed for dynamic weight observers and all observers - // for static quant. - std::vector inputs = {observer_out}; - // Insert GetAttr nodes for quantization parameters - for (const auto& qparam_name : qparam_names) { - inputs.push_back(g->insertGetAttr(self, qparam_name)); + // Temporary solution to quantize embedding_bag operators. Will be re-written + // once we support quantization of embedding_bag weights. + auto embedding_bag_name = getEmbeddingBagObsName(module, observer); + if (quant_type == QuantType::DYNAMIC && + isEmbeddingBagOp(observer, embedding_bag_name)) { + if (isWeight(module, observer_out)) { + auto op_name = embedding_bag_name.value(); + Node* dequant = insertEmbeddingBagOps(observer, op_name); + observer_out->replaceAllUsesWith(original_val); + original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output()); + } else { + // Special case for embedding bag operators indices input - we don't + // quantize the input but we still need to insert observers for it because + // the order of input and weight can be changed in the module code. + observer_out->replaceAllUsesWith(original_val); } - quant = insertQuant( - g, - inputs, - at::Symbol::aten(quantize_func), - original_val->debugName() + ".quant"); - dequant = insertDeQuant(g, quant->output(), original_val); + return; } + if (quant_type == QuantType::DYNAMIC) { + if (isFP16NoopObserver(module, observer)) { + dequant = insertFP16CastOps(g, observer_out); + } else if (!isWeight(module, observer_out)) { + // For activation tensors we insert choose_qparams, quant, dequant ops. + Value* dtype = g->insertGetAttr(self, qparam_names.back()); + std::tie(choose_qparams, quant, dequant) = insertChooseQParamQuantDequant( + g, observer_out, dtype, at::Symbol::aten(quantize_func)); + } else { + // For weight tensors we insert quant-dequant ops. + dequant = + insertQuantDequantNodes(self, observer, qparam_names, quantize_func); + } + } else { // Static quant + dequant = + insertQuantDequantNodes(self, observer, qparam_names, quantize_func); + } + observer_out->replaceAllUsesWith(original_val); original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output()); } -// find the observer for Value `v` and return the name of the observer -c10::optional findObserverName(Value* v) { - // Note that here we just check for the name of observer, but the ideally - // we should be comparing the type of observer, this is a temporary - // work around until data only clone of module.clone is supported. - Node* n = v->node(); - if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward") { - auto module_instance = n->inputs().at(0); - if (module_instance->node()->kind() == prim::GetAttr && - module_instance->node()->s(attr::name).find("_observer_") != - std::string::npos) { - return module_instance->node()->s(attr::name); - } - } - return c10::nullopt; -} - void ReplicateChooseQParamsQuantDequant(std::shared_ptr& graph) { const PatternInfo& dynamic_quant_pattern = PatternInfo::parse_from_str(R"( graph(%a, %reduce_range, %a_dtype): @@ -1186,15 +1320,19 @@ void InsertQuantDeQuantHelper::run( auto tp = getQSchemeAndQParamVector(module, n); checkQScheme(graph.get(), std::get<0>(tp)); auto qparam_map = std::get<1>(tp); - TORCH_INTERNAL_ASSERT( - qparam_name_map_for_node_.count(n), - "Expected to have a qparam_name_map for node:", - *n); - auto qparam_name_map = qparam_name_map_for_node_.at(n); - for (auto& pr : qparam_map) { - const auto& name = pr.first; - const auto& qparam = pr.second; - module._ivalue()->setAttr(qparam_name_map.at(name), qparam); + // We check the size here because for some observers (like NoopObserver) + // the qparams might be empty. + if (qparam_map.size() > 0) { + TORCH_INTERNAL_ASSERT( + qparam_name_map_for_node_.count(n), + "Expected to have a qparam_name_map for node:", + *n); + auto qparam_name_map = qparam_name_map_for_node_.at(n); + for (auto& pr : qparam_map) { + const auto& name = pr.first; + const auto& qparam = pr.second; + module._ivalue()->setAttr(qparam_name_map.at(name), qparam); + } } } return; diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index d34de3c244..60e6a5557b 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -1023,11 +1023,13 @@ class NoopObserver(ObserverBase): Args: dtype: Quantized data type + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). """ - def __init__(self, dtype=torch.float16): - if dtype != torch.float16: - raise ValueError("Only float16 quantization can be used without calibration process") + def __init__(self, dtype=torch.float16, custom_op_name=""): super(NoopObserver, self).__init__(dtype=dtype) + self.dtype = dtype + self.custom_op = custom_op_name def forward(self, x): return x