Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added transformers_config for passing arguments to the transformer #268

Merged
merged 25 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions spacy_transformers/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def transformer_listener_tok2vec_v1(
) -> Model[List[Doc], List[Floats2d]]:
"""Create a 'TransformerListener' layer, which will connect to a Transformer
component earlier in the pipeline.

The layer takes a list of Doc objects as input, and produces a list of
2d arrays as output, with each array having one row per token. Most spaCy
models expect a sublayer with this signature, making it easy to connect them
Expand Down Expand Up @@ -46,7 +46,7 @@ def transformer_listener_tok2vec_v1(
def transformer_tok2vec_v1(
name: str,
get_spans,
tokenizer_config,
tokenizer_config: dict,
pooling: Model[Ragged, Floats2d],
grad_factor: float = 1.0,
) -> Model[List[Doc], List[Floats2d]]:
Expand Down Expand Up @@ -74,6 +74,42 @@ def transformer_tok2vec_v1(
)


@registry.architectures.register("spacy-transformers.Tok2VecTransformer.v2")
def transformer_tok2vec_v2(
name: str,
get_spans,
tokenizer_config: dict,
transformer_config: dict,
pooling: Model[Ragged, Floats2d],
grad_factor: float = 1.0,
) -> Model[List[Doc], List[Floats2d]]:
"""Use a transformer as a "Tok2Vec" layer directly. This does not allow
multiple components to share the transformer weights, and does not allow
the transformer to set annotations into the `Doc` object, but it's a
simpler solution if you only need the transformer within one component.

