diff --git a/src/sparsetensors/utils/helpers.py b/src/sparsetensors/utils/helpers.py index 378ccd8c..8d06cde2 100644 --- a/src/sparsetensors/utils/helpers.py +++ b/src/sparsetensors/utils/helpers.py @@ -18,7 +18,6 @@ from sparsetensors.compressors import ModelCompressor from sparsetensors.config import CompressionConfig from sparsetensors.utils.helpers import SPARSITY_CONFIG_NAME -from torch.nn import Module from transformers import AutoConfig @@ -44,20 +43,3 @@ def infer_compressor_from_model_config( sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config) compressor = ModelCompressor.load_from_registry(format, config=sparsity_config) return compressor - - -def set_layer(target: str, layer: Module, module: Module) -> Module: - target = fix_fsdp_module_name(target) # noqa TODO - with summon_full_params_context(module): # noqa TODO - # importing here to avoid circular import - from sparseml.utils.fsdp.helpers import maybe_get_wrapped # noqa TODO - - parent_target = ".".join(target.split(".")[:-1]) - if parent_target != "": - parent_layer = get_layer(parent_target, module)[1] # noqa TODO - else: - parent_layer = maybe_get_wrapped(module) - old_layer = getattr(parent_layer, target.split(".")[-1]) - setattr(parent_layer, target.split(".")[-1], layer) - - return old_layer