trouble using OCRDataset #151
-
Hi again, I am trying to a make dataset with OCRDataset class of this project. In the OCRDatasetConfig, the 'path' could be a path to a csv file? or I should convert my data to another format? If it does not support csv format why it gets the column names? |
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 20 replies
-
Hello @kghezelbash, You can use this template as an example: from hezar.models import CRNNImage2TextConfig, CRNNImage2Text
from hezar.preprocessors import ImageProcessor
from hezar.trainer import Trainer, TrainerConfig
from hezar.data import OCRDataset, OCRDatasetConfig
class PersianOCRDataset(OCRDataset):
def __init__(self, config: OCRDatasetConfig, split=None, **kwargs):
super().__init__(config=config, split=split, **kwargs)
def _load(self, split=None):
# Load a dataframe here and make sure the split is fetched
data = pd.read_csv(self.config.path)
# preprocess if needed
return data
def __getitem__(self, index):
path, text = self.data.iloc[index].values()
pixel_values = self.image_processor(path, return_tensors="pt")["pixel_values"][0]
labels = self._text_to_tensor(text)
inputs = {
"pixel_values": pixel_values,
"labels": labels,
}
return inputs
dataset_config = OCRDatasetConfig(
path="path/to/csv",
text_split_type="char_split",
text_column="label",
images_paths_column="image_path",
reverse_digits=True,
)
train_dataset = PersianOCRDataset(dataset_config, split="train")
eval_dataset = PersianOCRDataset(dataset_config, split="test")
model = CRNNImage2Text(
CRNNImage2TextConfig(
id2label=train_dataset.config.id2label,
map2seq_in_dim=1024,
map2seq_out_dim=96
)
)
preprocessor = ImageProcessor(train_dataset.config.image_processor_config)
train_config = TrainerConfig(
output_dir="crnn-plate-fa-v1",
task="image2text",
device="cuda",
batch_size=8,
num_epochs=20,
metrics=["cer"],
metric_for_best_model="cer"
)
trainer = Trainer(
config=train_config,
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=train_dataset.data_collator,
preprocessor=preprocessor,
)
trainer.train() |
Beta Was this translation helpful? Give feedback.
-
Hi again, if we want to finetune the OCR model with text_split_type='tokenize' what id2label we should use? |
Beta Was this translation helpful? Give feedback.
-
@kghezelbash Hi. the tokenize type is only applicable for transformer models like TrOCR that need a tokenizer so that id2label is not necessary. Generally, we do not recommend using that method since it was never tested. |
Beta Was this translation helpful? Give feedback.
-
@kghezelbash Appreciate it. Thanks. |
Beta Was this translation helpful? Give feedback.
-
Thank you again =)). import csv
import os
from hezar.data.datasets.ocr_dataset import TextSplitType
from hezar.constants import TaskType
from hezar.data.datasets.ocr_dataset import OCRDatasetConfig
from hezar.preprocessors import ImageProcessor
from hezar.trainer import Trainer, TrainerConfig
import pandas as pd
from hezar.models import CRNNImage2TextConfig, CRNNImage2Text
from hezar.data import OCRDataset, OCRDatasetConfig
from hezar.preprocessors.image_processor import ImageProcessorConfig
from tqdm import tqdm
csv_file = "IDDATA2.csv"
fa_characters = [
"", "آ", "ا", "ب", "پ", "ت", "ث", "ج", "چ", "ح", "خ", "د", "ذ", "ر", "ز", "ژ", "س", "ش",
"ص", "ض", "ط", "ظ", "ع", "غ", "ف", "ق", "ک", "گ", "ل", "م", "ن", "و", "ه", "ی", " " , "ي"
]
fa_numbers = ["۱", "۲", "۳", "۴", "۵", "۶", "۷", "۸", "۹", "۰"]
fa_special_characters = ["ء", "ؤ", "ئ", "أ", "ّ", 'ٓ', 'ٕ', "ٔ", "\u200c" , "j", "p", "g"]
fa_symbols = ["/", "(", ")", "+", "-", ":", "،", "!", ".", "؛", "=", "%", "؟"]
en_numbers = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]
all_characters = fa_characters + fa_numbers + fa_special_characters + fa_symbols + en_numbers
ID2LABEL = dict(enumerate(all_characters))
image_processor_config = ImageProcessorConfig(
mean=[0.5], # Example mean values for normalization
std=[0.5], # Example standard deviation values for normalization
rescale=1.0, # Example rescaling factor
resample=2, # Example resampling filter (2 for BICUBIC)
size=(224, 224), # Example target image size (width, height)
mirror=False, # Example mirror augmentation
gray_scale=True # Example grayscale conversion
)
class PersianOCRDataset(OCRDataset):
def __init__(self, config: OCRDatasetConfig, split=None, **kwargs):
super().__init__(config=config, split=split, **kwargs)
def _load(self, split=None):
data = pd.read_csv(self.config.path)
return data
def __getitem__(self, index):
path = self.data.iloc[index][0]
text = self.data.iloc[index][1]
pixel_values = self.image_processor(path, return_tensors="pt")["pixel_values"][0]
labels = self._text_to_tensor(text)
inputs = {
"pixel_values": pixel_values,
"labels": labels,
}
return inputs
dataset_config = OCRDatasetConfig(
path=csv_file,
text_split_type='char_split',
id2label=ID2LABEL,
text_column="text",
images_paths_column="image_path",
max_length=100,
invalid_characters=[],
reverse_text=False,
reverse_digits=False,
image_processor_config=image_processor_config
)
train_dataset = PersianOCRDataset(dataset_config, split="train")
eval_dataset = PersianOCRDataset(dataset_config, split="test")
model = CRNNImage2Text(
CRNNImage2TextConfig(
id2label=train_dataset.config.id2label,
map2seq_in_dim=7168,
map2seq_out_dim=96
)
)
preprocessor = ImageProcessor(train_dataset.config.image_processor_config)
train_config = TrainerConfig(
output_dir="crnn-plate-fa-v1",
task="image2text",
device="CPU",
batch_size=10,
num_epochs=1,
metrics=["cer"],
metric_for_best_model="cer"
)
trainer = Trainer(
config=train_config,
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=train_dataset.data_collator,
preprocessor=preprocessor,
)
trainer.train() |
Beta Was this translation helpful? Give feedback.
Hello @kghezelbash,
To make it more clear, if you want to have your own class to be able to train your model using Hezar, you have to provide a regular PyTorch Dataset subclass. Hezar has its own dataset classes for casual tasks like OCR, image captioning, text classification, etc. There is no force to only use those classes, they're just there to make it easier for you. If your dataset is so different or need a lot of customizations, you can easily write your own dataset class.
You can use this template as an example: