From 5008d7b27295167ccfad0348c925b7349644ecf0 Mon Sep 17 00:00:00 2001 From: hankcs Date: Tue, 8 Oct 2024 01:37:15 -0700 Subject: [PATCH] Improve the safety of `torch.load` with `weights_only=True` --- hanlp/common/torch_component.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hanlp/common/torch_component.py b/hanlp/common/torch_component.py index 6d8e07f6a..7750ab1d7 100644 --- a/hanlp/common/torch_component.py +++ b/hanlp/common/torch_component.py @@ -97,7 +97,10 @@ def load_weights(self, save_dir, filename='model.pt', **kwargs): save_dir = get_resource(save_dir) filename = os.path.join(save_dir, filename) # flash(f'Loading model: {filename} [blink]...[/blink][/yellow]') - self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False) + try: + self.model_.load_state_dict(torch.load(filename, map_location='cpu', weights_only=True), strict=False) + except TypeError: + self.model_.load_state_dict(torch.load(filename, map_location='cpu'), strict=False) # flash('') def save_config(self, save_dir, filename='config.json'):