Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement coalesced pooling over entire batches #368

Merged
merged 2 commits into from
Mar 13, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 70 additions & 18 deletions spacy_transformers/layers/trfs2arrays.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, cast
from typing import Callable, List, Optional, Tuple, cast
import numpy
from spacy.util import all_equal
from transformers.file_utils import ModelOutput
from transformers.modeling_outputs import BaseModelOutput
Expand All @@ -13,15 +14,19 @@ def trfs2arrays(
) -> Model[List[TransformerData], List[Floats2d]]:
"""Pool transformer data into token-aligned tensors."""
return Model(
"trfs2arrays", forward, layers=[pooling], attrs={"grad_factor": grad_factor}
"trfs2arrays",
forward,
layers=[pooling],
attrs={"grad_factor": grad_factor},
)


def forward(model: Model, trf_datas: List[TransformerData], is_train: bool):
pooling: Model[Ragged, Floats2d] = model.layers[0]
grad_factor = model.attrs["grad_factor"]
outputs = []
backprops = []
zero_outputs: List[Tuple[int, Floats2d]] = []
backprops_alignment: List[Optional[Callable]] = []
aligned_outputs: List[Tuple[int, Ragged]] = []

# For zero-length documents, we could cache the output width by iterating
# through the batch outputs and retrieving the shape of a non-zero length
Expand All @@ -31,55 +36,70 @@ def forward(model: Model, trf_datas: List[TransformerData], is_train: bool):
# zero in these cases as the effective length of the resultant tensor is zero anyway.
output_width = 0

for trf_data in trf_datas:
for i, trf_data in enumerate(trf_datas):
if "last_hidden_state" in trf_data.model_output:
tensor_t_i = cast(BaseModelOutput, trf_data.model_output).last_hidden_state
if tensor_t_i.size == 0:
# This can happen during prediction/initialization if the transformer pipe was disabled/not executed and one of the inputs
# was of length zero. This causes the listenener to generate a zero-sized (in the sequence length dim) TransformerData
# output and pass it downstream.
outputs.append(model.ops.alloc2f(0, output_width))
backprops.append((None, None))
zero_outputs.append((i, model.ops.alloc2f(0, output_width)))
backprops_alignment.append(None)
else:
# This is the general case for non-zero length documents.
src = model.ops.reshape2f(tensor_t_i, -1, trf_data.width) # type: ignore
dst, get_d_src = apply_alignment(model.ops, trf_data.align, src)
output, get_d_dst = pooling(dst, is_train)
outputs.append(output)
backprops.append((get_d_dst, get_d_src)) # type: ignore
aligned_outputs.append((i, dst))
backprops_alignment.append(get_d_src)
else:
# This can happen during prediction/training for zero-length documents. Since zero-length docs
# are implicitly ignored in the span generation stage, the transformer model does not return any
# predictions for them and subsequently, FullTransformerBatch.split_by_doc() generates an empty
# TransformerData.
outputs.append(model.ops.alloc2f(0, output_width))
backprops.append((None, None))
zero_outputs.append((i, model.ops.alloc2f(0, output_width)))
backprops_alignment.append(None)

pooling_outputs, backprop_pooling = concat_pooling_forward(
pooling, [dst for _, dst in aligned_outputs], is_train
)

# Interleave the zero and non-zero outputs into the final result.
outputs: List[Optional[Floats2d]] = [None] * (
len(zero_outputs) + len(aligned_outputs)
)
for i, zero_output in zero_outputs:
outputs[i] = zero_output
for (i, _), pooling_output in zip(aligned_outputs, pooling_outputs):
outputs[i] = pooling_output

def backprop_trf_to_tensor(d_outputs: List[Floats2d]) -> List[TransformerData]:
d_trf_datas: List[TransformerData] = []
to_zip = (trf_datas, d_outputs, backprops)

# Only update the gradients that are relevant for pooling.
d_pooling = backprop_pooling([d_outputs[i] for i, _ in aligned_outputs])
for (i, _), d_pooling_i in zip(aligned_outputs, d_pooling):
d_outputs[i] = d_pooling_i

to_zip = (trf_datas, d_outputs, backprops_alignment)
assert all_equal(len(x) for x in to_zip) # type: ignore
zipped = zip(*to_zip)
for trf_data, d_output, (get_d_dst, get_d_src) in zipped:
for trf_data, d_output, get_d_src in zipped:
if "last_hidden_state" not in trf_data.model_output:
# This gradient belongs to a zero-length doc and must be ignored as it doesn't have a corresponding
# output from the transformer model (due to empty documents being skipped during the span generation
# stage in the forward pass).
assert len(d_output) == 0
assert get_d_src is None
assert get_d_dst is None
continue

assert get_d_src is not None
assert get_d_dst is not None
d_model_output = ModelOutput(
last_hidden_state=model.ops.alloc(
trf_data.model_output.last_hidden_state.shape, # type: ignore
dtype=trf_data.model_output.last_hidden_state.dtype, # type: ignore
)
)
d_dst = cast(Floats2d, get_d_dst(d_output))
d_src = cast(Floats2d, get_d_src(d_dst))
d_src = get_d_src(d_output)
d_src *= grad_factor
d_model_output["last_hidden_state"] = d_src.reshape(
cast(BaseModelOutput, trf_data.model_output).last_hidden_state.shape
Expand All @@ -95,3 +115,35 @@ def backprop_trf_to_tensor(d_outputs: List[Floats2d]) -> List[TransformerData]:

assert len(outputs) == len(trf_datas)
return outputs, backprop_trf_to_tensor


def concat_pooling_forward(
pooling: Model[Ragged, Floats2d], X: List[Ragged], is_train: bool
):
xp = pooling.ops.xp

datas = []
lens = []
doc_lens = []
for X_doc_data in X:
datas.append(X_doc_data.dataXd)
lens.append(X_doc_data.lengths)
doc_lens.append(len(X_doc_data.lengths))

X_flat = Ragged(xp.concatenate(datas, axis=0), xp.concatenate(lens, axis=0))
Y_pooled, pooling_backprop = pooling(X_flat, is_train)
Y = xp.split(Y_pooled, numpy.cumsum(doc_lens)[:-1])

def backprop(dY):
dY_pooled_flat = xp.concatenate(dY)
dY_flat = pooling_backprop(dY_pooled_flat).dataXd

dY = []
for X_doc_data in X:
doc_unpooled_len = X_doc_data.dataXd.shape[0]
dY.append(Ragged(dY_flat[:doc_unpooled_len], X_doc_data.lengths))
dY_flat = dY_flat[doc_unpooled_len:]

return dY

return Y, backprop