Skip to content

Commit

Permalink
engine: Check if LLM is actually loaded
Browse files Browse the repository at this point in the history
engine/engine.py
engine/llm.py
  • Loading branch information
ShikiOkasaka committed Sep 21, 2024
1 parent 2f9eaff commit 5151b56
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
3 changes: 1 addition & 2 deletions engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,8 +769,7 @@ def _load_use_llm(self):
enabled = self._settings.get_boolean('use-llm')
use_cuda = self._settings.get_boolean('use-cuda')
LOGGER.info(f'use-llm: {enabled}, use-cuda: {use_cuda}')
llm.load(enabled, 'cuda' if use_cuda else 'cpu')
return enabled
return llm.load(enabled, 'cuda' if use_cuda else 'cpu')

def _lookup_dictionary(self, text, pos, anchor=0):
self._lookup_table.clear()
Expand Down
13 changes: 6 additions & 7 deletions engine/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ def loaded() -> bool:
return True


def load(enable: bool, device_type: str = 'cpu'):
def load(enable: bool, device_type: str = 'cpu') -> bool:
global device, model, katuyou_tokens, tokenizer, torch, yougen_tokens
if not enable:
return
return False
try:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
except ImportError:
LOGGER.warning('Could not import transformers')
return
return False
try:
if model is None:
if device_type == 'cuda' and torch.cuda.is_available():
Expand All @@ -69,8 +69,7 @@ def load(enable: bool, device_type: str = 'cpu'):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, local_files_only=True)
except OSError:
LOGGER.warning(f'Local {MODEL_NAME} is not found')
return

return False
if loaded() and not yougen_tokens:
try:
vocab = tokenizer.get_vocab()
Expand All @@ -83,7 +82,6 @@ def load(enable: bool, device_type: str = 'cpu'):
yomi = words[0]
words = words[1].strip(' \n/').split('/')
yougen_tokens[yomi] = [vocab[word] for word in words]

with open(os.path.join(package.get_datadir(), 'dic', 'katuyou_token.dic'), 'r') as f:
for line in f:
line = line.strip('')
Expand All @@ -93,9 +91,10 @@ def load(enable: bool, device_type: str = 'cpu'):
stem = words[0]
words = words[1].strip(' \n/').split('/')
katuyou_tokens[stem] = words

except OSError:
LOGGER.warning('Could not load "yougen_vocab.dic"')
return False
return True


def pick(prefix, candidates, yougen=-1, yougen_shrunk='', yougen_yomi='') -> dict[int, str]:
Expand Down

0 comments on commit 5151b56

Please sign in to comment.