Skip to content

Commit

Permalink
[quant] Add Graph Mode Passes to quantize EmbeddingBag operators (#41…
Browse files Browse the repository at this point in the history
…612)

Summary:
Pull Request resolved: pytorch/pytorch#41612

This change adds preliminary support to quantize the EmbeddingBag operators. We currently support 4-bit and 8-bit quantization+packing of the weights.

To quantize these operators, specify the operator name in the `custom_op_name` field of the NoopObserver. Based on the op name (4bit or 8bit) we call the corresponding quantization functions.
Refer to the testplan for how to invoke the qconfig for the embedding_bag ops.

Future versions of this will support 4-bit and 2-bit qtensors with native support to observe and quantize it.

NB - This version assumes that the weights in the EmbeddingBag Module reside on the same device.

Test Plan:
python test/test_quantization.py TestQuantizeDynamicJitOps.test_embedding_bag

Imported from OSS

Reviewed By: vkuzo, jerryzh168

Differential Revision: D22609342

fbshipit-source-id: 23e33f44a451c26719e6e283e87fbf09b584c0e6
  • Loading branch information
supriyar authored and facebook-github-bot committed Jul 24, 2020
1 parent 401ac2d commit 36fb14b
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 52 deletions.
37 changes: 37 additions & 0 deletions test/quantization/test_quantize_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2813,6 +2813,43 @@ def forward(self, x):
FunctionalLinear(weight, bias), x,
"quantized::linear_dynamic", tracing=tracing, dynamic=True)

def test_embedding_bag(self):
class M(torch.nn.Module):
def __init__(self, weights):
super(M, self).__init__()
self.embedding1 = torch.nn.EmbeddingBag(num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
_weight=weights,
mode='sum')

self.embedding2 = torch.nn.EmbeddingBag(num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
_weight=weights,
mode='sum')

def forward(self, indices1, offsets1, indices2, offsets2):
e1 = self.embedding1(indices1, offsets1)
e2 = self.embedding2(indices2, offsets2)
return e1, e2

weights = torch.randn(10, 12, dtype=torch.float32)
module = M(weights)
m = torch.jit.script(module)
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
offsets = torch.tensor([0, 19, 20, 28, 28, 32])

from torch.quantization import QConfigDynamic, NoopObserver
int4_dynamic_qconfig = QConfigDynamic(activation=NoopObserver.with_args(custom_op_name="embedding_bag_4bit"),
weight=NoopObserver.with_args(custom_op_name="embedding_bag_4bit"))
int8_dynamic_qconfig = QConfigDynamic(activation=NoopObserver.with_args(custom_op_name="embedding_bag_byte"),
weight=NoopObserver.with_args(custom_op_name="embedding_bag_byte"))
m = quantize_dynamic_jit(m, {'embedding1' : int4_dynamic_qconfig, 'embedding2' : int8_dynamic_qconfig})
FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \
.check_next("quantized::embedding_bag_byte_rowwise_offsets") \
.run(m.graph)

