diff --git a/torchao/prototype/awq/__init__.py b/torchao/prototype/awq/__init__.py index 570b0821d4..4f34d5375a 100644 --- a/torchao/prototype/awq/__init__.py +++ b/torchao/prototype/awq/__init__.py @@ -1,8 +1,9 @@ -from .api import awq_uintx, insert_awq_observer_ +from .api import AWQConfig, awq_uintx, insert_awq_observer_ from .core import AWQObservedLinear __all__ = [ "awq_uintx", "insert_awq_observer_", "AWQObservedLinear", + "AWQConfig", ] diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 5806c29ce6..0741c1c7ba 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -30,12 +30,15 @@ ZeroPointDomain, ) from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, register_quantize_module_handler, ) +from torchao.utils import DummyModule from .core import ( AWQObservedLinear, AWQObserver, + AWQObserver2, ) assert len(_DTYPE_TO_BIT_WIDTH) > 0, ( @@ -50,6 +53,7 @@ def insert_awq_observer_( quant_dtype: torch.dtype = torch.uint4, scale_search_space_size: int = 20, group_size: int = 128, + base_config: Optional[AOBaseConfig] = None, ): """ Inserts AWQObserver into Linear layers of a given model. @@ -80,22 +84,32 @@ def insert_awq_observer_( def replace_with_observer(layer): # creates observer and replaces linear layers with AWQObservedLinear layers - observer = AWQObserver( - layer.weight, - layer.bias, - quantization_granularity, - mapping_type, - quant_dtype, - n_validation_examples, - validation_sequence_len, - scale_search_space_size, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, - zero_point_dtype=zero_point_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - ) + if base_config is None: + observer = AWQObserver( + layer.weight, + layer.bias, + quantization_granularity, + mapping_type, + quant_dtype, + n_validation_examples, + validation_sequence_len, + scale_search_space_size, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + zero_point_dtype=zero_point_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + ) + else: + observer = AWQObserver2( + layer.weight, + layer.bias, + base_config, + n_validation_examples, + validation_sequence_len, + scale_search_space_size, + ) return AWQObservedLinear.from_float(layer, observer) _replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear) @@ -194,3 +208,50 @@ def _awq_uintx_transform( linear.extra_repr = types.MethodType(_linear_extra_repr, module) linear.bias = observed_linear.bias return linear + + +@dataclass +class AWQConfig(AOBaseConfig): + """ + Configuration for quantizing linear layers when passed into quantize_() + + Args: + quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 + `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` + group_size: Quantization granularity. Use -1 for channel wise quantization + weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used + set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. + """ + + base_config: AOBaseConfig + set_inductor_config: bool = True + + +@register_quantize_module_handler(AWQConfig) +def _awq_transform( + module: torch.nn.Module, + config: AWQUIntXConfig, +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + observed_linear = module + equalization_scale = observed_linear.act_obs.calculate_qparams() + + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(config.base_config)] + dummy_mod = DummyModule(observed_linear.weight * equalization_scale) + quant_mod = base_config_handler(dummy_mod, config.base_config) + qw = quant_mod.weight + qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale) + + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + observed_linear.bias != None, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.weight = torch.nn.Parameter(qw, requires_grad=False) + linear.extra_repr = types.MethodType(_linear_extra_repr, module) + linear.bias = observed_linear.bias + return linear diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index e5ee96fea2..0524af64a9 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -8,8 +8,10 @@ import torch import torch.nn.functional as F +from torchao.core.config import AOBaseConfig from torchao.dtypes import to_affine_quantized_intx from torchao.dtypes.uintx.uintx_layout import UintxLayout +from torchao.quantization import Int8DynamicActivationIntxWeightConfig from torchao.quantization.granularity import Granularity from torchao.quantization.observer import ( AffineQuantizedObserverBase, @@ -18,6 +20,10 @@ MappingType, ZeroPointDomain, ) +from torchao.quantization.transform_module import ( + _QUANTIZE_CONFIG_HANDLER, +) +from torchao.utils import DummyModule class AWQObserver(AffineQuantizedObserverBase): @@ -145,6 +151,134 @@ def calculate_qparams(self): return best_scales.detach() +class AWQObserver2(AffineQuantizedObserverBase): + def __init__( + self, + weight: torch.Tensor, + bias: torch.Tensor, + config: AOBaseConfig, + n_validation_examples: int, + validation_sequence_len: int, + scale_search_space_size: int = 20, + base_config: Optional[AOBaseConfig] = None, + ): + """ + A custom observer for Activation aware Weight Quantization (AWQ) + + Args: + weight: The weight tensor to be observed. + bias: The bias tensor to be observed. + quantization_granularity: Granularity which specifies how many weights share the same scale/zero point + input_dtype: The data type of the input tensor. + mapping_type: Always set to asymmetric + target_dtype: The target data type of the quantized tensor + n_validation_examples: Number of examples used to calibrate observer + validation_sequence_len: Number of tokens in each example + scale_search_space_size: The number of scales to search for. + quant_min: The minimum quantized value + quant_max: The maximum quantized value + eps: The minimum scale. + scale_dtype: The data type of the scale tensor. + zero_point_dtype: The data type of the zero point tensor. + preserve_zero: A flag to indicate whether we need zero to be exactly + representable or not. + zero_point_domain: The domain of the zero point. + """ + self.base_config = base_config + quant_min = getattr(config, "quant_min", None) + quant_max = getattr(config, "quant_max", None) + + assert isinstance(base_config, Int8DynamicActivationIntxWeightConfig) + # TODO: + quantization_granularity = base_config.weight_granularity + target_dtype = base_config.weight_dtype + mapping_type = base_config.weight_mapping_type + + # TODO: + super().__init__( + mapping_type, + target_dtype, + quantization_granularity, + quant_min=quant_min, + quant_max=quant_max, + ) + self.quantization_granularity = quantization_granularity + self.weight = weight + self.bias = bias + self.n_validation_examples = n_validation_examples + self.validation_sequence_len = validation_sequence_len + self.calibration_token_count = 0 + self.inputs = [] + self.outputs = [] + self.scale_options = scale_search_space_size + self.device = self.weight.device + self.average = torch.zeros((1, weight.shape[1]), device=self.device) + if self.bias is not None: + self.bias.to(self.device) + + @torch.no_grad() + def forward(self, input: torch.Tensor, output: torch.Tensor): + # import pdb + # pdb.set_trace() + # print(input.shape, input.abs().sum(1).shape, self.average.shape) + if len(self.inputs) < self.n_validation_examples: + self.inputs.append(input.to("cpu")) + self.outputs.append(output.to("cpu")) + self.calibration_token_count += input.shape[-2] + self.average += input.abs().sum(-2) + + def calculate_qparams(self): + # import pdb + # pdb.set_trace() + assert self.outputs != None, ( + "calibrate observer first by running model on exemplar data" + ) + self.average /= self.calibration_token_count + for i in range(self.n_validation_examples): + self.inputs[i] = self.inputs[i].to(self.device) + self.outputs[i] = self.outputs[i].to(self.device) + + best_loss = float("inf") + best_scales = None + for i in range(self.scale_options): + ratio = i * 1 / self.scale_options + scales = self.average.pow(ratio).to(self.weight.dtype) + scales = scales / (scales.max() * scales.min()).sqrt() + # layout = UintxLayout(self.target_dtype) + # # regardless of weight dtype, we have to store as packed uint8 tensors + # tensor_dtype = torch.uint8 + # w = to_affine_quantized_intx( + # self.weight * scales, + # self.mapping_type, + # (1, self.quantization_granularity.group_size), + # tensor_dtype, + # quant_min=self.quant_min, + # quant_max=self.quant_max, + # eps=self.eps, + # scale_dtype=self.scale_dtype, + # zero_point_dtype=self.zero_point_dtype, + # preserve_zero=self.preserve_zero, + # zero_point_domain=self.zero_point_domain, + # _layout=layout, + # ) + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(self.base_config)] + dummy_mod = DummyModule(self.weight * scales) + quant_mod = base_config_handler(dummy_mod, self.base_config) + w = quant_mod.weight + + loss = 0 + for i in range(self.n_validation_examples): + q_out = F.linear(self.inputs[i] / scales, w, self.bias) + loss += (self.outputs[i] - q_out).pow(2).mean().item() + if loss < best_loss: + best_scales = scales + best_loss = loss + for i in range(self.n_validation_examples): + self.inputs[i].to("cpu") + self.outputs[i].to("cpu") + return best_scales.detach() + + class AWQObservedLinear(torch.nn.Linear): def __init__( self, diff --git a/torchao/prototype/awq/example2.py b/torchao/prototype/awq/example2.py new file mode 100644 index 0000000000..4be2ed74e3 --- /dev/null +++ b/torchao/prototype/awq/example2.py @@ -0,0 +1,351 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import time + +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from torchao.dtypes import Int4XPULayout, QDQLayout +from torchao.prototype.awq import ( + AWQConfig, + AWQObservedLinear, + insert_awq_observer_, +) +from torchao.quantization import ( + Int8DynamicActivationIntxWeightConfig, + int4_weight_only, + quantize_, +) +from torchao.quantization.granularity import PerGroup + + +# adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255 +def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): + dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") + samples = [] + n_tokens = n_samples * block_size + n_run = n_tokens + for data in dataset: + line = data["text"] + line = line.strip() + line_encoded = tokenizer.encode(line) + if len(line_encoded) > 512: + continue + sample = torch.tensor([line_encoded]) + if sample.numel() == 0: + continue + samples.append(sample) + n_run -= len(line_encoded) + if n_run <= n_samples: + break + + cat_samples = torch.cat(samples, dim=1) + return [ + cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples) + ] + + +# from https://github.com/mobiusml/hqq/blob/master/examples/llama2_benchmark/eval_model.py +def wiki2_eval( + model, tokenizer, sequence_length, stride=512, verbose=True, device="cuda" +): + model.eval() + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" + tokenizer.add_eos_token = False + + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt") + + encodings["input_ids"] = encodings["input_ids"].to(device) + + lls, t = [], [] + for i in tqdm( + range(0, encodings["input_ids"].size(1), stride), disable=not verbose + ): + begin_loc = max(i + stride - sequence_length, 0) + end_loc = min(i + stride, encodings["input_ids"].size(1)) + trg_len = end_loc - i + input_ids = encodings["input_ids"][:, begin_loc:end_loc] + target_ids = input_ids.clone() + target_ids[:, :-trg_len] = -100 # ignore context + + t1 = time.time() + with torch.no_grad(): + log_likelihood = model(input_ids, labels=target_ids).loss * trg_len + if device.startswith("cuda"): + torch.cuda.synchronize() + if device.startswith("xpu"): + torch.xpu.synchronize() + t2 = time.time() + t.append((t2 - t1)) + lls.append(log_likelihood) + + del input_ids, target_ids + + ppl = float(torch.exp(torch.stack(lls).sum() / end_loc)) + pred_time = sum(t) / len(t) + if verbose: + print("perplexity", ppl) + print("time", str(pred_time) + " sec") + + return {"perplexity": ppl, "prediction_time": pred_time} + + +# adapted from Hicham Badri (@mobicham) +def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): + import lm_eval + import numpy as np + + model.eval() + model.config.use_cache = False + try: + lm_eval.tasks.initialize_tasks() + except: + pass + model_eval = lm_eval.models.huggingface.HFLM(pretrained=model, tokenizer=tokenizer) + eval_batch_size = 1 # 8 + if tasks is None: + tasks = [ + "PPL", + "truthfulqa_mc2", + "winogrande", + "arc_challenge", + "hellaswag", + "gsm8k", + "mmlu", + ] + results = {} + if "PPL" in tasks: + results["perplexity"] = wiki2_eval( + model, tokenizer, 512, verbose=True, device=device + ) + ############################################ + if "truthfulqa_mc2" in tasks: + for task in [("truthfulqa_mc2", 0)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + if "winogrande" in tasks: + for task in [("winogrande", 5)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + if "arc_challenge" in tasks: + for task in [("arc_challenge", 25)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + + # ############################################ + if "hellaswag" in tasks: + for task in [("hellaswag", 10)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + if "gsm8k" in tasks: + for task in [("gsm8k", 5)]: + tag, fewshot = task + results[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results[tag]) + # ############################################ + + if "mmlu" in tasks: + # MMLU + results_mmlu = {} + for task in [("mmlu", 5)]: + tag, fewshot = task + results_mmlu[tag] = lm_eval.evaluator.simple_evaluate( + model_eval, tasks=[tag], num_fewshot=fewshot, batch_size=eval_batch_size + )["results"] + print(tag, results_mmlu[tag]) + + mmlu_list = "hendrycksTest-abstract_algebra,hendrycksTest-anatomy,hendrycksTest-astronomy,hendrycksTest-business_ethics,hendrycksTest-clinical_knowledge,hendrycksTest-college_biology,hendrycksTest-college_chemistry,hendrycksTest-college_computer_science,hendrycksTest-college_mathematics,hendrycksTest-college_medicine,hendrycksTest-college_physics,hendrycksTest-computer_security,hendrycksTest-conceptual_physics,hendrycksTest-econometrics,hendrycksTest-electrical_engineering,hendrycksTest-elementary_mathematics,hendrycksTest-formal_logic,hendrycksTest-global_facts,hendrycksTest-high_school_biology,hendrycksTest-high_school_chemistry,hendrycksTest-high_school_computer_science,hendrycksTest-high_school_european_history,hendrycksTest-high_school_geography,hendrycksTest-high_school_government_and_politics,hendrycksTest-high_school_macroeconomics,hendrycksTest-high_school_mathematics,hendrycksTest-high_school_microeconomics,hendrycksTest-high_school_physics,hendrycksTest-high_school_psychology,hendrycksTest-high_school_statistics,hendrycksTest-high_school_us_history,hendrycksTest-high_school_world_history,hendrycksTest-human_aging,hendrycksTest-human_sexuality,hendrycksTest-international_law,hendrycksTest-jurisprudence,hendrycksTest-logical_fallacies,hendrycksTest-machine_learning,hendrycksTest-management,hendrycksTest-marketing,hendrycksTest-medical_genetics,hendrycksTest-miscellaneous,hendrycksTest-moral_disputes,hendrycksTest-moral_scenarios,hendrycksTest-nutrition,hendrycksTest-philosophy,hendrycksTest-prehistory,hendrycksTest-professional_accounting,hendrycksTest-professional_law,hendrycksTest-professional_medicine,hendrycksTest-professional_psychology,hendrycksTest-public_relations,hendrycksTest-security_studies,hendrycksTest-sociology,hendrycksTest-us_foreign_policy,hendrycksTest-virology,hendrycksTest-world_religions" + mmlu_list = [l.replace("hendrycksTest-", "") for l in mmlu_list.split(",")] + results_mmlu = results_mmlu["mmlu"] + + k = [] + for r in results_mmlu: + if np.any([(l in r) for l in mmlu_list]): + k.append(results_mmlu[r]["acc,none"]) + + assert len(k) == 57 + print("MMLU avg acc", np.mean(k)) + + results["mmlu"] = np.mean(k) + return results + + +def wikitext2_ppl( + repo_id: str, + quant: str, + tasks: list[str], + calibration_size: int, + validation_size: int, + device: str, + precision: torch.dtype, + sequence_length: int, + compile: bool, + model_save_path: str, +): + print(f"Loading model on {device}...") + torch.manual_seed(34) + t0 = time.time() + # load any model with torch.nn.linear layers + tokenizer = AutoTokenizer.from_pretrained(repo_id) + model = ( + AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=precision) + .eval() + .to(device) + ) + print(f"Time to load model: {time.time() - t0:.02f} seconds") + if quant.startswith("awq"): + quant_dtype = quant.split("-")[1] + group_size = int(quant.split("-")[2]) + quant_dtype = getattr(torch, quant_dtype, torch.bfloat16) + base_config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=quant_dtype, + weight_granularity=PerGroup(group_size), + weight_scale_dtype=torch.bfloat16, + layout=QDQLayout(), + ) + print(f"running {quant_dtype} calibration") + t0 = time.time() + # insert observers to find average magnitude and calculate scales + insert_awq_observer_( + model, validation_size, sequence_length, base_config=base_config + ) + calibration_data = get_calib_dataset( + tokenizer=tokenizer, n_samples=calibration_size, block_size=sequence_length + ) + for batch in calibration_data: + model(batch.to(device)) + batch.to("cpu") + print(f"time for calibration: {time.time() - t0:.02f} seconds") + + is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear) + use_hqq = "hqq" in quant + print(f"running {quant_dtype} quantization") + t0 = time.time() + # awq_uintx_config = awq_uintx( + # quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + # ) + if "xpu" in device: + base_config.layout = Int4XPULayout() + awq_config = AWQConfig(base_config) + quantize_( + model, + awq_config, + is_observed_linear, + ) + print(f"time for quantization: {time.time() - t0:.02f} seconds") + if model_save_path is not None: + print(f"Saving model to {model_save_path}") + torch.save(model, model_save_path) + elif quant.startswith("int4wo"): + group_size = int(quant.split("-")[1]) + use_hqq = "hqq" in quant + print(f"running {quant} quantization with group size {group_size}") + int4_weight_only_config = int4_weight_only( + group_size=group_size, use_hqq=use_hqq + ) + if "xpu" in device: + int4_weight_only_config.layout = Int4XPULayout() + quantize_(model, int4_weight_only_config) + if compile: + model = torch.compile(model) + + return benchmark(model, tokenizer, sequence_length, tasks=tasks, device=device) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Evaluate a model with the specified parameters." + ) + + # Optional arguments with default values + parser.add_argument("repo", type=str, help="Repository ID of the model.") + parser.add_argument( + "quant", + type=str, + help="Quantization method. Options are either awq-uint- for x =[1..8], int4wo-, or int4wo--hqq.", + ) + parser.add_argument( + "--tasks", + type=list[str], + help="Task to benchmark model on. Either PPL or QA", + default=["PPL"], + ) + parser.add_argument( + "--calibration_samples", + type=int, + default=10, + help="Number of samples to use for calibration. Default is 10.", + ) + parser.add_argument( + "--validation_size", type=int, default=1, help="Validation size. Default is 1." + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to run the evaluation on. Default is 'cuda'.", + ) + parser.add_argument( + "--precision", + type=str, + default="bfloat16", + help="Precision type. Default is 'bfloat16'.", + ) + parser.add_argument( + "--seq_len", + type=int, + default=512, + help="Length of examples to calibrate and evaluate model on. Default is 512", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Flag to indicate if compilation is required.", + ) + parser.add_argument( + "--model_save_path", + type=str, + default=None, + help="Path to store the scale values.", + ) + + args = parser.parse_args() + + # Convert precision argument to torch dtype + precision_dtype = getattr(torch, args.precision, torch.bfloat16) + ppl = wikitext2_ppl( + args.repo, + args.quant, + args.tasks, + args.calibration_samples, + args.validation_size, + args.device, + args.precision, + args.seq_len, + args.compile, + args.model_save_path, + ) + + print(f"{args.quant} Results: {ppl}") diff --git a/torchao/prototype/moe_quant/utils.py b/torchao/prototype/moe_quant/utils.py index 0e75de2ee4..28291afdf4 100644 --- a/torchao/prototype/moe_quant/utils.py +++ b/torchao/prototype/moe_quant/utils.py @@ -20,18 +20,7 @@ dataclass, register_quantize_module_handler, ) -from torchao.utils import fill_defaults - - -class DummyModule(torch.nn.Module): - """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a - DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. - """ - - def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - super().__init__() - self.weight = weight - self.bias = bias +from torchao.utils import DummyModule, fill_defaults class FakeExtraDimTensor(torch.Tensor): diff --git a/torchao/utils.py b/torchao/utils.py index 416d23d785..f4a7d100af 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -11,7 +11,7 @@ from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable +from typing import Any, Callable, Optional import torch import torch.nn.utils.parametrize as parametrize @@ -42,6 +42,7 @@ "is_sm_at_least_89", "is_sm_at_least_90", "is_package_at_least", + "DummyModule", ] @@ -710,3 +711,14 @@ def is_package_at_least(package_name: str, min_version: str): return False return version(package_name) >= min_version + + +class DummyModule(torch.nn.Module): + """This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a + DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor. + """ + + def __init__(self, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): + super().__init__() + self.weight = weight + self.bias = bias