diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py index 9527689274..d2132f2e51 100644 --- a/test/quantization/test_quantize_jit.py +++ b/test/quantization/test_quantize_jit.py @@ -2813,6 +2813,43 @@ def forward(self, x): FunctionalLinear(weight, bias), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True) + def test_embedding_bag(self): + class M(torch.nn.Module): + def __init__(self, weights): + super(M, self).__init__() + self.embedding1 = torch.nn.EmbeddingBag(num_embeddings=10, + embedding_dim=12, + include_last_offset=True, + _weight=weights, + mode='sum') + + self.embedding2 = torch.nn.EmbeddingBag(num_embeddings=10, + embedding_dim=12, + include_last_offset=True, + _weight=weights, + mode='sum') + + def forward(self, indices1, offsets1, indices2, offsets2): + e1 = self.embedding1(indices1, offsets1) + e2 = self.embedding2(indices2, offsets2) + return e1, e2 + + weights = torch.randn(10, 12, dtype=torch.float32) + module = M(weights) + m = torch.jit.script(module) + indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3]) + offsets = torch.tensor([0, 19, 20, 28, 28, 32]) + + from torch.quantization import QConfigDynamic, NoopObserver + int4_dynamic_qconfig = QConfigDynamic(activation=NoopObserver.with_args(custom_op_name="embedding_bag_4bit"), + weight=NoopObserver.with_args(custom_op_name="embedding_bag_4bit")) + int8_dynamic_qconfig = QConfigDynamic(activation=NoopObserver.with_args(custom_op_name="embedding_bag_byte"), + weight=NoopObserver.with_args(custom_op_name="embedding_bag_byte")) + m = quantize_dynamic_jit(m, {'embedding1' : int4_dynamic_qconfig, 'embedding2' : int8_dynamic_qconfig}) + FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \ + .check_next("quantized::embedding_bag_byte_rowwise_offsets") \ + .run(m.graph) + class TestQuantizeJit(QuantizationTestCase): @override_qengines def test_single_linear(self): diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp index ef7556c501..a89981f945 100644 --- a/torch/csrc/jit/passes/quantization/helper.cpp +++ b/torch/csrc/jit/passes/quantization/helper.cpp @@ -46,6 +46,7 @@ std::vector _static_quantizable_aten_funcs = { std::vector _dynamic_quantizable_call_funcs = { "linear", + "embedding_bag", }; std::vector _dynamic_quantizable_aten_funcs = { @@ -221,8 +222,6 @@ bool matchAtenFuncToUse( (!n.has_value() || n.value() == use.offset); } -// Check if `use` is a CallFunction of name `func_name` and if value -// `v` is the nth argument (if provided) of the function bool matchCallFuncToUse( const Use& use, const std::string& func_name, @@ -255,12 +254,15 @@ bool matchArgPattern( return false; } +// TODO add other op signatures. bool isWeight(Value* v) { bool result = matchArgPattern( v, AtenFuncArgs( {{"conv1d", 1}, {"conv2d", 1}, {"conv3d", 1}, {"linear", 1}}), - CallFuncArgs({{"linear", 2}})); + // embedding_bag - prim::CallFunction(%func, %input.1, %weight, + // %offsets.1, %7, %8, %9, %10, %9, %per_sample_weights.1, %13) + CallFuncArgs({{"linear", 2}, {"embedding_bag", 2}})); return result; } diff --git a/torch/csrc/jit/passes/quantization/helper.h b/torch/csrc/jit/passes/quantization/helper.h index 8b933e4c89..b7cd56115a 100644 --- a/torch/csrc/jit/passes/quantization/helper.h +++ b/torch/csrc/jit/passes/quantization/helper.h @@ -99,6 +99,13 @@ TORCH_API bool useQuantizable(const Use& use, QuantType quant_type); // Given a CallFunction node, extract the graph of the called function TORCH_API std::shared_ptr getCallFunctionGraph(Node* n); +// Check if `use` is a CallFunction of name `func_name` and if value +// `v` is the nth argument (if provided) of the function +bool matchCallFuncToUse( + const Use& use, + const std::string& func_name, + c10::optional nth_arg); + // =========== helper functions for Block ========= // checks if a block will always raise an Exception TORCH_API bool alwaysRaisesException(Block* block); diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 7cc4efa2c2..b2c6e06a1a 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -218,6 +218,23 @@ Node* insertFP16CastOps(Graph* graph, Value* observer_out) { return cast_to_fp32; } +// find the observer for Value `v` and return the name of the observer +c10::optional findObserverName(Value* v) { + // Note that here we just check for the name of observer, but the ideally + // we should be comparing the type of observer, this is a temporary + // work around until data only clone of module.clone is supported. + Node* n = v->node(); + if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward") { + auto module_instance = n->inputs().at(0); + if (module_instance->node()->kind() == prim::GetAttr && + module_instance->node()->s(attr::name).find("_observer_") != + std::string::npos) { + return module_instance->node()->s(attr::name); + } + } + return c10::nullopt; +} + bool isNoopObserver(Value* observer) { if (getModuleName(observer).has_value()) { auto name = getModuleName(observer).value(); @@ -228,6 +245,126 @@ bool isNoopObserver(Value* observer) { return false; } +bool isFP16NoopObserver(script::Module& module, Node* n) { + Value* v = n->output(); + auto observer = n->input(0); + auto observer_module = module.attr(findObserverName(v).value()).toModule(); + return (observer_module.attr("dtype") == at::ScalarType::Half) && + isNoopObserver(observer); +} + +c10::optional getEmbeddingBagObsName( + script::Module& module, + Node* n) { + Value* v = n->output(); + auto observer = n->input(0); + auto observer_module = module.attr(findObserverName(v).value()).toModule(); + if (observer_module.hasattr("custom_op")) { + auto op_name = observer_module.attr("custom_op").toStringRef(); + return isNoopObserver(observer) ? op_name : ""; + } + return c10::nullopt; +} + +bool isEmbeddingBagOp( + Node* observer, + c10::optional embedding_bag_name) { + return embedding_bag_name && + embedding_bag_name.value().find("embedding_bag_") != std::string::npos; +} + +// Insert quant and dequant nodes into the graph for both static and dynamic +// quant. +Node* insertQuantDequantNodes( + Value* self, + Node* observer, + const std::vector& qparam_names, + const std::string& quantize_func) { + Graph* g = observer->owningGraph(); + Value* observer_out = observer->output(); + Value* original_val = observer->input(1); + std::vector inputs = {observer_out}; + // Insert GetAttr nodes for quantization parameters + for (const auto& qparam_name : qparam_names) { + inputs.push_back(g->insertGetAttr(self, qparam_name)); + } + Node* quant = insertQuant( + g, + inputs, + at::Symbol::aten(quantize_func), + original_val->debugName() + ".quant"); + Node* dequant = insertDeQuant(g, quant->output(), original_val); + return dequant; +} + +Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { + Graph* g = observer->owningGraph(); + auto observer_out = observer->output(); + + std::string prepack_fn, quant_fn; + if (op_name == "embedding_bag_4bit") { + prepack_fn = "quantized::embedding_bag_4bit_prepack"; + quant_fn = "quantized::embedding_bag_4bit_rowwise_offsets"; + } else if (op_name == "embedding_bag_byte") { + prepack_fn = "quantized::embedding_bag_byte_prepack"; + quant_fn = "quantized::embedding_bag_byte_rowwise_offsets"; + } else { + TORCH_INTERNAL_ASSERT( + "Graph Mode Quantization currently supports 4-bit and 8-bit embedding bag quantization."); + } + + std::vector prepack_inputs = {observer_out}; + std::vector uses = observer_out->uses(); + Node* embedding_bag_float_op; + // We expect that the output of the weight observer will be consumed by the + // embedding_bag operator. + for (const Use& use : uses) { + if (matchCallFuncToUse(use, "embedding_bag", 2)) { + embedding_bag_float_op = use.user; + } + } + TORCH_CHECK( + embedding_bag_float_op->inputs().size() == 11, + "Expecting FP EmbeddingBag operator to have 11 inputs"); + // Insert prepack op + Node* prepack = g->create(Symbol::fromQualString(prepack_fn), prepack_inputs); + g->insertNode(prepack); + + std::vector embedding_bag_inputs = + embedding_bag_float_op->inputs().vec(); + + // Create and insert quantized embedding op. + Value* none = g->insertConstant(IValue()); + Value* zero = g->insertConstant(IValue(0)); + embedding_bag_inputs[3]->setType(TensorType::get()); + + std::vector qembedding_bag_inputs = { + /* weight */ prepack->output(), + /* indices */ embedding_bag_inputs[1], + /* offsets */ embedding_bag_inputs[3], + /* scale_grad_by_freq */ embedding_bag_inputs[6], + /* mode */ zero, + /* sparse */ embedding_bag_inputs[8], + /* per_sample_weights_ */ embedding_bag_inputs[9]}; + + if (op_name == "embedding_bag_4bit") { + // 4-bit op has an extra input compressed_indices_mapping + qembedding_bag_inputs.push_back(none); + } + qembedding_bag_inputs.push_back(embedding_bag_inputs[10]); + + Node* qembedding_bag = + g->create(Symbol::fromQualString(quant_fn), qembedding_bag_inputs); + g->insertNode(qembedding_bag); + + embedding_bag_float_op->output()->replaceAllUsesWith( + qembedding_bag->output()); + embedding_bag_float_op->removeAllInputs(); + embedding_bag_float_op->destroy(); + g->lint(); + return qembedding_bag; +} + void insertQuantizationOps( Module& module, Value* self, @@ -249,50 +386,47 @@ void insertQuantizationOps( } Value* original_val = observer->input(1); Node *quant, *choose_qparams, *dequant; - if (quant_type == QuantType::DYNAMIC && isNoopObserver(observer->input(0))) { - dequant = insertFP16CastOps(g, observer_out); - } else if ( - quant_type == QuantType::DYNAMIC && !isWeight(module, observer_out)) { - Value* dtype = g->insertGetAttr(self, qparam_names.back()); - std::tie(choose_qparams, quant, dequant) = insertChooseQParamQuantDequant( - g, observer_out, dtype, at::Symbol::aten(quantize_func)); - } else { - // Else branch is executed for dynamic weight observers and all observers - // for static quant. - std::vector inputs = {observer_out}; - // Insert GetAttr nodes for quantization parameters - for (const auto& qparam_name : qparam_names) { - inputs.push_back(g->insertGetAttr(self, qparam_name)); + // Temporary solution to quantize embedding_bag operators. Will be re-written + // once we support quantization of embedding_bag weights. + auto embedding_bag_name = getEmbeddingBagObsName(module, observer); + if (quant_type == QuantType::DYNAMIC && + isEmbeddingBagOp(observer, embedding_bag_name)) { + if (isWeight(module, observer_out)) { + auto op_name = embedding_bag_name.value(); + Node* dequant = insertEmbeddingBagOps(observer, op_name); + observer_out->replaceAllUsesWith(original_val); + original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output()); + } else { + // Special case for embedding bag operators indices input - we don't + // quantize the input but we still need to insert observers for it because + // the order of input and weight can be changed in the module code. + observer_out->replaceAllUsesWith(original_val); } - quant = insertQuant( - g, - inputs, - at::Symbol::aten(quantize_func), - original_val->debugName() + ".quant"); - dequant = insertDeQuant(g, quant->output(), original_val); + return; } + if (quant_type == QuantType::DYNAMIC) { + if (isFP16NoopObserver(module, observer)) { + dequant = insertFP16CastOps(g, observer_out); + } else if (!isWeight(module, observer_out)) { + // For activation tensors we insert choose_qparams, quant, dequant ops. + Value* dtype = g->insertGetAttr(self, qparam_names.back()); + std::tie(choose_qparams, quant, dequant) = insertChooseQParamQuantDequant( + g, observer_out, dtype, at::Symbol::aten(quantize_func)); + } else { + // For weight tensors we insert quant-dequant ops. + dequant = + insertQuantDequantNodes(self, observer, qparam_names, quantize_func); + } + } else { // Static quant + dequant = + insertQuantDequantNodes(self, observer, qparam_names, quantize_func); + } + observer_out->replaceAllUsesWith(original_val); original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output()); } -// find the observer for Value `v` and return the name of the observer -c10::optional findObserverName(Value* v) { - // Note that here we just check for the name of observer, but the ideally - // we should be comparing the type of observer, this is a temporary - // work around until data only clone of module.clone is supported. - Node* n = v->node(); - if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward") { - auto module_instance = n->inputs().at(0); - if (module_instance->node()->kind() == prim::GetAttr && - module_instance->node()->s(attr::name).find("_observer_") != - std::string::npos) { - return module_instance->node()->s(attr::name); - } - } - return c10::nullopt; -} - void ReplicateChooseQParamsQuantDequant(std::shared_ptr& graph) { const PatternInfo& dynamic_quant_pattern = PatternInfo::parse_from_str(R"( graph(%a, %reduce_range, %a_dtype): @@ -1186,15 +1320,19 @@ void InsertQuantDeQuantHelper::run( auto tp = getQSchemeAndQParamVector(module, n); checkQScheme(graph.get(), std::get<0>(tp)); auto qparam_map = std::get<1>(tp); - TORCH_INTERNAL_ASSERT( - qparam_name_map_for_node_.count(n), - "Expected to have a qparam_name_map for node:", - *n); - auto qparam_name_map = qparam_name_map_for_node_.at(n); - for (auto& pr : qparam_map) { - const auto& name = pr.first; - const auto& qparam = pr.second; - module._ivalue()->setAttr(qparam_name_map.at(name), qparam); + // We check the size here because for some observers (like NoopObserver) + // the qparams might be empty. + if (qparam_map.size() > 0) { + TORCH_INTERNAL_ASSERT( + qparam_name_map_for_node_.count(n), + "Expected to have a qparam_name_map for node:", + *n); + auto qparam_name_map = qparam_name_map_for_node_.at(n); + for (auto& pr : qparam_map) { + const auto& name = pr.first; + const auto& qparam = pr.second; + module._ivalue()->setAttr(qparam_name_map.at(name), qparam); + } } } return; diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index d34de3c244..60e6a5557b 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -1023,11 +1023,13 @@ class NoopObserver(ObserverBase): Args: dtype: Quantized data type + custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation + (Can be used in Graph Mode Passes for special case ops). """ - def __init__(self, dtype=torch.float16): - if dtype != torch.float16: - raise ValueError("Only float16 quantization can be used without calibration process") + def __init__(self, dtype=torch.float16, custom_op_name=""): super(NoopObserver, self).__init__(dtype=dtype) + self.dtype = dtype + self.custom_op = custom_op_name def forward(self, x): return x