class TestQuantizeJit(QuantizationTestCase):
@override_qengines
def test_single_linear(self):
Expand Down
8 changes: 5 additions & 3 deletions torch/csrc/jit/passes/quantization/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ std::vector<std::string> _static_quantizable_aten_funcs = {

std::vector<std::string> _dynamic_quantizable_call_funcs = {
"linear",
"embedding_bag",
};

std::vector<std::string> _dynamic_quantizable_aten_funcs = {
Expand Down Expand Up @@ -221,8 +222,6 @@ bool matchAtenFuncToUse(
(!n.has_value() || n.value() == use.offset);
}

// Check if `use` is a CallFunction of name `func_name` and if value
// `v` is the nth argument (if provided) of the function
bool matchCallFuncToUse(
const Use& use,
const std::string& func_name,
Expand Down Expand Up @@ -255,12 +254,15 @@ bool matchArgPattern(
return false;
}

// TODO add other op signatures.
bool isWeight(Value* v) {
bool result = matchArgPattern(
v,
AtenFuncArgs(
{{"conv1d", 1}, {"conv2d", 1}, {"conv3d", 1}, {"linear", 1}}),
CallFuncArgs({{"linear", 2}}));
// embedding_bag - prim::CallFunction(%func, %input.1, %weight,
// %offsets.1, %7, %8, %9, %10, %9, %per_sample_weights.1, %13)
CallFuncArgs({{"linear", 2}, {"embedding_bag", 2}}));
return result;
}

Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/passes/quantization/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ TORCH_API bool useQuantizable(const Use& use, QuantType quant_type);
// Given a CallFunction node, extract the graph of the called function
TORCH_API std::shared_ptr<Graph> getCallFunctionGraph(Node* n);

// Check if `use` is a CallFunction of name `func_name` and if value
// `v` is the nth argument (if provided) of the function
bool matchCallFuncToUse(
const Use& use,
const std::string& func_name,
c10::optional<int> nth_arg);

// =========== helper functions for Block =========
// checks if a block will always raise an Exception
TORCH_API bool alwaysRaisesException(Block* block);
Expand Down
230 changes: 184 additions & 46 deletions torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,23 @@ Node* insertFP16CastOps(Graph* graph, Value* observer_out) {
return cast_to_fp32;
}

// find the observer for Value `v` and return the name of the observer
c10::optional<std::string> findObserverName(Value* v) {
// Note that here we just check for the name of observer, but the ideally
// we should be comparing the type of observer, this is a temporary
// work around until data only clone of module.clone is supported.
Node* n = v->node();
if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward") {
auto module_instance = n->inputs().at(0);
if (module_instance->node()->kind() == prim::GetAttr &&
module_instance->node()->s(attr::name).find("_observer_") !=
std::string::npos) {
return module_instance->node()->s(attr::name);
}
}
return c10::nullopt;
}

bool isNoopObserver(Value* observer) {
if (getModuleName(observer).has_value()) {
auto name = getModuleName(observer).value();
Expand All @@ -228,6 +245,126 @@ bool isNoopObserver(Value* observer) {
return false;
}

bool isFP16NoopObserver(script::Module& module, Node* n) {
Value* v = n->output();
auto observer = n->input(0);
auto observer_module = module.attr(findObserverName(v).value()).toModule();
return (observer_module.attr("dtype") == at::ScalarType::Half) &&
isNoopObserver(observer);
}

c10::optional<std::string> getEmbeddingBagObsName(
script::Module& module,
Node* n) {
Value* v = n->output();
auto observer = n->input(0);
auto observer_module = module.attr(findObserverName(v).value()).toModule();
if (observer_module.hasattr("custom_op")) {
auto op_name = observer_module.attr("custom_op").toStringRef();
return isNoopObserver(observer) ? op_name : "";
}
return c10::nullopt;
}

bool isEmbeddingBagOp(
Node* observer,
c10::optional<std::string> embedding_bag_name) {
return embedding_bag_name &&
embedding_bag_name.value().find("embedding_bag_") != std::string::npos;
}

// Insert quant and dequant nodes into the graph for both static and dynamic
// quant.
Node* insertQuantDequantNodes(
Value* self,
Node* observer,
const std::vector<std::string>& qparam_names,
const std::string& quantize_func) {
Graph* g = observer->owningGraph();
Value* observer_out = observer->output();
Value* original_val = observer->input(1);
std::vector<Value*> inputs = {observer_out};
// Insert GetAttr nodes for quantization parameters
for (const auto& qparam_name : qparam_names) {
inputs.push_back(g->insertGetAttr(self, qparam_name));
}
Node* quant = insertQuant(
g,
inputs,
at::Symbol::aten(quantize_func),
original_val->debugName() + ".quant");
Node* dequant = insertDeQuant(g, quant->output(), original_val);
return dequant;
}

Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) {
Graph* g = observer->owningGraph();
auto observer_out = observer->output();

std::string prepack_fn, quant_fn;
if (op_name == "embedding_bag_4bit") {
prepack_fn = "quantized::embedding_bag_4bit_prepack";
quant_fn = "quantized::embedding_bag_4bit_rowwise_offsets";
} else if (op_name == "embedding_bag_byte") {
prepack_fn = "quantized::embedding_bag_byte_prepack";
quant_fn = "quantized::embedding_bag_byte_rowwise_offsets";
} else {
TORCH_INTERNAL_ASSERT(
"Graph Mode Quantization currently supports 4-bit and 8-bit embedding bag quantization.");
}

std::vector<Value*> prepack_inputs = {observer_out};
std::vector<Use> uses = observer_out->uses();
Node* embedding_bag_float_op;
// We expect that the output of the weight observer will be consumed by the
// embedding_bag operator.
for (const Use& use : uses) {
if (matchCallFuncToUse(use, "embedding_bag", 2)) {
embedding_bag_float_op = use.user;
}
}
TORCH_CHECK(
embedding_bag_float_op->inputs().size() == 11,
"Expecting FP EmbeddingBag operator to have 11 inputs");
// Insert prepack op
Node* prepack = g->create(Symbol::fromQualString(prepack_fn), prepack_inputs);
g->insertNode(prepack);

std::vector<Value*> embedding_bag_inputs =
embedding_bag_float_op->inputs().vec();

// Create and insert quantized embedding op.
Value* none = g->insertConstant(IValue());
Value* zero = g->insertConstant(IValue(0));
embedding_bag_inputs[3]->setType(TensorType::get());

std::vector<Value*> qembedding_bag_inputs = {
/* weight */ prepack->output(),
/* indices */ embedding_bag_inputs[1],
/* offsets */ embedding_bag_inputs[3],
/* scale_grad_by_freq */ embedding_bag_inputs[6],
/* mode */ zero,
/* sparse */ embedding_bag_inputs[8],
/* per_sample_weights_ */ embedding_bag_inputs[9]};

if (op_name == "embedding_bag_4bit") {
// 4-bit op has an extra input compressed_indices_mapping
qembedding_bag_inputs.push_back(none);
}
qembedding_bag_inputs.push_back(embedding_bag_inputs[10]);

Node* qembedding_bag =
g->create(Symbol::fromQualString(quant_fn), qembedding_bag_inputs);
g->insertNode(qembedding_bag);

embedding_bag_float_op->output()->replaceAllUsesWith(
qembedding_bag->output());
embedding_bag_float_op->removeAllInputs();
embedding_bag_float_op->destroy();
g->lint();
return qembedding_bag;
}

void insertQuantizationOps(
Module& module,
Value* self,
Expand All @@ -249,50 +386,47 @@ void insertQuantizationOps(
}
Value* original_val = observer->input(1);
Node *quant, *choose_qparams, *dequant;
if (quant_type == QuantType::DYNAMIC && isNoopObserver(observer->input(0))) {
dequant = insertFP16CastOps(g, observer_out);
} else if (
quant_type == QuantType::DYNAMIC && !isWeight(module, observer_out)) {
Value* dtype = g->insertGetAttr(self, qparam_names.back());
std::tie(choose_qparams, quant, dequant) = insertChooseQParamQuantDequant(
g, observer_out, dtype, at::Symbol::aten(quantize_func));
} else {
// Else branch is executed for dynamic weight observers and all observers
// for static quant.
std::vector<Value*> inputs = {observer_out};
// Insert GetAttr nodes for quantization parameters
for (const auto& qparam_name : qparam_names) {
inputs.push_back(g->insertGetAttr(self, qparam_name));
// Temporary solution to quantize embedding_bag operators. Will be re-written
// once we support quantization of embedding_bag weights.
auto embedding_bag_name = getEmbeddingBagObsName(module, observer);
if (quant_type == QuantType::DYNAMIC &&
isEmbeddingBagOp(observer, embedding_bag_name)) {
if (isWeight(module, observer_out)) {
auto op_name = embedding_bag_name.value();
Node* dequant = insertEmbeddingBagOps(observer, op_name);
observer_out->replaceAllUsesWith(original_val);
original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output());
} else {
// Special case for embedding bag operators indices input - we don't
// quantize the input but we still need to insert observers for it because
// the order of input and weight can be changed in the module code.
observer_out->replaceAllUsesWith(original_val);
}
quant = insertQuant(
g,
inputs,
at::Symbol::aten(quantize_func),
original_val->debugName() + ".quant");
dequant = insertDeQuant(g, quant->output(), original_val);
return;
}
if (quant_type == QuantType::DYNAMIC) {
if (isFP16NoopObserver(module, observer)) {
dequant = insertFP16CastOps(g, observer_out);
} else if (!isWeight(module, observer_out)) {
// For activation tensors we insert choose_qparams, quant, dequant ops.
Value* dtype = g->insertGetAttr(self, qparam_names.back());
std::tie(choose_qparams, quant, dequant) = insertChooseQParamQuantDequant(
g, observer_out, dtype, at::Symbol::aten(quantize_func));
} else {
// For weight tensors we insert quant-dequant ops.
dequant =
insertQuantDequantNodes(self, observer, qparam_names, quantize_func);
}
} else { // Static quant
dequant =
insertQuantDequantNodes(self, observer, qparam_names, quantize_func);
}

observer_out->replaceAllUsesWith(original_val);

original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output());
}

// find the observer for Value `v` and return the name of the observer
c10::optional<std::string> findObserverName(Value* v) {
// Note that here we just check for the name of observer, but the ideally
// we should be comparing the type of observer, this is a temporary
// work around until data only clone of module.clone is supported.
Node* n = v->node();
if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward") {
auto module_instance = n->inputs().at(0);
if (module_instance->node()->kind() == prim::GetAttr &&
module_instance->node()->s(attr::name).find("_observer_") !=
std::string::npos) {
return module_instance->node()->s(attr::name);
}
}
return c10::nullopt;
}

void ReplicateChooseQParamsQuantDequant(std::shared_ptr<Graph>& graph) {
const PatternInfo& dynamic_quant_pattern = PatternInfo::parse_from_str(R"(
graph(%a, %reduce_range, %a_dtype):
Expand Down Expand Up @@ -1186,15 +1320,19 @@ void InsertQuantDeQuantHelper::run(
auto tp = getQSchemeAndQParamVector(module, n);
checkQScheme(graph.get(), std::get<0>(tp));
auto qparam_map = std::get<1>(tp);
TORCH_INTERNAL_ASSERT(
qparam_name_map_for_node_.count(n),
"Expected to have a qparam_name_map for node:",
*n);
auto qparam_name_map = qparam_name_map_for_node_.at(n);
for (auto& pr : qparam_map) {
const auto& name = pr.first;
const auto& qparam = pr.second;
module._ivalue()->setAttr(qparam_name_map.at(name), qparam);
// We check the size here because for some observers (like NoopObserver)
// the qparams might be empty.
if (qparam_map.size() > 0) {
TORCH_INTERNAL_ASSERT(
qparam_name_map_for_node_.count(n),
"Expected to have a qparam_name_map for node:",
*n);
auto qparam_name_map = qparam_name_map_for_node_.at(n);
for (auto& pr : qparam_map) {
const auto& name = pr.first;
const auto& qparam = pr.second;
module._ivalue()->setAttr(qparam_name_map.at(name), qparam);
}
}
}
return;
Expand Down
8 changes: 5 additions & 3 deletions torch/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,11 +1023,13 @@ class NoopObserver(ObserverBase):
Args:
dtype: Quantized data type
custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation
(Can be used in Graph Mode Passes for special case ops).
"""
def __init__(self, dtype=torch.float16):
if dtype != torch.float16:
raise ValueError("Only float16 quantization can be used without calibration process")
def __init__(self, dtype=torch.float16, custom_op_name=""):
super(NoopObserver, self).__init__(dtype=dtype)
self.dtype = dtype
self.custom_op = custom_op_name

def forward(self, x):
return x
Expand Down

0 comments on commit 36fb14b

Please sign in to comment.