Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update library defaults to use onnx #288

Merged
merged 1 commit into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions langkit/metrics/embeddings_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
from sentence_transformers import SentenceTransformer


class TransformerEmbeddingAdapter:
class EmbeddingEncoder(Protocol):
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
...


class TransformerEmbeddingAdapter(EmbeddingEncoder):
def __init__(self, transformer: SentenceTransformer):
self._transformer = transformer

@lru_cache(maxsize=6, typed=True)
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
return torch.as_tensor(self._transformer.encode(sentences=list(text))) # type: ignore[reportUnknownMemberType]


class EmbeddingEncoder(Protocol):
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor":
...
def encode(self, text: Tuple[str, ...]) -> "torch.Tensor": # pyright: ignore[reportIncompatibleMethodOverride]
return torch.as_tensor(self._transformer.encode(sentences=list(text), show_progress_bar=False)) # type: ignore[reportUnknownMemberType]


class CachingEmbeddingEncoder(EmbeddingEncoder):
Expand Down
24 changes: 10 additions & 14 deletions langkit/metrics/injections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langkit.config import LANGKIT_CACHE
from langkit.core.metric import Metric, SingleMetric, SingleMetricResult
from langkit.metrics.util import retry
from langkit.transformer import embedding_adapter, sentence_transformer
from langkit.transformer import embedding_adapter

logger = getLogger(__name__)

Expand All @@ -34,7 +34,8 @@ def __cache_embeddings(harm_embeddings: pd.DataFrame, embeddings_path: str, file
logger.warning(f"Injections - unable to serialize embeddings to {embeddings_path_local}. Error: {serialization_error}")


def __download_embeddings(filename: str) -> pd.DataFrame:
def __download_embeddings(version: str) -> pd.DataFrame:
filename = f"embeddings_{__transformer_name}_harm_{version}.parquet"
embeddings_path_remote: str = __injections_base_url + filename
embeddings_path_local: str = os.path.join(LANGKIT_INJECTIONS_CACHE, filename)
try:
Expand All @@ -60,18 +61,17 @@ def __process_embeddings(harm_embeddings: pd.DataFrame) -> "np.ndarray[Any, Any]

@lru_cache
def _get_embeddings(version: str) -> "np.ndarray[Any, Any]":
filename = f"embeddings_{__transformer_name}_harm_{version}.parquet"
harm_embeddings = __download_embeddings(filename)
embeddings_norm = __process_embeddings(harm_embeddings)
return embeddings_norm
return __process_embeddings(__download_embeddings(version))


def injections_metric(column_name: str, version: str = "v2", onnx: bool = True) -> Metric:
def cache_assets():
_get_embeddings(version)
__download_embeddings(version)
embedding_adapter(onnx)

def init():
embedding_adapter()
_get_embeddings(version)
embedding_adapter(onnx)

def udf(text: pd.DataFrame) -> SingleMetricResult:
if column_name not in text.columns:
Expand All @@ -80,12 +80,8 @@ def udf(text: pd.DataFrame) -> SingleMetricResult:

input_series: "pd.Series[str]" = cast("pd.Series[str]", text[column_name])

if onnx:
_transformer = embedding_adapter()
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(tuple(input_series)).numpy()
else:
_transformer = sentence_transformer()
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(list(input_series), show_progress_bar=False) # pyright: ignore[reportAssignmentType, reportUnknownMemberType]
_transformer = embedding_adapter(onnx)
target_embeddings: npt.NDArray[np.float32] = _transformer.encode(tuple(input_series)).numpy()

target_norms = target_embeddings / np.linalg.norm(target_embeddings, axis=1, keepdims=True)
cosine_similarities = np.dot(_embeddings, target_norms.T)
Expand Down
10 changes: 7 additions & 3 deletions langkit/metrics/input_output_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from langkit.transformer import embedding_adapter


def input_output_similarity_metric(input_column_name: str = "prompt", output_column_name: str = "response") -> Metric:
def input_output_similarity_metric(input_column_name: str = "prompt", output_column_name: str = "response", onnx: bool = True) -> Metric:
def cache_assets():
embedding_adapter(onnx)

def init():
embedding_adapter()
embedding_adapter(onnx)

def udf(text: pd.DataFrame) -> SingleMetricResult:
in_np = UdfInput(text).to_list(input_column_name)
out_np = UdfInput(text).to_list(output_column_name)
encoder = embedding_adapter()
encoder = embedding_adapter(onnx)
similarity = compute_embedding_similarity(encoder, in_np, out_np)

if len(similarity.shape) == 1:
Expand All @@ -27,6 +30,7 @@ def udf(text: pd.DataFrame) -> SingleMetricResult:
input_names=[input_column_name, output_column_name],
evaluate=udf,
init=init,
cache_assets=cache_assets,
)


Expand Down
20 changes: 10 additions & 10 deletions langkit/metrics/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,27 +249,27 @@ def __call__(self) -> MetricCreator:
]

