Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explanation2 text #97

Merged
merged 5 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 9 additions & 31 deletions aleph_alpha_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

import aleph_alpha_client
from aleph_alpha_client.document import Document
from aleph_alpha_client.explanation import ExplanationRequest, ExplanationResponse
from aleph_alpha_client.explanation import (
ExplanationRequest,
ExplanationResponse,
ExplanationRequest,
ExplanationResponse,
)
from aleph_alpha_client.image import Image
from aleph_alpha_client.prompt import _to_json, _to_serializable_prompt
from aleph_alpha_client.summarization import SummarizationRequest, SummarizationResponse
Expand Down Expand Up @@ -832,34 +837,6 @@ def summarize(
_raise_for_status(response.status_code, response.text)
return response.json()

def _explain(
self,
model: str,
request: ExplanationRequest,
hosting: Optional[str] = None,
):
body = {
"model": model,
"prompt": [_to_json(item) for item in request.prompt.items],
"target": request.target,
"suppression_factor": request.suppression_factor,
"conceptual_suppression_threshold": request.conceptual_suppression_threshold,
"normalize": request.normalize,
"square_outputs": request.square_outputs,
"prompt_explain_indices": request.prompt_explain_indices,
}

if hosting is not None:
body["hosting"] = hosting

response = self.post_request(
f"{self.host}explain",
headers=self.request_headers,
json=body,
)
_raise_for_status(response.status_code, response.text)
return response.json()


AnyRequest = Union[
CompletionRequest,
Expand All @@ -871,6 +848,7 @@ def _explain(
QaRequest,
SummarizationRequest,
ExplanationRequest,
ExplanationRequest,
SearchRequest,
]

Expand Down Expand Up @@ -1278,7 +1256,7 @@ def _explain(
model: str,
) -> ExplanationResponse:
response = self._post_request(
"explain",
"explain2",
request,
model,
)
Expand Down Expand Up @@ -1737,7 +1715,7 @@ async def _explain(
model: str,
) -> ExplanationResponse:
response = await self._post_request(
"explain",
"explain2",
request,
model,
)
Expand Down
8 changes: 0 additions & 8 deletions aleph_alpha_client/aleph_alpha_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
SemanticEmbeddingResponse,
)
from aleph_alpha_client.evaluation import EvaluationRequest, EvaluationResponse
from aleph_alpha_client.explanation import ExplanationRequest
from aleph_alpha_client.qa import QaRequest, QaResponse
from aleph_alpha_client.tokenization import TokenizationRequest, TokenizationResponse
from aleph_alpha_client.summarization import SummarizationRequest, SummarizationResponse
Expand Down Expand Up @@ -158,13 +157,6 @@ def qa(self, request: QaRequest) -> QaResponse:
)
return QaResponse.from_json(response_json)

def _explain(self, request: ExplanationRequest) -> Mapping[str, Any]:
return self.client._explain(
model=self.model_name,
hosting=self.hosting,
request=request,
)

