diff --git a/trainer/craft/.gitignore b/trainer/craft/.gitignore new file mode 100644 index 000000000..46d5bda86 --- /dev/null +++ b/trainer/craft/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +model/__pycache__/ +wandb/* +vis_result/* diff --git a/trainer/craft/README.md b/trainer/craft/README.md new file mode 100644 index 000000000..823c83dc6 --- /dev/null +++ b/trainer/craft/README.md @@ -0,0 +1,105 @@ +# CRAFT-train +On the official CRAFT github, there are many people who want to train CRAFT models. + +However, the training code is not published in the official CRAFT repository. + +There are other reproduced codes, but there is a gap between their performance and performance reported in the original paper. (https://arxiv.org/pdf/1904.01941.pdf) + +The trained model with this code recorded a level of performance similar to that of the original paper. + +```bash +├── config +│ ├── syn_train.yaml +│ └── custom_data_train.yaml +├── data +│ ├── pseudo_label +│ │ ├── make_charbox.py +│ │ └── watershed.py +│ ├── boxEnlarge.py +│ ├── dataset.py +│ ├── gaussian.py +│ ├── imgaug.py +│ └── imgproc.py +├── loss +│ └── mseloss.py +├── metrics +│ └── eval_det_iou.py +├── model +│ ├── craft.py +│ └── vgg16_bn.py +├── utils +│ ├── craft_utils.py +│ ├── inference_boxes.py +│ └── utils.py +├── trainSynth.py +├── train.py +├── train_distributed.py +├── eval.py +├── data_root_dir (place dataset folder here) +└── exp (model and experiment result files will saved here) +``` + +### Installation + +Install using `pip` + +``` bash +pip install -r requirements.txt +``` + + +### Training +1. Put your training, test data in the following format + ``` + └── data_root_dir (you can change root dir in yaml file) + ├── ch4_training_images + │ ├── img_1.jpg + │ └── img_2.jpg + ├── ch4_training_localization_transcription_gt + │ ├── gt_img_1.txt + │ └── gt_img_2.txt + ├── ch4_test_images + │ ├── img_1.jpg + │ └── img_2.jpg + └── ch4_training_localization_transcription_gt + ├── gt_img_1.txt + └── gt_img_2.txt + ``` + * localization_transcription_gt files format : + ``` + 377,117,463,117,465,130,378,130,Genaxis Theatre + 493,115,519,115,519,131,493,131,[06] + 374,155,409,155,409,170,374,170,### + ``` +2. Write configuration in yaml format (example config files are provided in `config` folder.) + * To speed up training time with multi-gpu, set num_worker > 0 +3. Put the yaml file in the config folder +4. Run training script like below (If you have multi-gpu, run train_distributed.py) +5. Then, experiment results will be saved to ```./exp/[yaml]``` by default. + +* Step 1 : To train CRAFT with SynthText dataset from scratch + * Note : This step is not necessary if you use this pretrain as a checkpoint when start training step 2. You can download and put it in `exp/CRAFT_clr_amp_29500.pth` and change `ckpt_path` in the config file according to your local setup. + ``` + CUDA_VISIBLE_DEVICES=0 python3 trainSynth.py --yaml=syn_train + ``` + +* Step 2 : To train CRAFT with [SynthText + IC15] or custom dataset + ``` + CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=custom_data_train ## if you run on single GPU + CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=custom_data_train ## if you run on multi GPU + ``` + +### Arguments +* ```--yaml``` : configuration file name + +### Evaluation +* In the official repository issues, the author mentioned that the first row setting F1-score is around 0.75. +* In the official paper, it is stated that the result F1-score of the second row setting is 0.87. + * If you adjust post-process parameter 'text_threshold' from 0.85 to 0.75, then F1-score reaches to 0.856. +* It took 14h to train weak-supervision 25k iteration with 8 RTX 3090 Ti. + * Half of GPU assigned for training, and half of GPU assigned for supervision setting. + +| Training Dataset | Evaluation Dataset | Precision | Recall | F1-score | pretrained model | +| ------------- |-----|:-----:|:-----:|:-----:|-----:| +| SynthText | ICDAR2013 | 0.801 | 0.748 | 0.773| download link| +| SynthText + ICDAR2015 | ICDAR2015 | 0.909 | 0.794 | 0.848| download link| diff --git a/trainer/craft/config/__init__.py b/trainer/craft/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/trainer/craft/config/custom_data_train.yaml b/trainer/craft/config/custom_data_train.yaml new file mode 100644 index 000000000..bf7b577bf --- /dev/null +++ b/trainer/craft/config/custom_data_train.yaml @@ -0,0 +1,100 @@ +wandb_opt: False + +results_dir: "./exp/" +vis_test_dir: "./vis_result/" + +data_root_dir: "./data_root_dir/" +score_gt_dir: None # "/data/ICDAR2015_official_supervision" +mode: "weak_supervision" + + +train: + backbone : vgg + use_synthtext: False # If you want to combine SynthText in train time as CRAFT did, you can turn on this option + synth_data_dir: "/data/SynthText/" + synth_ratio: 5 + real_dataset: custom + ckpt_path: "./pretrained_model/CRAFT_clr_amp_29500.pth" + eval_interval: 1000 + batch_size: 5 + st_iter: 0 + end_iter: 25000 + lr: 0.0001 + lr_decay: 7500 + gamma: 0.2 + weight_decay: 0.00001 + num_workers: 0 # On single gpu, train.py execution only works when num worker = 0 / On multi-gpu, you can set num_worker > 0 to speed up + amp: True + loss: 2 + neg_rto: 0.3 + n_min_neg: 5000 + data: + vis_opt: False + pseudo_vis_opt: False + output_size: 768 + do_not_care_label: ['###', ''] + mean: [0.485, 0.456, 0.406] + variance: [0.229, 0.224, 0.225] + enlarge_region : [0.5, 0.5] # x axis, y axis + enlarge_affinity: [0.5, 0.5] + gauss_init_size: 200 + gauss_sigma: 40 + watershed: + version: "skimage" + sure_fg_th: 0.75 + sure_bg_th: 0.05 + syn_sample: -1 + custom_sample: -1 + syn_aug: + random_scale: + range: [1.0, 1.5, 2.0] + option: False + random_rotate: + max_angle: 20 + option: False + random_crop: + version: "random_resize_crop_synth" + option: True + random_horizontal_flip: + option: False + random_colorjitter: + brightness: 0.2 + contrast: 0.2 + saturation: 0.2 + hue: 0.2 + option: True + custom_aug: + random_scale: + range: [ 1.0, 1.5, 2.0 ] + option: False + random_rotate: + max_angle: 20 + option: True + random_crop: + version: "random_resize_crop" + scale: [0.03, 0.4] + ratio: [0.75, 1.33] + rnd_threshold: 1.0 + option: True + random_horizontal_flip: + option: True + random_colorjitter: + brightness: 0.2 + contrast: 0.2 + saturation: 0.2 + hue: 0.2 + option: True + +test: + trained_model : null + custom_data: + test_set_size: 500 + test_data_dir: "./data_root_dir/" + text_threshold: 0.75 + low_text: 0.5 + link_threshold: 0.2 + canvas_size: 2240 + mag_ratio: 1.75 + poly: False + cuda: True + vis_opt: False diff --git a/trainer/craft/config/load_config.py b/trainer/craft/config/load_config.py new file mode 100644 index 000000000..abe3551f2 --- /dev/null +++ b/trainer/craft/config/load_config.py @@ -0,0 +1,37 @@ +import os +import yaml +from functools import reduce + +CONFIG_PATH = os.path.dirname(__file__) + +def load_yaml(config_name): + + with open(os.path.join(CONFIG_PATH, config_name)+ '.yaml') as file: + config = yaml.safe_load(file) + + return config + +class DotDict(dict): + def __getattr__(self, k): + try: + v = self[k] + except: + return super().__getattr__(k) + if isinstance(v, dict): + return DotDict(v) + return v + + def __getitem__(self, k): + if isinstance(k, str) and '.' in k: + k = k.split('.') + if isinstance(k, (list, tuple)): + return reduce(lambda d, kk: d[kk], k, self) + return super().__getitem__(k) + + def get(self, k, default=None): + if isinstance(k, str) and '.' in k: + try: + return self[k] + except KeyError: + return default + return super().get(k, default=default) \ No newline at end of file diff --git a/trainer/craft/config/syn_train.yaml b/trainer/craft/config/syn_train.yaml new file mode 100644 index 000000000..a41663f7c --- /dev/null +++ b/trainer/craft/config/syn_train.yaml @@ -0,0 +1,68 @@ +wandb_opt: False + +results_dir: "./exp/" +vis_test_dir: "./vis_result/" +data_dir: + synthtext: "/data/SynthText/" + synthtext_gt: NULL + +train: + backbone : vgg + dataset: ["synthtext"] + ckpt_path: null + eval_interval: 1000 + batch_size: 5 + st_iter: 0 + end_iter: 50000 + lr: 0.0001 + lr_decay: 15000 + gamma: 0.2 + weight_decay: 0.00001 + num_workers: 4 + amp: True + loss: 3 + neg_rto: 1 + n_min_neg: 1000 + data: + vis_opt: False + output_size: 768 + mean: [0.485, 0.456, 0.406] + variance: [0.229, 0.224, 0.225] + enlarge_region : [0.5, 0.5] # x axis, y axis + enlarge_affinity: [0.5, 0.5] + gauss_init_size: 200 + gauss_sigma: 40 + syn_sample : -1 + syn_aug: + random_scale: + range: [1.0, 1.5, 2.0] + option: False + random_rotate: + max_angle: 20 + option: False + random_crop: + version: "random_resize_crop_synth" + rnd_threshold : 1.0 + option: True + random_horizontal_flip: + option: False + random_colorjitter: + brightness: 0.2 + contrast: 0.2 + saturation: 0.2 + hue: 0.2 + option: True + +test: + trained_model: null + icdar2013: + test_set_size: 233 + cuda: True + vis_opt: True + test_data_dir : "/data/ICDAR2013/" + text_threshold: 0.85 + low_text: 0.5 + link_threshold: 0.2 + canvas_size: 960 + mag_ratio: 1.5 + poly: False \ No newline at end of file diff --git a/trainer/craft/data/boxEnlarge.py b/trainer/craft/data/boxEnlarge.py new file mode 100644 index 000000000..73d5bc5c2 --- /dev/null +++ b/trainer/craft/data/boxEnlarge.py @@ -0,0 +1,65 @@ +import math +import numpy as np + + +def pointAngle(Apoint, Bpoint): + angle = (Bpoint[1] - Apoint[1]) / ((Bpoint[0] - Apoint[0]) + 10e-8) + return angle + +def pointDistance(Apoint, Bpoint): + return math.sqrt((Bpoint[1] - Apoint[1])**2 + (Bpoint[0] - Apoint[0])**2) + +def lineBiasAndK(Apoint, Bpoint): + + K = pointAngle(Apoint, Bpoint) + B = Apoint[1] - K*Apoint[0] + return K, B + +def getX(K, B, Ypoint): + return int((Ypoint-B)/K) + +def sidePoint(Apoint, Bpoint, h, w, placehold, enlarge_size): + + K, B = lineBiasAndK(Apoint, Bpoint) + angle = abs(math.atan(pointAngle(Apoint, Bpoint))) + distance = pointDistance(Apoint, Bpoint) + + x_enlarge_size, y_enlarge_size = enlarge_size + + XaxisIncreaseDistance = abs(math.cos(angle) * x_enlarge_size * distance) + YaxisIncreaseDistance = abs(math.sin(angle) * y_enlarge_size * distance) + + if placehold == 'leftTop': + x1 = max(0, Apoint[0] - XaxisIncreaseDistance) + y1 = max(0, Apoint[1] - YaxisIncreaseDistance) + elif placehold == 'rightTop': + x1 = min(w, Bpoint[0] + XaxisIncreaseDistance) + y1 = max(0, Bpoint[1] - YaxisIncreaseDistance) + elif placehold == 'rightBottom': + x1 = min(w, Bpoint[0] + XaxisIncreaseDistance) + y1 = min(h, Bpoint[1] + YaxisIncreaseDistance) + elif placehold == 'leftBottom': + x1 = max(0, Apoint[0] - XaxisIncreaseDistance) + y1 = min(h, Apoint[1] + YaxisIncreaseDistance) + return int(x1), int(y1) + +def enlargebox(box, h, w, enlarge_size, horizontal_text_bool): + + if not horizontal_text_bool: + enlarge_size = (enlarge_size[1], enlarge_size[0]) + + box = np.roll(box, -np.argmin(box.sum(axis=1)), axis=0) + + Apoint, Bpoint, Cpoint, Dpoint = box + K1, B1 = lineBiasAndK(box[0], box[2]) + K2, B2 = lineBiasAndK(box[3], box[1]) + X = (B2 - B1)/(K1 - K2) + Y = K1 * X + B1 + center = [X, Y] + + x1, y1 = sidePoint(Apoint, center, h, w, 'leftTop', enlarge_size) + x2, y2 = sidePoint(center, Bpoint, h, w, 'rightTop', enlarge_size) + x3, y3 = sidePoint(center, Cpoint, h, w, 'rightBottom', enlarge_size) + x4, y4 = sidePoint(Dpoint, center, h, w, 'leftBottom', enlarge_size) + newcharbox = np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) + return newcharbox \ No newline at end of file diff --git a/trainer/craft/data/dataset.py b/trainer/craft/data/dataset.py new file mode 100644 index 000000000..e7e6943e4 --- /dev/null +++ b/trainer/craft/data/dataset.py @@ -0,0 +1,542 @@ +import os +import re +import itertools +import random + +import numpy as np +import scipy.io as scio +from PIL import Image +import cv2 +from torch.utils.data import Dataset +import torchvision.transforms as transforms + +from data import imgproc +from data.gaussian import GaussianBuilder +from data.imgaug import ( + rescale, + random_resize_crop_synth, + random_resize_crop, + random_horizontal_flip, + random_rotate, + random_scale, + random_crop, +) +from data.pseudo_label.make_charbox import PseudoCharBoxBuilder +from utils.util import saveInput, saveImage + + +class CraftBaseDataset(Dataset): + def __init__( + self, + output_size, + data_dir, + saved_gt_dir, + mean, + variance, + gauss_init_size, + gauss_sigma, + enlarge_region, + enlarge_affinity, + aug, + vis_test_dir, + vis_opt, + sample, + ): + self.output_size = output_size + self.data_dir = data_dir + self.saved_gt_dir = saved_gt_dir + self.mean, self.variance = mean, variance + self.gaussian_builder = GaussianBuilder( + gauss_init_size, gauss_sigma, enlarge_region, enlarge_affinity + ) + self.aug = aug + self.vis_test_dir = vis_test_dir + self.vis_opt = vis_opt + self.sample = sample + if self.sample != -1: + random.seed(0) + self.idx = random.sample(range(0, len(self.img_names)), self.sample) + + self.pre_crop_area = [] + + def augment_image( + self, image, region_score, affinity_score, confidence_mask, word_level_char_bbox + ): + augment_targets = [image, region_score, affinity_score, confidence_mask] + + if self.aug.random_scale.option: + augment_targets, word_level_char_bbox = random_scale( + augment_targets, word_level_char_bbox, self.aug.random_scale.range + ) + + if self.aug.random_rotate.option: + augment_targets = random_rotate( + augment_targets, self.aug.random_rotate.max_angle + ) + + if self.aug.random_crop.option: + if self.aug.random_crop.version == "random_crop_with_bbox": + augment_targets = random_crop_with_bbox( + augment_targets, word_level_char_bbox, self.output_size + ) + elif self.aug.random_crop.version == "random_resize_crop_synth": + augment_targets = random_resize_crop_synth( + augment_targets, self.output_size + ) + elif self.aug.random_crop.version == "random_resize_crop": + + if len(self.pre_crop_area) > 0: + pre_crop_area = self.pre_crop_area + else: + pre_crop_area = None + + augment_targets = random_resize_crop( + augment_targets, + self.aug.random_crop.scale, + self.aug.random_crop.ratio, + self.output_size, + self.aug.random_crop.rnd_threshold, + pre_crop_area, + ) + + elif self.aug.random_crop.version == "random_crop": + augment_targets = random_crop(augment_targets, self.output_size,) + + else: + assert "Undefined RandomCrop version" + + if self.aug.random_horizontal_flip.option: + augment_targets = random_horizontal_flip(augment_targets) + + if self.aug.random_colorjitter.option: + image, region_score, affinity_score, confidence_mask = augment_targets + image = Image.fromarray(image) + image = transforms.ColorJitter( + brightness=self.aug.random_colorjitter.brightness, + contrast=self.aug.random_colorjitter.contrast, + saturation=self.aug.random_colorjitter.saturation, + hue=self.aug.random_colorjitter.hue, + )(image) + else: + image, region_score, affinity_score, confidence_mask = augment_targets + + return np.array(image), region_score, affinity_score, confidence_mask + + def resize_to_half(self, ground_truth, interpolation): + return cv2.resize( + ground_truth, + (self.output_size // 2, self.output_size // 2), + interpolation=interpolation, + ) + + def __len__(self): + if self.sample != -1: + return len(self.idx) + else: + return len(self.img_names) + + def __getitem__(self, index): + if self.sample != -1: + index = self.idx[index] + if self.saved_gt_dir is None: + ( + image, + region_score, + affinity_score, + confidence_mask, + word_level_char_bbox, + all_affinity_bbox, + words, + ) = self.make_gt_score(index) + else: + ( + image, + region_score, + affinity_score, + confidence_mask, + word_level_char_bbox, + words, + ) = self.load_saved_gt_score(index) + all_affinity_bbox = [] + + if self.vis_opt: + saveImage( + self.img_names[index], + self.vis_test_dir, + image.copy(), + word_level_char_bbox.copy(), + all_affinity_bbox.copy(), + region_score.copy(), + affinity_score.copy(), + confidence_mask.copy(), + ) + + image, region_score, affinity_score, confidence_mask = self.augment_image( + image, region_score, affinity_score, confidence_mask, word_level_char_bbox + ) + + if self.vis_opt: + saveInput( + self.img_names[index], + self.vis_test_dir, + image, + region_score, + affinity_score, + confidence_mask, + ) + + region_score = self.resize_to_half(region_score, interpolation=cv2.INTER_CUBIC) + affinity_score = self.resize_to_half( + affinity_score, interpolation=cv2.INTER_CUBIC + ) + confidence_mask = self.resize_to_half( + confidence_mask, interpolation=cv2.INTER_NEAREST + ) + + image = imgproc.normalizeMeanVariance( + np.array(image), mean=self.mean, variance=self.variance + ) + image = image.transpose(2, 0, 1) + + return image, region_score, affinity_score, confidence_mask + + +class SynthTextDataSet(CraftBaseDataset): + def __init__( + self, + output_size, + data_dir, + saved_gt_dir, + mean, + variance, + gauss_init_size, + gauss_sigma, + enlarge_region, + enlarge_affinity, + aug, + vis_test_dir, + vis_opt, + sample, + ): + super().__init__( + output_size, + data_dir, + saved_gt_dir, + mean, + variance, + gauss_init_size, + gauss_sigma, + enlarge_region, + enlarge_affinity, + aug, + vis_test_dir, + vis_opt, + sample, + ) + self.img_names, self.char_bbox, self.img_words = self.load_data() + self.vis_index = list(range(1000)) + + def load_data(self, bbox="char"): + + gt = scio.loadmat(os.path.join(self.data_dir, "gt.mat")) + img_names = gt["imnames"][0] + img_words = gt["txt"][0] + + if bbox == "char": + img_bbox = gt["charBB"][0] + else: + img_bbox = gt["wordBB"][0] # word bbox needed for test + + return img_names, img_bbox, img_words + + def dilate_img_to_output_size(self, image, char_bbox): + h, w, _ = image.shape + if min(h, w) <= self.output_size: + scale = float(self.output_size) / min(h, w) + else: + scale = 1.0 + image = cv2.resize( + image, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC + ) + char_bbox *= scale + return image, char_bbox + + def make_gt_score(self, index): + img_path = os.path.join(self.data_dir, self.img_names[index][0]) + image = cv2.imread(img_path, cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + all_char_bbox = self.char_bbox[index].transpose( + (2, 1, 0) + ) # shape : (Number of characters in image, 4, 2) + + img_h, img_w, _ = image.shape + + confidence_mask = np.ones((img_h, img_w), dtype=np.float32) + + words = [ + re.split(" \n|\n |\n| ", word.strip()) for word in self.img_words[index] + ] + words = list(itertools.chain(*words)) + words = [word for word in words if len(word) > 0] + + word_level_char_bbox = [] + char_idx = 0 + + for i in range(len(words)): + length_of_word = len(words[i]) + word_bbox = all_char_bbox[char_idx : char_idx + length_of_word] + assert len(word_bbox) == length_of_word + char_idx += length_of_word + word_bbox = np.array(word_bbox) + word_level_char_bbox.append(word_bbox) + + region_score = self.gaussian_builder.generate_region( + img_h, + img_w, + word_level_char_bbox, + horizontal_text_bools=[True for _ in range(len(words))], + ) + affinity_score, all_affinity_bbox = self.gaussian_builder.generate_affinity( + img_h, + img_w, + word_level_char_bbox, + horizontal_text_bools=[True for _ in range(len(words))], + ) + + return ( + image, + region_score, + affinity_score, + confidence_mask, + word_level_char_bbox, + all_affinity_bbox, + words, + ) + + +class CustomDataset(CraftBaseDataset): + def __init__( + self, + output_size, + data_dir, + saved_gt_dir, + mean, + variance, + gauss_init_size, + gauss_sigma, + enlarge_region, + enlarge_affinity, + aug, + vis_test_dir, + vis_opt, + sample, + watershed_param, + pseudo_vis_opt, + do_not_care_label, + ): + super().__init__( + output_size, + data_dir, + saved_gt_dir, + mean, + variance, + gauss_init_size, + gauss_sigma, + enlarge_region, + enlarge_affinity, + aug, + vis_test_dir, + vis_opt, + sample, + ) + self.pseudo_vis_opt = pseudo_vis_opt + self.do_not_care_label = do_not_care_label + self.pseudo_charbox_builder = PseudoCharBoxBuilder( + watershed_param, vis_test_dir, pseudo_vis_opt, self.gaussian_builder + ) + self.vis_index = list(range(1000)) + self.img_dir = os.path.join(data_dir, "ch4_training_images") + self.img_gt_box_dir = os.path.join( + data_dir, "ch4_training_localization_transcription_gt" + ) + self.img_names = os.listdir(self.img_dir) + + def update_model(self, net): + self.net = net + + def update_device(self, gpu): + self.gpu = gpu + + def load_img_gt_box(self, img_gt_box_path): + lines = open(img_gt_box_path, encoding="utf-8").readlines() + word_bboxes = [] + words = [] + for line in lines: + box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",") + box_points = [int(box_info[i]) for i in range(8)] + box_points = np.array(box_points, np.float32).reshape(4, 2) + word = box_info[8:] + word = ",".join(word) + if word in self.do_not_care_label: + words.append(self.do_not_care_label[0]) + word_bboxes.append(box_points) + continue + word_bboxes.append(box_points) + words.append(word) + return np.array(word_bboxes), words + + def load_data(self, index): + img_name = self.img_names[index] + img_path = os.path.join(self.img_dir, img_name) + image = cv2.imread(img_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + img_gt_box_path = os.path.join( + self.img_gt_box_dir, "gt_%s.txt" % os.path.splitext(img_name)[0] + ) + word_bboxes, words = self.load_img_gt_box( + img_gt_box_path + ) # shape : (Number of word bbox, 4, 2) + confidence_mask = np.ones((image.shape[0], image.shape[1]), np.float32) + + word_level_char_bbox = [] + do_care_words = [] + horizontal_text_bools = [] + + if len(word_bboxes) == 0: + return ( + image, + word_level_char_bbox, + do_care_words, + confidence_mask, + horizontal_text_bools, + ) + _word_bboxes = word_bboxes.copy() + for i in range(len(word_bboxes)): + if words[i] in self.do_not_care_label: + cv2.fillPoly(confidence_mask, [np.int32(_word_bboxes[i])], 0) + continue + + ( + pseudo_char_bbox, + confidence, + horizontal_text_bool, + ) = self.pseudo_charbox_builder.build_char_box( + self.net, self.gpu, image, word_bboxes[i], words[i], img_name=img_name + ) + + cv2.fillPoly(confidence_mask, [np.int32(_word_bboxes[i])], confidence) + do_care_words.append(words[i]) + word_level_char_bbox.append(pseudo_char_bbox) + horizontal_text_bools.append(horizontal_text_bool) + + return ( + image, + word_level_char_bbox, + do_care_words, + confidence_mask, + horizontal_text_bools, + ) + + def make_gt_score(self, index): + """ + Make region, affinity scores using pseudo character-level GT bounding box + word_level_char_bbox's shape : [word_num, [char_num_in_one_word, 4, 2]] + :rtype region_score: np.float32 + :rtype affinity_score: np.float32 + :rtype confidence_mask: np.float32 + :rtype word_level_char_bbox: np.float32 + :rtype words: list + """ + ( + image, + word_level_char_bbox, + words, + confidence_mask, + horizontal_text_bools, + ) = self.load_data(index) + img_h, img_w, _ = image.shape + + if len(word_level_char_bbox) == 0: + region_score = np.zeros((img_h, img_w), dtype=np.float32) + affinity_score = np.zeros((img_h, img_w), dtype=np.float32) + all_affinity_bbox = [] + else: + region_score = self.gaussian_builder.generate_region( + img_h, img_w, word_level_char_bbox, horizontal_text_bools + ) + affinity_score, all_affinity_bbox = self.gaussian_builder.generate_affinity( + img_h, img_w, word_level_char_bbox, horizontal_text_bools + ) + + return ( + image, + region_score, + affinity_score, + confidence_mask, + word_level_char_bbox, + all_affinity_bbox, + words, + ) + + def load_saved_gt_score(self, index): + """ + Load pre-saved official CRAFT model's region, affinity scores to train + word_level_char_bbox's shape : [word_num, [char_num_in_one_word, 4, 2]] + :rtype region_score: np.float32 + :rtype affinity_score: np.float32 + :rtype confidence_mask: np.float32 + :rtype word_level_char_bbox: np.float32 + :rtype words: list + """ + img_name = self.img_names[index] + img_path = os.path.join(self.img_dir, img_name) + image = cv2.imread(img_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + img_gt_box_path = os.path.join( + self.img_gt_box_dir, "gt_%s.txt" % os.path.splitext(img_name)[0] + ) + word_bboxes, words = self.load_img_gt_box(img_gt_box_path) + image, word_bboxes = rescale(image, word_bboxes) + img_h, img_w, _ = image.shape + + query_idx = int(self.img_names[index].split(".")[0].split("_")[1]) + + saved_region_scores_path = os.path.join( + self.saved_gt_dir, f"res_img_{query_idx}_region.jpg" + ) + saved_affi_scores_path = os.path.join( + self.saved_gt_dir, f"res_img_{query_idx}_affi.jpg" + ) + saved_cf_mask_path = os.path.join( + self.saved_gt_dir, f"res_img_{query_idx}_cf_mask_thresh_0.6.jpg" + ) + region_score = cv2.imread(saved_region_scores_path, cv2.IMREAD_GRAYSCALE) + affinity_score = cv2.imread(saved_affi_scores_path, cv2.IMREAD_GRAYSCALE) + confidence_mask = cv2.imread(saved_cf_mask_path, cv2.IMREAD_GRAYSCALE) + + region_score = cv2.resize(region_score, (img_w, img_h)) + affinity_score = cv2.resize(affinity_score, (img_w, img_h)) + confidence_mask = cv2.resize( + confidence_mask, (img_w, img_h), interpolation=cv2.INTER_NEAREST + ) + + region_score = region_score.astype(np.float32) / 255 + affinity_score = affinity_score.astype(np.float32) / 255 + confidence_mask = confidence_mask.astype(np.float32) / 255 + + # NOTE : Even though word_level_char_bbox is not necessary, align bbox format with make_gt_score() + word_level_char_bbox = [] + + for i in range(len(word_bboxes)): + word_level_char_bbox.append(np.expand_dims(word_bboxes[i], 0)) + + return ( + image, + region_score, + affinity_score, + confidence_mask, + word_level_char_bbox, + words, + ) diff --git a/trainer/craft/data/gaussian.py b/trainer/craft/data/gaussian.py new file mode 100644 index 000000000..2d0b76e0a --- /dev/null +++ b/trainer/craft/data/gaussian.py @@ -0,0 +1,192 @@ +import numpy as np +import cv2 + +from data.boxEnlarge import enlargebox + + +class GaussianBuilder(object): + def __init__(self, init_size, sigma, enlarge_region, enlarge_affinity): + self.init_size = init_size + self.sigma = sigma + self.enlarge_region = enlarge_region + self.enlarge_affinity = enlarge_affinity + self.gaussian_map, self.gaussian_map_color = self.generate_gaussian_map() + + def generate_gaussian_map(self): + circle_mask = self.generate_circle_mask() + + gaussian_map = np.zeros((self.init_size, self.init_size), np.float32) + + for i in range(self.init_size): + for j in range(self.init_size): + gaussian_map[i, j] = ( + 1 + / 2 + / np.pi + / (self.sigma ** 2) + * np.exp( + -1 + / 2 + * ( + (i - self.init_size / 2) ** 2 / (self.sigma ** 2) + + (j - self.init_size / 2) ** 2 / (self.sigma ** 2) + ) + ) + ) + + gaussian_map = gaussian_map * circle_mask + gaussian_map = (gaussian_map / np.max(gaussian_map)).astype(np.float32) + + gaussian_map_color = (gaussian_map * 255).astype(np.uint8) + gaussian_map_color = cv2.applyColorMap(gaussian_map_color, cv2.COLORMAP_JET) + return gaussian_map, gaussian_map_color + + def generate_circle_mask(self): + + zero_arr = np.zeros((self.init_size, self.init_size), np.float32) + circle_mask = cv2.circle( + img=zero_arr, + center=(self.init_size // 2, self.init_size // 2), + radius=self.init_size // 2, + color=1, + thickness=-1, + ) + + return circle_mask + + def four_point_transform(self, bbox): + """ + Using the bbox, standard 2D gaussian map, returns Transformed 2d Gaussian map + """ + width, height = ( + np.max(bbox[:, 0]).astype(np.int32), + np.max(bbox[:, 1]).astype(np.int32), + ) + init_points = np.array( + [ + [0, 0], + [self.init_size, 0], + [self.init_size, self.init_size], + [0, self.init_size], + ], + dtype="float32", + ) + + M = cv2.getPerspectiveTransform(init_points, bbox) + warped_gaussian_map = cv2.warpPerspective(self.gaussian_map, M, (width, height)) + return warped_gaussian_map, width, height + + def add_gaussian_map_to_score_map( + self, score_map, bbox, enlarge_size, horizontal_text_bool, map_type=None + ): + """ + Mapping 2D Gaussian to the character box coordinates of the score_map. + + :param score_map: Target map to put 2D gaussian on character box + :type score_map: np.float32 + :param bbox: character boxes + :type bbox: np.float32 + :param enlarge_size: Enlarge size of gaussian map to fit character shape + :type enlarge_size: list of enlarge size [x dim, y dim] + :param horizontal_text_bool: Flag that bbox is horizontal text or not + :type horizontal_text_bool: bool + :param map_type: Whether map's type is "region" | "affinity" + :type map_type: str + :return score_map: score map that all 2D gaussian put on character box + :rtype: np.float32 + """ + + map_h, map_w = score_map.shape + bbox = enlargebox(bbox, map_h, map_w, enlarge_size, horizontal_text_bool) + + # If any one point of character bbox is out of range, don't put in on map + if np.any(bbox < 0) or np.any(bbox[:, 0] > map_w) or np.any(bbox[:, 1] > map_h): + return score_map + + bbox_left, bbox_top = np.array([np.min(bbox[:, 0]), np.min(bbox[:, 1])]).astype( + np.int32 + ) + bbox -= (bbox_left, bbox_top) + warped_gaussian_map, width, height = self.four_point_transform( + bbox.astype(np.float32) + ) + + try: + bbox_area_of_image = score_map[ + bbox_top : bbox_top + height, bbox_left : bbox_left + width, + ] + high_value_score = np.where( + warped_gaussian_map > bbox_area_of_image, + warped_gaussian_map, + bbox_area_of_image, + ) + score_map[ + bbox_top : bbox_top + height, bbox_left : bbox_left + width, + ] = high_value_score + + except Exception as e: + print("Error : {}".format(e)) + print( + "On generating {} map, strange box came out. (width: {}, height: {})".format( + map_type, width, height + ) + ) + + return score_map + + def calculate_affinity_box_points(self, bbox_1, bbox_2, vertical=False): + center_1, center_2 = np.mean(bbox_1, axis=0), np.mean(bbox_2, axis=0) + if vertical: + tl = (bbox_1[0] + bbox_1[-1] + center_1) / 3 + tr = (bbox_1[1:3].sum(0) + center_1) / 3 + br = (bbox_2[1:3].sum(0) + center_2) / 3 + bl = (bbox_2[0] + bbox_2[-1] + center_2) / 3 + else: + tl = (bbox_1[0:2].sum(0) + center_1) / 3 + tr = (bbox_2[0:2].sum(0) + center_2) / 3 + br = (bbox_2[2:4].sum(0) + center_2) / 3 + bl = (bbox_1[2:4].sum(0) + center_1) / 3 + affinity_box = np.array([tl, tr, br, bl]).astype(np.float32) + return affinity_box + + def generate_region( + self, img_h, img_w, word_level_char_bbox, horizontal_text_bools + ): + region_map = np.zeros([img_h, img_w], dtype=np.float32) + for i in range( + len(word_level_char_bbox) + ): # shape : [word_num, [char_num_in_one_word, 4, 2]] + for j in range(len(word_level_char_bbox[i])): + region_map = self.add_gaussian_map_to_score_map( + region_map, + word_level_char_bbox[i][j].copy(), + self.enlarge_region, + horizontal_text_bools[i], + map_type="region", + ) + return region_map + + def generate_affinity( + self, img_h, img_w, word_level_char_bbox, horizontal_text_bools + ): + + affinity_map = np.zeros([img_h, img_w], dtype=np.float32) + all_affinity_bbox = [] + for i in range(len(word_level_char_bbox)): + for j in range(len(word_level_char_bbox[i]) - 1): + affinity_bbox = self.calculate_affinity_box_points( + word_level_char_bbox[i][j], word_level_char_bbox[i][j + 1] + ) + + affinity_map = self.add_gaussian_map_to_score_map( + affinity_map, + affinity_bbox.copy(), + self.enlarge_affinity, + horizontal_text_bools[i], + map_type="affinity", + ) + all_affinity_bbox.append(np.expand_dims(affinity_bbox, axis=0)) + + if len(all_affinity_bbox) > 0: + all_affinity_bbox = np.concatenate(all_affinity_bbox, axis=0) + return affinity_map, all_affinity_bbox \ No newline at end of file diff --git a/trainer/craft/data/imgaug.py b/trainer/craft/data/imgaug.py new file mode 100644 index 000000000..d24a456d2 --- /dev/null +++ b/trainer/craft/data/imgaug.py @@ -0,0 +1,175 @@ +import random + +import cv2 +import numpy as np +from PIL import Image +from torchvision.transforms.functional import resized_crop, crop +from torchvision.transforms import RandomResizedCrop, RandomCrop +from torchvision.transforms import InterpolationMode + + +def rescale(img, bboxes, target_size=2240): + h, w = img.shape[0:2] + scale = target_size / max(h, w) + img = cv2.resize(img, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) + bboxes = bboxes * scale + return img, bboxes + + +def random_resize_crop_synth(augment_targets, size): + image, region_score, affinity_score, confidence_mask = augment_targets + + image = Image.fromarray(image) + region_score = Image.fromarray(region_score) + affinity_score = Image.fromarray(affinity_score) + confidence_mask = Image.fromarray(confidence_mask) + + short_side = min(image.size) + i, j, h, w = RandomCrop.get_params(image, output_size=(short_side, short_side)) + + image = resized_crop( + image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC + ) + region_score = resized_crop( + region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC + ) + affinity_score = resized_crop( + affinity_score, + i, + j, + h, + w, + (size, size), + interpolation=InterpolationMode.BICUBIC, + ) + confidence_mask = resized_crop( + confidence_mask, + i, + j, + h, + w, + (size, size), + interpolation=InterpolationMode.NEAREST, + ) + + image = np.array(image) + region_score = np.array(region_score) + affinity_score = np.array(affinity_score) + confidence_mask = np.array(confidence_mask) + augment_targets = [image, region_score, affinity_score, confidence_mask] + + return augment_targets + + +def random_resize_crop( + augment_targets, scale, ratio, size, threshold, pre_crop_area=None +): + image, region_score, affinity_score, confidence_mask = augment_targets + + image = Image.fromarray(image) + region_score = Image.fromarray(region_score) + affinity_score = Image.fromarray(affinity_score) + confidence_mask = Image.fromarray(confidence_mask) + + if pre_crop_area != None: + i, j, h, w = pre_crop_area + + else: + if random.random() < threshold: + i, j, h, w = RandomResizedCrop.get_params(image, scale=scale, ratio=ratio) + else: + i, j, h, w = RandomResizedCrop.get_params( + image, scale=(1.0, 1.0), ratio=(1.0, 1.0) + ) + + image = resized_crop( + image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC + ) + region_score = resized_crop( + region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC + ) + affinity_score = resized_crop( + affinity_score, + i, + j, + h, + w, + (size, size), + interpolation=InterpolationMode.BICUBIC, + ) + confidence_mask = resized_crop( + confidence_mask, + i, + j, + h, + w, + (size, size), + interpolation=InterpolationMode.NEAREST, + ) + + image = np.array(image) + region_score = np.array(region_score) + affinity_score = np.array(affinity_score) + confidence_mask = np.array(confidence_mask) + augment_targets = [image, region_score, affinity_score, confidence_mask] + + return augment_targets + + +def random_crop(augment_targets, size): + image, region_score, affinity_score, confidence_mask = augment_targets + + image = Image.fromarray(image) + region_score = Image.fromarray(region_score) + affinity_score = Image.fromarray(affinity_score) + confidence_mask = Image.fromarray(confidence_mask) + + i, j, h, w = RandomCrop.get_params(image, output_size=(size, size)) + + image = crop(image, i, j, h, w) + region_score = crop(region_score, i, j, h, w) + affinity_score = crop(affinity_score, i, j, h, w) + confidence_mask = crop(confidence_mask, i, j, h, w) + + image = np.array(image) + region_score = np.array(region_score) + affinity_score = np.array(affinity_score) + confidence_mask = np.array(confidence_mask) + augment_targets = [image, region_score, affinity_score, confidence_mask] + + return augment_targets + + +def random_horizontal_flip(imgs): + if random.random() < 0.5: + for i in range(len(imgs)): + imgs[i] = np.flip(imgs[i], axis=1).copy() + return imgs + + +def random_scale(images, word_level_char_bbox, scale_range): + scale = random.sample(scale_range, 1)[0] + + for i in range(len(images)): + images[i] = cv2.resize(images[i], dsize=None, fx=scale, fy=scale) + + for i in range(len(word_level_char_bbox)): + word_level_char_bbox[i] *= scale + + return images + + +def random_rotate(images, max_angle): + angle = random.random() * 2 * max_angle - max_angle + for i in range(len(images)): + img = images[i] + w, h = img.shape[:2] + rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) + if i == len(images) - 1: + img_rotation = cv2.warpAffine( + img, M=rotation_matrix, dsize=(h, w), flags=cv2.INTER_NEAREST + ) + else: + img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w)) + images[i] = img_rotation + return images diff --git a/trainer/craft/data/imgproc.py b/trainer/craft/data/imgproc.py new file mode 100644 index 000000000..9c302af06 --- /dev/null +++ b/trainer/craft/data/imgproc.py @@ -0,0 +1,91 @@ +""" +Copyright (c) 2019-present NAVER Corp. +MIT License +""" + +# -*- coding: utf-8 -*- +import numpy as np + +import cv2 +from skimage import io + + +def loadImage(img_file): + img = io.imread(img_file) # RGB order + if img.shape[0] == 2: + img = img[0] + if len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + if img.shape[2] == 4: + img = img[:, :, :3] + img = np.array(img) + + return img + + +def normalizeMeanVariance( + in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225) +): + # should be RGB order + img = in_img.copy().astype(np.float32) + + img -= np.array( + [mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32 + ) + img /= np.array( + [variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], + dtype=np.float32, + ) + return img + + +def denormalizeMeanVariance( + in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225) +): + # should be RGB order + img = in_img.copy() + img *= variance + img += mean + img *= 255.0 + img = np.clip(img, 0, 255).astype(np.uint8) + return img + + +def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1): + height, width, channel = img.shape + + # magnify image size + target_size = mag_ratio * max(height, width) + + # set original image size + if target_size > square_size: + target_size = square_size + + ratio = target_size / max(height, width) + + target_h, target_w = int(height * ratio), int(width * ratio) + + # NOTE + valid_size_heatmap = (int(target_h / 2), int(target_w / 2)) + + proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation) + + # make canvas and paste image + target_h32, target_w32 = target_h, target_w + if target_h % 32 != 0: + target_h32 = target_h + (32 - target_h % 32) + if target_w % 32 != 0: + target_w32 = target_w + (32 - target_w % 32) + resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32) + resized[0:target_h, 0:target_w, :] = proc + + # target_h, target_w = target_h32, target_w32 + # size_heatmap = (int(target_w/2), int(target_h/2)) + + return resized, ratio, valid_size_heatmap + + +def cvt2HeatmapImg(img): + img = (np.clip(img, 0, 1) * 255).astype(np.uint8) + img = cv2.applyColorMap(img, cv2.COLORMAP_JET) + return img diff --git a/trainer/craft/data/pseudo_label/make_charbox.py b/trainer/craft/data/pseudo_label/make_charbox.py new file mode 100644 index 000000000..09c5219b3 --- /dev/null +++ b/trainer/craft/data/pseudo_label/make_charbox.py @@ -0,0 +1,263 @@ +import os +import random +import math + +import numpy as np +import cv2 +import torch + +from data import imgproc +from data.pseudo_label.watershed import exec_watershed_by_version + + +class PseudoCharBoxBuilder: + def __init__(self, watershed_param, vis_test_dir, pseudo_vis_opt, gaussian_builder): + self.watershed_param = watershed_param + self.vis_test_dir = vis_test_dir + self.pseudo_vis_opt = pseudo_vis_opt + self.gaussian_builder = gaussian_builder + self.cnt = 0 + self.flag = False + + def crop_image_by_bbox(self, image, box, word): + w = max( + int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[2] - box[3])) + ) + h = max( + int(np.linalg.norm(box[0] - box[3])), int(np.linalg.norm(box[1] - box[2])) + ) + try: + word_ratio = h / w + except: + import ipdb + + ipdb.set_trace() + + one_char_ratio = min(h, w) / (max(h, w) / len(word)) + + # NOTE: criterion to split vertical word in here is set to work properly on IC15 dataset + if word_ratio > 2 or (word_ratio > 1.6 and one_char_ratio > 2.4): + # warping method of vertical word (classified by upper condition) + horizontal_text_bool = False + long_side = h + short_side = w + M = cv2.getPerspectiveTransform( + np.float32(box), + np.float32( + np.array( + [ + [long_side, 0], + [long_side, short_side], + [0, short_side], + [0, 0], + ] + ) + ), + ) + self.flag = True + else: + # warping method of horizontal word + horizontal_text_bool = True + long_side = w + short_side = h + M = cv2.getPerspectiveTransform( + np.float32(box), + np.float32( + np.array( + [ + [0, 0], + [long_side, 0], + [long_side, short_side], + [0, short_side], + ] + ) + ), + ) + self.flag = False + + warped = cv2.warpPerspective(image, M, (long_side, short_side)) + return warped, M, horizontal_text_bool + + def inference_word_box(self, net, gpu, word_image): + if net.training: + net.eval() + + with torch.no_grad(): + word_img_torch = torch.from_numpy( + imgproc.normalizeMeanVariance( + word_image, + mean=(0.485, 0.456, 0.406), + variance=(0.229, 0.224, 0.225), + ) + ) + word_img_torch = word_img_torch.permute(2, 0, 1).unsqueeze(0) + word_img_torch = word_img_torch.type(torch.FloatTensor).cuda(gpu) + with torch.cuda.amp.autocast(): + word_img_scores, _ = net(word_img_torch) + return word_img_scores + + def visualize_pseudo_label( + self, word_image, region_score, watershed_box, pseudo_char_bbox, img_name, + ): + word_img_h, word_img_w, _ = word_image.shape + word_img_cp1 = word_image.copy() + word_img_cp2 = word_image.copy() + _watershed_box = np.int32(watershed_box) + _pseudo_char_bbox = np.int32(pseudo_char_bbox) + + region_score_color = cv2.applyColorMap(np.uint8(region_score), cv2.COLORMAP_JET) + region_score_color = cv2.resize(region_score_color, (word_img_w, word_img_h)) + + for box in _watershed_box: + cv2.polylines( + np.uint8(word_img_cp1), + [np.reshape(box, (-1, 1, 2))], + True, + (255, 0, 0), + ) + + for box in _pseudo_char_bbox: + cv2.polylines( + np.uint8(word_img_cp2), [np.reshape(box, (-1, 1, 2))], True, (255, 0, 0) + ) + + # NOTE: Just for visualize, put gaussian map on char box + pseudo_gt_region_score = self.gaussian_builder.generate_region( + word_img_h, word_img_w, [_pseudo_char_bbox], [True] + ) + + pseudo_gt_region_score = cv2.applyColorMap( + (pseudo_gt_region_score * 255).astype("uint8"), cv2.COLORMAP_JET + ) + + overlay_img = cv2.addWeighted( + word_image[:, :, ::-1], 0.7, pseudo_gt_region_score, 0.3, 5 + ) + vis_result = np.hstack( + [ + word_image[:, :, ::-1], + region_score_color, + word_img_cp1[:, :, ::-1], + word_img_cp2[:, :, ::-1], + pseudo_gt_region_score, + overlay_img, + ] + ) + + if not os.path.exists(os.path.dirname(self.vis_test_dir)): + os.makedirs(os.path.dirname(self.vis_test_dir)) + cv2.imwrite( + os.path.join( + self.vis_test_dir, + "{}_{}".format( + img_name, f"pseudo_char_bbox_{random.randint(0,100)}.jpg" + ), + ), + vis_result, + ) + + def clip_into_boundary(self, box, bound): + if len(box) == 0: + return box + else: + box[:, :, 0] = np.clip(box[:, :, 0], 0, bound[1]) + box[:, :, 1] = np.clip(box[:, :, 1], 0, bound[0]) + return box + + def get_confidence(self, real_len, pseudo_len): + if pseudo_len == 0: + return 0.0 + return (real_len - min(real_len, abs(real_len - pseudo_len))) / real_len + + def split_word_equal_gap(self, word_img_w, word_img_h, word): + width = word_img_w + height = word_img_h + + width_per_char = width / len(word) + bboxes = [] + for j, char in enumerate(word): + if char == " ": + continue + left = j * width_per_char + right = (j + 1) * width_per_char + bbox = np.array([[left, 0], [right, 0], [right, height], [left, height]]) + bboxes.append(bbox) + + bboxes = np.array(bboxes, np.float32) + return bboxes + + def cal_angle(self, v1): + theta = np.arccos(min(1, v1[0] / (np.linalg.norm(v1) + 10e-8))) + return 2 * math.pi - theta if v1[1] < 0 else theta + + def clockwise_sort(self, points): + # returns 4x2 [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] ndarray + v1, v2, v3, v4 = points + center = (v1 + v2 + v3 + v4) / 4 + theta = np.array( + [ + self.cal_angle(v1 - center), + self.cal_angle(v2 - center), + self.cal_angle(v3 - center), + self.cal_angle(v4 - center), + ] + ) + index = np.argsort(theta) + return np.array([v1, v2, v3, v4])[index, :] + + def build_char_box(self, net, gpu, image, word_bbox, word, img_name=""): + word_image, M, horizontal_text_bool = self.crop_image_by_bbox( + image, word_bbox, word + ) + real_word_without_space = word.replace("\s", "") + real_char_len = len(real_word_without_space) + + scale = 128.0 / word_image.shape[0] + + word_image = cv2.resize(word_image, None, fx=scale, fy=scale) + word_img_h, word_img_w, _ = word_image.shape + + scores = self.inference_word_box(net, gpu, word_image) + region_score = scores[0, :, :, 0].cpu().data.numpy() + region_score = np.uint8(np.clip(region_score, 0, 1) * 255) + + region_score_rgb = cv2.resize(region_score, (word_img_w, word_img_h)) + region_score_rgb = cv2.cvtColor(region_score_rgb, cv2.COLOR_GRAY2RGB) + + pseudo_char_bbox = exec_watershed_by_version( + self.watershed_param, region_score, word_image, self.pseudo_vis_opt + ) + + # Used for visualize only + watershed_box = pseudo_char_bbox.copy() + + pseudo_char_bbox = self.clip_into_boundary( + pseudo_char_bbox, region_score_rgb.shape + ) + + confidence = self.get_confidence(real_char_len, len(pseudo_char_bbox)) + + if confidence <= 0.5: + pseudo_char_bbox = self.split_word_equal_gap(word_img_w, word_img_h, word) + confidence = 0.5 + + if self.pseudo_vis_opt and self.flag: + self.visualize_pseudo_label( + word_image, region_score, watershed_box, pseudo_char_bbox, img_name, + ) + + if len(pseudo_char_bbox) != 0: + index = np.argsort(pseudo_char_bbox[:, 0, 0]) + pseudo_char_bbox = pseudo_char_bbox[index] + + pseudo_char_bbox /= scale + + M_inv = np.linalg.pinv(M) + for i in range(len(pseudo_char_bbox)): + pseudo_char_bbox[i] = cv2.perspectiveTransform( + pseudo_char_bbox[i][None, :, :], M_inv + ) + + pseudo_char_bbox = self.clip_into_boundary(pseudo_char_bbox, image.shape) + + return pseudo_char_bbox, confidence, horizontal_text_bool diff --git a/trainer/craft/data/pseudo_label/watershed.py b/trainer/craft/data/pseudo_label/watershed.py new file mode 100644 index 000000000..72cadc089 --- /dev/null +++ b/trainer/craft/data/pseudo_label/watershed.py @@ -0,0 +1,45 @@ +import cv2 +import numpy as np +from skimage.segmentation import watershed + + +def segment_region_score(watershed_param, region_score, word_image, pseudo_vis_opt): + region_score = np.float32(region_score) / 255 + fore = np.uint8(region_score > 0.75) + back = np.uint8(region_score < 0.05) + unknown = 1 - (fore + back) + ret, markers = cv2.connectedComponents(fore) + markers += 1 + markers[unknown == 1] = 0 + + labels = watershed(-region_score, markers) + boxes = [] + for label in range(2, ret + 1): + y, x = np.where(labels == label) + x_max = x.max() + y_max = y.max() + x_min = x.min() + y_min = y.min() + box = [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]] + box = np.array(box) + box *= 2 + boxes.append(box) + return np.array(boxes, dtype=np.float32) + + +def exec_watershed_by_version( + watershed_param, region_score, word_image, pseudo_vis_opt +): + + func_name_map_dict = { + "skimage": segment_region_score, + } + + try: + return func_name_map_dict[watershed_param.version]( + watershed_param, region_score, word_image, pseudo_vis_opt + ) + except: + print( + f"Watershed version {watershed_param.version} does not exist in func_name_map_dict." + ) diff --git a/trainer/craft/data_root_dir/folder.txt b/trainer/craft/data_root_dir/folder.txt new file mode 100644 index 000000000..8e455ad77 --- /dev/null +++ b/trainer/craft/data_root_dir/folder.txt @@ -0,0 +1 @@ +place dataset folder here diff --git a/trainer/craft/eval.py b/trainer/craft/eval.py new file mode 100644 index 000000000..fceea4735 --- /dev/null +++ b/trainer/craft/eval.py @@ -0,0 +1,381 @@ +# -*- coding: utf-8 -*- + +import argparse +import os + +import cv2 +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from tqdm import tqdm +import wandb + +from config.load_config import load_yaml, DotDict +from model.craft import CRAFT +from metrics.eval_det_iou import DetectionIoUEvaluator +from utils.inference_boxes import ( + test_net, + load_icdar2015_gt, + load_icdar2013_gt, + load_synthtext_gt, +) +from utils.util import copyStateDict + + + +def save_result_synth(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""): + + img = np.array(img) + img_copy = img.copy() + region = pre_output[0] + affinity = pre_output[1] + + # make result file list + filename, file_ext = os.path.splitext(os.path.basename(img_file)) + + # draw bounding boxes for prediction, color green + for i, box in enumerate(pre_box): + poly = np.array(box).astype(np.int32).reshape((-1)) + poly = poly.reshape(-1, 2) + try: + cv2.polylines( + img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2 + ) + except: + pass + + # draw bounding boxes for gt, color red + if gt_box is not None: + for j in range(len(gt_box)): + cv2.polylines( + img, + [np.array(gt_box[j]["points"]).astype(np.int32).reshape((-1, 1, 2))], + True, + color=(0, 0, 255), + thickness=2, + ) + + # draw overlay image + overlay_img = overlay(img_copy, region, affinity, pre_box) + + # Save result image + res_img_path = result_dir + "/res_" + filename + ".jpg" + cv2.imwrite(res_img_path, img) + + overlay_image_path = result_dir + "/res_" + filename + "_box.jpg" + cv2.imwrite(overlay_image_path, overlay_img) + + +def save_result_2015(img_file, img, pre_output, pre_box, gt_box, result_dir): + + img = np.array(img) + img_copy = img.copy() + region = pre_output[0] + affinity = pre_output[1] + + # make result file list + filename, file_ext = os.path.splitext(os.path.basename(img_file)) + + for i, box in enumerate(pre_box): + poly = np.array(box).astype(np.int32).reshape((-1)) + poly = poly.reshape(-1, 2) + try: + cv2.polylines( + img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2 + ) + except: + pass + + if gt_box is not None: + for j in range(len(gt_box)): + _gt_box = np.array(gt_box[j]["points"]).reshape(-1, 2).astype(np.int32) + if gt_box[j]["text"] == "###": + cv2.polylines(img, [_gt_box], True, color=(128, 128, 128), thickness=2) + else: + cv2.polylines(img, [_gt_box], True, color=(0, 0, 255), thickness=2) + + # draw overlay image + overlay_img = overlay(img_copy, region, affinity, pre_box) + + # Save result image + res_img_path = result_dir + "/res_" + filename + ".jpg" + cv2.imwrite(res_img_path, img) + + overlay_image_path = result_dir + "/res_" + filename + "_box.jpg" + cv2.imwrite(overlay_image_path, overlay_img) + + +def save_result_2013(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""): + + img = np.array(img) + img_copy = img.copy() + region = pre_output[0] + affinity = pre_output[1] + + # make result file list + filename, file_ext = os.path.splitext(os.path.basename(img_file)) + + # draw bounding boxes for prediction, color green + for i, box in enumerate(pre_box): + poly = np.array(box).astype(np.int32).reshape((-1)) + poly = poly.reshape(-1, 2) + try: + cv2.polylines( + img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2 + ) + except: + pass + + # draw bounding boxes for gt, color red + if gt_box is not None: + for j in range(len(gt_box)): + cv2.polylines( + img, + [np.array(gt_box[j]["points"]).reshape((-1, 1, 2))], + True, + color=(0, 0, 255), + thickness=2, + ) + + # draw overlay image + overlay_img = overlay(img_copy, region, affinity, pre_box) + + # Save result image + res_img_path = result_dir + "/res_" + filename + ".jpg" + cv2.imwrite(res_img_path, img) + + overlay_image_path = result_dir + "/res_" + filename + "_box.jpg" + cv2.imwrite(overlay_image_path, overlay_img) + + +def overlay(image, region, affinity, single_img_bbox): + + height, width, channel = image.shape + + region_score = cv2.resize(region, (width, height)) + affinity_score = cv2.resize(affinity, (width, height)) + + overlay_region = cv2.addWeighted(image.copy(), 0.4, region_score, 0.6, 5) + overlay_aff = cv2.addWeighted(image.copy(), 0.4, affinity_score, 0.6, 5) + + boxed_img = image.copy() + for word_box in single_img_bbox: + cv2.polylines( + boxed_img, + [word_box.astype(np.int32).reshape((-1, 1, 2))], + True, + color=(0, 255, 0), + thickness=3, + ) + + temp1 = np.hstack([image, boxed_img]) + temp2 = np.hstack([overlay_region, overlay_aff]) + temp3 = np.vstack([temp1, temp2]) + + return temp3 + + +def load_test_dataset_iou(test_folder_name, config): + + if test_folder_name == "synthtext": + total_bboxes_gt, total_img_path = load_synthtext_gt(config.test_data_dir) + + elif test_folder_name == "icdar2013": + total_bboxes_gt, total_img_path = load_icdar2013_gt( + dataFolder=config.test_data_dir + ) + + elif test_folder_name == "icdar2015": + total_bboxes_gt, total_img_path = load_icdar2015_gt( + dataFolder=config.test_data_dir + ) + + elif test_folder_name == "custom_data": + total_bboxes_gt, total_img_path = load_icdar2015_gt( + dataFolder=config.test_data_dir + ) + + else: + print("not found test dataset") + return None, None + + return total_bboxes_gt, total_img_path + + +def viz_test(img, pre_output, pre_box, gt_box, img_name, result_dir, test_folder_name): + + if test_folder_name == "synthtext": + save_result_synth( + img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir + ) + elif test_folder_name == "icdar2013": + save_result_2013( + img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir + ) + elif test_folder_name == "icdar2015": + save_result_2015( + img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir + ) + elif test_folder_name == "custom_data": + save_result_2015( + img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir + ) + else: + print("not found test dataset") + + +def main_eval(model_path, backbone, config, evaluator, result_dir, buffer, model, mode): + + if not os.path.exists(result_dir): + os.makedirs(result_dir, exist_ok=True) + + total_imgs_bboxes_gt, total_imgs_path = load_test_dataset_iou("custom_data", config) + + if mode == "weak_supervision" and torch.cuda.device_count() != 1: + gpu_count = torch.cuda.device_count() // 2 + else: + gpu_count = torch.cuda.device_count() + gpu_idx = torch.cuda.current_device() + torch.cuda.set_device(gpu_idx) + + # Only evaluation time + if model is None: + piece_imgs_path = total_imgs_path + + if backbone == "vgg": + model = CRAFT() + else: + raise Exception("Undefined architecture") + + print("Loading weights from checkpoint (" + model_path + ")") + net_param = torch.load(model_path, map_location=f"cuda:{gpu_idx}") + model.load_state_dict(copyStateDict(net_param["craft"])) + + if config.cuda: + model = model.cuda() + cudnn.benchmark = False + + # Distributed evaluation in the middle of training time + else: + if buffer is not None: + # check all buffer value is None for distributed evaluation + assert all( + v is None for v in buffer + ), "Buffer already filled with another value." + slice_idx = len(total_imgs_bboxes_gt) // gpu_count + + # last gpu + if gpu_idx == gpu_count - 1: + piece_imgs_path = total_imgs_path[gpu_idx * slice_idx :] + # piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx:] + else: + piece_imgs_path = total_imgs_path[ + gpu_idx * slice_idx : (gpu_idx + 1) * slice_idx + ] + # piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx: (gpu_idx + 1) * slice_idx] + + model.eval() + + # -----------------------------------------------------------------------------------------------------------------# + total_imgs_bboxes_pre = [] + for k, img_path in enumerate(tqdm(piece_imgs_path)): + image = cv2.imread(img_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + single_img_bbox = [] + bboxes, polys, score_text = test_net( + model, + image, + config.text_threshold, + config.link_threshold, + config.low_text, + config.cuda, + config.poly, + config.canvas_size, + config.mag_ratio, + ) + + for box in bboxes: + box_info = {"points": box, "text": "###", "ignore": False} + single_img_bbox.append(box_info) + total_imgs_bboxes_pre.append(single_img_bbox) + # Distributed evaluation -------------------------------------------------------------------------------------# + if buffer is not None: + buffer[gpu_idx * slice_idx + k] = single_img_bbox + # print(sum([element is not None for element in buffer])) + # -------------------------------------------------------------------------------------------------------------# + + if config.vis_opt: + viz_test( + image, + score_text, + pre_box=polys, + gt_box=total_imgs_bboxes_gt[k], + img_name=img_path, + result_dir=result_dir, + test_folder_name="custom_data", + ) + + # When distributed evaluation mode, wait until buffer is full filled + if buffer is not None: + while None in buffer: + continue + assert all(v is not None for v in buffer), "Buffer not filled" + total_imgs_bboxes_pre = buffer + + results = [] + for i, (gt, pred) in enumerate(zip(total_imgs_bboxes_gt, total_imgs_bboxes_pre)): + perSampleMetrics_dict = evaluator.evaluate_image(gt, pred) + results.append(perSampleMetrics_dict) + + metrics = evaluator.combine_results(results) + print(metrics) + return metrics + +def cal_eval(config, data, res_dir_name, opt, mode): + evaluator = DetectionIoUEvaluator() + test_config = DotDict(config.test[data]) + res_dir = os.path.join(os.path.join("exp", args.yaml), "{}".format(res_dir_name)) + + if opt == "iou_eval": + main_eval( + config.test.trained_model, + config.train.backbone, + test_config, + evaluator, + res_dir, + buffer=None, + model=None, + mode=mode, + ) + else: + print("Undefined evaluation") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="CRAFT Text Detection Eval") + parser.add_argument( + "--yaml", + "--yaml_file_name", + default="custom_data_train", + type=str, + help="Load configuration", + ) + args = parser.parse_args() + + # load configure + config = load_yaml(args.yaml) + config = DotDict(config) + + if config["wandb_opt"]: + wandb.init(project="evaluation", entity="gmuffiness", name=args.yaml) + wandb.config.update(config) + + val_result_dir_name = args.yaml + cal_eval( + config, + "custom_data", + val_result_dir_name + "-ic15-iou", + opt="iou_eval", + mode=None, + ) diff --git a/trainer/craft/exp/folder.txt b/trainer/craft/exp/folder.txt new file mode 100644 index 000000000..b557bc215 --- /dev/null +++ b/trainer/craft/exp/folder.txt @@ -0,0 +1 @@ +trained model will be saved here diff --git a/trainer/craft/loss/mseloss.py b/trainer/craft/loss/mseloss.py new file mode 100644 index 000000000..dc24d5ab4 --- /dev/null +++ b/trainer/craft/loss/mseloss.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn + + +class Loss(nn.Module): + def __init__(self): + super(Loss, self).__init__() + + def forward(self, gt_region, gt_affinity, pred_region, pred_affinity, conf_map): + loss = torch.mean( + ((gt_region - pred_region).pow(2) + (gt_affinity - pred_affinity).pow(2)) + * conf_map + ) + return loss + + +class Maploss_v2(nn.Module): + def __init__(self): + + super(Maploss_v2, self).__init__() + + def batch_image_loss(self, pred_score, label_score, neg_rto, n_min_neg): + + # positive_loss + positive_pixel = (label_score > 0.1).float() + positive_pixel_number = torch.sum(positive_pixel) + + positive_loss_region = pred_score * positive_pixel + + # negative_loss + negative_pixel = (label_score <= 0.1).float() + negative_pixel_number = torch.sum(negative_pixel) + negative_loss_region = pred_score * negative_pixel + + if positive_pixel_number != 0: + if negative_pixel_number < neg_rto * positive_pixel_number: + negative_loss = ( + torch.sum( + torch.topk( + negative_loss_region.view(-1), n_min_neg, sorted=False + )[0] + ) + / n_min_neg + ) + else: + negative_loss = torch.sum( + torch.topk( + negative_loss_region.view(-1), + int(neg_rto * positive_pixel_number), + sorted=False, + )[0] + ) / (positive_pixel_number * neg_rto) + positive_loss = torch.sum(positive_loss_region) / positive_pixel_number + else: + # only negative pixel + negative_loss = ( + torch.sum( + torch.topk(negative_loss_region.view(-1), n_min_neg, sorted=False)[ + 0 + ] + ) + / n_min_neg + ) + positive_loss = 0.0 + total_loss = positive_loss + negative_loss + return total_loss + + def forward( + self, + region_scores_label, + affinity_socres_label, + region_scores_pre, + affinity_scores_pre, + mask, + neg_rto, + n_min_neg, + ): + loss_fn = torch.nn.MSELoss(reduce=False, size_average=False) + assert ( + region_scores_label.size() == region_scores_pre.size() + and affinity_socres_label.size() == affinity_scores_pre.size() + ) + loss1 = loss_fn(region_scores_pre, region_scores_label) + loss2 = loss_fn(affinity_scores_pre, affinity_socres_label) + + loss_region = torch.mul(loss1, mask) + loss_affinity = torch.mul(loss2, mask) + + char_loss = self.batch_image_loss( + loss_region, region_scores_label, neg_rto, n_min_neg + ) + affi_loss = self.batch_image_loss( + loss_affinity, affinity_socres_label, neg_rto, n_min_neg + ) + return char_loss + affi_loss + + +class Maploss_v3(nn.Module): + def __init__(self): + + super(Maploss_v3, self).__init__() + + def single_image_loss(self, pre_loss, loss_label, neg_rto, n_min_neg): + + batch_size = pre_loss.shape[0] + + positive_loss, negative_loss = 0, 0 + for single_loss, single_label in zip(pre_loss, loss_label): + + # positive_loss + pos_pixel = (single_label >= 0.1).float() + n_pos_pixel = torch.sum(pos_pixel) + pos_loss_region = single_loss * pos_pixel + positive_loss += torch.sum(pos_loss_region) / max(n_pos_pixel, 1e-12) + + # negative_loss + neg_pixel = (single_label < 0.1).float() + n_neg_pixel = torch.sum(neg_pixel) + neg_loss_region = single_loss * neg_pixel + + if n_pos_pixel != 0: + if n_neg_pixel < neg_rto * n_pos_pixel: + negative_loss += torch.sum(neg_loss_region) / n_neg_pixel + else: + n_hard_neg = max(n_min_neg, neg_rto * n_pos_pixel) + # n_hard_neg = neg_rto*n_pos_pixel + negative_loss += ( + torch.sum( + torch.topk(neg_loss_region.view(-1), int(n_hard_neg))[0] + ) + / n_hard_neg + ) + else: + # only negative pixel + negative_loss += ( + torch.sum(torch.topk(neg_loss_region.view(-1), n_min_neg)[0]) + / n_min_neg + ) + + total_loss = (positive_loss + negative_loss) / batch_size + + return total_loss + + def forward( + self, + region_scores_label, + affinity_scores_label, + region_scores_pre, + affinity_scores_pre, + mask, + neg_rto, + n_min_neg, + ): + loss_fn = torch.nn.MSELoss(reduce=False, size_average=False) + + assert ( + region_scores_label.size() == region_scores_pre.size() + and affinity_scores_label.size() == affinity_scores_pre.size() + ) + loss1 = loss_fn(region_scores_pre, region_scores_label) + loss2 = loss_fn(affinity_scores_pre, affinity_scores_label) + + loss_region = torch.mul(loss1, mask) + loss_affinity = torch.mul(loss2, mask) + char_loss = self.single_image_loss( + loss_region, region_scores_label, neg_rto, n_min_neg + ) + affi_loss = self.single_image_loss( + loss_affinity, affinity_scores_label, neg_rto, n_min_neg + ) + + return char_loss + affi_loss diff --git a/trainer/craft/metrics/eval_det_iou.py b/trainer/craft/metrics/eval_det_iou.py new file mode 100644 index 000000000..f89004fd8 --- /dev/null +++ b/trainer/craft/metrics/eval_det_iou.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from collections import namedtuple +import numpy as np +from shapely.geometry import Polygon +""" +cite from: +PaddleOCR, github: https://github.com/PaddlePaddle/PaddleOCR +PaddleOCR reference from : +https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8 +""" + + +class DetectionIoUEvaluator(object): + def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5): + self.iou_constraint = iou_constraint + self.area_precision_constraint = area_precision_constraint + + def evaluate_image(self, gt, pred): + def get_union(pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def compute_ap(confList, matchList, numGtCare): + correct = 0 + AP = 0 + if len(confList) > 0: + confList = np.array(confList) + matchList = np.array(matchList) + sorted_ind = np.argsort(-confList) + confList = confList[sorted_ind] + matchList = matchList[sorted_ind] + for n in range(len(confList)): + match = matchList[n] + if match: + correct += 1 + AP += float(correct) / (n + 1) + + if numGtCare > 0: + AP /= numGtCare + + return AP + + perSampleMetrics = {} + + matchedSum = 0 + + Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') + + numGlobalCareGt = 0 + numGlobalCareDet = 0 + + arrGlobalConfidences = [] + arrGlobalMatches = [] + + recall = 0 + precision = 0 + hmean = 0 + + detMatched = 0 + + iouMat = np.empty([1, 1]) + + gtPols = [] + detPols = [] + + gtPolPoints = [] + detPolPoints = [] + + # Array of Ground Truth Polygons' keys marked as don't Care + gtDontCarePolsNum = [] + # Array of Detected Polygons' matched with a don't Care GT + detDontCarePolsNum = [] + + pairs = [] + detMatchedNums = [] + + arrSampleConfidences = [] + arrSampleMatch = [] + + evaluationLog = "" + + # print(len(gt)) + + for n in range(len(gt)): + points = gt[n]['points'] + # transcription = gt[n]['text'] + dontCare = gt[n]['ignore'] + # points = Polygon(points) + # points = points.buffer(0) + try: + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + except: + import ipdb; + ipdb.set_trace() + + #import ipdb;ipdb.set_trace() + gtPol = points + gtPols.append(gtPol) + gtPolPoints.append(points) + if dontCare: + gtDontCarePolsNum.append(len(gtPols) - 1) + + evaluationLog += "GT polygons: " + str(len(gtPols)) + ( + " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" + if len(gtDontCarePolsNum) > 0 else "\n") + + for n in range(len(pred)): + points = pred[n]['points'] + # points = Polygon(points) + # points = points.buffer(0) + if not Polygon(points).is_valid or not Polygon(points).is_simple: + continue + + detPol = points + detPols.append(detPol) + detPolPoints.append(points) + if len(gtDontCarePolsNum) > 0: + for dontCarePol in gtDontCarePolsNum: + dontCarePol = gtPols[dontCarePol] + intersected_area = get_intersection(dontCarePol, detPol) + pdDimensions = Polygon(detPol).area + precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions + if (precision > self.area_precision_constraint): + detDontCarePolsNum.append(len(detPols) - 1) + break + + evaluationLog += "DET polygons: " + str(len(detPols)) + ( + " (" + str(len(detDontCarePolsNum)) + " don't care)\n" + if len(detDontCarePolsNum) > 0 else "\n") + + if len(gtPols) > 0 and len(detPols) > 0: + # Calculate IoU and precision matrixs + outputShape = [len(gtPols), len(detPols)] + iouMat = np.empty(outputShape) + gtRectMat = np.zeros(len(gtPols), np.int8) + detRectMat = np.zeros(len(detPols), np.int8) + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + pG = gtPols[gtNum] + pD = detPols[detNum] + iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) + + for gtNum in range(len(gtPols)): + for detNum in range(len(detPols)): + if gtRectMat[gtNum] == 0 and detRectMat[ + detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: + if iouMat[gtNum, detNum] > self.iou_constraint: + gtRectMat[gtNum] = 1 + detRectMat[detNum] = 1 + detMatched += 1 + pairs.append({'gt': gtNum, 'det': detNum}) + detMatchedNums.append(detNum) + evaluationLog += "Match GT #" + \ + str(gtNum) + " with Det #" + str(detNum) + "\n" + + numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) + numDetCare = (len(detPols) - len(detDontCarePolsNum)) + if numGtCare == 0: + recall = float(1) + precision = float(0) if numDetCare > 0 else float(1) + else: + recall = float(detMatched) / numGtCare + precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare + + hmean = 0 if (precision + recall) == 0 else 2.0 * \ + precision * recall / (precision + recall) + + matchedSum += detMatched + numGlobalCareGt += numGtCare + numGlobalCareDet += numDetCare + + perSampleMetrics = { + 'precision': precision, + 'recall': recall, + 'hmean': hmean, + 'pairs': pairs, + 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), + 'gtPolPoints': gtPolPoints, + 'detPolPoints': detPolPoints, + 'gtCare': numGtCare, + 'detCare': numDetCare, + 'gtDontCare': gtDontCarePolsNum, + 'detDontCare': detDontCarePolsNum, + 'detMatched': detMatched, + 'evaluationLog': evaluationLog + } + + return perSampleMetrics + + def combine_results(self, results): + numGlobalCareGt = 0 + numGlobalCareDet = 0 + matchedSum = 0 + for result in results: + numGlobalCareGt += result['gtCare'] + numGlobalCareDet += result['detCare'] + matchedSum += result['detMatched'] + + methodRecall = 0 if numGlobalCareGt == 0 else float( + matchedSum) / numGlobalCareGt + methodPrecision = 0 if numGlobalCareDet == 0 else float( + matchedSum) / numGlobalCareDet + methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ + methodRecall * methodPrecision / ( + methodRecall + methodPrecision) + # print(methodRecall, methodPrecision, methodHmean) + # sys.exit(-1) + methodMetrics = { + 'precision': methodPrecision, + 'recall': methodRecall, + 'hmean': methodHmean + } + + return methodMetrics + + +if __name__ == '__main__': + evaluator = DetectionIoUEvaluator() + gts = [[{ + 'points': [(0, 0), (1, 0), (1, 1), (0, 1)], + 'text': 1234, + 'ignore': False, + }, { + 'points': [(2, 2), (3, 2), (3, 3), (2, 3)], + 'text': 5678, + 'ignore': False, + }]] + preds = [[{ + 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)], + 'text': 123, + 'ignore': False, + }]] + results = [] + for gt, pred in zip(gts, preds): + results.append(evaluator.evaluate_image(gt, pred)) + metrics = evaluator.combine_results(results) + print(metrics) diff --git a/trainer/craft/model/craft.py b/trainer/craft/model/craft.py new file mode 100644 index 000000000..f0da362a5 --- /dev/null +++ b/trainer/craft/model/craft.py @@ -0,0 +1,112 @@ +""" +Copyright (c) 2019-present NAVER Corp. +MIT License +""" + +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.vgg16_bn import vgg16_bn, init_weights + +class double_conv(nn.Module): + def __init__(self, in_ch, mid_ch, out_ch): + super(double_conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), + nn.BatchNorm2d(mid_ch), + nn.ReLU(inplace=True), + nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + x = self.conv(x) + return x + + +class CRAFT(nn.Module): + def __init__(self, pretrained=True, freeze=False, amp=False): + super(CRAFT, self).__init__() + + self.amp = amp + + """ Base network """ + self.basenet = vgg16_bn(pretrained, freeze) + + """ U network """ + self.upconv1 = double_conv(1024, 512, 256) + self.upconv2 = double_conv(512, 256, 128) + self.upconv3 = double_conv(256, 128, 64) + self.upconv4 = double_conv(128, 64, 32) + + num_class = 2 + self.conv_cls = nn.Sequential( + nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), + nn.Conv2d(16, num_class, kernel_size=1), + ) + + init_weights(self.upconv1.modules()) + init_weights(self.upconv2.modules()) + init_weights(self.upconv3.modules()) + init_weights(self.upconv4.modules()) + init_weights(self.conv_cls.modules()) + + def forward(self, x): + """ Base network """ + if self.amp: + with torch.cuda.amp.autocast(): + sources = self.basenet(x) + + """ U network """ + y = torch.cat([sources[0], sources[1]], dim=1) + y = self.upconv1(y) + + y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) + y = torch.cat([y, sources[2]], dim=1) + y = self.upconv2(y) + + y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) + y = torch.cat([y, sources[3]], dim=1) + y = self.upconv3(y) + + y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) + y = torch.cat([y, sources[4]], dim=1) + feature = self.upconv4(y) + + y = self.conv_cls(feature) + + return y.permute(0,2,3,1), feature + else: + + sources = self.basenet(x) + + """ U network """ + y = torch.cat([sources[0], sources[1]], dim=1) + y = self.upconv1(y) + + y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) + y = torch.cat([y, sources[2]], dim=1) + y = self.upconv2(y) + + y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) + y = torch.cat([y, sources[3]], dim=1) + y = self.upconv3(y) + + y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) + y = torch.cat([y, sources[4]], dim=1) + feature = self.upconv4(y) + + y = self.conv_cls(feature) + + return y.permute(0, 2, 3, 1), feature + +if __name__ == '__main__': + model = CRAFT(pretrained=True).cuda() + output, _ = model(torch.randn(1, 3, 768, 768).cuda()) + print(output.shape) \ No newline at end of file diff --git a/trainer/craft/model/vgg16_bn.py b/trainer/craft/model/vgg16_bn.py new file mode 100644 index 000000000..f3f21a79e --- /dev/null +++ b/trainer/craft/model/vgg16_bn.py @@ -0,0 +1,73 @@ +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.init as init +from torchvision import models +from torchvision.models.vgg import model_urls + +def init_weights(modules): + for m in modules: + if isinstance(m, nn.Conv2d): + init.xavier_uniform_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + +class vgg16_bn(torch.nn.Module): + def __init__(self, pretrained=True, freeze=True): + super(vgg16_bn, self).__init__() + model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://') + vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(12): # conv2_2 + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 19): # conv3_3 + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(19, 29): # conv4_3 + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(29, 39): # conv5_3 + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + + # fc6, fc7 without atrous conv + self.slice5 = torch.nn.Sequential( + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), + nn.Conv2d(1024, 1024, kernel_size=1) + ) + + if not pretrained: + init_weights(self.slice1.modules()) + init_weights(self.slice2.modules()) + init_weights(self.slice3.modules()) + init_weights(self.slice4.modules()) + + init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 + + if freeze: + for param in self.slice1.parameters(): # only first conv + param.requires_grad= False + + def forward(self, X): + h = self.slice1(X) + h_relu2_2 = h + h = self.slice2(h) + h_relu3_2 = h + h = self.slice3(h) + h_relu4_3 = h + h = self.slice4(h) + h_relu5_3 = h + h = self.slice5(h) + h_fc7 = h + vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2']) + out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) + return out diff --git a/trainer/craft/requirements.txt b/trainer/craft/requirements.txt new file mode 100644 index 000000000..af5f524cb --- /dev/null +++ b/trainer/craft/requirements.txt @@ -0,0 +1,10 @@ +conda==4.10.3 +opencv-python==4.5.3.56 +Pillow==8.2.0 +Polygon3==3.0.9.1 +PyYAML==5.4.1 +scikit-image==0.17.2 +Shapely==1.8.0 +torch==1.9.0 +torchvision==0.10.0 +wandb==0.12.9 diff --git a/trainer/craft/scripts/run_cde.sh b/trainer/craft/scripts/run_cde.sh new file mode 100644 index 000000000..5fd232184 --- /dev/null +++ b/trainer/craft/scripts/run_cde.sh @@ -0,0 +1,7 @@ +# sed -i -e 's/\r$//' scripts/run_cde.sh +EXP_NAME=custom_data_release_test_3 +yaml_path="config/$EXP_NAME.yaml" +cp config/custom_data_train.yaml $yaml_path +#CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=$EXP_NAME --port=2468 +CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=$EXP_NAME --port=2468 +rm "config/$EXP_NAME.yaml" \ No newline at end of file diff --git a/trainer/craft/train.py b/trainer/craft/train.py new file mode 100644 index 000000000..de441cf2e --- /dev/null +++ b/trainer/craft/train.py @@ -0,0 +1,479 @@ +# -*- coding: utf-8 -*- +import argparse +import os +import shutil +import time +import multiprocessing as mp +import yaml + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import wandb + +from config.load_config import load_yaml, DotDict +from data.dataset import SynthTextDataSet, CustomDataset +from loss.mseloss import Maploss_v2, Maploss_v3 +from model.craft import CRAFT +from eval import main_eval +from metrics.eval_det_iou import DetectionIoUEvaluator +from utils.util import copyStateDict, save_parser + + +class Trainer(object): + def __init__(self, config, gpu, mode): + + self.config = config + self.gpu = gpu + self.mode = mode + self.net_param = self.get_load_param(gpu) + + def get_synth_loader(self): + + dataset = SynthTextDataSet( + output_size=self.config.train.data.output_size, + data_dir=self.config.train.synth_data_dir, + saved_gt_dir=None, + mean=self.config.train.data.mean, + variance=self.config.train.data.variance, + gauss_init_size=self.config.train.data.gauss_init_size, + gauss_sigma=self.config.train.data.gauss_sigma, + enlarge_region=self.config.train.data.enlarge_region, + enlarge_affinity=self.config.train.data.enlarge_affinity, + aug=self.config.train.data.syn_aug, + vis_test_dir=self.config.vis_test_dir, + vis_opt=self.config.train.data.vis_opt, + sample=self.config.train.data.syn_sample, + ) + + syn_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.config.train.batch_size // self.config.train.synth_ratio, + shuffle=False, + num_workers=self.config.train.num_workers, + drop_last=True, + pin_memory=True, + ) + return syn_loader + + def get_custom_dataset(self): + + custom_dataset = CustomDataset( + output_size=self.config.train.data.output_size, + data_dir=self.config.data_root_dir, + saved_gt_dir=None, + mean=self.config.train.data.mean, + variance=self.config.train.data.variance, + gauss_init_size=self.config.train.data.gauss_init_size, + gauss_sigma=self.config.train.data.gauss_sigma, + enlarge_region=self.config.train.data.enlarge_region, + enlarge_affinity=self.config.train.data.enlarge_affinity, + watershed_param=self.config.train.data.watershed, + aug=self.config.train.data.custom_aug, + vis_test_dir=self.config.vis_test_dir, + sample=self.config.train.data.custom_sample, + vis_opt=self.config.train.data.vis_opt, + pseudo_vis_opt=self.config.train.data.pseudo_vis_opt, + do_not_care_label=self.config.train.data.do_not_care_label, + ) + + return custom_dataset + + def get_load_param(self, gpu): + + if self.config.train.ckpt_path is not None: + map_location = "cuda:%d" % gpu + param = torch.load(self.config.train.ckpt_path, map_location=map_location) + else: + param = None + + return param + + def adjust_learning_rate(self, optimizer, gamma, step, lr): + lr = lr * (gamma ** step) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + return param_group["lr"] + + def get_loss(self): + if self.config.train.loss == 2: + criterion = Maploss_v2() + elif self.config.train.loss == 3: + criterion = Maploss_v3() + else: + raise Exception("Undefined loss") + return criterion + + def iou_eval(self, dataset, train_step, buffer, model): + test_config = DotDict(self.config.test[dataset]) + + val_result_dir = os.path.join( + self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step)) + ) + + evaluator = DetectionIoUEvaluator() + + metrics = main_eval( + None, + self.config.train.backbone, + test_config, + evaluator, + val_result_dir, + buffer, + model, + self.mode, + ) + if self.gpu == 0 and self.config.wandb_opt: + wandb.log( + { + "{} iou Recall".format(dataset): np.round(metrics["recall"], 3), + "{} iou Precision".format(dataset): np.round( + metrics["precision"], 3 + ), + "{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3), + } + ) + + def train(self, buffer_dict): + + torch.cuda.set_device(self.gpu) + + # MODEL -------------------------------------------------------------------------------------------------------# + # SUPERVISION model + if self.config.mode == "weak_supervision": + if self.config.train.backbone == "vgg": + supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp) + else: + raise Exception("Undefined architecture") + + supervision_device = self.gpu + if self.config.train.ckpt_path is not None: + supervision_param = self.get_load_param(supervision_device) + supervision_model.load_state_dict( + copyStateDict(supervision_param["craft"]) + ) + supervision_model = supervision_model.to(f"cuda:{supervision_device}") + print(f"Supervision model loading on : gpu {supervision_device}") + else: + supervision_model, supervision_device = None, None + + # TRAIN model + if self.config.train.backbone == "vgg": + craft = CRAFT(pretrained=False, amp=self.config.train.amp) + else: + raise Exception("Undefined architecture") + + if self.config.train.ckpt_path is not None: + craft.load_state_dict(copyStateDict(self.net_param["craft"])) + + craft = craft.cuda() + craft = torch.nn.DataParallel(craft) + + torch.backends.cudnn.benchmark = True + + # DATASET -----------------------------------------------------------------------------------------------------# + + if self.config.train.use_synthtext: + trn_syn_loader = self.get_synth_loader() + batch_syn = iter(trn_syn_loader) + + if self.config.train.real_dataset == "custom": + trn_real_dataset = self.get_custom_dataset() + else: + raise Exception("Undefined dataset") + + if self.config.mode == "weak_supervision": + trn_real_dataset.update_model(supervision_model) + trn_real_dataset.update_device(supervision_device) + + trn_real_loader = torch.utils.data.DataLoader( + trn_real_dataset, + batch_size=self.config.train.batch_size, + shuffle=False, + num_workers=self.config.train.num_workers, + drop_last=False, + pin_memory=True, + ) + + # OPTIMIZER ---------------------------------------------------------------------------------------------------# + optimizer = optim.Adam( + craft.parameters(), + lr=self.config.train.lr, + weight_decay=self.config.train.weight_decay, + ) + + if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0: + optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"])) + self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"] + self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"] + + # LOSS --------------------------------------------------------------------------------------------------------# + # mixed precision + if self.config.train.amp: + scaler = torch.cuda.amp.GradScaler() + + if ( + self.config.train.ckpt_path is not None + and self.config.train.st_iter != 0 + ): + scaler.load_state_dict(copyStateDict(self.net_param["scaler"])) + else: + scaler = None + + criterion = self.get_loss() + + # TRAIN -------------------------------------------------------------------------------------------------------# + train_step = self.config.train.st_iter + whole_training_step = self.config.train.end_iter + update_lr_rate_step = 0 + training_lr = self.config.train.lr + loss_value = 0 + batch_time = 0 + start_time = time.time() + + print( + "================================ Train start ================================" + ) + while train_step < whole_training_step: + for ( + index, + ( + images, + region_scores, + affinity_scores, + confidence_masks, + ), + ) in enumerate(trn_real_loader): + craft.train() + if train_step > 0 and train_step % self.config.train.lr_decay == 0: + update_lr_rate_step += 1 + training_lr = self.adjust_learning_rate( + optimizer, + self.config.train.gamma, + update_lr_rate_step, + self.config.train.lr, + ) + + images = images.cuda(non_blocking=True) + region_scores = region_scores.cuda(non_blocking=True) + affinity_scores = affinity_scores.cuda(non_blocking=True) + confidence_masks = confidence_masks.cuda(non_blocking=True) + + if self.config.train.use_synthtext: + # Synth image load + syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next( + batch_syn + ) + syn_image = syn_image.cuda(non_blocking=True) + syn_region_label = syn_region_label.cuda(non_blocking=True) + syn_affi_label = syn_affi_label.cuda(non_blocking=True) + syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True) + + # concat syn & custom image + images = torch.cat((syn_image, images), 0) + region_image_label = torch.cat( + (syn_region_label, region_scores), 0 + ) + affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0) + confidence_mask_label = torch.cat( + (syn_confidence_mask, confidence_masks), 0 + ) + else: + region_image_label = region_scores + affinity_image_label = affinity_scores + confidence_mask_label = confidence_masks + + if self.config.train.amp: + with torch.cuda.amp.autocast(): + + output, _ = craft(images) + out1 = output[:, :, :, 0] + out2 = output[:, :, :, 1] + + loss = criterion( + region_image_label, + affinity_image_label, + out1, + out2, + confidence_mask_label, + self.config.train.neg_rto, + self.config.train.n_min_neg, + ) + + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + else: + output, _ = craft(images) + out1 = output[:, :, :, 0] + out2 = output[:, :, :, 1] + loss = criterion( + region_image_label, + affinity_image_label, + out1, + out2, + confidence_mask_label, + self.config.train.neg_rto, + ) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + end_time = time.time() + loss_value += loss.item() + batch_time += end_time - start_time + + if train_step > 0 and train_step % 5 == 0: + mean_loss = loss_value / 5 + loss_value = 0 + avg_batch_time = batch_time / 5 + batch_time = 0 + + print( + "{}, training_step: {}|{}, learning rate: {:.8f}, " + "training_loss: {:.5f}, avg_batch_time: {:.5f}".format( + time.strftime( + "%Y-%m-%d:%H:%M:%S", time.localtime(time.time()) + ), + train_step, + whole_training_step, + training_lr, + mean_loss, + avg_batch_time, + ) + ) + + if self.config.wandb_opt: + wandb.log({"train_step": train_step, "mean_loss": mean_loss}) + + if ( + train_step % self.config.train.eval_interval == 0 + and train_step != 0 + ): + + craft.eval() + + print("Saving state, index:", train_step) + save_param_dic = { + "iter": train_step, + "craft": craft.state_dict(), + "optimizer": optimizer.state_dict(), + } + save_param_path = ( + self.config.results_dir + + "/CRAFT_clr_" + + repr(train_step) + + ".pth" + ) + + if self.config.train.amp: + save_param_dic["scaler"] = scaler.state_dict() + save_param_path = ( + self.config.results_dir + + "/CRAFT_clr_amp_" + + repr(train_step) + + ".pth" + ) + + torch.save(save_param_dic, save_param_path) + + # validation + self.iou_eval( + "custom_data", + train_step, + buffer_dict["custom_data"], + craft, + ) + + train_step += 1 + if train_step >= whole_training_step: + break + + if self.config.mode == "weak_supervision": + state_dict = craft.module.state_dict() + supervision_model.load_state_dict(state_dict) + trn_real_dataset.update_model(supervision_model) + + # save last model + save_param_dic = { + "iter": train_step, + "craft": craft.state_dict(), + "optimizer": optimizer.state_dict(), + } + save_param_path = ( + self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth" + ) + + if self.config.train.amp: + save_param_dic["scaler"] = scaler.state_dict() + save_param_path = ( + self.config.results_dir + + "/CRAFT_clr_amp_" + + repr(train_step) + + ".pth" + ) + torch.save(save_param_dic, save_param_path) + + +def main(): + parser = argparse.ArgumentParser(description="CRAFT custom data train") + parser.add_argument( + "--yaml", + "--yaml_file_name", + default="custom_data_train", + type=str, + help="Load configuration", + ) + parser.add_argument( + "--port", "--use ddp port", default="2346", type=str, help="Port number" + ) + + args = parser.parse_args() + + # load configure + exp_name = args.yaml + config = load_yaml(args.yaml) + + print("-" * 20 + " Options " + "-" * 20) + print(yaml.dump(config)) + print("-" * 40) + + # Make result_dir + res_dir = os.path.join(config["results_dir"], args.yaml) + config["results_dir"] = res_dir + if not os.path.exists(res_dir): + os.makedirs(res_dir) + + # Duplicate yaml file to result_dir + shutil.copy( + "config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml" + ) + + if config["mode"] == "weak_supervision": + mode = "weak_supervision" + else: + mode = None + + + # Apply config to wandb + if config["wandb_opt"]: + wandb.init(project="craft-stage2", entity="user_name", name=exp_name) + wandb.config.update(config) + + config = DotDict(config) + + # Start train + buffer_dict = {"custom_data":None} + trainer = Trainer(config, 0, mode) + trainer.train(buffer_dict) + + if config["wandb_opt"]: + wandb.finish() + + +if __name__ == "__main__": + main() diff --git a/trainer/craft/trainSynth.py b/trainer/craft/trainSynth.py new file mode 100644 index 000000000..4d1d0dc72 --- /dev/null +++ b/trainer/craft/trainSynth.py @@ -0,0 +1,408 @@ +# -*- coding: utf-8 -*- +import argparse +import os +import shutil +import time +import yaml +import multiprocessing as mp + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import wandb + +from config.load_config import load_yaml, DotDict +from data.dataset import SynthTextDataSet +from loss.mseloss import Maploss_v2, Maploss_v3 +from model.craft import CRAFT +from metrics.eval_det_iou import DetectionIoUEvaluator +from eval import main_eval +from utils.util import copyStateDict, save_parser + + +class Trainer(object): + def __init__(self, config, gpu): + + self.config = config + self.gpu = gpu + self.mode = None + self.trn_loader, self.trn_sampler = self.get_trn_loader() + self.net_param = self.get_load_param(gpu) + + def get_trn_loader(self): + + dataset = SynthTextDataSet( + output_size=self.config.train.data.output_size, + data_dir=self.config.data_dir.synthtext, + saved_gt_dir=None, + mean=self.config.train.data.mean, + variance=self.config.train.data.variance, + gauss_init_size=self.config.train.data.gauss_init_size, + gauss_sigma=self.config.train.data.gauss_sigma, + enlarge_region=self.config.train.data.enlarge_region, + enlarge_affinity=self.config.train.data.enlarge_affinity, + aug=self.config.train.data.syn_aug, + vis_test_dir=self.config.vis_test_dir, + vis_opt=self.config.train.data.vis_opt, + sample=self.config.train.data.syn_sample, + ) + + trn_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + + trn_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.config.train.batch_size, + shuffle=False, + num_workers=self.config.train.num_workers, + sampler=trn_sampler, + drop_last=True, + pin_memory=True, + ) + return trn_loader, trn_sampler + + def get_load_param(self, gpu): + if self.config.train.ckpt_path is not None: + map_location = {"cuda:%d" % 0: "cuda:%d" % gpu} + param = torch.load(self.config.train.ckpt_path, map_location=map_location) + else: + param = None + return param + + def adjust_learning_rate(self, optimizer, gamma, step, lr): + lr = lr * (gamma ** step) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + return param_group["lr"] + + def get_loss(self): + if self.config.train.loss == 2: + criterion = Maploss_v2() + elif self.config.train.loss == 3: + criterion = Maploss_v3() + else: + raise Exception("Undefined loss") + return criterion + + def iou_eval(self, dataset, train_step, save_param_path, buffer, model): + test_config = DotDict(self.config.test[dataset]) + + val_result_dir = os.path.join( + self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step)) + ) + + evaluator = DetectionIoUEvaluator() + + metrics = main_eval( + save_param_path, + self.config.train.backbone, + test_config, + evaluator, + val_result_dir, + buffer, + model, + self.mode, + ) + if self.gpu == 0 and self.config.wandb_opt: + wandb.log( + { + "{} IoU Recall".format(dataset): np.round(metrics["recall"], 3), + "{} IoU Precision".format(dataset): np.round( + metrics["precision"], 3 + ), + "{} IoU F1-score".format(dataset): np.round(metrics["hmean"], 3), + } + ) + + def train(self, buffer_dict): + torch.cuda.set_device(self.gpu) + + # DATASET -----------------------------------------------------------------------------------------------------# + trn_loader = self.trn_loader + + # MODEL -------------------------------------------------------------------------------------------------------# + if self.config.train.backbone == "vgg": + craft = CRAFT(pretrained=True, amp=self.config.train.amp) + else: + raise Exception("Undefined architecture") + + if self.config.train.ckpt_path is not None: + craft.load_state_dict(copyStateDict(self.net_param["craft"])) + craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft) + craft = craft.cuda() + craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu]) + + torch.backends.cudnn.benchmark = True + + # OPTIMIZER----------------------------------------------------------------------------------------------------# + + optimizer = optim.Adam( + craft.parameters(), + lr=self.config.train.lr, + weight_decay=self.config.train.weight_decay, + ) + + if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0: + optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"])) + self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"] + self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"] + + # LOSS --------------------------------------------------------------------------------------------------------# + # mixed precision + if self.config.train.amp: + scaler = torch.cuda.amp.GradScaler() + + # load model + if ( + self.config.train.ckpt_path is not None + and self.config.train.st_iter != 0 + ): + scaler.load_state_dict(copyStateDict(self.net_param["scaler"])) + else: + scaler = None + + criterion = self.get_loss() + + # TRAIN -------------------------------------------------------------------------------------------------------# + train_step = self.config.train.st_iter + whole_training_step = self.config.train.end_iter + update_lr_rate_step = 0 + training_lr = self.config.train.lr + loss_value = 0 + batch_time = 0 + epoch = 0 + start_time = time.time() + + while train_step < whole_training_step: + self.trn_sampler.set_epoch(train_step) + for ( + index, + (image, region_image, affinity_image, confidence_mask,), + ) in enumerate(trn_loader): + craft.train() + if train_step > 0 and train_step % self.config.train.lr_decay == 0: + update_lr_rate_step += 1 + training_lr = self.adjust_learning_rate( + optimizer, + self.config.train.gamma, + update_lr_rate_step, + self.config.train.lr, + ) + + images = image.cuda(non_blocking=True) + region_image_label = region_image.cuda(non_blocking=True) + affinity_image_label = affinity_image.cuda(non_blocking=True) + confidence_mask_label = confidence_mask.cuda(non_blocking=True) + + if self.config.train.amp: + with torch.cuda.amp.autocast(): + + output, _ = craft(images) + out1 = output[:, :, :, 0] + out2 = output[:, :, :, 1] + + loss = criterion( + region_image_label, + affinity_image_label, + out1, + out2, + confidence_mask_label, + self.config.train.neg_rto, + self.config.train.n_min_neg, + ) + + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + else: + output, _ = craft(images) + out1 = output[:, :, :, 0] + out2 = output[:, :, :, 1] + loss = criterion( + region_image_label, + affinity_image_label, + out1, + out2, + confidence_mask_label, + self.config.train.neg_rto, + ) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + end_time = time.time() + loss_value += loss.item() + batch_time += end_time - start_time + + if train_step > 0 and train_step % 5 == 0 and self.gpu == 0: + mean_loss = loss_value / 5 + loss_value = 0 + avg_batch_time = batch_time / 5 + batch_time = 0 + + print( + "{}, training_step: {}|{}, learning rate: {:.8f}, " + "training_loss: {:.5f}, avg_batch_time: {:.5f}".format( + time.strftime( + "%Y-%m-%d:%H:%M:%S", time.localtime(time.time()) + ), + train_step, + whole_training_step, + training_lr, + mean_loss, + avg_batch_time, + ) + ) + if self.gpu == 0 and self.config.wandb_opt: + wandb.log({"train_step": train_step, "mean_loss": mean_loss}) + + if ( + train_step % self.config.train.eval_interval == 0 + and train_step != 0 + ): + + # initialize all buffer value with zero + if self.gpu == 0: + for buffer in buffer_dict.values(): + for i in range(len(buffer)): + buffer[i] = None + + print("Saving state, index:", train_step) + save_param_dic = { + "iter": train_step, + "craft": craft.state_dict(), + "optimizer": optimizer.state_dict(), + } + save_param_path = ( + self.config.results_dir + + "/CRAFT_clr_" + + repr(train_step) + + ".pth" + ) + + if self.config.train.amp: + save_param_dic["scaler"] = scaler.state_dict() + save_param_path = ( + self.config.results_dir + + "/CRAFT_clr_amp_" + + repr(train_step) + + ".pth" + ) + + torch.save(save_param_dic, save_param_path) + + # validation + self.iou_eval( + "icdar2013", + train_step, + save_param_path, + buffer_dict["icdar2013"], + craft, + ) + + train_step += 1 + if train_step >= whole_training_step: + break + epoch += 1 + + # save last model + if self.gpu == 0: + save_param_dic = { + "iter": train_step, + "craft": craft.state_dict(), + "optimizer": optimizer.state_dict(), + } + save_param_path = ( + self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth" + ) + + if self.config.train.amp: + save_param_dic["scaler"] = scaler.state_dict() + save_param_path = ( + self.config.results_dir + + "/CRAFT_clr_amp_" + + repr(train_step) + + ".pth" + ) + torch.save(save_param_dic, save_param_path) + + +def main(): + parser = argparse.ArgumentParser(description="CRAFT SynthText Train") + parser.add_argument( + "--yaml", + "--yaml_file_name", + default="syn_train", + type=str, + help="Load configuration", + ) + parser.add_argument( + "--port", "--use ddp port", default="2646", type=str, help="Load configuration" + ) + + args = parser.parse_args() + + # load configure + exp_name = args.yaml + config = load_yaml(args.yaml) + + print("-" * 20 + " Options " + "-" * 20) + print(yaml.dump(config)) + print("-" * 40) + + # Make result_dir + res_dir = os.path.join(config["results_dir"], args.yaml) + config["results_dir"] = res_dir + if not os.path.exists(res_dir): + os.makedirs(res_dir) + + # Duplicate yaml file to result_dir + shutil.copy( + "config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml" + ) + + ngpus_per_node = torch.cuda.device_count() + print(f"Total device num : {ngpus_per_node}") + + manager = mp.Manager() + buffer1 = manager.list([None] * config["test"]["icdar2013"]["test_set_size"]) + buffer_dict = {"icdar2013": buffer1} + torch.multiprocessing.spawn( + main_worker, + nprocs=ngpus_per_node, + args=(args.port, ngpus_per_node, config, buffer_dict, exp_name,), + ) + print('flag5') + + +def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name): + + torch.distributed.init_process_group( + backend="nccl", + init_method="tcp://127.0.0.1:" + port, + world_size=ngpus_per_node, + rank=gpu, + ) + + # Apply config to wandb + if gpu == 0 and config["wandb_opt"]: + wandb.init(project="craft-stage1", entity="gmuffiness", name=exp_name) + wandb.config.update(config) + + batch_size = int(config["train"]["batch_size"] / ngpus_per_node) + config["train"]["batch_size"] = batch_size + config = DotDict(config) + + # Start train + trainer = Trainer(config, gpu) + trainer.train(buffer_dict) + + if gpu == 0 and config["wandb_opt"]: + wandb.finish() + torch.distributed.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/trainer/craft/train_distributed.py b/trainer/craft/train_distributed.py new file mode 100644 index 000000000..8ab320c19 --- /dev/null +++ b/trainer/craft/train_distributed.py @@ -0,0 +1,523 @@ +# -*- coding: utf-8 -*- +import argparse +import os +import shutil +import time +import multiprocessing as mp +import yaml + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import wandb + +from config.load_config import load_yaml, DotDict +from data.dataset import SynthTextDataSet, CustomDataset +from loss.mseloss import Maploss_v2, Maploss_v3 +from model.craft import CRAFT +from eval import main_eval +from metrics.eval_det_iou import DetectionIoUEvaluator +from utils.util import copyStateDict, save_parser + + +class Trainer(object): + def __init__(self, config, gpu, mode): + + self.config = config + self.gpu = gpu + self.mode = mode + self.net_param = self.get_load_param(gpu) + + def get_synth_loader(self): + + dataset = SynthTextDataSet( + output_size=self.config.train.data.output_size, + data_dir=self.config.train.synth_data_dir, + saved_gt_dir=None, + mean=self.config.train.data.mean, + variance=self.config.train.data.variance, + gauss_init_size=self.config.train.data.gauss_init_size, + gauss_sigma=self.config.train.data.gauss_sigma, + enlarge_region=self.config.train.data.enlarge_region, + enlarge_affinity=self.config.train.data.enlarge_affinity, + aug=self.config.train.data.syn_aug, + vis_test_dir=self.config.vis_test_dir, + vis_opt=self.config.train.data.vis_opt, + sample=self.config.train.data.syn_sample, + ) + + syn_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + + syn_loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.config.train.batch_size // self.config.train.synth_ratio, + shuffle=False, + num_workers=self.config.train.num_workers, + sampler=syn_sampler, + drop_last=True, + pin_memory=True, + ) + return syn_loader + + def get_custom_dataset(self): + + custom_dataset = CustomDataset( + output_size=self.config.train.data.output_size, + data_dir=self.config.data_root_dir, + saved_gt_dir=None, + mean=self.config.train.data.mean, + variance=self.config.train.data.variance, + gauss_init_size=self.config.train.data.gauss_init_size, + gauss_sigma=self.config.train.data.gauss_sigma, + enlarge_region=self.config.train.data.enlarge_region, + enlarge_affinity=self.config.train.data.enlarge_affinity, + watershed_param=self.config.train.data.watershed, + aug=self.config.train.data.custom_aug, + vis_test_dir=self.config.vis_test_dir, + sample=self.config.train.data.custom_sample, + vis_opt=self.config.train.data.vis_opt, + pseudo_vis_opt=self.config.train.data.pseudo_vis_opt, + do_not_care_label=self.config.train.data.do_not_care_label, + ) + + return custom_dataset + + def get_load_param(self, gpu): + + if self.config.train.ckpt_path is not None: + map_location = "cuda:%d" % gpu + param = torch.load(self.config.train.ckpt_path, map_location=map_location) + else: + param = None + + return param + + def adjust_learning_rate(self, optimizer, gamma, step, lr): + lr = lr * (gamma ** step) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + return param_group["lr"] + + def get_loss(self): + if self.config.train.loss == 2: + criterion = Maploss_v2() + elif self.config.train.loss == 3: + criterion = Maploss_v3() + else: + raise Exception("Undefined loss") + return criterion + + def iou_eval(self, dataset, train_step, buffer, model): + test_config = DotDict(self.config.test[dataset]) + + val_result_dir = os.path.join( + self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step)) + ) + + evaluator = DetectionIoUEvaluator() + + metrics = main_eval( + None, + self.config.train.backbone, + test_config, + evaluator, + val_result_dir, + buffer, + model, + self.mode, + ) + if self.gpu == 0 and self.config.wandb_opt: + wandb.log( + { + "{} iou Recall".format(dataset): np.round(metrics["recall"], 3), + "{} iou Precision".format(dataset): np.round( + metrics["precision"], 3 + ), + "{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3), + } + ) + + def train(self, buffer_dict): + + torch.cuda.set_device(self.gpu) + total_gpu_num = torch.cuda.device_count() + + # MODEL -------------------------------------------------------------------------------------------------------# + # SUPERVISION model + if self.config.mode == "weak_supervision": + if self.config.train.backbone == "vgg": + supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp) + else: + raise Exception("Undefined architecture") + + # NOTE: only work on half GPU assign train / half GPU assign supervision setting + supervision_device = total_gpu_num // 2 + self.gpu + if self.config.train.ckpt_path is not None: + supervision_param = self.get_load_param(supervision_device) + supervision_model.load_state_dict( + copyStateDict(supervision_param["craft"]) + ) + supervision_model = supervision_model.to(f"cuda:{supervision_device}") + print(f"Supervision model loading on : gpu {supervision_device}") + else: + supervision_model, supervision_device = None, None + + # TRAIN model + if self.config.train.backbone == "vgg": + craft = CRAFT(pretrained=False, amp=self.config.train.amp) + else: + raise Exception("Undefined architecture") + + if self.config.train.ckpt_path is not None: + craft.load_state_dict(copyStateDict(self.net_param["craft"])) + + craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft) + craft = craft.cuda() + craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu]) + + torch.backends.cudnn.benchmark = True + + # DATASET -----------------------------------------------------------------------------------------------------# + + if self.config.train.use_synthtext: + trn_syn_loader = self.get_synth_loader() + batch_syn = iter(trn_syn_loader) + + if self.config.train.real_dataset == "custom": + trn_real_dataset = self.get_custom_dataset() + else: + raise Exception("Undefined dataset") + + if self.config.mode == "weak_supervision": + trn_real_dataset.update_model(supervision_model) + trn_real_dataset.update_device(supervision_device) + + trn_real_sampler = torch.utils.data.distributed.DistributedSampler( + trn_real_dataset + ) + trn_real_loader = torch.utils.data.DataLoader( + trn_real_dataset, + batch_size=self.config.train.batch_size, + shuffle=False, + num_workers=self.config.train.num_workers, + sampler=trn_real_sampler, + drop_last=False, + pin_memory=True, + ) + + # OPTIMIZER ---------------------------------------------------------------------------------------------------# + optimizer = optim.Adam( + craft.parameters(), + lr=self.config.train.lr, + weight_decay=self.config.train.weight_decay, + ) + + if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0: + optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"])) + self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"] + self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"] + + # LOSS --------------------------------------------------------------------------------------------------------# + # mixed precision + if self.config.train.amp: + scaler = torch.cuda.amp.GradScaler() + + if ( + self.config.train.ckpt_path is not None + and self.config.train.st_iter != 0 + ): + scaler.load_state_dict(copyStateDict(self.net_param["scaler"])) + else: + scaler = None + + criterion = self.get_loss() + + # TRAIN -------------------------------------------------------------------------------------------------------# + train_step = self.config.train.st_iter + whole_training_step = self.config.train.end_iter + update_lr_rate_step = 0 + training_lr = self.config.train.lr + loss_value = 0 + batch_time = 0 + start_time = time.time() + + print( + "================================ Train start ================================" + ) + while train_step < whole_training_step: + trn_real_sampler.set_epoch(train_step) + for ( + index, + ( + images, + region_scores, + affinity_scores, + confidence_masks, + ), + ) in enumerate(trn_real_loader): + craft.train() + if train_step > 0 and train_step % self.config.train.lr_decay == 0: + update_lr_rate_step += 1 + training_lr = self.adjust_learning_rate( + optimizer, + self.config.train.gamma, + update_lr_rate_step, + self.config.train.lr, + ) + + images = images.cuda(non_blocking=True) + region_scores = region_scores.cuda(non_blocking=True) + affinity_scores = affinity_scores.cuda(non_blocking=True) + confidence_masks = confidence_masks.cuda(non_blocking=True) + + if self.config.train.use_synthtext: + # Synth image load + syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next( + batch_syn + ) + syn_image = syn_image.cuda(non_blocking=True) + syn_region_label = syn_region_label.cuda(non_blocking=True) + syn_affi_label = syn_affi_label.cuda(non_blocking=True) + syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True) + + # concat syn & custom image + images = torch.cat((syn_image, images), 0) + region_image_label = torch.cat( + (syn_region_label, region_scores), 0 + ) + affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0) + confidence_mask_label = torch.cat( + (syn_confidence_mask, confidence_masks), 0 + ) + else: + region_image_label = region_scores + affinity_image_label = affinity_scores + confidence_mask_label = confidence_masks + + if self.config.train.amp: + with torch.cuda.amp.autocast(): + + output, _ = craft(images) + out1 = output[:, :, :, 0] + out2 = output[:, :, :, 1] + + loss = criterion( + region_image_label, + affinity_image_label, + out1, + out2, + confidence_mask_label, + self.config.train.neg_rto, + self.config.train.n_min_neg, + ) + + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + else: + output, _ = craft(images) + out1 = output[:, :, :, 0] + out2 = output[:, :, :, 1] + loss = criterion( + region_image_label, + affinity_image_label, + out1, + out2, + confidence_mask_label, + self.config.train.neg_rto, + ) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + end_time = time.time() + loss_value += loss.item() + batch_time += end_time - start_time + + if train_step > 0 and train_step % 5 == 0 and self.gpu == 0: + mean_loss = loss_value / 5 + loss_value = 0 + avg_batch_time = batch_time / 5 + batch_time = 0 + + print( + "{}, training_step: {}|{}, learning rate: {:.8f}, " + "training_loss: {:.5f}, avg_batch_time: {:.5f}".format( + time.strftime( + "%Y-%m-%d:%H:%M:%S", time.localtime(time.time()) + ), + train_step, + whole_training_step, + training_lr, + mean_loss, + avg_batch_time, + ) + ) + + if self.gpu == 0 and self.config.wandb_opt: + wandb.log({"train_step": train_step, "mean_loss": mean_loss}) + + if ( + train_step % self.config.train.eval_interval == 0 + and train_step != 0 + ): + + craft.eval() + # initialize all buffer value with zero + if self.gpu == 0: + for buffer in buffer_dict.values(): + for i in range(len(buffer)): + buffer[i] = None + + print("Saving state, index:", train_step) + save_param_dic = { + "iter": train_step, + "craft": craft.state_dict(), + "optimizer": optimizer.state_dict(), + } + save_param_path = ( + self.config.results_dir + + "/CRAFT_clr_" + + repr(train_step) + + ".pth" + ) + + if self.config.train.amp: + save_param_dic["scaler"] = scaler.state_dict() + save_param_path = ( + self.config.results_dir + + "/CRAFT_clr_amp_" + + repr(train_step) + + ".pth" + ) + + torch.save(save_param_dic, save_param_path) + + # validation + self.iou_eval( + "custom_data", + train_step, + buffer_dict["custom_data"], + craft, + ) + + train_step += 1 + if train_step >= whole_training_step: + break + + if self.config.mode == "weak_supervision": + state_dict = craft.module.state_dict() + supervision_model.load_state_dict(state_dict) + trn_real_dataset.update_model(supervision_model) + + # save last model + if self.gpu == 0: + save_param_dic = { + "iter": train_step, + "craft": craft.state_dict(), + "optimizer": optimizer.state_dict(), + } + save_param_path = ( + self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth" + ) + + if self.config.train.amp: + save_param_dic["scaler"] = scaler.state_dict() + save_param_path = ( + self.config.results_dir + + "/CRAFT_clr_amp_" + + repr(train_step) + + ".pth" + ) + torch.save(save_param_dic, save_param_path) + +def main(): + parser = argparse.ArgumentParser(description="CRAFT custom data train") + parser.add_argument( + "--yaml", + "--yaml_file_name", + default="custom_data_train", + type=str, + help="Load configuration", + ) + parser.add_argument( + "--port", "--use ddp port", default="2346", type=str, help="Port number" + ) + + args = parser.parse_args() + + # load configure + exp_name = args.yaml + config = load_yaml(args.yaml) + + print("-" * 20 + " Options " + "-" * 20) + print(yaml.dump(config)) + print("-" * 40) + + # Make result_dir + res_dir = os.path.join(config["results_dir"], args.yaml) + config["results_dir"] = res_dir + if not os.path.exists(res_dir): + os.makedirs(res_dir) + + # Duplicate yaml file to result_dir + shutil.copy( + "config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml" + ) + + if config["mode"] == "weak_supervision": + # NOTE: half GPU assign train / half GPU assign supervision setting + ngpus_per_node = torch.cuda.device_count() // 2 + mode = "weak_supervision" + else: + ngpus_per_node = torch.cuda.device_count() + mode = None + + print(f"Total process num : {ngpus_per_node}") + + manager = mp.Manager() + buffer1 = manager.list([None] * config["test"]["custom_data"]["test_set_size"]) + + buffer_dict = {"custom_data": buffer1} + torch.multiprocessing.spawn( + main_worker, + nprocs=ngpus_per_node, + args=(args.port, ngpus_per_node, config, buffer_dict, exp_name, mode,), + ) + + +def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name, mode): + + torch.distributed.init_process_group( + backend="nccl", + init_method="tcp://127.0.0.1:" + port, + world_size=ngpus_per_node, + rank=gpu, + ) + + # Apply config to wandb + if gpu == 0 and config["wandb_opt"]: + wandb.init(project="craft-stage2", entity="user_name", name=exp_name) + wandb.config.update(config) + + batch_size = int(config["train"]["batch_size"] / ngpus_per_node) + config["train"]["batch_size"] = batch_size + config = DotDict(config) + + # Start train + trainer = Trainer(config, gpu, mode) + trainer.train(buffer_dict) + + if gpu == 0: + if config["wandb_opt"]: + wandb.finish() + + torch.distributed.barrier() + torch.distributed.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/trainer/craft/utils/craft_utils.py b/trainer/craft/utils/craft_utils.py new file mode 100644 index 000000000..f5d39df4e --- /dev/null +++ b/trainer/craft/utils/craft_utils.py @@ -0,0 +1,345 @@ + +# -*- coding: utf-8 -*- +import os +import torch +import cv2 +import math +import numpy as np +from data import imgproc + +""" auxilary functions """ +# unwarp corodinates + + + + +def warpCoord(Minv, pt): + out = np.matmul(Minv, (pt[0], pt[1], 1)) + return np.array([out[0]/out[2], out[1]/out[2]]) +""" end of auxilary functions """ + +def test(): + print('pass') + + +def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text): + # prepare data + linkmap = linkmap.copy() + textmap = textmap.copy() + img_h, img_w = textmap.shape + + """ labeling method """ + ret, text_score = cv2.threshold(textmap, low_text, 1, 0) + ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0) + + text_score_comb = np.clip(text_score + link_score, 0, 1) + nLabels, labels, stats, centroids = \ + cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4) + + det = [] + mapper = [] + for k in range(1,nLabels): + # size filtering + size = stats[k, cv2.CC_STAT_AREA] + if size < 10: continue + + # thresholding + if np.max(textmap[labels==k]) < text_threshold: continue + + # make segmentation map + segmap = np.zeros(textmap.shape, dtype=np.uint8) + segmap[labels==k] = 255 + segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area + x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP] + w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT] + niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2) + sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1 + # boundary check + if sx < 0 : sx = 0 + if sy < 0 : sy = 0 + if ex >= img_w: ex = img_w + if ey >= img_h: ey = img_h + kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter)) + segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel, iterations=1) + #kernel1 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 5)) + #segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel1, iterations=1) + + + # make box + np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2) + rectangle = cv2.minAreaRect(np_contours) + box = cv2.boxPoints(rectangle) + + # align diamond-shape + w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) + box_ratio = max(w, h) / (min(w, h) + 1e-5) + if abs(1 - box_ratio) <= 0.1: + l, r = min(np_contours[:,0]), max(np_contours[:,0]) + t, b = min(np_contours[:,1]), max(np_contours[:,1]) + box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) + + # make clock-wise order + startidx = box.sum(axis=1).argmin() + box = np.roll(box, 4-startidx, 0) + box = np.array(box) + + det.append(box) + mapper.append(k) + + return det, labels, mapper + +def getPoly_core(boxes, labels, mapper, linkmap): + # configs + num_cp = 5 + max_len_ratio = 0.7 + expand_ratio = 1.45 + max_r = 2.0 + step_r = 0.2 + + polys = [] + for k, box in enumerate(boxes): + # size filter for small instance + w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1) + if w < 30 or h < 30: + polys.append(None); continue + + # warp image + tar = np.float32([[0,0],[w,0],[w,h],[0,h]]) + M = cv2.getPerspectiveTransform(box, tar) + word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST) + try: + Minv = np.linalg.inv(M) + except: + polys.append(None); continue + + # binarization for selected label + cur_label = mapper[k] + word_label[word_label != cur_label] = 0 + word_label[word_label > 0] = 1 + + """ Polygon generation """ + # find top/bottom contours + cp = [] + max_len = -1 + for i in range(w): + region = np.where(word_label[:,i] != 0)[0] + if len(region) < 2 : continue + cp.append((i, region[0], region[-1])) + length = region[-1] - region[0] + 1 + if length > max_len: max_len = length + + # pass if max_len is similar to h + if h * max_len_ratio < max_len: + polys.append(None); continue + + # get pivot points with fixed length + tot_seg = num_cp * 2 + 1 + seg_w = w / tot_seg # segment width + pp = [None] * num_cp # init pivot points + cp_section = [[0, 0]] * tot_seg + seg_height = [0] * num_cp + seg_num = 0 + num_sec = 0 + prev_h = -1 + for i in range(0,len(cp)): + (x, sy, ey) = cp[i] + if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg: + # average previous segment + if num_sec == 0: break + cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec] + num_sec = 0 + + # reset variables + seg_num += 1 + prev_h = -1 + + # accumulate center points + cy = (sy + ey) * 0.5 + cur_h = ey - sy + 1 + cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy] + num_sec += 1 + + if seg_num % 2 == 0: continue # No polygon area + + if prev_h < cur_h: + pp[int((seg_num - 1)/2)] = (x, cy) + seg_height[int((seg_num - 1)/2)] = cur_h + prev_h = cur_h + + # processing last segment + if num_sec != 0: + cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec] + + # pass if num of pivots is not sufficient or segment widh is smaller than character height + if None in pp or seg_w < np.max(seg_height) * 0.25: + polys.append(None); continue + + # calc median maximum of pivot points + half_char_h = np.median(seg_height) * expand_ratio / 2 + + # calc gradiant and apply to make horizontal pivots + new_pp = [] + for i, (x, cy) in enumerate(pp): + dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0] + dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1] + if dx == 0: # gradient if zero + new_pp.append([x, cy - half_char_h, x, cy + half_char_h]) + continue + rad = - math.atan2(dy, dx) + c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad) + new_pp.append([x - s, cy - c, x + s, cy + c]) + + # get edge points to cover character heatmaps + isSppFound, isEppFound = False, False + grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0]) + grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0]) + for r in np.arange(0.5, max_r, step_r): + dx = 2 * half_char_h * r + if not isSppFound: + line_img = np.zeros(word_label.shape, dtype=np.uint8) + dy = grad_s * dx + p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy]) + cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) + if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: + spp = p + isSppFound = True + if not isEppFound: + line_img = np.zeros(word_label.shape, dtype=np.uint8) + dy = grad_e * dx + p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy]) + cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) + if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: + epp = p + isEppFound = True + if isSppFound and isEppFound: + break + + # pass if boundary of polygon is not found + if not (isSppFound and isEppFound): + polys.append(None); continue + + # make final polygon + poly = [] + poly.append(warpCoord(Minv, (spp[0], spp[1]))) + for p in new_pp: + poly.append(warpCoord(Minv, (p[0], p[1]))) + poly.append(warpCoord(Minv, (epp[0], epp[1]))) + poly.append(warpCoord(Minv, (epp[2], epp[3]))) + for p in reversed(new_pp): + poly.append(warpCoord(Minv, (p[2], p[3]))) + poly.append(warpCoord(Minv, (spp[2], spp[3]))) + + # add to final result + polys.append(np.array(poly)) + + return polys + +def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False): + boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text) + + if poly: + polys = getPoly_core(boxes, labels, mapper, linkmap) + else: + polys = [None] * len(boxes) + + return boxes, polys + +def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2): + if len(polys) > 0: + polys = np.array(polys) + for k in range(len(polys)): + if polys[k] is not None: + polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net) + return polys + +def save_outputs(image, region_scores, affinity_scores, text_threshold, link_threshold, + low_text, outoput_path, confidence_mask = None): + """save image, region_scores, and affinity_scores in a single image. region_scores and affinity_scores must be + cpu numpy arrays. You can convert GPU Tensors to CPU numpy arrays like this: + >>> array = tensor.cpu().data.numpy() + When saving outputs of the network during training, make sure you convert ALL tensors (image, region_score, + affinity_score) to numpy array first. + :param image: numpy array + :param region_scores: [] 2D numpy array with each element between 0~1. + :param affinity_scores: same as region_scores + :param text_threshold: 0 ~ 1. Closer to 0, characters with lower confidence will also be considered a word and be boxed + :param link_threshold: 0 ~ 1. Closer to 0, links with lower confidence will also be considered a word and be boxed + :param low_text: 0 ~ 1. Closer to 0, boxes will be more loosely drawn. + :param outoput_path: + :param confidence_mask: + :return: + """ + + assert region_scores.shape == affinity_scores.shape + assert len(image.shape) - 1 == len(region_scores.shape) + + boxes, polys = getDetBoxes(region_scores, affinity_scores, text_threshold, link_threshold, + low_text, False) + boxes = np.array(boxes, np.int32) * 2 + if len(boxes) > 0: + np.clip(boxes[:, :, 0], 0, image.shape[1]) + np.clip(boxes[:, :, 1], 0, image.shape[0]) + for box in boxes: + cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255)) + + target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores) + target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores) + + if confidence_mask is not None: + confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask) + gt_scores = np.hstack([target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color]) + confidence_mask_gray = np.hstack([np.zeros_like(confidence_mask_gray), confidence_mask_gray]) + output = np.concatenate([gt_scores, confidence_mask_gray], axis=0) + output = np.hstack([image, output]) + + else: + gt_scores = np.concatenate([target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color], axis=0) + output = np.hstack([image, gt_scores]) + + cv2.imwrite(outoput_path, output) + return output + + +def save_outputs_from_tensors(images, region_scores, affinity_scores, text_threshold, link_threshold, + low_text, output_dir, image_names, confidence_mask = None): + + """takes images, region_scores, and affinity_scores as tensors (cab be GPU). + :param images: 4D tensor + :param region_scores: 3D tensor with values between 0 ~ 1 + :param affinity_scores: 3D tensor with values between 0 ~ 1 + :param text_threshold: + :param link_threshold: + :param low_text: + :param output_dir: direcotry to save the output images. Will be joined with base names of image_names + :param image_names: names of each image. Doesn't have to be the base name (image file names) + :param confidence_mask: + :return: + """ + #import ipdb;ipdb.set_trace() + #images = images.cpu().permute(0, 2, 3, 1).contiguous().data.numpy() + if type(images) == torch.Tensor: + images = np.array(images) + + region_scores = region_scores.cpu().data.numpy() + affinity_scores = affinity_scores.cpu().data.numpy() + + batch_size = images.shape[0] + assert batch_size == region_scores.shape[0] and batch_size == affinity_scores.shape[0] and batch_size == len(image_names), \ + "The first dimension (i.e. batch size) of images, region scores, and affinity scores must be equal" + + output_images = [] + + for i in range(batch_size): + image = images[i] + region_score = region_scores[i] + affinity_score = affinity_scores[i] + + image_name = os.path.basename(image_names[i]) + outoput_path = os.path.join(output_dir,image_name) + + output_image = save_outputs(image, region_score, affinity_score, text_threshold, link_threshold, + low_text, outoput_path, confidence_mask=confidence_mask) + + output_images.append(output_image) + + return output_images \ No newline at end of file diff --git a/trainer/craft/utils/inference_boxes.py b/trainer/craft/utils/inference_boxes.py new file mode 100644 index 000000000..e334395bb --- /dev/null +++ b/trainer/craft/utils/inference_boxes.py @@ -0,0 +1,361 @@ +import os +import re +import itertools + +import cv2 +import time +import numpy as np +import torch +from torch.autograd import Variable + +from utils.craft_utils import getDetBoxes, adjustResultCoordinates +from data import imgproc +from data.dataset import SynthTextDataSet +import math +import xml.etree.ElementTree as elemTree + + +#-------------------------------------------------------------------------------------------------------------------# +def rotatePoint(xc, yc, xp, yp, theta): + xoff = xp - xc + yoff = yp - yc + + cosTheta = math.cos(theta) + sinTheta = math.sin(theta) + pResx = cosTheta * xoff + sinTheta * yoff + pResy = - sinTheta * xoff + cosTheta * yoff + # pRes = (xc + pResx, yc + pResy) + return int(xc + pResx), int(yc + pResy) + +def addRotatedShape(cx, cy, w, h, angle): + p0x, p0y = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle) + p1x, p1y = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle) + p2x, p2y = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle) + p3x, p3y = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle) + + points = [[p0x, p0y], [p1x, p1y], [p2x, p2y], [p3x, p3y]] + + return points + +def xml_parsing(xml): + tree = elemTree.parse(xml) + + annotations = [] # Initialize the list to store labels + iter_element = tree.iter(tag="object") + + for element in iter_element: + annotation = {} # Initialize the dict to store labels + + annotation['name'] = element.find("name").text # Save the name tag value + + box_coords = element.iter(tag="robndbox") + + for box_coord in box_coords: + cx = float(box_coord.find("cx").text) + cy = float(box_coord.find("cy").text) + w = float(box_coord.find("w").text) + h = float(box_coord.find("h").text) + angle = float(box_coord.find("angle").text) + + convertcoodi = addRotatedShape(cx, cy, w, h, angle) + + annotation['box_coodi'] = convertcoodi + annotations.append(annotation) + + box_coords = element.iter(tag="bndbox") + + for box_coord in box_coords: + xmin = int(box_coord.find("xmin").text) + ymin = int(box_coord.find("ymin").text) + xmax = int(box_coord.find("xmax").text) + ymax = int(box_coord.find("ymax").text) + # annotation['bndbox'] = [xmin,ymin,xmax,ymax] + + annotation['box_coodi'] = [[xmin, ymin], [xmax, ymin], [xmax, ymax], + [xmin, ymax]] + annotations.append(annotation) + + + + + bounds = [] + for i in range(len(annotations)): + box_info_dict = {"points": None, "text": None, "ignore": None} + + box_info_dict["points"] = np.array(annotations[i]['box_coodi']) + if annotations[i]['name'] == "dnc": + box_info_dict["text"] = "###" + box_info_dict["ignore"] = True + else: + box_info_dict["text"] = annotations[i]['name'] + box_info_dict["ignore"] = False + + bounds.append(box_info_dict) + + + + return bounds + +#-------------------------------------------------------------------------------------------------------------------# + +def load_prescription_gt(dataFolder): + + + total_img_path = [] + total_imgs_bboxes = [] + for (root, directories, files) in os.walk(dataFolder): + for file in files: + if '.jpg' in file: + img_path = os.path.join(root, file) + total_img_path.append(img_path) + if '.xml' in file: + gt_path = os.path.join(root, file) + total_imgs_bboxes.append(gt_path) + + + total_imgs_parsing_bboxes = [] + for img_path, bbox in zip(sorted(total_img_path), sorted(total_imgs_bboxes)): + # check file + + assert img_path.split(".jpg")[0] == bbox.split(".xml")[0] + + result_label = xml_parsing(bbox) + total_imgs_parsing_bboxes.append(result_label) + + + return total_imgs_parsing_bboxes, sorted(total_img_path) + + +# NOTE +def load_prescription_cleval_gt(dataFolder): + + + total_img_path = [] + total_gt_path = [] + for (root, directories, files) in os.walk(dataFolder): + for file in files: + if '.jpg' in file: + img_path = os.path.join(root, file) + total_img_path.append(img_path) + if '_cl.txt' in file: + gt_path = os.path.join(root, file) + total_gt_path.append(gt_path) + + + total_imgs_parsing_bboxes = [] + for img_path, gt_path in zip(sorted(total_img_path), sorted(total_gt_path)): + # check file + + assert img_path.split(".jpg")[0] == gt_path.split('_label_cl.txt')[0] + + lines = open(gt_path, encoding="utf-8").readlines() + word_bboxes = [] + + for line in lines: + box_info_dict = {"points": None, "text": None, "ignore": None} + box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",") + + box_points = [int(box_info[i]) for i in range(8)] + box_info_dict["points"] = np.array(box_points) + + word_bboxes.append(box_info_dict) + total_imgs_parsing_bboxes.append(word_bboxes) + + return total_imgs_parsing_bboxes, sorted(total_img_path) + + +def load_synthtext_gt(data_folder): + + synth_dataset = SynthTextDataSet( + output_size=768, data_dir=data_folder, saved_gt_dir=data_folder, logging=False + ) + img_names, img_bbox, img_words = synth_dataset.load_data(bbox="word") + + total_img_path = [] + total_imgs_bboxes = [] + for index in range(len(img_bbox[:100])): + img_path = os.path.join(data_folder, img_names[index][0]) + total_img_path.append(img_path) + try: + wordbox = img_bbox[index].transpose((2, 1, 0)) + except: + wordbox = np.expand_dims(img_bbox[index], axis=0) + wordbox = wordbox.transpose((0, 2, 1)) + + words = [re.split(" \n|\n |\n| ", t.strip()) for t in img_words[index]] + words = list(itertools.chain(*words)) + words = [t for t in words if len(t) > 0] + + if len(words) != len(wordbox): + import ipdb + + ipdb.set_trace() + + single_img_bboxes = [] + for j in range(len(words)): + box_info_dict = {"points": None, "text": None, "ignore": None} + box_info_dict["points"] = wordbox[j] + box_info_dict["text"] = words[j] + box_info_dict["ignore"] = False + single_img_bboxes.append(box_info_dict) + + total_imgs_bboxes.append(single_img_bboxes) + + return total_imgs_bboxes, total_img_path + + +def load_icdar2015_gt(dataFolder, isTraing=False): + if isTraing: + img_folderName = "ch4_training_images" + gt_folderName = "ch4_training_localization_transcription_gt" + else: + img_folderName = "ch4_test_images" + gt_folderName = "ch4_test_localization_transcription_gt" + + gt_folder_path = os.listdir(os.path.join(dataFolder, gt_folderName)) + total_imgs_bboxes = [] + total_img_path = [] + for gt_path in gt_folder_path: + gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path) + img_path = ( + gt_path.replace(gt_folderName, img_folderName) + .replace(".txt", ".jpg") + .replace("gt_", "") + ) + image = cv2.imread(img_path) + lines = open(gt_path, encoding="utf-8").readlines() + single_img_bboxes = [] + for line in lines: + box_info_dict = {"points": None, "text": None, "ignore": None} + + box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",") + box_points = [int(box_info[j]) for j in range(8)] + word = box_info[8:] + word = ",".join(word) + box_points = np.array(box_points, np.int32).reshape(4, 2) + cv2.polylines( + image, [np.array(box_points).astype(np.int)], True, (0, 0, 255), 1 + ) + box_info_dict["points"] = box_points + box_info_dict["text"] = word + if word == "###": + box_info_dict["ignore"] = True + else: + box_info_dict["ignore"] = False + + single_img_bboxes.append(box_info_dict) + total_imgs_bboxes.append(single_img_bboxes) + total_img_path.append(img_path) + return total_imgs_bboxes, total_img_path + + +def load_icdar2013_gt(dataFolder, isTraing=False): + + # choise test dataset + if isTraing: + img_folderName = "Challenge2_Test_Task12_Images" + gt_folderName = "Challenge2_Test_Task1_GT" + else: + img_folderName = "Challenge2_Test_Task12_Images" + gt_folderName = "Challenge2_Test_Task1_GT" + + gt_folder_path = os.listdir(os.path.join(dataFolder, gt_folderName)) + + total_imgs_bboxes = [] + total_img_path = [] + for gt_path in gt_folder_path: + gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path) + img_path = ( + gt_path.replace(gt_folderName, img_folderName) + .replace(".txt", ".jpg") + .replace("gt_", "") + ) + image = cv2.imread(img_path) + lines = open(gt_path, encoding="utf-8").readlines() + single_img_bboxes = [] + for line in lines: + box_info_dict = {"points": None, "text": None, "ignore": None} + + box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",") + box = [int(box_info[j]) for j in range(4)] + word = box_info[4:] + word = ",".join(word) + box = [ + [box[0], box[1]], + [box[2], box[1]], + [box[2], box[3]], + [box[0], box[3]], + ] + + box_info_dict["points"] = box + box_info_dict["text"] = word + if word == "###": + box_info_dict["ignore"] = True + else: + box_info_dict["ignore"] = False + + single_img_bboxes.append(box_info_dict) + + total_imgs_bboxes.append(single_img_bboxes) + total_img_path.append(img_path) + + return total_imgs_bboxes, total_img_path + + +def test_net( + net, + image, + text_threshold, + link_threshold, + low_text, + cuda, + poly, + canvas_size=1280, + mag_ratio=1.5, +): + # resize + + img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio( + image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio + ) + ratio_h = ratio_w = 1 / target_ratio + + # preprocessing + x = imgproc.normalizeMeanVariance(img_resized) + x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] + x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] + if cuda: + x = x.cuda() + + # forward pass + with torch.no_grad(): + y, feature = net(x) + + # make score and link map + score_text = y[0, :, :, 0].cpu().data.numpy().astype(np.float32) + score_link = y[0, :, :, 1].cpu().data.numpy().astype(np.float32) + + # NOTE + score_text = score_text[: size_heatmap[0], : size_heatmap[1]] + score_link = score_link[: size_heatmap[0], : size_heatmap[1]] + + # Post-processing + boxes, polys = getDetBoxes( + score_text, score_link, text_threshold, link_threshold, low_text, poly + ) + + # coordinate adjustment + boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h) + polys = adjustResultCoordinates(polys, ratio_w, ratio_h) + for k in range(len(polys)): + if polys[k] is None: + polys[k] = boxes[k] + + # render results (optional) + score_text = score_text.copy() + render_score_text = imgproc.cvt2HeatmapImg(score_text) + render_score_link = imgproc.cvt2HeatmapImg(score_link) + render_img = [render_score_text, render_score_link] + # ret_score_text = imgproc.cvt2HeatmapImg(render_img) + + return boxes, polys, render_img diff --git a/trainer/craft/utils/util.py b/trainer/craft/utils/util.py new file mode 100644 index 000000000..f6c862220 --- /dev/null +++ b/trainer/craft/utils/util.py @@ -0,0 +1,142 @@ +from collections import OrderedDict +import os + +import cv2 +import numpy as np + +from data import imgproc +from utils import craft_utils + + +def copyStateDict(state_dict): + if list(state_dict.keys())[0].startswith("module"): + start_idx = 1 + else: + start_idx = 0 + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = ".".join(k.split(".")[start_idx:]) + new_state_dict[name] = v + return new_state_dict + + +def saveInput( + imagename, vis_dir, image, region_scores, affinity_scores, confidence_mask +): + image = np.uint8(image.copy()) + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + boxes, polys = craft_utils.getDetBoxes( + region_scores, affinity_scores, 0.85, 0.2, 0.5, False + ) + + if image.shape[0] / region_scores.shape[0] >= 2: + boxes = np.array(boxes, np.int32) * 2 + else: + boxes = np.array(boxes, np.int32) + + if len(boxes) > 0: + np.clip(boxes[:, :, 0], 0, image.shape[1]) + np.clip(boxes[:, :, 1], 0, image.shape[0]) + for box in boxes: + cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255)) + target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores) + target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores) + confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask) + + # overlay + height, width, channel = image.shape + overlay_region = cv2.resize(target_gaussian_heatmap_color, (width, height)) + overlay_aff = cv2.resize(target_gaussian_affinity_heatmap_color, (width, height)) + confidence_mask_gray = cv2.resize( + confidence_mask_gray, (width, height), interpolation=cv2.INTER_NEAREST + ) + overlay_region = cv2.addWeighted(image, 0.4, overlay_region, 0.6, 5) + overlay_aff = cv2.addWeighted(image, 0.4, overlay_aff, 0.7, 6) + + gt_scores = np.concatenate([overlay_region, overlay_aff], axis=1) + + output = np.concatenate([gt_scores, confidence_mask_gray], axis=1) + + output = np.hstack([image, output]) + + # synthtext + if type(imagename) is not str: + imagename = imagename[0].split("/")[-1][:-4] + + outpath = vis_dir + f"/{imagename}_input.jpg" + if not os.path.exists(os.path.dirname(outpath)): + os.makedirs(os.path.dirname(outpath), exist_ok=True) + cv2.imwrite(outpath, output) + # print(f'Logging train input into {outpath}') + + +def saveImage( + imagename, + vis_dir, + image, + bboxes, + affi_bboxes, + region_scores, + affinity_scores, + confidence_mask, +): + output_image = np.uint8(image.copy()) + output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR) + if len(bboxes) > 0: + for i in range(len(bboxes)): + _bboxes = np.int32(bboxes[i]) + for j in range(_bboxes.shape[0]): + cv2.polylines( + output_image, + [np.reshape(_bboxes[j], (-1, 1, 2))], + True, + (0, 0, 255), + ) + + for i in range(len(affi_bboxes)): + cv2.polylines( + output_image, + [np.reshape(affi_bboxes[i].astype(np.int32), (-1, 1, 2))], + True, + (255, 0, 0), + ) + + target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores) + target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores) + confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask) + + # overlay + height, width, channel = image.shape + overlay_region = cv2.resize(target_gaussian_heatmap_color, (width, height)) + overlay_aff = cv2.resize(target_gaussian_affinity_heatmap_color, (width, height)) + + overlay_region = cv2.addWeighted(image.copy(), 0.4, overlay_region, 0.6, 5) + overlay_aff = cv2.addWeighted(image.copy(), 0.4, overlay_aff, 0.6, 5) + + heat_map = np.concatenate([overlay_region, overlay_aff], axis=1) + + # synthtext + if type(imagename) is not str: + imagename = imagename[0].split("/")[-1][:-4] + + output = np.concatenate([output_image, heat_map, confidence_mask_gray], axis=1) + outpath = vis_dir + f"/{imagename}.jpg" + if not os.path.exists(os.path.dirname(outpath)): + os.makedirs(os.path.dirname(outpath), exist_ok=True) + + cv2.imwrite(outpath, output) + # print(f'Logging original image into {outpath}') + + +def save_parser(args): + + """ final options """ + with open(f"{args.results_dir}/opt.txt", "a", encoding="utf-8") as opt_file: + opt_log = "------------ Options -------------\n" + arg = vars(args) + for k, v in arg.items(): + opt_log += f"{str(k)}: {str(v)}\n" + opt_log += "---------------------------------------\n" + print(opt_log) + opt_file.write(opt_log)