@staticmethod
def injection(version: Optional[str] = None) -> MetricCreator:
def injection(version: Optional[str] = None, onnx: bool = True) -> MetricCreator:
"""
Analyze the input for injection themes. The injection score is a measure of how similar the input is
to known injection examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.injections import injections_metric, prompt_injections_metric

if version:
return partial(injections_metric, column_name="prompt", version=version)
return partial(injections_metric, column_name="prompt", version=version, onnx=onnx)

return prompt_injections_metric

@staticmethod
def jailbreak() -> MetricCreator:
def jailbreak(onnx: bool = True) -> MetricCreator:
"""
Analyze the input for jailbreak themes. The jailbreak score is a measure of how similar the input is
to known jailbreak examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.themes.themes import prompt_jailbreak_similarity_metric

return prompt_jailbreak_similarity_metric
return partial(prompt_jailbreak_similarity_metric, onnx=onnx)

class sentiment:
def __call__(self) -> MetricCreator:
Expand Down Expand Up @@ -302,7 +302,7 @@ def __call__(self) -> MetricCreator:
return partial(topic_metric, "prompt", self.topics, self.hypothesis_template)

@staticmethod
def medicine(onnx: bool = False) -> MetricCreator:
def medicine(onnx: bool = True) -> MetricCreator:
if onnx:
from langkit.metrics.topic_onnx import topic_metric

Expand Down Expand Up @@ -486,24 +486,24 @@ def __call__(self) -> MetricCreator:
]

@staticmethod
def prompt() -> MetricCreator:
def prompt(onnx: bool = True) -> MetricCreator:
"""
Analyze the similarity between the input and the response. The output of this metric ranges from 0 to 1,
where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.input_output_similarity import prompt_response_input_output_similarity_metric

return prompt_response_input_output_similarity_metric
return partial(prompt_response_input_output_similarity_metric, onnx=onnx)

@staticmethod
def refusal() -> MetricCreator:
def refusal(onnx: bool = True) -> MetricCreator:
"""
Analyze the response for refusal themes. The refusal score is a measure of how similar the response is
to known refusal examples, where 0 indicates no similarity and 1 indicates a high similarity.
"""
from langkit.metrics.themes.themes import response_refusal_similarity_metric

return response_refusal_similarity_metric
return partial(response_refusal_similarity_metric, onnx=onnx)

class topics:
def __init__(self, topics: List[str], hypothesis_template: Optional[str] = None, onnx: bool = True):
Expand All @@ -522,7 +522,7 @@ def __call__(self) -> MetricCreator:
return partial(topic_metric, "response", self.topics, self.hypothesis_template)

@staticmethod
def medicine(onnx: bool = False) -> MetricCreator:
def medicine(onnx: bool = True) -> MetricCreator:
if onnx:
from langkit.metrics.topic_onnx import topic_metric

Expand Down
10 changes: 7 additions & 3 deletions langkit/metrics/themes/themes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,21 @@ def _get_themes(encoder: TransformerEmbeddingAdapter) -> Dict[str, torch.Tensor]
return {group: torch.as_tensor(encoder.encode(tuple(themes))) for group, themes in theme_groups.items()}


def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"]) -> Metric:
def __themes_metric(column_name: str, themes_group: Literal["jailbreak", "refusal"], onnx: bool = True) -> Metric:
if themes_group == "refusal" and column_name == "prompt":
raise ValueError("Refusal themes are not applicable to prompt")

if themes_group == "jailbreak" and column_name == "response":
raise ValueError("Jailbreak themes are not applicable to response")

def cache_assets():
_get_themes(embedding_adapter())
embedding_adapter(onnx)

def init():
_get_themes(embedding_adapter(onnx))

def udf(text: pd.DataFrame) -> SingleMetricResult:
encoder = embedding_adapter()
encoder = embedding_adapter(onnx)
theme = _get_themes(encoder)[themes_group] # (n_theme_examples, embedding_dim)
text_list: List[str] = text[column_name].tolist()
encoded_text = encoder.encode(tuple(text_list)) # (n_input_rows, embedding_dim)
Expand All @@ -84,6 +87,7 @@ def udf(text: pd.DataFrame) -> SingleMetricResult:
input_names=[column_name],
evaluate=udf,
cache_assets=cache_assets,
init=init,
)


Expand Down
12 changes: 7 additions & 5 deletions langkit/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import torch
from sentence_transformers import SentenceTransformer

from langkit.metrics.embeddings_types import CachingEmbeddingEncoder, EmbeddingEncoder
from langkit.metrics.embeddings_types import CachingEmbeddingEncoder, EmbeddingEncoder, TransformerEmbeddingAdapter
from langkit.onnx_encoder import OnnxSentenceTransformer, TransformerModel


@lru_cache
def sentence_transformer(
def _sentence_transformer(
name_revision: Tuple[str, str] = ("all-MiniLM-L6-v2", "44eb4044493a3c34bc6d7faae1a71ec76665ebc6"),
) -> SentenceTransformer:
"""
Expand All @@ -25,5 +24,8 @@ def sentence_transformer(


@lru_cache
def embedding_adapter() -> EmbeddingEncoder:
return CachingEmbeddingEncoder(OnnxSentenceTransformer(TransformerModel.AllMiniLM))
def embedding_adapter(onnx: bool = True) -> EmbeddingEncoder:
if onnx:
return CachingEmbeddingEncoder(OnnxSentenceTransformer(TransformerModel.AllMiniLM))
else:
return TransformerEmbeddingAdapter(_sentence_transformer())
Loading