diff --git a/src/llmcompressor/transformers/sparsification/sparse_model.py b/src/llmcompressor/transformers/sparsification/sparse_model.py index 6a148c784..c91af1299 100644 --- a/src/llmcompressor/transformers/sparsification/sparse_model.py +++ b/src/llmcompressor/transformers/sparsification/sparse_model.py @@ -74,7 +74,9 @@ def skip(*args, **kwargs): ) # instantiate compressor from model config - compressor = ModelCompressor.from_pretrained(pretrained_model_name_or_path) + compressor = ModelCompressor.from_pretrained( + pretrained_model_name_or_path, **kwargs + ) # temporarily set the log level to error, to ignore printing out long missing # and unexpected key error messages (these are EXPECTED for quantized models) @@ -82,9 +84,19 @@ def skip(*args, **kwargs): restore_log_level = transformers_logger.getEffectiveLevel() transformers_logger.setLevel(level=logging.ERROR) + if kwargs.get("trust_remote_code"): + # By artifically aliasing + # class name SparseAutoModelForCausallLM to + # AutoModelForCausalLM we can "trick" the + # `from_pretrained` method into properly + # resolving the logic when + # (has_remote_code and trust_remote_code) == True + cls.__name__ = AutoModelForCausalLM.__name__ + model = super(AutoModelForCausalLM, cls).from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) + if model.dtype != model.config.torch_dtype: logger.warning( f"The dtype of the loaded model: {model.dtype} is different "