Skip to content

Commit

Permalink
#106 done
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Oct 15, 2023
1 parent 5663361 commit c716578
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
5 changes: 3 additions & 2 deletions arelight/pipelines/items/inference_bert_opennre.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ def init_bert_model(pretrain_path, rel2id, ckpt_path, device_type, dir_to_donwlo
""" This is a main and core method for inference based on OpenNRE framework.
"""
# Check predefined checkpoints for local downloading.
predefined_ckpt_path = try_download_predefined_checkpoint(
predefined_pretrain_path, predefined_ckpt_path = try_download_predefined_checkpoint(
checkpoint=ckpt_path, dir_to_download=dir_to_donwload)

# Update checkpoint path with the predefined one.
# Update checkpoint and pretrain paths with the predefined.
ckpt_path = predefined_ckpt_path if predefined_ckpt_path is not None else ckpt_path
pretrain_path = predefined_pretrain_path if predefined_pretrain_path is not None else pretrain_path

# Load original model.
bert_encoder = BertOpenNREInferencePipelineItem.load_bert_sentence_encoder(
Expand Down
24 changes: 14 additions & 10 deletions arelight/pipelines/items/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,27 @@ def try_download_predefined_checkpoint(checkpoint, dir_to_download):
assert(isinstance(checkpoint, str))
assert(isinstance(dir_to_download, str))

predefined_checkponts = {
"ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar":
"https://www.dropbox.com/scl/fi/rwjf7ag3w3z90pifeywrd/ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar?rlkey=p0mmu81o6c2u6iboe9m20uzqk&dl=1",
"ra4-rsr1_bert-base-cased_cls.pth.tar":
"https://www.dropbox.com/scl/fi/k5arragv1g4wwftgw5xxd/ra-rsr_bert-base-cased_cls.pth.tar?rlkey=8hzavrxunekf0woesxrr0zqys&dl=1"
predefined = {
"ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar": {
"state": "DeepPavlov/rubert-base-cased",
"checkpoint": "https://www.dropbox.com/scl/fi/rwjf7ag3w3z90pifeywrd/ra4-rsr1_DeepPavlov-rubert-base-cased_cls.pth.tar?rlkey=p0mmu81o6c2u6iboe9m20uzqk&dl=1",
},
"ra4-rsr1_bert-base-cased_cls.pth.tar": {
"state": "bert-base-cased",
"checkpoint": "https://www.dropbox.com/scl/fi/k5arragv1g4wwftgw5xxd/ra-rsr_bert-base-cased_cls.pth.tar?rlkey=8hzavrxunekf0woesxrr0zqys&dl=1"
}
}

if checkpoint in predefined_checkponts:
url = predefined_checkponts[checkpoint]
if checkpoint in predefined:
data = predefined[checkpoint]
target_path = join(dir_to_download, checkpoint)

logger.info("Found predefined checkpoint: {}".format(checkpoint))
# No need to do anything, file has been already downloaded.
if not exists(target_path):
logger.info("Downloading checkpoint to: {}".format(target_path))
download(dest_file_path=target_path, source_url=url)
download(dest_file_path=target_path, source_url=data["checkpoint"])

return target_path
return data["state"], data["checkpoint"]

return None
return None, None

0 comments on commit c716578

Please sign in to comment.