Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

decouple batch_size to det_batch_size, rec_batch_size and kie_batch_size in MMOCRInferencer #1801

Merged
merged 5 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/en/user_guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,9 @@ Here are extensive lists of parameters that you can use.
| `inputs` | str/list/tuple/np.array | **required** | It can be a path to an image/a folder, an np array or a list/tuple (with img paths or np arrays) |
| `return_datasamples` | bool | False | Whether to return results as DataSamples. If False, the results will be packed into a dict. |
| `batch_size` | int | 1 | Inference batch size. |
| `det_batch_size` | int, optional | None | Inference batch size for text detection model. Overwrite batch_size if it is not None. |
| `rec_batch_size` | int, optional | None | Inference batch size for text recognition model. Overwrite batch_size if it is not None. |
| `kie_batch_size` | int, optional | None | Inference batch size for KIE model. Overwrite batch_size if it is not None. |
| `return_vis` | bool | False | Whether to return the visualization result. |
| `print_result` | bool | False | Whether to print the inference result to the console. |
| `show` | bool | False | Whether to display the visualization results in a popup window. |
Expand Down
3 changes: 3 additions & 0 deletions docs/zh_cn/user_guides/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ outputs
| `inputs` | str/list/tuple/np.array | **必需** | 它可以是一个图片/文件夹的路径,一个 numpy 数组,或者是一个包含图片路径或 numpy 数组的列表/元组 |
| `return_datasamples` | bool | False | 是否将结果作为 DataSample 返回。如果为 False,结果将被打包成一个字典。 |
| `batch_size` | int | 1 | 推理的批大小。 |
| `det_batch_size` | int, 可选 | None | 推理的批大小 (文本检测模型)。如果不为 None,则覆盖 batch_size。 |
| `rec_batch_size` | int, 可选 | None | 推理的批大小 (文本识别模型)。如果不为 None,则覆盖 batch_size。 |
| `kie_batch_size` | int, 可选 | None | 推理的批大小 (关键信息提取模型)。如果不为 None,则覆盖 batch_size。 |
| `return_vis` | bool | False | 是否返回可视化结果。 |
| `print_result` | bool | False | 是否将推理结果打印到控制台。 |
| `show` | bool | False | 是否在弹出窗口中显示可视化结果。 |
Expand Down
55 changes: 49 additions & 6 deletions mmocr/apis/inferencers/mmocr_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,34 +105,54 @@ def _inputs2ndarrray(self, inputs: List[InputsType]) -> List[np.ndarray]:
'supported yet.')
return new_inputs

def forward(self, inputs: InputsType, batch_size: int,
def forward(self,
inputs: InputsType,
batch_size: int = 1,
det_batch_size: Optional[int] = None,
rec_batch_size: Optional[int] = None,
kie_batch_size: Optional[int] = None,
**forward_kwargs) -> PredType:
"""Forward the inputs to the model.

Args:
inputs (InputsType): The inputs to be forwarded.
batch_size (int): Batch size. Defaults to 1.
det_batch_size (Optional[int]): Batch size for text detection
model. Overwrite batch_size if it is not None.
Defaults to None.
rec_batch_size (Optional[int]): Batch size for text recognition
model. Overwrite batch_size if it is not None.
Defaults to None.
kie_batch_size (Optional[int]): Batch size for KIE model.
Overwrite batch_size if it is not None.
Defaults to None.

Returns:
Dict: The prediction results. Possibly with keys "det", "rec", and
"kie"..
"""
result = {}
forward_kwargs['progress_bar'] = False
if det_batch_size is None:
det_batch_size = batch_size
if rec_batch_size is None:
rec_batch_size = batch_size
if kie_batch_size is None:
kie_batch_size = batch_size
if self.mode == 'rec':
# The extra list wrapper here is for the ease of postprocessing
self.rec_inputs = inputs
predictions = self.textrec_inferencer(
self.rec_inputs,
return_datasamples=True,
batch_size=batch_size,
batch_size=rec_batch_size,
**forward_kwargs)['predictions']
result['rec'] = [[p] for p in predictions]
elif self.mode.startswith('det'): # 'det'/'det_rec'/'det_rec_kie'
result['det'] = self.textdet_inferencer(
inputs,
return_datasamples=True,
batch_size=batch_size,
batch_size=det_batch_size,
**forward_kwargs)['predictions']
if self.mode.startswith('det_rec'): # 'det_rec'/'det_rec_kie'
result['rec'] = []
Expand All @@ -149,7 +169,7 @@ def forward(self, inputs: InputsType, batch_size: int,
self.textrec_inferencer(
self.rec_inputs,
return_datasamples=True,
batch_size=batch_size,
batch_size=rec_batch_size,
**forward_kwargs)['predictions'])
if self.mode == 'det_rec_kie':
self.kie_inputs = []
Expand All @@ -172,7 +192,7 @@ def forward(self, inputs: InputsType, batch_size: int,
result['kie'] = self.kie_inferencer(
self.kie_inputs,
return_datasamples=True,
batch_size=batch_size,
batch_size=kie_batch_size,
**forward_kwargs)['predictions']
return result

Expand Down Expand Up @@ -219,6 +239,9 @@ def __call__(
self,
inputs: InputsType,
batch_size: int = 1,
det_batch_size: Optional[int] = None,
rec_batch_size: Optional[int] = None,
kie_batch_size: Optional[int] = None,
out_dir: str = 'results/',
return_vis: bool = False,
save_vis: bool = False,
Expand All @@ -231,6 +254,15 @@ def __call__(
inputs (InputsType): Inputs for the inferencer. It can be a path
to image / image directory, or an array, or a list of these.
batch_size (int): Batch size. Defaults to 1.
det_batch_size (Optional[int]): Batch size for text detection
model. Overwrite batch_size if it is not None.
Defaults to None.
rec_batch_size (Optional[int]): Batch size for text recognition
model. Overwrite batch_size if it is not None.
Defaults to None.
kie_batch_size (Optional[int]): Batch size for KIE model.
Overwrite batch_size if it is not None.
Defaults to None.
out_dir (str): Output directory of results. Defaults to 'results/'.
return_vis (bool): Whether to return the visualization result.
Defaults to False.
Expand Down Expand Up @@ -269,12 +301,23 @@ def __call__(
**kwargs)

ori_inputs = self._inputs_to_list(inputs)
if det_batch_size is None:
det_batch_size = batch_size
if rec_batch_size is None:
rec_batch_size = batch_size
if kie_batch_size is None:
kie_batch_size = batch_size

chunked_inputs = super(BaseMMOCRInferencer,
self)._get_chunk_data(ori_inputs, batch_size)
results = {'predictions': [], 'visualization': []}
for ori_input in track(chunked_inputs, description='Inference'):
preds = self.forward(ori_input, batch_size, **forward_kwargs)
preds = self.forward(
ori_input,
det_batch_size=det_batch_size,
rec_batch_size=rec_batch_size,
kie_batch_size=kie_batch_size,
**forward_kwargs)
visualization = self.visualize(
ori_input, preds, img_out_dir=img_out_dir, **visualize_kwargs)
batch_res = self.postprocess(
Expand Down