Skip to content

Commit

Permalink
Merge pull request #27 from ankandrew/dev
Browse files Browse the repository at this point in the history
Add more CLI train options
  • Loading branch information
ankandrew authored Dec 8, 2024
2 parents 29e29f2 + 522ac8f commit 8c51776
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 11 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.2.0] - 2024-10-14

### Added

- New European model using MobileViTV2 - trained on +40 countries 🚀 .
- Added more logging to train script.

[0.2.0]: https://github.com/ankandrew/fast-plate-ocr/compare/v0.1.6...v0.2.0

## [0.1.6] - 2024-05-09

### Added
Expand Down
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ the trained models for inference.

The idea is to use this after a plate object detector, since the OCR expects the cropped plates.

> [!TIP]
> Try `fast-plate-ocr` pre-trained models in [Hugging Spaces](https://huggingface.co/spaces/ankandrew/fast-alpr).
### Features

- **Keras 3 Backend Support**: Compatible with **[TensorFlow](https://www.tensorflow.org/)**, **[JAX](https://github.com/google/jax)**, and **[PyTorch](https://pytorch.org/)** backends 🧠
Expand All @@ -40,15 +43,14 @@ The idea is to use this after a plate object detector, since the OCR expects the
| `argentinian-plates-cnn-synth-model` | 2.1 | 476 | [arg_plate_dataset.zip](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/arg_plate_dataset_plus_synth.zip) | 94.19% | Plates up to 2020 + synthetic plates. |
| 🆕 `european-plates-mobile-vit-v2-model` | 2.9 | 344 | - | 92.5%<sup>[3]</sup> | European plates (+40 countries). |

> [!TIP]
> Try `fast-plate-ocr` pre-trained models in [Hugging Spaces](https://huggingface.co/spaces/ankandrew/fast-alpr).

_<sup>[1]</sup> Inference on Mac M1 chip using CPUExecutionProvider. Utilizing CoreMLExecutionProvider accelerates speed
> [!NOTE]
> _<sup>[1]</sup> Inference on Mac M1 chip using CPUExecutionProvider. Utilizing CoreMLExecutionProvider accelerates speed
by 5x in the CNN models._

_<sup>[2]</sup> Accuracy is what we refer as plate_acc. See [metrics section](#model-metrics)._

_<sup>[3]</sup> For detailed accuracy for each country see [results](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/european_mobile_vit_v2_ocr_results.json) and
>
> _<sup>[2]</sup> Accuracy is what we refer as plate_acc. See [metrics section](#model-metrics)._
>
> _<sup>[3]</sup> For detailed accuracy for each country see [results](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/european_mobile_vit_v2_ocr_results.json) and
the corresponding [val split](https://github.com/ankandrew/fast-plate-ocr/releases/download/arg-plates/european_mobile_vit_v2_ocr_val.zip) used._

<details>
Expand Down Expand Up @@ -136,6 +138,9 @@ To train or use the CLI tool, you'll need to install:
pip install fast_plate_ocr[train]
```

> [!IMPORTANT]
> Make sure you have installed a supported backend for Keras.
#### Train Model

To train the model you will need:
Expand Down
20 changes: 19 additions & 1 deletion fast_plate_ocr/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@
type=float,
help="Initial learning rate to use.",
)
@click.option(
"--label-smoothing",
default=0.05,
show_default=True,
type=float,
help="Amount of label smoothing to apply.",
)
@click.option(
"--batch-size",
default=128,
Expand Down Expand Up @@ -142,6 +149,11 @@
type=click.Choice(["max", "avg"]),
help="Choose the pooling layer to use.",
)
@click.option(
"--weights-path",
type=click.Path(exists=True, file_okay=True, path_type=pathlib.Path),
help="Path to the pretrained model weights file.",
)
@print_params(table_title="CLI Training Parameters", c1_title="Parameter", c2_title="Details")
def train(
dense: bool,
Expand All @@ -150,6 +162,7 @@ def train(
val_annotations: pathlib.Path,
augmentation_path: pathlib.Path | None,
lr: float,
label_smoothing: float,
batch_size: int,
num_workers: int,
output_dir: pathlib.Path,
Expand All @@ -161,6 +174,7 @@ def train(
reduce_lr_factor: float,
activation: str,
pool_layer: Literal["max", "avg"],
weights_path: pathlib.Path | None,
) -> None:
"""
Train the License Plate OCR model.
Expand Down Expand Up @@ -200,8 +214,12 @@ def train(
activation=activation,
pool_layer=pool_layer,
)

if weights_path:
model.load_weights(weights_path)

model.compile(
loss=cce_loss(vocabulary_size=config.vocabulary_size),
loss=cce_loss(vocabulary_size=config.vocabulary_size, label_smoothing=label_smoothing),
optimizer=Adam(lr),
metrics=[
cat_acc_metric(
Expand Down
6 changes: 4 additions & 2 deletions fast_plate_ocr/train/model/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def top_3_k(y_true, y_pred):


# Custom loss
def cce_loss(vocabulary_size: int):
def cce_loss(vocabulary_size: int, label_smoothing: float = 0.2):
"""
Categorical cross-entropy loss.
"""
Expand All @@ -73,7 +73,9 @@ def cce(y_true, y_pred):
y_true = ops.reshape(y_true, newshape=(-1, vocabulary_size))
y_pred = ops.reshape(y_pred, newshape=(-1, vocabulary_size))
return ops.mean(
losses.categorical_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0.2)
losses.categorical_crossentropy(
y_true, y_pred, from_logits=False, label_smoothing=label_smoothing
)
)

return cce
2 changes: 1 addition & 1 deletion test/fast_lp_ocr/inference/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_model_and_config_urls(model_name):
model_url, config_url = AVAILABLE_ONNX_MODELS[model_name]

for url in [model_url, config_url]:
response = requests.get(url, timeout=5, allow_redirects=True)
response = requests.head(url, timeout=5, allow_redirects=True)
assert (
response.status_code == HTTPStatus.OK
), f"URL {url} is not accessible, got {response.status_code}"

0 comments on commit 8c51776

Please sign in to comment.