Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
feat: update list_models function
Browse files Browse the repository at this point in the history
  • Loading branch information
lmmilliken committed Apr 14, 2023
1 parent d85f2d4 commit 4d627dd
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 16 deletions.
35 changes: 24 additions & 11 deletions finetuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from finetuner.constants import (
DEFAULT_FINETUNER_HOST,
DEFAULT_HUBBLE_REGISTRY,
EMBEDDING,
HOST,
HUBBLE_REGISTRY,
)
Expand Down Expand Up @@ -65,9 +66,16 @@ def _build_name_stub_map() -> Dict[str, model_stub.ModelStubType]:
return rv


def list_models() -> List[str]:
"""List available models."""
return [name for name in list_model_classes()]
def list_models(model_type: str = EMBEDDING) -> List[str]:
"""List available models.
:param type: The type of backbone model, one of 'embedding', 'cross_ecoding' or
'relation_mining'. 'embedding' by default.
"""
return [
stub.display_name for stub in list_model_classes(model_type=model_type).values()
]


def list_model_options() -> Dict[str, List[Dict[str, Any]]]:
Expand All @@ -91,16 +99,19 @@ def list_model_options() -> Dict[str, List[Dict[str, Any]]]:
}


def describe_models(task: Optional[str] = None) -> None:
def describe_models(task: Optional[str] = None, model_type: str = EMBEDDING) -> None:
"""Print model information, such as name, task, output dimension, architecture
and description as a table.
:param task: The task for the backbone model, one of `text-to-text`,
`text-to-image`, `image-to-image`. If not provided, will print all backbone
models.
:param type: The type of backbone model, one of 'embedding', 'cross_ecoding' or
'relation_mining'. 'embedding' by default, the `task` parameter will be ignored
if this is set to anything else.
"""
print_model_table(model, task=task)
print_model_table(model, task=task, model_type=model_type)


@login_required
Expand Down Expand Up @@ -294,18 +305,20 @@ def synthesize(
"""Create a Finetuner generation :class:`Run`, calling this function will submit a
data generation job to the Jina AI Cloud.
:param query_data: Either a :class:`DocumentArray` for example queries. can be the
name of a `DocumentArray` that is pushed on Jina AI Cloud, the dataset itself as
:param query_data: Either a :class:`DocumentArray` for example queries, name of a
`DocumentArray` that is pushed on Jina AI Cloud, the dataset itself as
a list of strings or a path to a CSV file.
:param corpus_data: Either a :class:`DocumentArray` for corpus data, a name of a
`DocumentArray` that is pushed on Jina AI Cloud, the dataset itself as a
list of strings or a path to a CSV file.
:param mining_models: The name or a list of names of models to be used during
relation mining. Run `finetuner.list_models()` or `finetuner.describe_models()`
to see the available model names. #TODO double check this
relation mining. Run `finetuner.list_models(model_type='relation_mining')` or
`finetuner.describe_models(model_type='relation_mining')` to see the
available model names.
:param cross_encoder_model: The name of the model to be used as the cross-encoder.
Run `finetuner.list_models()` or `finetuner.describe_models()` to see the
available model names. #TODO double check this
Run `finetuner.list_models(model_type='cross_encoding')` or
`finetuner.describe_models(model_type='cross_encoding')` to see the
available model names.
:param num_relations: The number of relations to mine per query.
:param max_num_docs: The maximum number of documents to consider.
:param run_name: Name of the run.
Expand Down
5 changes: 3 additions & 2 deletions finetuner/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from rich.console import Console
from rich.table import Table

from finetuner.constants import EMBEDDING
from finetuner.model import list_model_classes

console = Console()


def print_model_table(model, task: Optional[str] = None):
def print_model_table(model, task: Optional[str] = None, model_type: str = EMBEDDING):
"""Prints a table of model descriptions.
:param model: Module with model definitions
Expand All @@ -24,7 +25,7 @@ def print_model_table(model, task: Optional[str] = None):
for column in header:
table.add_column(column, justify='right', style='cyan', no_wrap=False)

for _, _model_class in list_model_classes().items():
for _, _model_class in list_model_classes(model_type=model_type).items():
if _model_class.display_name not in model_display_names:
row = model.get_row(_model_class)
if task and row[1] != task:
Expand Down
3 changes: 3 additions & 0 deletions finetuner/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,6 @@
MODELS = 'models'
NUM_RELATIONS = 'num_relations'
MAX_NUM_DOCS = 'max_num_docs'
# Stub types
EMBEDDING = 'embedding'
CROSS_ENCODING = 'cross_encoding'
18 changes: 15 additions & 3 deletions finetuner/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from _finetuner.runner.stubs import model
from _finetuner.runner.stubs.model import * # noqa F401
from _finetuner.runner.stubs.model import _EmbeddingModelStub
from _finetuner.runner.stubs.model import (
_CrossEncoderStub,
_EmbeddingModelStub,
_TextTransformerStub,
)

from finetuner.constants import CROSS_ENCODING, EMBEDDING, RELATION_MINING


def get_header() -> Tuple[str, ...]:
Expand All @@ -19,15 +25,21 @@ def get_row(model_stub) -> Tuple[str, ...]:
)


def list_model_classes() -> Dict[str, ModelStubType]:
def list_model_classes(model_type: str = EMBEDDING) -> Dict[str, ModelStubType]:
rv = {}
members = inspect.getmembers(model, inspect.isclass)
if model_type == EMBEDDING:
parent_class = _EmbeddingModelStub
elif model_type == CROSS_ENCODING:
parent_class = _CrossEncoderStub
elif model_type == RELATION_MINING:
parent_class = _TextTransformerStub
for name, stub in members:
if (
name != 'MLPStub'
and not name.startswith('_')
and type(stub) != type
and issubclass(stub, _EmbeddingModelStub)
and issubclass(stub, parent_class)
):
rv[name] = stub
return rv

0 comments on commit 4d627dd

Please sign in to comment.