Fix torch.cat()
issue when processing large number of documents with TransformersModelForTokenClassificationNerStep
#80
+90
−69
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The Issue
Noticed a weird problem occurring in the evaluation script when trying to naively process a large number (365) of Kazu documents with the
TransformersModelForTokenClassificationNerStep
step using theMPS
device.The 365 documents used totalled over 14k sections and were being processed with a newly trained 400MB model. Performing this on a Mac M3 with MPS, I saw Python's memory usage peak at 18GB:
In the end the step failed to predict any entities but without any exceptions thrown. The result was a weird phenomenon inside https://github.com/AstraZeneca/KAZU/blob/main/kazu/steps/ner/hf_token_classification.py:
Where
torch.cat
was producing a tensor full of zeros, indicating the model has not found any entities. This is likely due to torch.cat exceeding the allocated memory of the device.The Fix
The fix is in two places. Firstly in the evaluate script we now process the documents in batches through the pipeline. However, to stop a user naively processing many documents with the Kazu pipeline and hitting this issue, there is also a fix inside the
TransformersModelForTokenClassificationNerStep
. This offloads the model logits onto CPU before concatenation.Testing Performance
Here we perform the test with the naive call to the step with all the documents at once as before and test the version of
TransformersModelForTokenClassificationNerStep
before and after the change.Before the change we observe a peak memory usage of 18GB and it takes 690s to process all the documents. With the new implementation we see a peak memory usage of 4GB and it takes 680s to process all the documents - also fixing the weird issue. Thus there doesn't seem to be any performance degradation in executing
torch.cat
on cpu vs mps. Cuda device was not tested however.General Test for single label classification
A test script with the default model pipeline and Kazu model pack was run as a sanity check. The integration tests will now also be run.
Note
There is also a small refactor moving some functions from
train_multilabel_ner
tomodelling_utils
. Individual changes can be seen at commit level.