def summarize(self, request: SummarizationRequest) -> SummarizationResponse:
response_json = self.client.summarize(
self.model_name,
Expand Down
116 changes: 108 additions & 8 deletions aleph_alpha_client/explanation.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,129 @@
from typing import Any, List, Dict, NamedTuple, Optional
from typing import Any, Generic, List, Dict, NamedTuple, Optional, TypeVar, Union
from aleph_alpha_client.prompt import Prompt


class ExplanationRequest(NamedTuple):
prompt: Prompt
target: str
suppression_factor: float
conceptual_suppression_threshold: Optional[float] = None
normalize: Optional[bool] = None
square_outputs: Optional[bool] = None
prompt_explain_indices: Optional[List[int]] = None

def to_json(self) -> Dict[str, Any]:
payload = self._asdict()
payload["prompt"] = self.prompt.to_json()
return payload


class TextScore(NamedTuple):
start: int
length: int
score: float

@staticmethod
def from_json(score: Any) -> "TextScore":
return TextScore(
start=score["start"],
length=score["length"],
score=score["score"],
)


class TargetScore(NamedTuple):
start: int
length: int
score: float

@staticmethod
def from_json(score: Any) -> "TargetScore":
return TargetScore(
start=score["start"],
length=score["length"],
score=score["score"],
)


class TokenScore(NamedTuple):
score: float

@staticmethod
def from_json(score: Any) -> "TokenScore":
return TokenScore(
score=score,
)


class TextPromptItemExplanation(NamedTuple):
scores: List[TextScore]

@staticmethod
def from_json(item: Dict[str, Any]) -> "TextPromptItemExplanation":
return TextPromptItemExplanation(
scores=[TextScore.from_json(score) for score in item["scores"]]
)


class TargetPromptItemExplanation(NamedTuple):
scores: List[TargetScore]

@staticmethod
def from_json(item: Dict[str, Any]) -> "TargetPromptItemExplanation":
return TargetPromptItemExplanation(
scores=[TargetScore.from_json(score) for score in item["scores"]]
)


class TokenPromptItemExplanation(NamedTuple):
scores: List[TokenScore]

@staticmethod
def from_json(item: Dict[str, Any]) -> "TokenPromptItemExplanation":
return TokenPromptItemExplanation(
scores=[TokenScore.from_json(score) for score in item["scores"]]
)


class Explanation(NamedTuple):
target: str
items: List[
Union[
TextPromptItemExplanation,
TargetPromptItemExplanation,
TokenPromptItemExplanation,
]
]

def prompt_item_from_json(
item: Any,
) -> Union[
TextPromptItemExplanation,
TargetPromptItemExplanation,
TokenPromptItemExplanation,
]:
if item["type"] == "text":
return TextPromptItemExplanation.from_json(item)
elif item["type"] == "target":
return TargetPromptItemExplanation.from_json(item)
elif item["type"] == "token_ids":
return TokenPromptItemExplanation.from_json(item)
else:
raise NotImplementedError("Unsupported explanation type")

@staticmethod
def from_json(json: Dict[str, Any]) -> "Explanation":
return Explanation(
target=json["target"],
items=[Explanation.prompt_item_from_json(item) for item in json["items"]],
)


class ExplanationResponse(NamedTuple):
model_version: str
result: List[Any]
explanations: List[Explanation]

@staticmethod
def from_json(json: Dict[str, Any]) -> "ExplanationResponse":
return ExplanationResponse(
model_version=json["model_version"],
result=json["result"],
explanations=[
Explanation.from_json(explanation)
for explanation in json["explanations"]
],
)
37 changes: 24 additions & 13 deletions tests/test_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from aleph_alpha_client import ExplanationRequest, AlephAlphaClient
from aleph_alpha_client.aleph_alpha_client import AsyncClient, Client
from aleph_alpha_client.aleph_alpha_model import AlephAlphaModel
from aleph_alpha_client.prompt import Prompt
from aleph_alpha_client.explanation import ExplanationRequest
from aleph_alpha_client.prompt import Prompt, Text

from tests.common import (
sync_client,
Expand All @@ -20,28 +21,38 @@ async def test_can_explain_with_async_client(
async_client: AsyncClient, model_name: str
):
request = ExplanationRequest(
prompt=Prompt.from_text("An apple a day"),
target=" keeps the doctor away",
suppression_factor=0.1,
prompt=Prompt(
[
Text.from_text("I am a programmer and French. My favourite food is"),
# " My favorite food is"
[4014, 36316, 5681, 387],
]
),
target=" pizza with cheese",
)

response = await async_client._explain(request, model=model_name)
assert response.result
explanation = await async_client._explain(request, model=model_name)

assert len(explanation.explanations) == 3
assert all([len(exp.items) == 3 for exp in explanation.explanations])


# Client


def test_explanation(sync_client: Client, model_name: str):
request = ExplanationRequest(
prompt=Prompt.from_text("An apple a day"),
target=" keeps the doctor away",
suppression_factor=0.1,
prompt=Prompt(
[
Text.from_text("I am a programmer and French. My favourite food is"),
# " My favorite food is"
[4014, 36316, 5681, 387],
]
),
target=" pizza with cheese",
)

explanation = sync_client._explain(request, model=model_name)

assert len(explanation.result) > 0


# AlephAlphaClient
assert len(explanation.explanations) == 3
assert all([len(exp.items) == 3 for exp in explanation.explanations])