Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Add batch length padding for fixed function accelerators (#1457)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1457

Add batch length padding for fixed function accelerators

Reviewed By: chenyangyu1988

Differential Revision: D23669471

fbshipit-source-id: 44cb02c521faeebcd82a396cdc9056a2ac8e926e
  • Loading branch information
mikekgfb authored and facebook-github-bot committed Oct 2, 2020
1 parent ebd9547 commit 033f262
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 45 deletions.
9 changes: 8 additions & 1 deletion pytext/config/pytext_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,14 @@ class PyTextConfig(ConfigBase):
# larger than the actual longest sequence in a batch.
# The list of padding boundaries must be sorted in asecending order.
# The first list element must be 0. (Will serve as future padding control "version number")
padding_control: Optional[List[int]] = None
seq_padding_control: Optional[List[int]] = None
# Padding boundaries for padded tensor batch length dimension.
# Specified as a list of boundaries to be rounded up to.
# Each batch length dimension will be rounded to the smallest number
# larger than the actual longest sequence in a batch.
# The list of padding boundaries must be sorted in asecending order.
# The first list element must be 0. (Will serve as future padding control "version number")
batch_padding_control: Optional[List[int]] = None
# Base directory where modules are saved
modules_save_dir: str = ""
# Whether to save intermediate checkpoints for modules if they are best yet
Expand Down
5 changes: 3 additions & 2 deletions pytext/data/bert_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,9 @@ def tensorize(
tokens, pad_mask = pad_2d_mask(
tokens_2d,
pad_value=self.vocab.pad_idx,
padding_control=self.padding_control,
max_pad_len=self.max_seq_len,
seq_padding_control=self.seq_padding_control,
max_seq_pad_len=self.max_seq_len,
batch_padding_control=self.batch_padding_control,
)
segment_labels = torch.tensor(
pad_2d(segment_labels_2d, seq_lens=seq_lens_1d, pad_idx=0), dtype=torch.long
Expand Down
21 changes: 15 additions & 6 deletions pytext/data/tensorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytext.data.sources.data_source import Gazetteer
from pytext.data.tokenizers import Token, Tokenizer
from pytext.torchscript.tensorizer import VectorNormalizer
from pytext.torchscript.utils import ScriptBatchInput
from pytext.torchscript.utils import ScriptBatchInput, validate_padding_control
from pytext.utils import cuda, precision
from pytext.utils.data import Slot
from pytext.utils.file_io import PathManager
Expand Down Expand Up @@ -110,28 +110,37 @@ def lookup_tokens(

class TensorizerScriptImpl(torch.nn.Module):
device: str
padding_control: Optional[List[int]]
seq_padding_control: Optional[List[int]]
batch_padding_control: Optional[List[int]]

def __init__(self):
super().__init__()
self.device: str = ""
# padding_control options:
# None - no padding
# [0, pad1, pad2, pad3,...] - pads sequence length to smallest padX larger than sequence
self.padding_control = torch.jit.annotate(Optional[List[int]], None)
# [0, pad1, pad2, pad3,...] - pads sequence/batch length to smallest padX larger than sequence
self.seq_padding_control = None
self.batch_padding_control = None

@torch.jit.export
def set_device(self, device: str):
self.device = device

@torch.jit.export
def set_padding_control(self, control: Optional[List[int]]):
def set_padding_control(self, dimension: str, padding_control: Optional[List[int]]):
"""
This functions will be called to set a padding style.
None - No padding
List: first element 0, round seq length to the smallest list element larger than inputs
"""
self.padding_control = control
if not validate_padding_control(padding_control):
raise RuntimeError("Malformed padding_control value")
if dimension == "sequence_length":
self.seq_padding_control = padding_control
elif dimension == "batch_length":
self.batch_padding_control = padding_control
else:
raise RuntimeError("Illegal padding dimension specified.")

def batch_size(self, inputs: ScriptBatchInput) -> int:
texts: Optional[List[List[str]]] = inputs.texts
Expand Down
14 changes: 11 additions & 3 deletions pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ def torchscript_export(
# unpack export kwargs
quantize = kwargs.get("quantize", False)
accelerate = kwargs.get("accelerate", [])
padding_control = kwargs.get("padding_control")
seq_padding_control = kwargs.get("seq_padding_control")
batch_padding_control = kwargs.get("batch_padding_control")
inference_interface = kwargs.get("inference_interface")

# Make sure to put the model on CPU and disable CUDA before exporting to
Expand Down Expand Up @@ -319,9 +320,16 @@ def torchscript_export(
trace = model.trace(inputs)
if hasattr(model, "torchscriptify"):
trace = model.torchscriptify(self.data.tensorizers, trace)
if padding_control is not None:
if seq_padding_control is not None:
if hasattr(trace, "set_padding_control"):
trace.set_padding_control(padding_control)
trace.set_padding_control("sequence_length", seq_padding_control)
else:
print(
"Padding_control not supported by model. Ignoring padding_control"
)
if batch_padding_control is not None:
if hasattr(trace, "set_padding_control"):
trace.set_padding_control("batch_length", batch_padding_control)
else:
print(
"Padding_control not supported by model. Ignoring padding_control"
Expand Down
14 changes: 11 additions & 3 deletions pytext/task/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def torchscript_export(self, model, export_path=None, **kwargs):
# unpack export kwargs
quantize = kwargs.get("quantize", False)
accelerate = kwargs.get("accelerate", [])
padding_control = kwargs.get("padding_control")
seq_padding_control = kwargs.get("seq_padding_control")
batch_padding_control = kwargs.get("batch_padding_control")
inference_interface = kwargs.get("inference_interface")

cuda.CUDA_ENABLED = False
Expand Down Expand Up @@ -252,9 +253,16 @@ def torchscript_export(self, model, export_path=None, **kwargs):
trace = model.torchscriptify(
self.data.tensorizers, trace, self.trace_both_encoders
)
if padding_control is not None:
if seq_padding_control is not None:
if hasattr(trace, "set_padding_control"):
trace.set_padding_control(padding_control)
trace.set_padding_control("sequence_length", seq_padding_control)
else:
print(
"Padding_control not supported by model. Ignoring padding_control"
)
if batch_padding_control is not None:
if hasattr(trace, "set_padding_control"):
trace.set_padding_control("batch_length", batch_padding_control)
else:
print(
"Padding_control not supported by model. Ignoring padding_control"
Expand Down
24 changes: 15 additions & 9 deletions pytext/torchscript/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ def inference_interface(self, argument_type: str):
raise RuntimeError("Unsupported argument type.")

@torch.jit.script_method
def set_padding_control(self, control: Optional[List[int]]):
def set_padding_control(self, dimension: str, control: Optional[List[int]]):
"""
This functions will be called to set a padding style.
None - No padding
List: first element 0, round seq length to the smallest list element larger than inputs
"""
self.tensorizer.set_padding_control(control)
self.tensorizer.set_padding_control(dimension, control)

@torch.jit.script_method
def _forward(self, inputs: ScriptBatchInput):
Expand Down Expand Up @@ -303,6 +303,9 @@ def make_prediction(
# batch elements (and which ones) were malformed
raise RuntimeError("Malformed request.")

if len(flat_texts) == 0:
raise RuntimeError("This is not good. Empty request batch.")

flat_result: torch.Tensor = self.forward(
texts=flat_texts,
multi_texts=None,
Expand Down Expand Up @@ -551,8 +554,8 @@ def make_prediction(
if argno == TEXTS:
flat_texts: List[str] = []

for i in range(batchsize):
batch_element = batch[i][0]
for be in batch:
batch_element = be[0]
if batch_element is not None:
flat_texts.extend(batch_element)
client_batch.append(len(batch_element))
Expand All @@ -564,7 +567,10 @@ def make_prediction(
# we can skip malformed requests,
# and return a list plus an indiction that one or more
# batch elements (and which ones) were malformed
raise RuntimeError("Malformed request.")
raise RuntimeError("(VE) Malformed request.")

if len(flat_texts) == 0:
raise RuntimeError("This is not good. Empty request batch.")

flat_result: List[torch.Tensor] = self.forward(
texts=flat_texts,
Expand All @@ -584,7 +590,7 @@ def make_prediction(
start = 0
for elems in client_batch:
end = start + elems
res_list.append(flat_result[start:elems])
res_list.append(flat_result[start:end])
start = end

return res_list
Expand Down Expand Up @@ -622,14 +628,14 @@ def inference_interface(self, argument_type: str):
raise RuntimeError("Unsupported argument type.")

@torch.jit.script_method
def set_padding_control(self, control: Optional[List[int]]):
def set_padding_control(self, dimension: str, control: Optional[List[int]]):
"""
This functions will be called to set a padding style.
None - No padding
List: first element 0, round seq length to the smallest list element larger than inputs
"""
self.right_tensorizer.set_padding_control(control)
self.left_tensorizer.set_padding_control(control)
self.right_tensorizer.set_padding_control(dimension, control)
self.left_tensorizer.set_padding_control(dimension, control)

@torch.jit.script_method
def _forward(self, right_inputs: ScriptBatchInput, left_inputs: ScriptBatchInput):
Expand Down
61 changes: 41 additions & 20 deletions pytext/torchscript/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,34 @@ def list_membership(item: int, list: List[int]):
return item_present


@torch.jit.script
def validate_padding_control(padding_control: Optional[List[int]]) -> bool:
if padding_control is not None:
if len(padding_control) < 2:
return False
elif padding_control[0] != 0:
return False

return True


@torch.jit.script
def pad_length(
len: int, padding_control: Optional[List[int]], max_len: int = -1
) -> int:
if not validate_padding_control(padding_control):
raise NotImplementedError

if padding_control is not None:
for pad in padding_control:
if pad >= len:
len = pad
break
if max_len > 0:
len = min(len, max_len)
return len


@torch.jit.script
def reverse_tensor_list(int_list: List[torch.Tensor]) -> List[torch.Tensor]:
l_len = len(int_list)
Expand All @@ -76,34 +104,27 @@ def long_tensor_2d(shape: Tuple[int, int], fill_value: int = 0) -> torch.Tensor:
def pad_2d_mask(
input: List[List[int]],
pad_value: int = 0,
padding_control: Optional[List[int]] = None,
max_pad_len: int = -1,
seq_padding_control: Optional[List[int]] = None,
max_seq_pad_len: int = -1,
batch_padding_control: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pad a list to a 2d tensor. Returns a pair of tensors, the padded tensor
as well as a mask tensor. The mask tensor has the same shape as the padded tensor,
with a 1 in the position of non-pad values and a 0 in the position of pads.
If padding_control is set, perform padding according to the specified padding style"""
max_len = 0
for i in input:
max_len = max(max_len, len(i))
if padding_control is not None:
if len(padding_control) < 2:
raise NotImplementedError
elif padding_control[0] != 0:
raise NotImplementedError
else:
for pad in padding_control:
if pad >= max_len:
max_len = pad
break
if max_pad_len > 0:
max_len = min(max_len, max_pad_len)
tensor = long_tensor_2d((len(input), max_len), pad_value)
mask = long_tensor_2d((len(input), max_len), 0)

# List comprehension required for TorchScript
max_seq_len = max([len(i) for i in input]) # noqa
max_seq_len = pad_length(max_seq_len, seq_padding_control, max_seq_pad_len)

max_batch_len = len(input)
max_batch_len = pad_length(max_batch_len, batch_padding_control, -1)

tensor = long_tensor_2d((max_batch_len, max_seq_len), pad_value)
for i in range(len(input)):
for j in range(len(input[i])):
tensor[i][j] = input[i][j]
mask[i][j] = 1
mask = tensor.ne(pad_value).to(torch.long)
return tensor, mask


Expand Down
3 changes: 2 additions & 1 deletion pytext/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def save_and_export(
quantize=config.torchscript_quantize,
inference_interface=config.inference_interface,
accelerate=config.accelerate,
padding_control=config.padding_control,
seq_padding_control=config.seq_padding_control,
batch_padding_control=config.batch_padding_control,
)


Expand Down

0 comments on commit 033f262

Please sign in to comment.