get_spans (Callable[[List[Doc]], List[List[Span]]]): A function to extract
spans from the batch of Doc objects. See the "TransformerModel" layer
for details.
tokenizer_config (dict): Settings to pass to the transformers tokenizer.
transformers_config (dict): Settings to pass to the transformers forward pass
of the transformer.
pooling (Model[Ragged, Floats2d]): A reduction layer used to calculate
the token vectors based on zero or more wordpiece vectors. If in doubt,
mean pooling (see `thinc.layers.reduce_mean`) is usually a good choice.
grad_factor (float): Reweight gradients from the component before passing
them to the transformer. You can set this to 0 to "freeze" the transformer
weights with respect to the component, or to make it learn more slowly.
Leaving it at 1.0 is usually fine.
"""
return chain(
TransformerModel(name, get_spans, tokenizer_config, transformer_config),
split_trf_batch(),
trfs2arrays(pooling, grad_factor),
)



registry.architectures.register(
"spacy-transformers.TransformerModel.v1", func=TransformerModel
)
14 changes: 12 additions & 2 deletions spacy_transformers/data_classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Tuple
import torch
import numpy
from transformers.tokenization_utils import BatchEncoding
Expand Down Expand Up @@ -155,11 +155,14 @@ class TransformerData:
wordpieces: WordpieceBatch
tensors: List[FloatsXd]
align: Ragged
attention: Optional[Tuple[FloatsXd, ...]] = None

@classmethod
def empty(cls) -> "TransformerData":
align = Ragged(numpy.zeros((0,), dtype="i"), numpy.zeros((0,), dtype="i"))
return cls(wordpieces=WordpieceBatch.empty(), tensors=[], align=align)
return cls(
wordpieces=WordpieceBatch.empty(), tensors=[], align=align, attention=None
)

@classmethod
def zeros(cls, length: int, width: int, *, xp=numpy) -> "TransformerData":
Expand Down Expand Up @@ -247,6 +250,7 @@ class FullTransformerBatch:
wordpieces: WordpieceBatch
tensors: List[torch.Tensor]
align: Ragged
attention: Optional[Tuple[torch.Tensor]] = None
cached_doc_data: Optional[List[TransformerData]] = None

@classmethod
Expand All @@ -259,6 +263,7 @@ def empty(cls, nr_docs) -> "FullTransformerBatch":
wordpieces=WordpieceBatch.empty(),
tensors=[],
align=align,
attention=None,
cached_doc_data=doc_data,
)

Expand Down Expand Up @@ -312,11 +317,16 @@ def split_by_doc(self) -> List[TransformerData]:
doc_tokens = self.wordpieces[start:end]
doc_align = self.align[start_i:end_i]
doc_align.data = doc_align.data - prev_tokens
if self.attention:
attn = [torch2xp(t[start:end]) for t in self.attention]
else:
attn = None
outputs.append(
TransformerData(
wordpieces=doc_tokens,
tensors=[torch2xp(t[start:end]) for t in self.tensors],
align=doc_align,
attention=attn,
)
)
prev_tokens += doc_tokens.input_ids.size
Expand Down
32 changes: 21 additions & 11 deletions spacy_transformers/layers/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


def TransformerModel(
name: str, get_spans: Callable, tokenizer_config: dict
name: str, get_spans: Callable, tokenizer_config: dict = {}, transformer_config: dict = {}
) -> Model[List[Doc], FullTransformerBatch]:
"""
get_spans (Callable[[List[Doc]], List[Span]]):
Expand All @@ -25,6 +25,7 @@ def TransformerModel(
overlap, and you can also omit sections of the Doc if they are not
relevant.
tokenizer_config (dict): Settings to pass to the transformers tokenizer.
transformer_config (dict): Settings to pass to the transformers forward pass.
"""

return Model(
Expand All @@ -38,6 +39,7 @@ def TransformerModel(
"get_spans": get_spans,
"name": name,
"tokenizer_config": tokenizer_config,
"transformer_config": transformer_config,
"set_transformer": set_pytorch_transformer,
"has_transformer": False,
"flush_cache_chance": 0.0,
Expand Down Expand Up @@ -75,7 +77,8 @@ def init(model: Model, X=None, Y=None):
return
name = model.attrs["name"]
tok_cfg = model.attrs["tokenizer_config"]
tokenizer, transformer = huggingface_from_pretrained(name, tok_cfg)
trf_cfg = model.attrs["transformer_config"]
tokenizer, transformer = huggingface_from_pretrained(name, tok_cfg, trf_cfg)
model.attrs["tokenizer"] = tokenizer
model.attrs["set_transformer"](model, transformer)
# Call the model with a batch of inputs to infer the width
Expand All @@ -89,26 +92,23 @@ def init(model: Model, X=None, Y=None):
for doc_spans in nested_spans:
flat_spans.extend(doc_spans)
token_data = huggingface_tokenize(
model.attrs["tokenizer"],
[span.text for span in flat_spans]
model.attrs["tokenizer"], [span.text for span in flat_spans]
)
wordpieces = WordpieceBatch.from_batch_encoding(token_data)
align = get_alignment(
flat_spans,
wordpieces.strings, model.attrs["tokenizer"].all_special_tokens
flat_spans, wordpieces.strings, model.attrs["tokenizer"].all_special_tokens
)
wordpieces, align = truncate_oversize_splits(
wordpieces, align, tokenizer.model_max_length
)
else:
texts = ["hello world", "foo bar"]
token_data = huggingface_tokenize(
model.attrs["tokenizer"],
texts
)
token_data = huggingface_tokenize(model.attrs["tokenizer"], texts)
wordpieces = WordpieceBatch.from_batch_encoding(token_data)
model.layers[0].initialize(X=wordpieces)
tensors = model.layers[0].predict(wordpieces)
if trf_cfg["output_attentions"] is True:
tensors = tensors[:-1] # remove attention
t_i = find_last_hidden(tensors)
model.set_dim("nO", tensors[t_i].shape[-1])

Expand All @@ -118,6 +118,7 @@ def forward(
) -> Tuple[FullTransformerBatch, Callable]:
tokenizer = model.attrs["tokenizer"]
get_spans = model.attrs["get_spans"]
trf_config = model.attrs["transformer_config"]
transformer = model.layers[0]

nested_spans = get_spans(docs)
Expand All @@ -142,8 +143,17 @@ def forward(
tensors, bp_tensors = transformer(wordpieces, is_train)
if "logger" in model.attrs:
log_gpu_memory(model.attrs["logger"], "after forward")
if ("output_attentions" in trf_config) and (trf_config["output_attentions"] is True):
attn = tensors[-1]
tensors = tensors[:-1]
else:
attn = None
output = FullTransformerBatch(
spans=nested_spans, wordpieces=wordpieces, tensors=tensors, align=align
spans=nested_spans,
wordpieces=wordpieces,
tensors=tensors,
align=align,
attention=attn,
)
if "logger" in model.attrs:
log_gpu_memory(model.attrs["logger"], "return from forward")
Expand Down
15 changes: 11 additions & 4 deletions spacy_transformers/pipeline_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
@architectures = "spacy-transformers.TransformerModel.v1"
name = "roberta-base"
tokenizer_config = {"use_fast": true}
transformer_config = {"output_attentions": false}

[transformer.model.get_spans]
@span_getters = "spacy-transformers.strided_spans.v1"
Expand Down Expand Up @@ -143,7 +144,9 @@ def add_listener(self, listener: TransformerListener, component_name: str) -> No
if self.model.has_dim("nO") and listener.has_dim("nO") is None:
listener.set_dim("nO", self.model.get_dim("nO"))

def remove_listener(self, listener: TransformerListener, component_name: str) -> bool:
def remove_listener(
self, listener: TransformerListener, component_name: str
) -> bool:
"""Remove a listener for a downstream component. Usually internals."""
if component_name in self.listener_map:
if listener in self.listener_map[component_name]:
Expand All @@ -167,7 +170,10 @@ def find_listeners(self, component) -> None:
names = ("*", self.name)
if isinstance(getattr(component, "model", None), Model):
for node in component.model.walk():
if isinstance(node, TransformerListener) and node.upstream_name in names:
if (
isinstance(node, TransformerListener)
and node.upstream_name in names
):
self.add_listener(node, component.name)

def __call__(self, doc: Doc) -> Doc:
Expand Down Expand Up @@ -294,7 +300,8 @@ def accumulate_gradient(d_trf_datas: List[TransformerData]):
nonlocal d_tensors
for i, d_trf_data in enumerate(d_trf_datas):
for d_tensor in d_trf_data.tensors:
losses[self.name] += float((d_tensor ** 2).sum()) # type: ignore
# type: ignore
losses[self.name] += float((d_tensor ** 2).sum())
if i >= len(d_tensors):
d_tensors.append(d_trf_data.tensors)
else:
Expand Down Expand Up @@ -387,7 +394,7 @@ def from_disk(
def load_model(p):
p = Path(p).absolute()
tokenizer, transformer = huggingface_from_pretrained(
p, self.model.attrs["tokenizer_config"]
p, self.model.attrs["tokenizer_config"], self.model.attrs["transformer_config"]
)
self.model.attrs["tokenizer"] = tokenizer
self.model.attrs["set_transformer"](self.model, transformer)
Expand Down
4 changes: 3 additions & 1 deletion spacy_transformers/tests/test_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def name(request):

@pytest.fixture(scope="session")
def trf_model(name):
model = TransformerModel(name, get_doc_spans, {"use_fast": True})
model = TransformerModel(
name, get_doc_spans, {"use_fast": True}, {"output_attentions": False}
)
model.initialize()
return model

Expand Down
7 changes: 6 additions & 1 deletion spacy_transformers/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,11 @@ def _forward(model, tokens, is_train):
tensors.append(torch.zeros(*shape))
return tensors, lambda d_tensors: tokens

return Model("dummy-transformer", _forward, attrs={"width": width, "depth": depth})
return Model(
"dummy-transformer",
_forward,
attrs={"width": width, "depth": depth},
)


def DummyTransformer(
Expand All @@ -132,6 +136,7 @@ def DummyTransformer(
"tokenizer": DummyTokenizer(),
"grad_factor": 1.0,
"flush_cache_chance": 0.0,
"transformer_config": {}
},
dims={"nO": width},
)
11 changes: 8 additions & 3 deletions spacy_transformers/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict, Union
from pathlib import Path
from functools import partial
import random
from transformers import AutoModel, AutoTokenizer
from transformers.tokenization_utils import BatchEncoding
Expand All @@ -16,20 +17,24 @@
# fmt: on


def huggingface_from_pretrained(source: Union[Path, str], config: Dict):
def huggingface_from_pretrained(
source: Union[Path, str], tok_config: Dict, trf_config: Dict
):
"""Create a Huggingface transformer model from pretrained weights. Will
download the model if it is not already downloaded.

source (Union[str, Path]): The name of the model or a path to it, such as
'bert-base-cased'.
config (dict): Settings to pass to the tokenizer.
tok_config (dict): Settings to pass to the tokenizer.
trf_config (dict): Settings to pass to the transformer.
"""
if hasattr(source, "absolute"):
str_path = str(source.absolute())
else:
str_path = source
tokenizer = AutoTokenizer.from_pretrained(str_path, **config)
tokenizer = AutoTokenizer.from_pretrained(str_path, **tok_config)
transformer = AutoModel.from_pretrained(str_path)
transformer.forward = partial(transformer.forward, **trf_config)
ops = get_current_ops()
if isinstance(ops, CupyOps):
transformer.cuda()
Expand Down