Skip to content

Commit

Permalink
[Torchscript] Parallelized Text/Sequence Preprocessing (#2206)
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffreyangus authored Jun 29, 2022
1 parent aa0c63b commit 76055a2
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 67 deletions.
62 changes: 36 additions & 26 deletions ludwig/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,38 +92,48 @@ def forward(self, v: TorchscriptPreprocessingInput) -> torch.Tensor:
if not torch.jit.isinstance(v, List[str]):
raise ValueError(f"Unsupported input: {v}")

v = [self.computed_fill_value if s == "nan" else s for s in v]
futures: List[torch.jit.Future[torch.Tensor]] = []
for sequence in v:
futures.append(
torch.jit.fork(
self._process_sequence,
sequence,
)
)

sequence_matrix = []
for future in futures:
sequence_matrix.append(torch.jit.wait(future))

return torch.stack(sequence_matrix)

def _process_sequence(self, sequence: str) -> torch.Tensor:
sequence = self.computed_fill_value if sequence == "nan" else sequence

if self.lowercase:
sequences = [sequence.lower() for sequence in v]
sequence_str: str = sequence.lower()
else:
sequences = v
sequence_str: str = sequence

unit_sequences = self.tokenizer(sequences)
# refines type of unit_sequences from Any to List[List[str]]
assert torch.jit.isinstance(unit_sequences, List[List[str]]), "unit_sequences is not a list of lists."
unit_sequence = self.tokenizer(sequence_str)
assert torch.jit.isinstance(unit_sequence, List[str])

sequence_matrix = torch.full(
[len(unit_sequences), self.max_sequence_length], self.unit_to_id[self.padding_symbol]
)
sequence_matrix[:, 0] = self.unit_to_id[self.start_symbol]
for sample_idx, unit_sequence in enumerate(unit_sequences):
# Add <EOS> if sequence length is less than max_sequence_length. Else, truncate to max_sequence_length.
if len(unit_sequence) + 1 < self.max_sequence_length:
sequence_length = len(unit_sequence)
sequence_matrix[sample_idx][len(unit_sequence) + 1] = self.unit_to_id[self.stop_symbol]
else:
sequence_length = self.max_sequence_length - 1

for i in range(sequence_length):
curr_unit = unit_sequence[i]
if curr_unit in self.unit_to_id:
curr_id = self.unit_to_id[curr_unit]
else:
curr_id = self.unit_to_id[self.unknown_symbol]
sequence_matrix[sample_idx][i + 1] = curr_id
sequence_vector = torch.full([self.max_sequence_length], self.unit_to_id[self.padding_symbol])
sequence_vector[0] = self.unit_to_id[self.start_symbol]
if len(unit_sequence) + 1 < self.max_sequence_length:
sequence_length = len(unit_sequence)
sequence_vector[len(unit_sequence) + 1] = self.unit_to_id[self.stop_symbol]
else:
sequence_length = self.max_sequence_length - 1

return sequence_matrix
for i in range(sequence_length):
curr_unit = unit_sequence[i]
if curr_unit in self.unit_to_id:
curr_id = self.unit_to_id[curr_unit]
else:
curr_id = self.unit_to_id[self.unknown_symbol]
sequence_vector[i + 1] = curr_id
return sequence_vector


class _SequencePostprocessing(torch.nn.Module):
Expand Down
74 changes: 34 additions & 40 deletions ludwig/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def __init__(

def preprocessor_forward(self, inputs: Dict[str, TorchscriptPreprocessingInput]) -> Dict[str, torch.Tensor]:
"""Forward pass through the preprocessor."""
with torch.no_grad():
return self.preprocessor(inputs)
return self.preprocessor(inputs)

def predictor_forward(self, preproc_inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Forward pass through the predictor.
Expand All @@ -67,24 +66,22 @@ def predictor_forward(self, preproc_inputs: Dict[str, torch.Tensor]) -> Dict[str
for k, v in preproc_inputs.items():
preproc_inputs[k] = v.to(self.predictor.device)

with torch.no_grad():
with torch.no_grad(): # Ensure model params do not compute gradients
predictions_flattened = self.predictor(preproc_inputs)
return predictions_flattened

def postprocessor_forward(self, predictions_flattened: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]:
"""Forward pass through the postprocessor."""
with torch.no_grad():
postproc_outputs_flattened: Dict[str, Any] = self.postprocessor(predictions_flattened)
# Turn flat inputs into nested predictions per feature name
postproc_outputs: Dict[str, Dict[str, Any]] = _unflatten_dict_by_feature_name(postproc_outputs_flattened)
return postproc_outputs
postproc_outputs_flattened: Dict[str, Any] = self.postprocessor(predictions_flattened)
# Turn flat inputs into nested predictions per feature name
postproc_outputs: Dict[str, Dict[str, Any]] = _unflatten_dict_by_feature_name(postproc_outputs_flattened)
return postproc_outputs

def forward(self, inputs: Dict[str, TorchscriptPreprocessingInput]) -> Dict[str, Dict[str, Any]]:
with torch.no_grad():
preproc_inputs: Dict[str, torch.Tensor] = self.preprocessor_forward(inputs)
predictions_flattened: Dict[str, torch.Tensor] = self.predictor_forward(preproc_inputs)
postproc_outputs: Dict[str, Dict[str, Any]] = self.postprocessor_forward(predictions_flattened)
return postproc_outputs
preproc_inputs: Dict[str, torch.Tensor] = self.preprocessor_forward(inputs)
predictions_flattened: Dict[str, torch.Tensor] = self.predictor_forward(preproc_inputs)
postproc_outputs: Dict[str, Dict[str, Any]] = self.postprocessor_forward(predictions_flattened)
return postproc_outputs

@torch.jit.unused
def predict(
Expand Down Expand Up @@ -172,12 +169,11 @@ def __init__(self, config: Dict[str, Any], training_set_metadata: Dict[str, Any]
self.preproc_modules[module_dict_key] = feature.create_preproc_module(training_set_metadata[feature_name])

def forward(self, inputs: Dict[str, TorchscriptPreprocessingInput]) -> Dict[str, torch.Tensor]:
with torch.no_grad():
preproc_inputs = {}
for module_dict_key, preproc in self.preproc_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
preproc_inputs[feature_name] = preproc(inputs[feature_name])
return preproc_inputs
preproc_inputs = {}
for module_dict_key, preproc in self.preproc_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
preproc_inputs[feature_name] = preproc(inputs[feature_name])
return preproc_inputs


class _InferencePredictor(nn.Module):
Expand All @@ -200,17 +196,16 @@ def __init__(self, model: "BaseModel", device: TorchDevice):
self.predict_modules[module_dict_key] = feature.prediction_module.to(device=self.device)

def forward(self, preproc_inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
with torch.no_grad():
model_outputs = self.model(preproc_inputs)
predictions_flattened: Dict[str, torch.Tensor] = {}
for module_dict_key, predict in self.predict_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
feature_predictions = predict(model_outputs, feature_name)
# Flatten out the predictions to support Triton input/output
for predict_key, tensor_values in feature_predictions.items():
predict_concat_key = output_feature_utils.get_feature_concat_name(feature_name, predict_key)
predictions_flattened[predict_concat_key] = tensor_values
return predictions_flattened
model_outputs = self.model(preproc_inputs)
predictions_flattened: Dict[str, torch.Tensor] = {}
for module_dict_key, predict in self.predict_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
feature_predictions = predict(model_outputs, feature_name)
# Flatten out the predictions to support Triton input/output
for predict_key, tensor_values in feature_predictions.items():
predict_concat_key = output_feature_utils.get_feature_concat_name(feature_name, predict_key)
predictions_flattened[predict_concat_key] = tensor_values
return predictions_flattened


class _InferencePostprocessor(nn.Module):
Expand All @@ -231,16 +226,15 @@ def __init__(self, model: "BaseModel", training_set_metadata: Dict[str, Any]):
self.postproc_modules[module_dict_key] = feature.create_postproc_module(training_set_metadata[feature_name])

def forward(self, predictions_flattened: Dict[str, torch.Tensor]) -> Dict[str, Any]:
with torch.no_grad():
postproc_outputs_flattened: Dict[str, Any] = {}
for module_dict_key, postproc in self.postproc_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
feature_postproc_outputs = postproc(predictions_flattened, feature_name)
# Flatten out the predictions to support Triton input/output
for postproc_key, tensor_values in feature_postproc_outputs.items():
postproc_concat_key = output_feature_utils.get_feature_concat_name(feature_name, postproc_key)
postproc_outputs_flattened[postproc_concat_key] = tensor_values
return postproc_outputs_flattened
postproc_outputs_flattened: Dict[str, Any] = {}
for module_dict_key, postproc in self.postproc_modules.items():
feature_name = get_name_from_module_dict_key(module_dict_key)
feature_postproc_outputs = postproc(predictions_flattened, feature_name)
# Flatten out the predictions to support Triton input/output
for postproc_key, tensor_values in feature_postproc_outputs.items():
postproc_concat_key = output_feature_utils.get_feature_concat_name(feature_name, postproc_key)
postproc_outputs_flattened[postproc_concat_key] = tensor_values
return postproc_outputs_flattened


def save_ludwig_model_for_inference(
Expand Down
3 changes: 2 additions & 1 deletion tests/integration_tests/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,10 @@ def validate_torchscript_outputs(tmpdir, config, backend, training_data_csv_path

assert output_name in feature_outputs
output_values = feature_outputs[output_name]
assert utils.has_no_grad(output_values), f'"{feature_name}.{output_name}" tensors have gradients'
assert utils.is_all_close(
output_values, output_values_expected
), f"feature: {feature_name}, output: {output_name}"
), f'"{feature_name}.{output_name}" tensors are not close to ludwig model'


def initialize_torchscript_module(tmpdir, config, backend, training_data_csv_path, device=None):
Expand Down
11 changes: 11 additions & 0 deletions tests/integration_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,17 @@ def get_weights(model: torch.nn.Module) -> List[torch.Tensor]:
return [param.data for param in model.parameters()]


def has_no_grad(
val: Union[np.ndarray, torch.Tensor, str, list],
):
"""Checks if two values are close to each other."""
if isinstance(val, list):
return all(has_no_grad(v) for v in val)
if isinstance(val, torch.Tensor):
return not val.requires_grad
return True


def is_all_close(
val1: Union[np.ndarray, torch.Tensor, str, list],
val2: Union[np.ndarray, torch.Tensor, str, list],
Expand Down

0 comments on commit 76055a2

Please sign in to comment.