Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DTrOCR-2 - Add support for key-value caching #11

Merged
merged 16 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
cff-version: 1.2.0
message: "If you use this software, please cite using metadata below."
authors:
- family-names: "Rajan"
given-names: "Arvind"
orcid: "https://orcid.org/0000-0003-4829-5007"
title: "A PyTorch implementation of DTrOCR: Decoder-only Transformer for Optical Character Recognition"
repository-code: 'https://github.com/arvindrajan92/DTrOCR'
date-released: 2024-07-13
license: MIT
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
# DTrOCR
![logo](logo.png)
[![Python application](https://github.com/arvindrajan92/DTrOCR/actions/workflows/python-app.yml/badge.svg)](https://github.com/arvindrajan92/DTrOCR/actions/workflows/python-app.yml)
[![CodeQL](https://github.com/arvindrajan92/DTrOCR/actions/workflows/github-code-scanning/codeql/badge.svg)](https://github.com/arvindrajan92/DTrOCR/actions/workflows/github-code-scanning/codeql)
[![Python Versions](https://img.shields.io/badge/python-3.11-blue)](https://www.python.org/downloads/)
[![License](https://img.shields.io/github/license/arvindrajan92/DTrOCR.svg)](https://github.com/arvindrajan92/DTrOCR/LICENSE)
[![GitHub stars](https://img.shields.io/github/stars/arvindrajan92/DTrOCR?style=social)](https://github.com/arvindrajan92/DTrOCR)

A PyTorch implementation of DTrOCR: Decoder-only Transformer for Optical Character Recognition.

> [!NOTE]
>
> The author of this repository is not in any way affiliated to the author of the [DTrOCR paper](https://doi.org/10.48550/arXiv.2308.15996). This implementation is purely based on the published details of DTrOCR model architecture and its training.
>
> Pre-trained weight for the model is currently not available as this is a personal project with limited resources.
> Pre-trained weight for the model is not available at this time as this is a personal project with limited resources.

Below are the key differences between the original implementation (from the paper) and this implementation.

| | Original implementation | This implementation |
| ------------------------------------------------------------ | ---------------------------- | --------------------- |
| Maximum token length<br />(including 128 image patch tokens) | 512 | 256 |
| Language | English & Chinese | English |
| Pre-training corpus (planned) | Scene, printed & handwritten | Printed & handwritten |

## Installation

Expand All @@ -28,6 +43,7 @@ config = DTrOCRConfig()
model = DTrOCRLMHeadModel(config)
processor = DTrOCRProcessor(DTrOCRConfig())

model.eval() # set model to evaluation mode for deterministic behaviour
path_to_image = "" # path to image file

inputs = processor(
Expand All @@ -39,7 +55,8 @@ inputs = processor(
model_output = model.generate(
inputs=inputs,
processor=processor,
num_beams=3 # defaults to 1 if not specified
num_beams=3, # defaults to 1 if not specified
use_cache=True # defaults to True if not specified
)

predicted_text = processor.tokeniser.decode(model_output[0], skip_special_tokens=True)
Expand Down
2 changes: 2 additions & 0 deletions dtrocr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
@dataclass
class DTrOCRModelOutput:
hidden_states: torch.FloatTensor
past_key_values: torch.FloatTensor


@dataclass
class DTrOCRLMHeadModelOutput:
logits: torch.FloatTensor
loss: Optional[torch.FloatTensor] = None
accuracy: Optional[torch.FloatTensor] = None
past_key_values: Optional[torch.FloatTensor] = None


@dataclass
Expand Down
Loading