Skip to content

Commit

Permalink
Enable visualization of attribute-context after save
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Jan 23, 2024
1 parent 06fe1b4 commit 5742858
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 18 deletions.
2 changes: 2 additions & 0 deletions inseq/commands/attribute_context/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .attribute_context import AttributeContextCommand
from .attribute_context_args import AttributeContextArgs
from .attribute_context_helpers import AttributeContextOutput, CCIOutput
from .attribute_context_viz_helpers import visualize_attribute_context

__all__ = [
"AttributeContextCommand",
"AttributeContextArgs",
"AttributeContextOutput",
"CCIOutput",
"visualize_attribute_context",
]
10 changes: 6 additions & 4 deletions inseq/commands/attribute_context/attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
get_source_target_cci_scores,
prepare_outputs,
)
from .attribute_context_viz_helpers import handle_visualization
from .attribute_context_viz_helpers import visualize_attribute_context

warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
Expand Down Expand Up @@ -134,7 +134,7 @@ def attribute_context(args: AttributeContextArgs):
output_current=args.output_current_text,
output_current_tokens=output_current_tokens,
cti_scores=cti_scores,
info=args if args.add_output_info else None,
info=args,
)
# Part 2: Contextual Cues Imputation (CCI)
for cti_idx, cti_score, cti_tok in cti_ranked_tokens:
Expand Down Expand Up @@ -202,11 +202,13 @@ def attribute_context(args: AttributeContextArgs):
output_context_scores=target_scores,
)
output.cci_scores.append(cci_out)
if args.show_viz or args.viz_path:
visualize_attribute_context(output, model, cti_threshold)
if not args.add_output_info:
output.info = None
if args.save_path:
with open(args.save_path, "w") as f:
json.dump(output.to_dict(), f, indent=4)
if args.show_viz or args.viz_path:
handle_visualization(args, model, output, cti_threshold)
return output


Expand Down
16 changes: 14 additions & 2 deletions inseq/commands/attribute_context/attribute_context_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import re
from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields
from typing import Any, Optional

from rich import print as rprint
Expand Down Expand Up @@ -59,6 +59,18 @@ def to_dict(self) -> dict[str, Any]:
out_dict["info"] = self.info.to_dict()
return out_dict

@classmethod
def from_dict(cls, out_dict: dict[str, Any]) -> "AttributeContextOutput":
out = cls()
for k, v in out_dict.items():
if k not in ["cci_scores", "info"]:
setattr(out, k, v)
out.cci_scores = [CCIOutput(**cci_out) for cci_out in out_dict["cci_scores"]]
if "info" in out_dict:
not_init_fields = [f.name for f in fields(AttributeContextArgs) if not f.init]
out.info = AttributeContextArgs(**{k: v for k, v in out_dict["info"].items() if k not in not_init_fields})
return out


def format_template(template: str, current: str, context: Optional[str] = None) -> str:
kwargs = {"current": current}
Expand Down Expand Up @@ -290,7 +302,7 @@ def filter_rank_tokens(
indices = list(range(0, len(scores)))
token_score_tuples = sorted(zip(indices, scores, tokens), key=lambda x: abs(x[1]), reverse=True)
threshold = None
if std_threshold:
if std_threshold is not None:
threshold = tensor(scores).mean() + std_threshold * tensor(scores).std()
token_score_tuples = [(i, s, t) for i, s, t in token_score_tuples if abs(s) >= threshold]
if topk:
Expand Down
49 changes: 37 additions & 12 deletions inseq/commands/attribute_context/attribute_context_viz_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Literal, Optional
from copy import deepcopy
from typing import Literal, Optional, Union

from rich.console import Console
from torch import tensor

from ... import load_model
from ...models import HuggingfaceModel
from .attribute_context_args import AttributeContextArgs
from .attribute_context_helpers import AttributeContextOutput, filter_rank_tokens, get_filtered_tokens
Expand Down Expand Up @@ -110,19 +113,41 @@ def format_context_comment(
return out_string


def handle_visualization(
args: AttributeContextArgs,
model: HuggingfaceModel,
def visualize_attribute_context(
output: AttributeContextOutput,
cti_threshold: float,
) -> None:
model: Union[HuggingfaceModel, str, None] = None,
cti_threshold: Optional[float] = None,
return_html: bool = False,
) -> Optional[str]:
if output.info is None:
raise ValueError("Cannot visualize attribution results without args. Set add_output_info = True.")
console = Console(record=True)
viz = get_formatted_procedure_details(args)
viz += "\n\n" + get_formatted_attribute_context_results(model, args, output, cti_threshold)
if args.viz_path:
viz = get_formatted_procedure_details(output.info)
if model is None:
model = output.info.model_name_or_path
if isinstance(model, str):
model = load_model(
output.info.model_name_or_path,
output.info.attribution_method,
model_kwargs=deepcopy(output.info.model_kwargs),
tokenizer_kwargs=deepcopy(output.info.tokenizer_kwargs),
)
elif not isinstance(model, HuggingfaceModel):
raise TypeError(f"Unsupported model type {type(model)} for visualization.")
if cti_threshold is None:
cti_threshold = (
tensor(output.cti_scores).mean()
+ output.info.context_sensitivity_std_threshold * tensor(output.cti_scores).std()
)
viz += "\n\n" + get_formatted_attribute_context_results(model, output.info, output, cti_threshold)
html = console.export_html()
if output.info.viz_path:
with console.capture() as _:
console.print(viz, soft_wrap=False)
with open(args.viz_path, "w") as f:
f.write(console.export_html())
if args.show_viz:
with open(output.info.viz_path, "w") as f:
f.write(html)
if output.info.show_viz:
console.print(viz, soft_wrap=False)
if return_html:
return html
return None
17 changes: 17 additions & 0 deletions tests/commands/test_attribute_context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

import pytest
from pytest import fixture
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, GPT2LMHeadModel, MarianMTModel
Expand Down Expand Up @@ -335,3 +337,18 @@ def test_in_out_ctx_encdec_langtag_whitespace_sep():
)
cli_out = attribute_context(in_out_ctx_encdec_langtag_whitespace_sep)
assert round_scores(cli_out) == expected_output


def test_save_reload_attribute_context_outputs(tmp_path):
args = AttributeContextArgs(
model_name_or_path="gpt2",
input_context_text="George was sick yesterday.",
input_current_text="His colleagues asked him to come",
attributed_fn="contrast_prob_diff",
show_viz=False,
save_path=str(tmp_path) + "/test.json",
)
out_pre_save = attribute_context(args)
with open(tmp_path / "test.json") as f:
out_post_save = AttributeContextOutput.from_dict(json.load(f))
assert out_pre_save == out_post_save

0 comments on commit 5742858

Please sign in to comment.