-
Notifications
You must be signed in to change notification settings - Fork 1
/
embedding.py
44 lines (35 loc) · 1.49 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
class EmbeddingDownLoader() :
'''
property : sentence_transformers (pip install transformers)
HuggingFace Embedding Model 쉽게 받아오려고 쓰는거.
모델 이름은 기본으로 저장된 all-MiniLM-L6-v2 말고 다른거 쓰고싶으면 HuggingFace에서 검색해서 받으면 됩니다.
'''
def __init__ (
self,
model : str = 'BAAI/bge-base-en-v1.5',
path = None,
) -> None :
try :
from sentence_transformers import SentenceTransformer
except ImportError :
raise ImportError(
"package not found. try install sentence_transformers."
"try following command : pip install sentence-transformers"
)
self.model = model
self.path = path
def download(self) -> None :
from sentence_transformers import SentenceTransformer
# 경로 지정 안해놨으면 현재 스크립트 실행되는 곳에 모델명으로 파일 생성되게 만듦.
if self.path is None :
self.path = os.path.join(os.path.dirname(os.path.abspath(__file__)),self.model)
### 다운로드 함수.
downloader = SentenceTransformer(model_name_or_path=self.model)
os.makedirs(self.path)
downloader.save(self.path)
print(f'model {self.model} download at path {self.path}.')
return None
if __name__ == "__main__" :
loader = EmbeddingDownLoader()
loader.download()