diff --git a/trainer/craft/.gitignore b/trainer/craft/.gitignore
new file mode 100644
index 000000000..46d5bda86
--- /dev/null
+++ b/trainer/craft/.gitignore
@@ -0,0 +1,4 @@
+__pycache__/
+model/__pycache__/
+wandb/*
+vis_result/*
diff --git a/trainer/craft/README.md b/trainer/craft/README.md
new file mode 100644
index 000000000..823c83dc6
--- /dev/null
+++ b/trainer/craft/README.md
@@ -0,0 +1,105 @@
+# CRAFT-train
+On the official CRAFT github, there are many people who want to train CRAFT models.
+
+However, the training code is not published in the official CRAFT repository.
+
+There are other reproduced codes, but there is a gap between their performance and performance reported in the original paper. (https://arxiv.org/pdf/1904.01941.pdf)
+
+The trained model with this code recorded a level of performance similar to that of the original paper.
+
+```bash
+├── config
+│ ├── syn_train.yaml
+│ └── custom_data_train.yaml
+├── data
+│ ├── pseudo_label
+│ │ ├── make_charbox.py
+│ │ └── watershed.py
+│ ├── boxEnlarge.py
+│ ├── dataset.py
+│ ├── gaussian.py
+│ ├── imgaug.py
+│ └── imgproc.py
+├── loss
+│ └── mseloss.py
+├── metrics
+│ └── eval_det_iou.py
+├── model
+│ ├── craft.py
+│ └── vgg16_bn.py
+├── utils
+│ ├── craft_utils.py
+│ ├── inference_boxes.py
+│ └── utils.py
+├── trainSynth.py
+├── train.py
+├── train_distributed.py
+├── eval.py
+├── data_root_dir (place dataset folder here)
+└── exp (model and experiment result files will saved here)
+```
+
+### Installation
+
+Install using `pip`
+
+``` bash
+pip install -r requirements.txt
+```
+
+
+### Training
+1. Put your training, test data in the following format
+ ```
+ └── data_root_dir (you can change root dir in yaml file)
+ ├── ch4_training_images
+ │ ├── img_1.jpg
+ │ └── img_2.jpg
+ ├── ch4_training_localization_transcription_gt
+ │ ├── gt_img_1.txt
+ │ └── gt_img_2.txt
+ ├── ch4_test_images
+ │ ├── img_1.jpg
+ │ └── img_2.jpg
+ └── ch4_training_localization_transcription_gt
+ ├── gt_img_1.txt
+ └── gt_img_2.txt
+ ```
+ * localization_transcription_gt files format :
+ ```
+ 377,117,463,117,465,130,378,130,Genaxis Theatre
+ 493,115,519,115,519,131,493,131,[06]
+ 374,155,409,155,409,170,374,170,###
+ ```
+2. Write configuration in yaml format (example config files are provided in `config` folder.)
+ * To speed up training time with multi-gpu, set num_worker > 0
+3. Put the yaml file in the config folder
+4. Run training script like below (If you have multi-gpu, run train_distributed.py)
+5. Then, experiment results will be saved to ```./exp/[yaml]``` by default.
+
+* Step 1 : To train CRAFT with SynthText dataset from scratch
+ * Note : This step is not necessary if you use this pretrain as a checkpoint when start training step 2. You can download and put it in `exp/CRAFT_clr_amp_29500.pth` and change `ckpt_path` in the config file according to your local setup.
+ ```
+ CUDA_VISIBLE_DEVICES=0 python3 trainSynth.py --yaml=syn_train
+ ```
+
+* Step 2 : To train CRAFT with [SynthText + IC15] or custom dataset
+ ```
+ CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=custom_data_train ## if you run on single GPU
+ CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=custom_data_train ## if you run on multi GPU
+ ```
+
+### Arguments
+* ```--yaml``` : configuration file name
+
+### Evaluation
+* In the official repository issues, the author mentioned that the first row setting F1-score is around 0.75.
+* In the official paper, it is stated that the result F1-score of the second row setting is 0.87.
+ * If you adjust post-process parameter 'text_threshold' from 0.85 to 0.75, then F1-score reaches to 0.856.
+* It took 14h to train weak-supervision 25k iteration with 8 RTX 3090 Ti.
+ * Half of GPU assigned for training, and half of GPU assigned for supervision setting.
+
+| Training Dataset | Evaluation Dataset | Precision | Recall | F1-score | pretrained model |
+| ------------- |-----|:-----:|:-----:|:-----:|-----:|
+| SynthText | ICDAR2013 | 0.801 | 0.748 | 0.773| download link|
+| SynthText + ICDAR2015 | ICDAR2015 | 0.909 | 0.794 | 0.848| download link|
diff --git a/trainer/craft/config/__init__.py b/trainer/craft/config/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/trainer/craft/config/custom_data_train.yaml b/trainer/craft/config/custom_data_train.yaml
new file mode 100644
index 000000000..bf7b577bf
--- /dev/null
+++ b/trainer/craft/config/custom_data_train.yaml
@@ -0,0 +1,100 @@
+wandb_opt: False
+
+results_dir: "./exp/"
+vis_test_dir: "./vis_result/"
+
+data_root_dir: "./data_root_dir/"
+score_gt_dir: None # "/data/ICDAR2015_official_supervision"
+mode: "weak_supervision"
+
+
+train:
+ backbone : vgg
+ use_synthtext: False # If you want to combine SynthText in train time as CRAFT did, you can turn on this option
+ synth_data_dir: "/data/SynthText/"
+ synth_ratio: 5
+ real_dataset: custom
+ ckpt_path: "./pretrained_model/CRAFT_clr_amp_29500.pth"
+ eval_interval: 1000
+ batch_size: 5
+ st_iter: 0
+ end_iter: 25000
+ lr: 0.0001
+ lr_decay: 7500
+ gamma: 0.2
+ weight_decay: 0.00001
+ num_workers: 0 # On single gpu, train.py execution only works when num worker = 0 / On multi-gpu, you can set num_worker > 0 to speed up
+ amp: True
+ loss: 2
+ neg_rto: 0.3
+ n_min_neg: 5000
+ data:
+ vis_opt: False
+ pseudo_vis_opt: False
+ output_size: 768
+ do_not_care_label: ['###', '']
+ mean: [0.485, 0.456, 0.406]
+ variance: [0.229, 0.224, 0.225]
+ enlarge_region : [0.5, 0.5] # x axis, y axis
+ enlarge_affinity: [0.5, 0.5]
+ gauss_init_size: 200
+ gauss_sigma: 40
+ watershed:
+ version: "skimage"
+ sure_fg_th: 0.75
+ sure_bg_th: 0.05
+ syn_sample: -1
+ custom_sample: -1
+ syn_aug:
+ random_scale:
+ range: [1.0, 1.5, 2.0]
+ option: False
+ random_rotate:
+ max_angle: 20
+ option: False
+ random_crop:
+ version: "random_resize_crop_synth"
+ option: True
+ random_horizontal_flip:
+ option: False
+ random_colorjitter:
+ brightness: 0.2
+ contrast: 0.2
+ saturation: 0.2
+ hue: 0.2
+ option: True
+ custom_aug:
+ random_scale:
+ range: [ 1.0, 1.5, 2.0 ]
+ option: False
+ random_rotate:
+ max_angle: 20
+ option: True
+ random_crop:
+ version: "random_resize_crop"
+ scale: [0.03, 0.4]
+ ratio: [0.75, 1.33]
+ rnd_threshold: 1.0
+ option: True
+ random_horizontal_flip:
+ option: True
+ random_colorjitter:
+ brightness: 0.2
+ contrast: 0.2
+ saturation: 0.2
+ hue: 0.2
+ option: True
+
+test:
+ trained_model : null
+ custom_data:
+ test_set_size: 500
+ test_data_dir: "./data_root_dir/"
+ text_threshold: 0.75
+ low_text: 0.5
+ link_threshold: 0.2
+ canvas_size: 2240
+ mag_ratio: 1.75
+ poly: False
+ cuda: True
+ vis_opt: False
diff --git a/trainer/craft/config/load_config.py b/trainer/craft/config/load_config.py
new file mode 100644
index 000000000..abe3551f2
--- /dev/null
+++ b/trainer/craft/config/load_config.py
@@ -0,0 +1,37 @@
+import os
+import yaml
+from functools import reduce
+
+CONFIG_PATH = os.path.dirname(__file__)
+
+def load_yaml(config_name):
+
+ with open(os.path.join(CONFIG_PATH, config_name)+ '.yaml') as file:
+ config = yaml.safe_load(file)
+
+ return config
+
+class DotDict(dict):
+ def __getattr__(self, k):
+ try:
+ v = self[k]
+ except:
+ return super().__getattr__(k)
+ if isinstance(v, dict):
+ return DotDict(v)
+ return v
+
+ def __getitem__(self, k):
+ if isinstance(k, str) and '.' in k:
+ k = k.split('.')
+ if isinstance(k, (list, tuple)):
+ return reduce(lambda d, kk: d[kk], k, self)
+ return super().__getitem__(k)
+
+ def get(self, k, default=None):
+ if isinstance(k, str) and '.' in k:
+ try:
+ return self[k]
+ except KeyError:
+ return default
+ return super().get(k, default=default)
\ No newline at end of file
diff --git a/trainer/craft/config/syn_train.yaml b/trainer/craft/config/syn_train.yaml
new file mode 100644
index 000000000..a41663f7c
--- /dev/null
+++ b/trainer/craft/config/syn_train.yaml
@@ -0,0 +1,68 @@
+wandb_opt: False
+
+results_dir: "./exp/"
+vis_test_dir: "./vis_result/"
+data_dir:
+ synthtext: "/data/SynthText/"
+ synthtext_gt: NULL
+
+train:
+ backbone : vgg
+ dataset: ["synthtext"]
+ ckpt_path: null
+ eval_interval: 1000
+ batch_size: 5
+ st_iter: 0
+ end_iter: 50000
+ lr: 0.0001
+ lr_decay: 15000
+ gamma: 0.2
+ weight_decay: 0.00001
+ num_workers: 4
+ amp: True
+ loss: 3
+ neg_rto: 1
+ n_min_neg: 1000
+ data:
+ vis_opt: False
+ output_size: 768
+ mean: [0.485, 0.456, 0.406]
+ variance: [0.229, 0.224, 0.225]
+ enlarge_region : [0.5, 0.5] # x axis, y axis
+ enlarge_affinity: [0.5, 0.5]
+ gauss_init_size: 200
+ gauss_sigma: 40
+ syn_sample : -1
+ syn_aug:
+ random_scale:
+ range: [1.0, 1.5, 2.0]
+ option: False
+ random_rotate:
+ max_angle: 20
+ option: False
+ random_crop:
+ version: "random_resize_crop_synth"
+ rnd_threshold : 1.0
+ option: True
+ random_horizontal_flip:
+ option: False
+ random_colorjitter:
+ brightness: 0.2
+ contrast: 0.2
+ saturation: 0.2
+ hue: 0.2
+ option: True
+
+test:
+ trained_model: null
+ icdar2013:
+ test_set_size: 233
+ cuda: True
+ vis_opt: True
+ test_data_dir : "/data/ICDAR2013/"
+ text_threshold: 0.85
+ low_text: 0.5
+ link_threshold: 0.2
+ canvas_size: 960
+ mag_ratio: 1.5
+ poly: False
\ No newline at end of file
diff --git a/trainer/craft/data/boxEnlarge.py b/trainer/craft/data/boxEnlarge.py
new file mode 100644
index 000000000..73d5bc5c2
--- /dev/null
+++ b/trainer/craft/data/boxEnlarge.py
@@ -0,0 +1,65 @@
+import math
+import numpy as np
+
+
+def pointAngle(Apoint, Bpoint):
+ angle = (Bpoint[1] - Apoint[1]) / ((Bpoint[0] - Apoint[0]) + 10e-8)
+ return angle
+
+def pointDistance(Apoint, Bpoint):
+ return math.sqrt((Bpoint[1] - Apoint[1])**2 + (Bpoint[0] - Apoint[0])**2)
+
+def lineBiasAndK(Apoint, Bpoint):
+
+ K = pointAngle(Apoint, Bpoint)
+ B = Apoint[1] - K*Apoint[0]
+ return K, B
+
+def getX(K, B, Ypoint):
+ return int((Ypoint-B)/K)
+
+def sidePoint(Apoint, Bpoint, h, w, placehold, enlarge_size):
+
+ K, B = lineBiasAndK(Apoint, Bpoint)
+ angle = abs(math.atan(pointAngle(Apoint, Bpoint)))
+ distance = pointDistance(Apoint, Bpoint)
+
+ x_enlarge_size, y_enlarge_size = enlarge_size
+
+ XaxisIncreaseDistance = abs(math.cos(angle) * x_enlarge_size * distance)
+ YaxisIncreaseDistance = abs(math.sin(angle) * y_enlarge_size * distance)
+
+ if placehold == 'leftTop':
+ x1 = max(0, Apoint[0] - XaxisIncreaseDistance)
+ y1 = max(0, Apoint[1] - YaxisIncreaseDistance)
+ elif placehold == 'rightTop':
+ x1 = min(w, Bpoint[0] + XaxisIncreaseDistance)
+ y1 = max(0, Bpoint[1] - YaxisIncreaseDistance)
+ elif placehold == 'rightBottom':
+ x1 = min(w, Bpoint[0] + XaxisIncreaseDistance)
+ y1 = min(h, Bpoint[1] + YaxisIncreaseDistance)
+ elif placehold == 'leftBottom':
+ x1 = max(0, Apoint[0] - XaxisIncreaseDistance)
+ y1 = min(h, Apoint[1] + YaxisIncreaseDistance)
+ return int(x1), int(y1)
+
+def enlargebox(box, h, w, enlarge_size, horizontal_text_bool):
+
+ if not horizontal_text_bool:
+ enlarge_size = (enlarge_size[1], enlarge_size[0])
+
+ box = np.roll(box, -np.argmin(box.sum(axis=1)), axis=0)
+
+ Apoint, Bpoint, Cpoint, Dpoint = box
+ K1, B1 = lineBiasAndK(box[0], box[2])
+ K2, B2 = lineBiasAndK(box[3], box[1])
+ X = (B2 - B1)/(K1 - K2)
+ Y = K1 * X + B1
+ center = [X, Y]
+
+ x1, y1 = sidePoint(Apoint, center, h, w, 'leftTop', enlarge_size)
+ x2, y2 = sidePoint(center, Bpoint, h, w, 'rightTop', enlarge_size)
+ x3, y3 = sidePoint(center, Cpoint, h, w, 'rightBottom', enlarge_size)
+ x4, y4 = sidePoint(Dpoint, center, h, w, 'leftBottom', enlarge_size)
+ newcharbox = np.array([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
+ return newcharbox
\ No newline at end of file
diff --git a/trainer/craft/data/dataset.py b/trainer/craft/data/dataset.py
new file mode 100644
index 000000000..e7e6943e4
--- /dev/null
+++ b/trainer/craft/data/dataset.py
@@ -0,0 +1,542 @@
+import os
+import re
+import itertools
+import random
+
+import numpy as np
+import scipy.io as scio
+from PIL import Image
+import cv2
+from torch.utils.data import Dataset
+import torchvision.transforms as transforms
+
+from data import imgproc
+from data.gaussian import GaussianBuilder
+from data.imgaug import (
+ rescale,
+ random_resize_crop_synth,
+ random_resize_crop,
+ random_horizontal_flip,
+ random_rotate,
+ random_scale,
+ random_crop,
+)
+from data.pseudo_label.make_charbox import PseudoCharBoxBuilder
+from utils.util import saveInput, saveImage
+
+
+class CraftBaseDataset(Dataset):
+ def __init__(
+ self,
+ output_size,
+ data_dir,
+ saved_gt_dir,
+ mean,
+ variance,
+ gauss_init_size,
+ gauss_sigma,
+ enlarge_region,
+ enlarge_affinity,
+ aug,
+ vis_test_dir,
+ vis_opt,
+ sample,
+ ):
+ self.output_size = output_size
+ self.data_dir = data_dir
+ self.saved_gt_dir = saved_gt_dir
+ self.mean, self.variance = mean, variance
+ self.gaussian_builder = GaussianBuilder(
+ gauss_init_size, gauss_sigma, enlarge_region, enlarge_affinity
+ )
+ self.aug = aug
+ self.vis_test_dir = vis_test_dir
+ self.vis_opt = vis_opt
+ self.sample = sample
+ if self.sample != -1:
+ random.seed(0)
+ self.idx = random.sample(range(0, len(self.img_names)), self.sample)
+
+ self.pre_crop_area = []
+
+ def augment_image(
+ self, image, region_score, affinity_score, confidence_mask, word_level_char_bbox
+ ):
+ augment_targets = [image, region_score, affinity_score, confidence_mask]
+
+ if self.aug.random_scale.option:
+ augment_targets, word_level_char_bbox = random_scale(
+ augment_targets, word_level_char_bbox, self.aug.random_scale.range
+ )
+
+ if self.aug.random_rotate.option:
+ augment_targets = random_rotate(
+ augment_targets, self.aug.random_rotate.max_angle
+ )
+
+ if self.aug.random_crop.option:
+ if self.aug.random_crop.version == "random_crop_with_bbox":
+ augment_targets = random_crop_with_bbox(
+ augment_targets, word_level_char_bbox, self.output_size
+ )
+ elif self.aug.random_crop.version == "random_resize_crop_synth":
+ augment_targets = random_resize_crop_synth(
+ augment_targets, self.output_size
+ )
+ elif self.aug.random_crop.version == "random_resize_crop":
+
+ if len(self.pre_crop_area) > 0:
+ pre_crop_area = self.pre_crop_area
+ else:
+ pre_crop_area = None
+
+ augment_targets = random_resize_crop(
+ augment_targets,
+ self.aug.random_crop.scale,
+ self.aug.random_crop.ratio,
+ self.output_size,
+ self.aug.random_crop.rnd_threshold,
+ pre_crop_area,
+ )
+
+ elif self.aug.random_crop.version == "random_crop":
+ augment_targets = random_crop(augment_targets, self.output_size,)
+
+ else:
+ assert "Undefined RandomCrop version"
+
+ if self.aug.random_horizontal_flip.option:
+ augment_targets = random_horizontal_flip(augment_targets)
+
+ if self.aug.random_colorjitter.option:
+ image, region_score, affinity_score, confidence_mask = augment_targets
+ image = Image.fromarray(image)
+ image = transforms.ColorJitter(
+ brightness=self.aug.random_colorjitter.brightness,
+ contrast=self.aug.random_colorjitter.contrast,
+ saturation=self.aug.random_colorjitter.saturation,
+ hue=self.aug.random_colorjitter.hue,
+ )(image)
+ else:
+ image, region_score, affinity_score, confidence_mask = augment_targets
+
+ return np.array(image), region_score, affinity_score, confidence_mask
+
+ def resize_to_half(self, ground_truth, interpolation):
+ return cv2.resize(
+ ground_truth,
+ (self.output_size // 2, self.output_size // 2),
+ interpolation=interpolation,
+ )
+
+ def __len__(self):
+ if self.sample != -1:
+ return len(self.idx)
+ else:
+ return len(self.img_names)
+
+ def __getitem__(self, index):
+ if self.sample != -1:
+ index = self.idx[index]
+ if self.saved_gt_dir is None:
+ (
+ image,
+ region_score,
+ affinity_score,
+ confidence_mask,
+ word_level_char_bbox,
+ all_affinity_bbox,
+ words,
+ ) = self.make_gt_score(index)
+ else:
+ (
+ image,
+ region_score,
+ affinity_score,
+ confidence_mask,
+ word_level_char_bbox,
+ words,
+ ) = self.load_saved_gt_score(index)
+ all_affinity_bbox = []
+
+ if self.vis_opt:
+ saveImage(
+ self.img_names[index],
+ self.vis_test_dir,
+ image.copy(),
+ word_level_char_bbox.copy(),
+ all_affinity_bbox.copy(),
+ region_score.copy(),
+ affinity_score.copy(),
+ confidence_mask.copy(),
+ )
+
+ image, region_score, affinity_score, confidence_mask = self.augment_image(
+ image, region_score, affinity_score, confidence_mask, word_level_char_bbox
+ )
+
+ if self.vis_opt:
+ saveInput(
+ self.img_names[index],
+ self.vis_test_dir,
+ image,
+ region_score,
+ affinity_score,
+ confidence_mask,
+ )
+
+ region_score = self.resize_to_half(region_score, interpolation=cv2.INTER_CUBIC)
+ affinity_score = self.resize_to_half(
+ affinity_score, interpolation=cv2.INTER_CUBIC
+ )
+ confidence_mask = self.resize_to_half(
+ confidence_mask, interpolation=cv2.INTER_NEAREST
+ )
+
+ image = imgproc.normalizeMeanVariance(
+ np.array(image), mean=self.mean, variance=self.variance
+ )
+ image = image.transpose(2, 0, 1)
+
+ return image, region_score, affinity_score, confidence_mask
+
+
+class SynthTextDataSet(CraftBaseDataset):
+ def __init__(
+ self,
+ output_size,
+ data_dir,
+ saved_gt_dir,
+ mean,
+ variance,
+ gauss_init_size,
+ gauss_sigma,
+ enlarge_region,
+ enlarge_affinity,
+ aug,
+ vis_test_dir,
+ vis_opt,
+ sample,
+ ):
+ super().__init__(
+ output_size,
+ data_dir,
+ saved_gt_dir,
+ mean,
+ variance,
+ gauss_init_size,
+ gauss_sigma,
+ enlarge_region,
+ enlarge_affinity,
+ aug,
+ vis_test_dir,
+ vis_opt,
+ sample,
+ )
+ self.img_names, self.char_bbox, self.img_words = self.load_data()
+ self.vis_index = list(range(1000))
+
+ def load_data(self, bbox="char"):
+
+ gt = scio.loadmat(os.path.join(self.data_dir, "gt.mat"))
+ img_names = gt["imnames"][0]
+ img_words = gt["txt"][0]
+
+ if bbox == "char":
+ img_bbox = gt["charBB"][0]
+ else:
+ img_bbox = gt["wordBB"][0] # word bbox needed for test
+
+ return img_names, img_bbox, img_words
+
+ def dilate_img_to_output_size(self, image, char_bbox):
+ h, w, _ = image.shape
+ if min(h, w) <= self.output_size:
+ scale = float(self.output_size) / min(h, w)
+ else:
+ scale = 1.0
+ image = cv2.resize(
+ image, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC
+ )
+ char_bbox *= scale
+ return image, char_bbox
+
+ def make_gt_score(self, index):
+ img_path = os.path.join(self.data_dir, self.img_names[index][0])
+ image = cv2.imread(img_path, cv2.IMREAD_COLOR)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ all_char_bbox = self.char_bbox[index].transpose(
+ (2, 1, 0)
+ ) # shape : (Number of characters in image, 4, 2)
+
+ img_h, img_w, _ = image.shape
+
+ confidence_mask = np.ones((img_h, img_w), dtype=np.float32)
+
+ words = [
+ re.split(" \n|\n |\n| ", word.strip()) for word in self.img_words[index]
+ ]
+ words = list(itertools.chain(*words))
+ words = [word for word in words if len(word) > 0]
+
+ word_level_char_bbox = []
+ char_idx = 0
+
+ for i in range(len(words)):
+ length_of_word = len(words[i])
+ word_bbox = all_char_bbox[char_idx : char_idx + length_of_word]
+ assert len(word_bbox) == length_of_word
+ char_idx += length_of_word
+ word_bbox = np.array(word_bbox)
+ word_level_char_bbox.append(word_bbox)
+
+ region_score = self.gaussian_builder.generate_region(
+ img_h,
+ img_w,
+ word_level_char_bbox,
+ horizontal_text_bools=[True for _ in range(len(words))],
+ )
+ affinity_score, all_affinity_bbox = self.gaussian_builder.generate_affinity(
+ img_h,
+ img_w,
+ word_level_char_bbox,
+ horizontal_text_bools=[True for _ in range(len(words))],
+ )
+
+ return (
+ image,
+ region_score,
+ affinity_score,
+ confidence_mask,
+ word_level_char_bbox,
+ all_affinity_bbox,
+ words,
+ )
+
+
+class CustomDataset(CraftBaseDataset):
+ def __init__(
+ self,
+ output_size,
+ data_dir,
+ saved_gt_dir,
+ mean,
+ variance,
+ gauss_init_size,
+ gauss_sigma,
+ enlarge_region,
+ enlarge_affinity,
+ aug,
+ vis_test_dir,
+ vis_opt,
+ sample,
+ watershed_param,
+ pseudo_vis_opt,
+ do_not_care_label,
+ ):
+ super().__init__(
+ output_size,
+ data_dir,
+ saved_gt_dir,
+ mean,
+ variance,
+ gauss_init_size,
+ gauss_sigma,
+ enlarge_region,
+ enlarge_affinity,
+ aug,
+ vis_test_dir,
+ vis_opt,
+ sample,
+ )
+ self.pseudo_vis_opt = pseudo_vis_opt
+ self.do_not_care_label = do_not_care_label
+ self.pseudo_charbox_builder = PseudoCharBoxBuilder(
+ watershed_param, vis_test_dir, pseudo_vis_opt, self.gaussian_builder
+ )
+ self.vis_index = list(range(1000))
+ self.img_dir = os.path.join(data_dir, "ch4_training_images")
+ self.img_gt_box_dir = os.path.join(
+ data_dir, "ch4_training_localization_transcription_gt"
+ )
+ self.img_names = os.listdir(self.img_dir)
+
+ def update_model(self, net):
+ self.net = net
+
+ def update_device(self, gpu):
+ self.gpu = gpu
+
+ def load_img_gt_box(self, img_gt_box_path):
+ lines = open(img_gt_box_path, encoding="utf-8").readlines()
+ word_bboxes = []
+ words = []
+ for line in lines:
+ box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
+ box_points = [int(box_info[i]) for i in range(8)]
+ box_points = np.array(box_points, np.float32).reshape(4, 2)
+ word = box_info[8:]
+ word = ",".join(word)
+ if word in self.do_not_care_label:
+ words.append(self.do_not_care_label[0])
+ word_bboxes.append(box_points)
+ continue
+ word_bboxes.append(box_points)
+ words.append(word)
+ return np.array(word_bboxes), words
+
+ def load_data(self, index):
+ img_name = self.img_names[index]
+ img_path = os.path.join(self.img_dir, img_name)
+ image = cv2.imread(img_path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ img_gt_box_path = os.path.join(
+ self.img_gt_box_dir, "gt_%s.txt" % os.path.splitext(img_name)[0]
+ )
+ word_bboxes, words = self.load_img_gt_box(
+ img_gt_box_path
+ ) # shape : (Number of word bbox, 4, 2)
+ confidence_mask = np.ones((image.shape[0], image.shape[1]), np.float32)
+
+ word_level_char_bbox = []
+ do_care_words = []
+ horizontal_text_bools = []
+
+ if len(word_bboxes) == 0:
+ return (
+ image,
+ word_level_char_bbox,
+ do_care_words,
+ confidence_mask,
+ horizontal_text_bools,
+ )
+ _word_bboxes = word_bboxes.copy()
+ for i in range(len(word_bboxes)):
+ if words[i] in self.do_not_care_label:
+ cv2.fillPoly(confidence_mask, [np.int32(_word_bboxes[i])], 0)
+ continue
+
+ (
+ pseudo_char_bbox,
+ confidence,
+ horizontal_text_bool,
+ ) = self.pseudo_charbox_builder.build_char_box(
+ self.net, self.gpu, image, word_bboxes[i], words[i], img_name=img_name
+ )
+
+ cv2.fillPoly(confidence_mask, [np.int32(_word_bboxes[i])], confidence)
+ do_care_words.append(words[i])
+ word_level_char_bbox.append(pseudo_char_bbox)
+ horizontal_text_bools.append(horizontal_text_bool)
+
+ return (
+ image,
+ word_level_char_bbox,
+ do_care_words,
+ confidence_mask,
+ horizontal_text_bools,
+ )
+
+ def make_gt_score(self, index):
+ """
+ Make region, affinity scores using pseudo character-level GT bounding box
+ word_level_char_bbox's shape : [word_num, [char_num_in_one_word, 4, 2]]
+ :rtype region_score: np.float32
+ :rtype affinity_score: np.float32
+ :rtype confidence_mask: np.float32
+ :rtype word_level_char_bbox: np.float32
+ :rtype words: list
+ """
+ (
+ image,
+ word_level_char_bbox,
+ words,
+ confidence_mask,
+ horizontal_text_bools,
+ ) = self.load_data(index)
+ img_h, img_w, _ = image.shape
+
+ if len(word_level_char_bbox) == 0:
+ region_score = np.zeros((img_h, img_w), dtype=np.float32)
+ affinity_score = np.zeros((img_h, img_w), dtype=np.float32)
+ all_affinity_bbox = []
+ else:
+ region_score = self.gaussian_builder.generate_region(
+ img_h, img_w, word_level_char_bbox, horizontal_text_bools
+ )
+ affinity_score, all_affinity_bbox = self.gaussian_builder.generate_affinity(
+ img_h, img_w, word_level_char_bbox, horizontal_text_bools
+ )
+
+ return (
+ image,
+ region_score,
+ affinity_score,
+ confidence_mask,
+ word_level_char_bbox,
+ all_affinity_bbox,
+ words,
+ )
+
+ def load_saved_gt_score(self, index):
+ """
+ Load pre-saved official CRAFT model's region, affinity scores to train
+ word_level_char_bbox's shape : [word_num, [char_num_in_one_word, 4, 2]]
+ :rtype region_score: np.float32
+ :rtype affinity_score: np.float32
+ :rtype confidence_mask: np.float32
+ :rtype word_level_char_bbox: np.float32
+ :rtype words: list
+ """
+ img_name = self.img_names[index]
+ img_path = os.path.join(self.img_dir, img_name)
+ image = cv2.imread(img_path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ img_gt_box_path = os.path.join(
+ self.img_gt_box_dir, "gt_%s.txt" % os.path.splitext(img_name)[0]
+ )
+ word_bboxes, words = self.load_img_gt_box(img_gt_box_path)
+ image, word_bboxes = rescale(image, word_bboxes)
+ img_h, img_w, _ = image.shape
+
+ query_idx = int(self.img_names[index].split(".")[0].split("_")[1])
+
+ saved_region_scores_path = os.path.join(
+ self.saved_gt_dir, f"res_img_{query_idx}_region.jpg"
+ )
+ saved_affi_scores_path = os.path.join(
+ self.saved_gt_dir, f"res_img_{query_idx}_affi.jpg"
+ )
+ saved_cf_mask_path = os.path.join(
+ self.saved_gt_dir, f"res_img_{query_idx}_cf_mask_thresh_0.6.jpg"
+ )
+ region_score = cv2.imread(saved_region_scores_path, cv2.IMREAD_GRAYSCALE)
+ affinity_score = cv2.imread(saved_affi_scores_path, cv2.IMREAD_GRAYSCALE)
+ confidence_mask = cv2.imread(saved_cf_mask_path, cv2.IMREAD_GRAYSCALE)
+
+ region_score = cv2.resize(region_score, (img_w, img_h))
+ affinity_score = cv2.resize(affinity_score, (img_w, img_h))
+ confidence_mask = cv2.resize(
+ confidence_mask, (img_w, img_h), interpolation=cv2.INTER_NEAREST
+ )
+
+ region_score = region_score.astype(np.float32) / 255
+ affinity_score = affinity_score.astype(np.float32) / 255
+ confidence_mask = confidence_mask.astype(np.float32) / 255
+
+ # NOTE : Even though word_level_char_bbox is not necessary, align bbox format with make_gt_score()
+ word_level_char_bbox = []
+
+ for i in range(len(word_bboxes)):
+ word_level_char_bbox.append(np.expand_dims(word_bboxes[i], 0))
+
+ return (
+ image,
+ region_score,
+ affinity_score,
+ confidence_mask,
+ word_level_char_bbox,
+ words,
+ )
diff --git a/trainer/craft/data/gaussian.py b/trainer/craft/data/gaussian.py
new file mode 100644
index 000000000..2d0b76e0a
--- /dev/null
+++ b/trainer/craft/data/gaussian.py
@@ -0,0 +1,192 @@
+import numpy as np
+import cv2
+
+from data.boxEnlarge import enlargebox
+
+
+class GaussianBuilder(object):
+ def __init__(self, init_size, sigma, enlarge_region, enlarge_affinity):
+ self.init_size = init_size
+ self.sigma = sigma
+ self.enlarge_region = enlarge_region
+ self.enlarge_affinity = enlarge_affinity
+ self.gaussian_map, self.gaussian_map_color = self.generate_gaussian_map()
+
+ def generate_gaussian_map(self):
+ circle_mask = self.generate_circle_mask()
+
+ gaussian_map = np.zeros((self.init_size, self.init_size), np.float32)
+
+ for i in range(self.init_size):
+ for j in range(self.init_size):
+ gaussian_map[i, j] = (
+ 1
+ / 2
+ / np.pi
+ / (self.sigma ** 2)
+ * np.exp(
+ -1
+ / 2
+ * (
+ (i - self.init_size / 2) ** 2 / (self.sigma ** 2)
+ + (j - self.init_size / 2) ** 2 / (self.sigma ** 2)
+ )
+ )
+ )
+
+ gaussian_map = gaussian_map * circle_mask
+ gaussian_map = (gaussian_map / np.max(gaussian_map)).astype(np.float32)
+
+ gaussian_map_color = (gaussian_map * 255).astype(np.uint8)
+ gaussian_map_color = cv2.applyColorMap(gaussian_map_color, cv2.COLORMAP_JET)
+ return gaussian_map, gaussian_map_color
+
+ def generate_circle_mask(self):
+
+ zero_arr = np.zeros((self.init_size, self.init_size), np.float32)
+ circle_mask = cv2.circle(
+ img=zero_arr,
+ center=(self.init_size // 2, self.init_size // 2),
+ radius=self.init_size // 2,
+ color=1,
+ thickness=-1,
+ )
+
+ return circle_mask
+
+ def four_point_transform(self, bbox):
+ """
+ Using the bbox, standard 2D gaussian map, returns Transformed 2d Gaussian map
+ """
+ width, height = (
+ np.max(bbox[:, 0]).astype(np.int32),
+ np.max(bbox[:, 1]).astype(np.int32),
+ )
+ init_points = np.array(
+ [
+ [0, 0],
+ [self.init_size, 0],
+ [self.init_size, self.init_size],
+ [0, self.init_size],
+ ],
+ dtype="float32",
+ )
+
+ M = cv2.getPerspectiveTransform(init_points, bbox)
+ warped_gaussian_map = cv2.warpPerspective(self.gaussian_map, M, (width, height))
+ return warped_gaussian_map, width, height
+
+ def add_gaussian_map_to_score_map(
+ self, score_map, bbox, enlarge_size, horizontal_text_bool, map_type=None
+ ):
+ """
+ Mapping 2D Gaussian to the character box coordinates of the score_map.
+
+ :param score_map: Target map to put 2D gaussian on character box
+ :type score_map: np.float32
+ :param bbox: character boxes
+ :type bbox: np.float32
+ :param enlarge_size: Enlarge size of gaussian map to fit character shape
+ :type enlarge_size: list of enlarge size [x dim, y dim]
+ :param horizontal_text_bool: Flag that bbox is horizontal text or not
+ :type horizontal_text_bool: bool
+ :param map_type: Whether map's type is "region" | "affinity"
+ :type map_type: str
+ :return score_map: score map that all 2D gaussian put on character box
+ :rtype: np.float32
+ """
+
+ map_h, map_w = score_map.shape
+ bbox = enlargebox(bbox, map_h, map_w, enlarge_size, horizontal_text_bool)
+
+ # If any one point of character bbox is out of range, don't put in on map
+ if np.any(bbox < 0) or np.any(bbox[:, 0] > map_w) or np.any(bbox[:, 1] > map_h):
+ return score_map
+
+ bbox_left, bbox_top = np.array([np.min(bbox[:, 0]), np.min(bbox[:, 1])]).astype(
+ np.int32
+ )
+ bbox -= (bbox_left, bbox_top)
+ warped_gaussian_map, width, height = self.four_point_transform(
+ bbox.astype(np.float32)
+ )
+
+ try:
+ bbox_area_of_image = score_map[
+ bbox_top : bbox_top + height, bbox_left : bbox_left + width,
+ ]
+ high_value_score = np.where(
+ warped_gaussian_map > bbox_area_of_image,
+ warped_gaussian_map,
+ bbox_area_of_image,
+ )
+ score_map[
+ bbox_top : bbox_top + height, bbox_left : bbox_left + width,
+ ] = high_value_score
+
+ except Exception as e:
+ print("Error : {}".format(e))
+ print(
+ "On generating {} map, strange box came out. (width: {}, height: {})".format(
+ map_type, width, height
+ )
+ )
+
+ return score_map
+
+ def calculate_affinity_box_points(self, bbox_1, bbox_2, vertical=False):
+ center_1, center_2 = np.mean(bbox_1, axis=0), np.mean(bbox_2, axis=0)
+ if vertical:
+ tl = (bbox_1[0] + bbox_1[-1] + center_1) / 3
+ tr = (bbox_1[1:3].sum(0) + center_1) / 3
+ br = (bbox_2[1:3].sum(0) + center_2) / 3
+ bl = (bbox_2[0] + bbox_2[-1] + center_2) / 3
+ else:
+ tl = (bbox_1[0:2].sum(0) + center_1) / 3
+ tr = (bbox_2[0:2].sum(0) + center_2) / 3
+ br = (bbox_2[2:4].sum(0) + center_2) / 3
+ bl = (bbox_1[2:4].sum(0) + center_1) / 3
+ affinity_box = np.array([tl, tr, br, bl]).astype(np.float32)
+ return affinity_box
+
+ def generate_region(
+ self, img_h, img_w, word_level_char_bbox, horizontal_text_bools
+ ):
+ region_map = np.zeros([img_h, img_w], dtype=np.float32)
+ for i in range(
+ len(word_level_char_bbox)
+ ): # shape : [word_num, [char_num_in_one_word, 4, 2]]
+ for j in range(len(word_level_char_bbox[i])):
+ region_map = self.add_gaussian_map_to_score_map(
+ region_map,
+ word_level_char_bbox[i][j].copy(),
+ self.enlarge_region,
+ horizontal_text_bools[i],
+ map_type="region",
+ )
+ return region_map
+
+ def generate_affinity(
+ self, img_h, img_w, word_level_char_bbox, horizontal_text_bools
+ ):
+
+ affinity_map = np.zeros([img_h, img_w], dtype=np.float32)
+ all_affinity_bbox = []
+ for i in range(len(word_level_char_bbox)):
+ for j in range(len(word_level_char_bbox[i]) - 1):
+ affinity_bbox = self.calculate_affinity_box_points(
+ word_level_char_bbox[i][j], word_level_char_bbox[i][j + 1]
+ )
+
+ affinity_map = self.add_gaussian_map_to_score_map(
+ affinity_map,
+ affinity_bbox.copy(),
+ self.enlarge_affinity,
+ horizontal_text_bools[i],
+ map_type="affinity",
+ )
+ all_affinity_bbox.append(np.expand_dims(affinity_bbox, axis=0))
+
+ if len(all_affinity_bbox) > 0:
+ all_affinity_bbox = np.concatenate(all_affinity_bbox, axis=0)
+ return affinity_map, all_affinity_bbox
\ No newline at end of file
diff --git a/trainer/craft/data/imgaug.py b/trainer/craft/data/imgaug.py
new file mode 100644
index 000000000..d24a456d2
--- /dev/null
+++ b/trainer/craft/data/imgaug.py
@@ -0,0 +1,175 @@
+import random
+
+import cv2
+import numpy as np
+from PIL import Image
+from torchvision.transforms.functional import resized_crop, crop
+from torchvision.transforms import RandomResizedCrop, RandomCrop
+from torchvision.transforms import InterpolationMode
+
+
+def rescale(img, bboxes, target_size=2240):
+ h, w = img.shape[0:2]
+ scale = target_size / max(h, w)
+ img = cv2.resize(img, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+ bboxes = bboxes * scale
+ return img, bboxes
+
+
+def random_resize_crop_synth(augment_targets, size):
+ image, region_score, affinity_score, confidence_mask = augment_targets
+
+ image = Image.fromarray(image)
+ region_score = Image.fromarray(region_score)
+ affinity_score = Image.fromarray(affinity_score)
+ confidence_mask = Image.fromarray(confidence_mask)
+
+ short_side = min(image.size)
+ i, j, h, w = RandomCrop.get_params(image, output_size=(short_side, short_side))
+
+ image = resized_crop(
+ image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC
+ )
+ region_score = resized_crop(
+ region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC
+ )
+ affinity_score = resized_crop(
+ affinity_score,
+ i,
+ j,
+ h,
+ w,
+ (size, size),
+ interpolation=InterpolationMode.BICUBIC,
+ )
+ confidence_mask = resized_crop(
+ confidence_mask,
+ i,
+ j,
+ h,
+ w,
+ (size, size),
+ interpolation=InterpolationMode.NEAREST,
+ )
+
+ image = np.array(image)
+ region_score = np.array(region_score)
+ affinity_score = np.array(affinity_score)
+ confidence_mask = np.array(confidence_mask)
+ augment_targets = [image, region_score, affinity_score, confidence_mask]
+
+ return augment_targets
+
+
+def random_resize_crop(
+ augment_targets, scale, ratio, size, threshold, pre_crop_area=None
+):
+ image, region_score, affinity_score, confidence_mask = augment_targets
+
+ image = Image.fromarray(image)
+ region_score = Image.fromarray(region_score)
+ affinity_score = Image.fromarray(affinity_score)
+ confidence_mask = Image.fromarray(confidence_mask)
+
+ if pre_crop_area != None:
+ i, j, h, w = pre_crop_area
+
+ else:
+ if random.random() < threshold:
+ i, j, h, w = RandomResizedCrop.get_params(image, scale=scale, ratio=ratio)
+ else:
+ i, j, h, w = RandomResizedCrop.get_params(
+ image, scale=(1.0, 1.0), ratio=(1.0, 1.0)
+ )
+
+ image = resized_crop(
+ image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC
+ )
+ region_score = resized_crop(
+ region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC
+ )
+ affinity_score = resized_crop(
+ affinity_score,
+ i,
+ j,
+ h,
+ w,
+ (size, size),
+ interpolation=InterpolationMode.BICUBIC,
+ )
+ confidence_mask = resized_crop(
+ confidence_mask,
+ i,
+ j,
+ h,
+ w,
+ (size, size),
+ interpolation=InterpolationMode.NEAREST,
+ )
+
+ image = np.array(image)
+ region_score = np.array(region_score)
+ affinity_score = np.array(affinity_score)
+ confidence_mask = np.array(confidence_mask)
+ augment_targets = [image, region_score, affinity_score, confidence_mask]
+
+ return augment_targets
+
+
+def random_crop(augment_targets, size):
+ image, region_score, affinity_score, confidence_mask = augment_targets
+
+ image = Image.fromarray(image)
+ region_score = Image.fromarray(region_score)
+ affinity_score = Image.fromarray(affinity_score)
+ confidence_mask = Image.fromarray(confidence_mask)
+
+ i, j, h, w = RandomCrop.get_params(image, output_size=(size, size))
+
+ image = crop(image, i, j, h, w)
+ region_score = crop(region_score, i, j, h, w)
+ affinity_score = crop(affinity_score, i, j, h, w)
+ confidence_mask = crop(confidence_mask, i, j, h, w)
+
+ image = np.array(image)
+ region_score = np.array(region_score)
+ affinity_score = np.array(affinity_score)
+ confidence_mask = np.array(confidence_mask)
+ augment_targets = [image, region_score, affinity_score, confidence_mask]
+
+ return augment_targets
+
+
+def random_horizontal_flip(imgs):
+ if random.random() < 0.5:
+ for i in range(len(imgs)):
+ imgs[i] = np.flip(imgs[i], axis=1).copy()
+ return imgs
+
+
+def random_scale(images, word_level_char_bbox, scale_range):
+ scale = random.sample(scale_range, 1)[0]
+
+ for i in range(len(images)):
+ images[i] = cv2.resize(images[i], dsize=None, fx=scale, fy=scale)
+
+ for i in range(len(word_level_char_bbox)):
+ word_level_char_bbox[i] *= scale
+
+ return images
+
+
+def random_rotate(images, max_angle):
+ angle = random.random() * 2 * max_angle - max_angle
+ for i in range(len(images)):
+ img = images[i]
+ w, h = img.shape[:2]
+ rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
+ if i == len(images) - 1:
+ img_rotation = cv2.warpAffine(
+ img, M=rotation_matrix, dsize=(h, w), flags=cv2.INTER_NEAREST
+ )
+ else:
+ img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w))
+ images[i] = img_rotation
+ return images
diff --git a/trainer/craft/data/imgproc.py b/trainer/craft/data/imgproc.py
new file mode 100644
index 000000000..9c302af06
--- /dev/null
+++ b/trainer/craft/data/imgproc.py
@@ -0,0 +1,91 @@
+"""
+Copyright (c) 2019-present NAVER Corp.
+MIT License
+"""
+
+# -*- coding: utf-8 -*-
+import numpy as np
+
+import cv2
+from skimage import io
+
+
+def loadImage(img_file):
+ img = io.imread(img_file) # RGB order
+ if img.shape[0] == 2:
+ img = img[0]
+ if len(img.shape) == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ if img.shape[2] == 4:
+ img = img[:, :, :3]
+ img = np.array(img)
+
+ return img
+
+
+def normalizeMeanVariance(
+ in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)
+):
+ # should be RGB order
+ img = in_img.copy().astype(np.float32)
+
+ img -= np.array(
+ [mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32
+ )
+ img /= np.array(
+ [variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0],
+ dtype=np.float32,
+ )
+ return img
+
+
+def denormalizeMeanVariance(
+ in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)
+):
+ # should be RGB order
+ img = in_img.copy()
+ img *= variance
+ img += mean
+ img *= 255.0
+ img = np.clip(img, 0, 255).astype(np.uint8)
+ return img
+
+
+def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1):
+ height, width, channel = img.shape
+
+ # magnify image size
+ target_size = mag_ratio * max(height, width)
+
+ # set original image size
+ if target_size > square_size:
+ target_size = square_size
+
+ ratio = target_size / max(height, width)
+
+ target_h, target_w = int(height * ratio), int(width * ratio)
+
+ # NOTE
+ valid_size_heatmap = (int(target_h / 2), int(target_w / 2))
+
+ proc = cv2.resize(img, (target_w, target_h), interpolation=interpolation)
+
+ # make canvas and paste image
+ target_h32, target_w32 = target_h, target_w
+ if target_h % 32 != 0:
+ target_h32 = target_h + (32 - target_h % 32)
+ if target_w % 32 != 0:
+ target_w32 = target_w + (32 - target_w % 32)
+ resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
+ resized[0:target_h, 0:target_w, :] = proc
+
+ # target_h, target_w = target_h32, target_w32
+ # size_heatmap = (int(target_w/2), int(target_h/2))
+
+ return resized, ratio, valid_size_heatmap
+
+
+def cvt2HeatmapImg(img):
+ img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
+ img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
+ return img
diff --git a/trainer/craft/data/pseudo_label/make_charbox.py b/trainer/craft/data/pseudo_label/make_charbox.py
new file mode 100644
index 000000000..09c5219b3
--- /dev/null
+++ b/trainer/craft/data/pseudo_label/make_charbox.py
@@ -0,0 +1,263 @@
+import os
+import random
+import math
+
+import numpy as np
+import cv2
+import torch
+
+from data import imgproc
+from data.pseudo_label.watershed import exec_watershed_by_version
+
+
+class PseudoCharBoxBuilder:
+ def __init__(self, watershed_param, vis_test_dir, pseudo_vis_opt, gaussian_builder):
+ self.watershed_param = watershed_param
+ self.vis_test_dir = vis_test_dir
+ self.pseudo_vis_opt = pseudo_vis_opt
+ self.gaussian_builder = gaussian_builder
+ self.cnt = 0
+ self.flag = False
+
+ def crop_image_by_bbox(self, image, box, word):
+ w = max(
+ int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[2] - box[3]))
+ )
+ h = max(
+ int(np.linalg.norm(box[0] - box[3])), int(np.linalg.norm(box[1] - box[2]))
+ )
+ try:
+ word_ratio = h / w
+ except:
+ import ipdb
+
+ ipdb.set_trace()
+
+ one_char_ratio = min(h, w) / (max(h, w) / len(word))
+
+ # NOTE: criterion to split vertical word in here is set to work properly on IC15 dataset
+ if word_ratio > 2 or (word_ratio > 1.6 and one_char_ratio > 2.4):
+ # warping method of vertical word (classified by upper condition)
+ horizontal_text_bool = False
+ long_side = h
+ short_side = w
+ M = cv2.getPerspectiveTransform(
+ np.float32(box),
+ np.float32(
+ np.array(
+ [
+ [long_side, 0],
+ [long_side, short_side],
+ [0, short_side],
+ [0, 0],
+ ]
+ )
+ ),
+ )
+ self.flag = True
+ else:
+ # warping method of horizontal word
+ horizontal_text_bool = True
+ long_side = w
+ short_side = h
+ M = cv2.getPerspectiveTransform(
+ np.float32(box),
+ np.float32(
+ np.array(
+ [
+ [0, 0],
+ [long_side, 0],
+ [long_side, short_side],
+ [0, short_side],
+ ]
+ )
+ ),
+ )
+ self.flag = False
+
+ warped = cv2.warpPerspective(image, M, (long_side, short_side))
+ return warped, M, horizontal_text_bool
+
+ def inference_word_box(self, net, gpu, word_image):
+ if net.training:
+ net.eval()
+
+ with torch.no_grad():
+ word_img_torch = torch.from_numpy(
+ imgproc.normalizeMeanVariance(
+ word_image,
+ mean=(0.485, 0.456, 0.406),
+ variance=(0.229, 0.224, 0.225),
+ )
+ )
+ word_img_torch = word_img_torch.permute(2, 0, 1).unsqueeze(0)
+ word_img_torch = word_img_torch.type(torch.FloatTensor).cuda(gpu)
+ with torch.cuda.amp.autocast():
+ word_img_scores, _ = net(word_img_torch)
+ return word_img_scores
+
+ def visualize_pseudo_label(
+ self, word_image, region_score, watershed_box, pseudo_char_bbox, img_name,
+ ):
+ word_img_h, word_img_w, _ = word_image.shape
+ word_img_cp1 = word_image.copy()
+ word_img_cp2 = word_image.copy()
+ _watershed_box = np.int32(watershed_box)
+ _pseudo_char_bbox = np.int32(pseudo_char_bbox)
+
+ region_score_color = cv2.applyColorMap(np.uint8(region_score), cv2.COLORMAP_JET)
+ region_score_color = cv2.resize(region_score_color, (word_img_w, word_img_h))
+
+ for box in _watershed_box:
+ cv2.polylines(
+ np.uint8(word_img_cp1),
+ [np.reshape(box, (-1, 1, 2))],
+ True,
+ (255, 0, 0),
+ )
+
+ for box in _pseudo_char_bbox:
+ cv2.polylines(
+ np.uint8(word_img_cp2), [np.reshape(box, (-1, 1, 2))], True, (255, 0, 0)
+ )
+
+ # NOTE: Just for visualize, put gaussian map on char box
+ pseudo_gt_region_score = self.gaussian_builder.generate_region(
+ word_img_h, word_img_w, [_pseudo_char_bbox], [True]
+ )
+
+ pseudo_gt_region_score = cv2.applyColorMap(
+ (pseudo_gt_region_score * 255).astype("uint8"), cv2.COLORMAP_JET
+ )
+
+ overlay_img = cv2.addWeighted(
+ word_image[:, :, ::-1], 0.7, pseudo_gt_region_score, 0.3, 5
+ )
+ vis_result = np.hstack(
+ [
+ word_image[:, :, ::-1],
+ region_score_color,
+ word_img_cp1[:, :, ::-1],
+ word_img_cp2[:, :, ::-1],
+ pseudo_gt_region_score,
+ overlay_img,
+ ]
+ )
+
+ if not os.path.exists(os.path.dirname(self.vis_test_dir)):
+ os.makedirs(os.path.dirname(self.vis_test_dir))
+ cv2.imwrite(
+ os.path.join(
+ self.vis_test_dir,
+ "{}_{}".format(
+ img_name, f"pseudo_char_bbox_{random.randint(0,100)}.jpg"
+ ),
+ ),
+ vis_result,
+ )
+
+ def clip_into_boundary(self, box, bound):
+ if len(box) == 0:
+ return box
+ else:
+ box[:, :, 0] = np.clip(box[:, :, 0], 0, bound[1])
+ box[:, :, 1] = np.clip(box[:, :, 1], 0, bound[0])
+ return box
+
+ def get_confidence(self, real_len, pseudo_len):
+ if pseudo_len == 0:
+ return 0.0
+ return (real_len - min(real_len, abs(real_len - pseudo_len))) / real_len
+
+ def split_word_equal_gap(self, word_img_w, word_img_h, word):
+ width = word_img_w
+ height = word_img_h
+
+ width_per_char = width / len(word)
+ bboxes = []
+ for j, char in enumerate(word):
+ if char == " ":
+ continue
+ left = j * width_per_char
+ right = (j + 1) * width_per_char
+ bbox = np.array([[left, 0], [right, 0], [right, height], [left, height]])
+ bboxes.append(bbox)
+
+ bboxes = np.array(bboxes, np.float32)
+ return bboxes
+
+ def cal_angle(self, v1):
+ theta = np.arccos(min(1, v1[0] / (np.linalg.norm(v1) + 10e-8)))
+ return 2 * math.pi - theta if v1[1] < 0 else theta
+
+ def clockwise_sort(self, points):
+ # returns 4x2 [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] ndarray
+ v1, v2, v3, v4 = points
+ center = (v1 + v2 + v3 + v4) / 4
+ theta = np.array(
+ [
+ self.cal_angle(v1 - center),
+ self.cal_angle(v2 - center),
+ self.cal_angle(v3 - center),
+ self.cal_angle(v4 - center),
+ ]
+ )
+ index = np.argsort(theta)
+ return np.array([v1, v2, v3, v4])[index, :]
+
+ def build_char_box(self, net, gpu, image, word_bbox, word, img_name=""):
+ word_image, M, horizontal_text_bool = self.crop_image_by_bbox(
+ image, word_bbox, word
+ )
+ real_word_without_space = word.replace("\s", "")
+ real_char_len = len(real_word_without_space)
+
+ scale = 128.0 / word_image.shape[0]
+
+ word_image = cv2.resize(word_image, None, fx=scale, fy=scale)
+ word_img_h, word_img_w, _ = word_image.shape
+
+ scores = self.inference_word_box(net, gpu, word_image)
+ region_score = scores[0, :, :, 0].cpu().data.numpy()
+ region_score = np.uint8(np.clip(region_score, 0, 1) * 255)
+
+ region_score_rgb = cv2.resize(region_score, (word_img_w, word_img_h))
+ region_score_rgb = cv2.cvtColor(region_score_rgb, cv2.COLOR_GRAY2RGB)
+
+ pseudo_char_bbox = exec_watershed_by_version(
+ self.watershed_param, region_score, word_image, self.pseudo_vis_opt
+ )
+
+ # Used for visualize only
+ watershed_box = pseudo_char_bbox.copy()
+
+ pseudo_char_bbox = self.clip_into_boundary(
+ pseudo_char_bbox, region_score_rgb.shape
+ )
+
+ confidence = self.get_confidence(real_char_len, len(pseudo_char_bbox))
+
+ if confidence <= 0.5:
+ pseudo_char_bbox = self.split_word_equal_gap(word_img_w, word_img_h, word)
+ confidence = 0.5
+
+ if self.pseudo_vis_opt and self.flag:
+ self.visualize_pseudo_label(
+ word_image, region_score, watershed_box, pseudo_char_bbox, img_name,
+ )
+
+ if len(pseudo_char_bbox) != 0:
+ index = np.argsort(pseudo_char_bbox[:, 0, 0])
+ pseudo_char_bbox = pseudo_char_bbox[index]
+
+ pseudo_char_bbox /= scale
+
+ M_inv = np.linalg.pinv(M)
+ for i in range(len(pseudo_char_bbox)):
+ pseudo_char_bbox[i] = cv2.perspectiveTransform(
+ pseudo_char_bbox[i][None, :, :], M_inv
+ )
+
+ pseudo_char_bbox = self.clip_into_boundary(pseudo_char_bbox, image.shape)
+
+ return pseudo_char_bbox, confidence, horizontal_text_bool
diff --git a/trainer/craft/data/pseudo_label/watershed.py b/trainer/craft/data/pseudo_label/watershed.py
new file mode 100644
index 000000000..72cadc089
--- /dev/null
+++ b/trainer/craft/data/pseudo_label/watershed.py
@@ -0,0 +1,45 @@
+import cv2
+import numpy as np
+from skimage.segmentation import watershed
+
+
+def segment_region_score(watershed_param, region_score, word_image, pseudo_vis_opt):
+ region_score = np.float32(region_score) / 255
+ fore = np.uint8(region_score > 0.75)
+ back = np.uint8(region_score < 0.05)
+ unknown = 1 - (fore + back)
+ ret, markers = cv2.connectedComponents(fore)
+ markers += 1
+ markers[unknown == 1] = 0
+
+ labels = watershed(-region_score, markers)
+ boxes = []
+ for label in range(2, ret + 1):
+ y, x = np.where(labels == label)
+ x_max = x.max()
+ y_max = y.max()
+ x_min = x.min()
+ y_min = y.min()
+ box = [[x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max]]
+ box = np.array(box)
+ box *= 2
+ boxes.append(box)
+ return np.array(boxes, dtype=np.float32)
+
+
+def exec_watershed_by_version(
+ watershed_param, region_score, word_image, pseudo_vis_opt
+):
+
+ func_name_map_dict = {
+ "skimage": segment_region_score,
+ }
+
+ try:
+ return func_name_map_dict[watershed_param.version](
+ watershed_param, region_score, word_image, pseudo_vis_opt
+ )
+ except:
+ print(
+ f"Watershed version {watershed_param.version} does not exist in func_name_map_dict."
+ )
diff --git a/trainer/craft/data_root_dir/folder.txt b/trainer/craft/data_root_dir/folder.txt
new file mode 100644
index 000000000..8e455ad77
--- /dev/null
+++ b/trainer/craft/data_root_dir/folder.txt
@@ -0,0 +1 @@
+place dataset folder here
diff --git a/trainer/craft/eval.py b/trainer/craft/eval.py
new file mode 100644
index 000000000..fceea4735
--- /dev/null
+++ b/trainer/craft/eval.py
@@ -0,0 +1,381 @@
+# -*- coding: utf-8 -*-
+
+import argparse
+import os
+
+import cv2
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from tqdm import tqdm
+import wandb
+
+from config.load_config import load_yaml, DotDict
+from model.craft import CRAFT
+from metrics.eval_det_iou import DetectionIoUEvaluator
+from utils.inference_boxes import (
+ test_net,
+ load_icdar2015_gt,
+ load_icdar2013_gt,
+ load_synthtext_gt,
+)
+from utils.util import copyStateDict
+
+
+
+def save_result_synth(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""):
+
+ img = np.array(img)
+ img_copy = img.copy()
+ region = pre_output[0]
+ affinity = pre_output[1]
+
+ # make result file list
+ filename, file_ext = os.path.splitext(os.path.basename(img_file))
+
+ # draw bounding boxes for prediction, color green
+ for i, box in enumerate(pre_box):
+ poly = np.array(box).astype(np.int32).reshape((-1))
+ poly = poly.reshape(-1, 2)
+ try:
+ cv2.polylines(
+ img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
+ )
+ except:
+ pass
+
+ # draw bounding boxes for gt, color red
+ if gt_box is not None:
+ for j in range(len(gt_box)):
+ cv2.polylines(
+ img,
+ [np.array(gt_box[j]["points"]).astype(np.int32).reshape((-1, 1, 2))],
+ True,
+ color=(0, 0, 255),
+ thickness=2,
+ )
+
+ # draw overlay image
+ overlay_img = overlay(img_copy, region, affinity, pre_box)
+
+ # Save result image
+ res_img_path = result_dir + "/res_" + filename + ".jpg"
+ cv2.imwrite(res_img_path, img)
+
+ overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
+ cv2.imwrite(overlay_image_path, overlay_img)
+
+
+def save_result_2015(img_file, img, pre_output, pre_box, gt_box, result_dir):
+
+ img = np.array(img)
+ img_copy = img.copy()
+ region = pre_output[0]
+ affinity = pre_output[1]
+
+ # make result file list
+ filename, file_ext = os.path.splitext(os.path.basename(img_file))
+
+ for i, box in enumerate(pre_box):
+ poly = np.array(box).astype(np.int32).reshape((-1))
+ poly = poly.reshape(-1, 2)
+ try:
+ cv2.polylines(
+ img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
+ )
+ except:
+ pass
+
+ if gt_box is not None:
+ for j in range(len(gt_box)):
+ _gt_box = np.array(gt_box[j]["points"]).reshape(-1, 2).astype(np.int32)
+ if gt_box[j]["text"] == "###":
+ cv2.polylines(img, [_gt_box], True, color=(128, 128, 128), thickness=2)
+ else:
+ cv2.polylines(img, [_gt_box], True, color=(0, 0, 255), thickness=2)
+
+ # draw overlay image
+ overlay_img = overlay(img_copy, region, affinity, pre_box)
+
+ # Save result image
+ res_img_path = result_dir + "/res_" + filename + ".jpg"
+ cv2.imwrite(res_img_path, img)
+
+ overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
+ cv2.imwrite(overlay_image_path, overlay_img)
+
+
+def save_result_2013(img_file, img, pre_output, pre_box, gt_box=None, result_dir=""):
+
+ img = np.array(img)
+ img_copy = img.copy()
+ region = pre_output[0]
+ affinity = pre_output[1]
+
+ # make result file list
+ filename, file_ext = os.path.splitext(os.path.basename(img_file))
+
+ # draw bounding boxes for prediction, color green
+ for i, box in enumerate(pre_box):
+ poly = np.array(box).astype(np.int32).reshape((-1))
+ poly = poly.reshape(-1, 2)
+ try:
+ cv2.polylines(
+ img, [poly.reshape((-1, 1, 2))], True, color=(0, 255, 0), thickness=2
+ )
+ except:
+ pass
+
+ # draw bounding boxes for gt, color red
+ if gt_box is not None:
+ for j in range(len(gt_box)):
+ cv2.polylines(
+ img,
+ [np.array(gt_box[j]["points"]).reshape((-1, 1, 2))],
+ True,
+ color=(0, 0, 255),
+ thickness=2,
+ )
+
+ # draw overlay image
+ overlay_img = overlay(img_copy, region, affinity, pre_box)
+
+ # Save result image
+ res_img_path = result_dir + "/res_" + filename + ".jpg"
+ cv2.imwrite(res_img_path, img)
+
+ overlay_image_path = result_dir + "/res_" + filename + "_box.jpg"
+ cv2.imwrite(overlay_image_path, overlay_img)
+
+
+def overlay(image, region, affinity, single_img_bbox):
+
+ height, width, channel = image.shape
+
+ region_score = cv2.resize(region, (width, height))
+ affinity_score = cv2.resize(affinity, (width, height))
+
+ overlay_region = cv2.addWeighted(image.copy(), 0.4, region_score, 0.6, 5)
+ overlay_aff = cv2.addWeighted(image.copy(), 0.4, affinity_score, 0.6, 5)
+
+ boxed_img = image.copy()
+ for word_box in single_img_bbox:
+ cv2.polylines(
+ boxed_img,
+ [word_box.astype(np.int32).reshape((-1, 1, 2))],
+ True,
+ color=(0, 255, 0),
+ thickness=3,
+ )
+
+ temp1 = np.hstack([image, boxed_img])
+ temp2 = np.hstack([overlay_region, overlay_aff])
+ temp3 = np.vstack([temp1, temp2])
+
+ return temp3
+
+
+def load_test_dataset_iou(test_folder_name, config):
+
+ if test_folder_name == "synthtext":
+ total_bboxes_gt, total_img_path = load_synthtext_gt(config.test_data_dir)
+
+ elif test_folder_name == "icdar2013":
+ total_bboxes_gt, total_img_path = load_icdar2013_gt(
+ dataFolder=config.test_data_dir
+ )
+
+ elif test_folder_name == "icdar2015":
+ total_bboxes_gt, total_img_path = load_icdar2015_gt(
+ dataFolder=config.test_data_dir
+ )
+
+ elif test_folder_name == "custom_data":
+ total_bboxes_gt, total_img_path = load_icdar2015_gt(
+ dataFolder=config.test_data_dir
+ )
+
+ else:
+ print("not found test dataset")
+ return None, None
+
+ return total_bboxes_gt, total_img_path
+
+
+def viz_test(img, pre_output, pre_box, gt_box, img_name, result_dir, test_folder_name):
+
+ if test_folder_name == "synthtext":
+ save_result_synth(
+ img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
+ )
+ elif test_folder_name == "icdar2013":
+ save_result_2013(
+ img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
+ )
+ elif test_folder_name == "icdar2015":
+ save_result_2015(
+ img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
+ )
+ elif test_folder_name == "custom_data":
+ save_result_2015(
+ img_name, img[:, :, ::-1].copy(), pre_output, pre_box, gt_box, result_dir
+ )
+ else:
+ print("not found test dataset")
+
+
+def main_eval(model_path, backbone, config, evaluator, result_dir, buffer, model, mode):
+
+ if not os.path.exists(result_dir):
+ os.makedirs(result_dir, exist_ok=True)
+
+ total_imgs_bboxes_gt, total_imgs_path = load_test_dataset_iou("custom_data", config)
+
+ if mode == "weak_supervision" and torch.cuda.device_count() != 1:
+ gpu_count = torch.cuda.device_count() // 2
+ else:
+ gpu_count = torch.cuda.device_count()
+ gpu_idx = torch.cuda.current_device()
+ torch.cuda.set_device(gpu_idx)
+
+ # Only evaluation time
+ if model is None:
+ piece_imgs_path = total_imgs_path
+
+ if backbone == "vgg":
+ model = CRAFT()
+ else:
+ raise Exception("Undefined architecture")
+
+ print("Loading weights from checkpoint (" + model_path + ")")
+ net_param = torch.load(model_path, map_location=f"cuda:{gpu_idx}")
+ model.load_state_dict(copyStateDict(net_param["craft"]))
+
+ if config.cuda:
+ model = model.cuda()
+ cudnn.benchmark = False
+
+ # Distributed evaluation in the middle of training time
+ else:
+ if buffer is not None:
+ # check all buffer value is None for distributed evaluation
+ assert all(
+ v is None for v in buffer
+ ), "Buffer already filled with another value."
+ slice_idx = len(total_imgs_bboxes_gt) // gpu_count
+
+ # last gpu
+ if gpu_idx == gpu_count - 1:
+ piece_imgs_path = total_imgs_path[gpu_idx * slice_idx :]
+ # piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx:]
+ else:
+ piece_imgs_path = total_imgs_path[
+ gpu_idx * slice_idx : (gpu_idx + 1) * slice_idx
+ ]
+ # piece_imgs_bboxes_gt = total_imgs_bboxes_gt[gpu_idx * slice_idx: (gpu_idx + 1) * slice_idx]
+
+ model.eval()
+
+ # -----------------------------------------------------------------------------------------------------------------#
+ total_imgs_bboxes_pre = []
+ for k, img_path in enumerate(tqdm(piece_imgs_path)):
+ image = cv2.imread(img_path)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ single_img_bbox = []
+ bboxes, polys, score_text = test_net(
+ model,
+ image,
+ config.text_threshold,
+ config.link_threshold,
+ config.low_text,
+ config.cuda,
+ config.poly,
+ config.canvas_size,
+ config.mag_ratio,
+ )
+
+ for box in bboxes:
+ box_info = {"points": box, "text": "###", "ignore": False}
+ single_img_bbox.append(box_info)
+ total_imgs_bboxes_pre.append(single_img_bbox)
+ # Distributed evaluation -------------------------------------------------------------------------------------#
+ if buffer is not None:
+ buffer[gpu_idx * slice_idx + k] = single_img_bbox
+ # print(sum([element is not None for element in buffer]))
+ # -------------------------------------------------------------------------------------------------------------#
+
+ if config.vis_opt:
+ viz_test(
+ image,
+ score_text,
+ pre_box=polys,
+ gt_box=total_imgs_bboxes_gt[k],
+ img_name=img_path,
+ result_dir=result_dir,
+ test_folder_name="custom_data",
+ )
+
+ # When distributed evaluation mode, wait until buffer is full filled
+ if buffer is not None:
+ while None in buffer:
+ continue
+ assert all(v is not None for v in buffer), "Buffer not filled"
+ total_imgs_bboxes_pre = buffer
+
+ results = []
+ for i, (gt, pred) in enumerate(zip(total_imgs_bboxes_gt, total_imgs_bboxes_pre)):
+ perSampleMetrics_dict = evaluator.evaluate_image(gt, pred)
+ results.append(perSampleMetrics_dict)
+
+ metrics = evaluator.combine_results(results)
+ print(metrics)
+ return metrics
+
+def cal_eval(config, data, res_dir_name, opt, mode):
+ evaluator = DetectionIoUEvaluator()
+ test_config = DotDict(config.test[data])
+ res_dir = os.path.join(os.path.join("exp", args.yaml), "{}".format(res_dir_name))
+
+ if opt == "iou_eval":
+ main_eval(
+ config.test.trained_model,
+ config.train.backbone,
+ test_config,
+ evaluator,
+ res_dir,
+ buffer=None,
+ model=None,
+ mode=mode,
+ )
+ else:
+ print("Undefined evaluation")
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description="CRAFT Text Detection Eval")
+ parser.add_argument(
+ "--yaml",
+ "--yaml_file_name",
+ default="custom_data_train",
+ type=str,
+ help="Load configuration",
+ )
+ args = parser.parse_args()
+
+ # load configure
+ config = load_yaml(args.yaml)
+ config = DotDict(config)
+
+ if config["wandb_opt"]:
+ wandb.init(project="evaluation", entity="gmuffiness", name=args.yaml)
+ wandb.config.update(config)
+
+ val_result_dir_name = args.yaml
+ cal_eval(
+ config,
+ "custom_data",
+ val_result_dir_name + "-ic15-iou",
+ opt="iou_eval",
+ mode=None,
+ )
diff --git a/trainer/craft/exp/folder.txt b/trainer/craft/exp/folder.txt
new file mode 100644
index 000000000..b557bc215
--- /dev/null
+++ b/trainer/craft/exp/folder.txt
@@ -0,0 +1 @@
+trained model will be saved here
diff --git a/trainer/craft/loss/mseloss.py b/trainer/craft/loss/mseloss.py
new file mode 100644
index 000000000..dc24d5ab4
--- /dev/null
+++ b/trainer/craft/loss/mseloss.py
@@ -0,0 +1,172 @@
+import torch
+import torch.nn as nn
+
+
+class Loss(nn.Module):
+ def __init__(self):
+ super(Loss, self).__init__()
+
+ def forward(self, gt_region, gt_affinity, pred_region, pred_affinity, conf_map):
+ loss = torch.mean(
+ ((gt_region - pred_region).pow(2) + (gt_affinity - pred_affinity).pow(2))
+ * conf_map
+ )
+ return loss
+
+
+class Maploss_v2(nn.Module):
+ def __init__(self):
+
+ super(Maploss_v2, self).__init__()
+
+ def batch_image_loss(self, pred_score, label_score, neg_rto, n_min_neg):
+
+ # positive_loss
+ positive_pixel = (label_score > 0.1).float()
+ positive_pixel_number = torch.sum(positive_pixel)
+
+ positive_loss_region = pred_score * positive_pixel
+
+ # negative_loss
+ negative_pixel = (label_score <= 0.1).float()
+ negative_pixel_number = torch.sum(negative_pixel)
+ negative_loss_region = pred_score * negative_pixel
+
+ if positive_pixel_number != 0:
+ if negative_pixel_number < neg_rto * positive_pixel_number:
+ negative_loss = (
+ torch.sum(
+ torch.topk(
+ negative_loss_region.view(-1), n_min_neg, sorted=False
+ )[0]
+ )
+ / n_min_neg
+ )
+ else:
+ negative_loss = torch.sum(
+ torch.topk(
+ negative_loss_region.view(-1),
+ int(neg_rto * positive_pixel_number),
+ sorted=False,
+ )[0]
+ ) / (positive_pixel_number * neg_rto)
+ positive_loss = torch.sum(positive_loss_region) / positive_pixel_number
+ else:
+ # only negative pixel
+ negative_loss = (
+ torch.sum(
+ torch.topk(negative_loss_region.view(-1), n_min_neg, sorted=False)[
+ 0
+ ]
+ )
+ / n_min_neg
+ )
+ positive_loss = 0.0
+ total_loss = positive_loss + negative_loss
+ return total_loss
+
+ def forward(
+ self,
+ region_scores_label,
+ affinity_socres_label,
+ region_scores_pre,
+ affinity_scores_pre,
+ mask,
+ neg_rto,
+ n_min_neg,
+ ):
+ loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
+ assert (
+ region_scores_label.size() == region_scores_pre.size()
+ and affinity_socres_label.size() == affinity_scores_pre.size()
+ )
+ loss1 = loss_fn(region_scores_pre, region_scores_label)
+ loss2 = loss_fn(affinity_scores_pre, affinity_socres_label)
+
+ loss_region = torch.mul(loss1, mask)
+ loss_affinity = torch.mul(loss2, mask)
+
+ char_loss = self.batch_image_loss(
+ loss_region, region_scores_label, neg_rto, n_min_neg
+ )
+ affi_loss = self.batch_image_loss(
+ loss_affinity, affinity_socres_label, neg_rto, n_min_neg
+ )
+ return char_loss + affi_loss
+
+
+class Maploss_v3(nn.Module):
+ def __init__(self):
+
+ super(Maploss_v3, self).__init__()
+
+ def single_image_loss(self, pre_loss, loss_label, neg_rto, n_min_neg):
+
+ batch_size = pre_loss.shape[0]
+
+ positive_loss, negative_loss = 0, 0
+ for single_loss, single_label in zip(pre_loss, loss_label):
+
+ # positive_loss
+ pos_pixel = (single_label >= 0.1).float()
+ n_pos_pixel = torch.sum(pos_pixel)
+ pos_loss_region = single_loss * pos_pixel
+ positive_loss += torch.sum(pos_loss_region) / max(n_pos_pixel, 1e-12)
+
+ # negative_loss
+ neg_pixel = (single_label < 0.1).float()
+ n_neg_pixel = torch.sum(neg_pixel)
+ neg_loss_region = single_loss * neg_pixel
+
+ if n_pos_pixel != 0:
+ if n_neg_pixel < neg_rto * n_pos_pixel:
+ negative_loss += torch.sum(neg_loss_region) / n_neg_pixel
+ else:
+ n_hard_neg = max(n_min_neg, neg_rto * n_pos_pixel)
+ # n_hard_neg = neg_rto*n_pos_pixel
+ negative_loss += (
+ torch.sum(
+ torch.topk(neg_loss_region.view(-1), int(n_hard_neg))[0]
+ )
+ / n_hard_neg
+ )
+ else:
+ # only negative pixel
+ negative_loss += (
+ torch.sum(torch.topk(neg_loss_region.view(-1), n_min_neg)[0])
+ / n_min_neg
+ )
+
+ total_loss = (positive_loss + negative_loss) / batch_size
+
+ return total_loss
+
+ def forward(
+ self,
+ region_scores_label,
+ affinity_scores_label,
+ region_scores_pre,
+ affinity_scores_pre,
+ mask,
+ neg_rto,
+ n_min_neg,
+ ):
+ loss_fn = torch.nn.MSELoss(reduce=False, size_average=False)
+
+ assert (
+ region_scores_label.size() == region_scores_pre.size()
+ and affinity_scores_label.size() == affinity_scores_pre.size()
+ )
+ loss1 = loss_fn(region_scores_pre, region_scores_label)
+ loss2 = loss_fn(affinity_scores_pre, affinity_scores_label)
+
+ loss_region = torch.mul(loss1, mask)
+ loss_affinity = torch.mul(loss2, mask)
+ char_loss = self.single_image_loss(
+ loss_region, region_scores_label, neg_rto, n_min_neg
+ )
+ affi_loss = self.single_image_loss(
+ loss_affinity, affinity_scores_label, neg_rto, n_min_neg
+ )
+
+ return char_loss + affi_loss
diff --git a/trainer/craft/metrics/eval_det_iou.py b/trainer/craft/metrics/eval_det_iou.py
new file mode 100644
index 000000000..f89004fd8
--- /dev/null
+++ b/trainer/craft/metrics/eval_det_iou.py
@@ -0,0 +1,244 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+from collections import namedtuple
+import numpy as np
+from shapely.geometry import Polygon
+"""
+cite from:
+PaddleOCR, github: https://github.com/PaddlePaddle/PaddleOCR
+PaddleOCR reference from :
+https://github.com/MhLiao/DB/blob/3c32b808d4412680310d3d28eeb6a2d5bf1566c5/concern/icdar2015_eval/detection/iou.py#L8
+"""
+
+
+class DetectionIoUEvaluator(object):
+ def __init__(self, iou_constraint=0.5, area_precision_constraint=0.5):
+ self.iou_constraint = iou_constraint
+ self.area_precision_constraint = area_precision_constraint
+
+ def evaluate_image(self, gt, pred):
+ def get_union(pD, pG):
+ return Polygon(pD).union(Polygon(pG)).area
+
+ def get_intersection_over_union(pD, pG):
+ return get_intersection(pD, pG) / get_union(pD, pG)
+
+ def get_intersection(pD, pG):
+ return Polygon(pD).intersection(Polygon(pG)).area
+
+ def compute_ap(confList, matchList, numGtCare):
+ correct = 0
+ AP = 0
+ if len(confList) > 0:
+ confList = np.array(confList)
+ matchList = np.array(matchList)
+ sorted_ind = np.argsort(-confList)
+ confList = confList[sorted_ind]
+ matchList = matchList[sorted_ind]
+ for n in range(len(confList)):
+ match = matchList[n]
+ if match:
+ correct += 1
+ AP += float(correct) / (n + 1)
+
+ if numGtCare > 0:
+ AP /= numGtCare
+
+ return AP
+
+ perSampleMetrics = {}
+
+ matchedSum = 0
+
+ Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
+
+ numGlobalCareGt = 0
+ numGlobalCareDet = 0
+
+ arrGlobalConfidences = []
+ arrGlobalMatches = []
+
+ recall = 0
+ precision = 0
+ hmean = 0
+
+ detMatched = 0
+
+ iouMat = np.empty([1, 1])
+
+ gtPols = []
+ detPols = []
+
+ gtPolPoints = []
+ detPolPoints = []
+
+ # Array of Ground Truth Polygons' keys marked as don't Care
+ gtDontCarePolsNum = []
+ # Array of Detected Polygons' matched with a don't Care GT
+ detDontCarePolsNum = []
+
+ pairs = []
+ detMatchedNums = []
+
+ arrSampleConfidences = []
+ arrSampleMatch = []
+
+ evaluationLog = ""
+
+ # print(len(gt))
+
+ for n in range(len(gt)):
+ points = gt[n]['points']
+ # transcription = gt[n]['text']
+ dontCare = gt[n]['ignore']
+ # points = Polygon(points)
+ # points = points.buffer(0)
+ try:
+ if not Polygon(points).is_valid or not Polygon(points).is_simple:
+ continue
+ except:
+ import ipdb;
+ ipdb.set_trace()
+
+ #import ipdb;ipdb.set_trace()
+ gtPol = points
+ gtPols.append(gtPol)
+ gtPolPoints.append(points)
+ if dontCare:
+ gtDontCarePolsNum.append(len(gtPols) - 1)
+
+ evaluationLog += "GT polygons: " + str(len(gtPols)) + (
+ " (" + str(len(gtDontCarePolsNum)) + " don't care)\n"
+ if len(gtDontCarePolsNum) > 0 else "\n")
+
+ for n in range(len(pred)):
+ points = pred[n]['points']
+ # points = Polygon(points)
+ # points = points.buffer(0)
+ if not Polygon(points).is_valid or not Polygon(points).is_simple:
+ continue
+
+ detPol = points
+ detPols.append(detPol)
+ detPolPoints.append(points)
+ if len(gtDontCarePolsNum) > 0:
+ for dontCarePol in gtDontCarePolsNum:
+ dontCarePol = gtPols[dontCarePol]
+ intersected_area = get_intersection(dontCarePol, detPol)
+ pdDimensions = Polygon(detPol).area
+ precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
+ if (precision > self.area_precision_constraint):
+ detDontCarePolsNum.append(len(detPols) - 1)
+ break
+
+ evaluationLog += "DET polygons: " + str(len(detPols)) + (
+ " (" + str(len(detDontCarePolsNum)) + " don't care)\n"
+ if len(detDontCarePolsNum) > 0 else "\n")
+
+ if len(gtPols) > 0 and len(detPols) > 0:
+ # Calculate IoU and precision matrixs
+ outputShape = [len(gtPols), len(detPols)]
+ iouMat = np.empty(outputShape)
+ gtRectMat = np.zeros(len(gtPols), np.int8)
+ detRectMat = np.zeros(len(detPols), np.int8)
+ for gtNum in range(len(gtPols)):
+ for detNum in range(len(detPols)):
+ pG = gtPols[gtNum]
+ pD = detPols[detNum]
+ iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
+
+ for gtNum in range(len(gtPols)):
+ for detNum in range(len(detPols)):
+ if gtRectMat[gtNum] == 0 and detRectMat[
+ detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
+ if iouMat[gtNum, detNum] > self.iou_constraint:
+ gtRectMat[gtNum] = 1
+ detRectMat[detNum] = 1
+ detMatched += 1
+ pairs.append({'gt': gtNum, 'det': detNum})
+ detMatchedNums.append(detNum)
+ evaluationLog += "Match GT #" + \
+ str(gtNum) + " with Det #" + str(detNum) + "\n"
+
+ numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
+ numDetCare = (len(detPols) - len(detDontCarePolsNum))
+ if numGtCare == 0:
+ recall = float(1)
+ precision = float(0) if numDetCare > 0 else float(1)
+ else:
+ recall = float(detMatched) / numGtCare
+ precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
+
+ hmean = 0 if (precision + recall) == 0 else 2.0 * \
+ precision * recall / (precision + recall)
+
+ matchedSum += detMatched
+ numGlobalCareGt += numGtCare
+ numGlobalCareDet += numDetCare
+
+ perSampleMetrics = {
+ 'precision': precision,
+ 'recall': recall,
+ 'hmean': hmean,
+ 'pairs': pairs,
+ 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
+ 'gtPolPoints': gtPolPoints,
+ 'detPolPoints': detPolPoints,
+ 'gtCare': numGtCare,
+ 'detCare': numDetCare,
+ 'gtDontCare': gtDontCarePolsNum,
+ 'detDontCare': detDontCarePolsNum,
+ 'detMatched': detMatched,
+ 'evaluationLog': evaluationLog
+ }
+
+ return perSampleMetrics
+
+ def combine_results(self, results):
+ numGlobalCareGt = 0
+ numGlobalCareDet = 0
+ matchedSum = 0
+ for result in results:
+ numGlobalCareGt += result['gtCare']
+ numGlobalCareDet += result['detCare']
+ matchedSum += result['detMatched']
+
+ methodRecall = 0 if numGlobalCareGt == 0 else float(
+ matchedSum) / numGlobalCareGt
+ methodPrecision = 0 if numGlobalCareDet == 0 else float(
+ matchedSum) / numGlobalCareDet
+ methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
+ methodRecall * methodPrecision / (
+ methodRecall + methodPrecision)
+ # print(methodRecall, methodPrecision, methodHmean)
+ # sys.exit(-1)
+ methodMetrics = {
+ 'precision': methodPrecision,
+ 'recall': methodRecall,
+ 'hmean': methodHmean
+ }
+
+ return methodMetrics
+
+
+if __name__ == '__main__':
+ evaluator = DetectionIoUEvaluator()
+ gts = [[{
+ 'points': [(0, 0), (1, 0), (1, 1), (0, 1)],
+ 'text': 1234,
+ 'ignore': False,
+ }, {
+ 'points': [(2, 2), (3, 2), (3, 3), (2, 3)],
+ 'text': 5678,
+ 'ignore': False,
+ }]]
+ preds = [[{
+ 'points': [(0.1, 0.1), (1, 0), (1, 1), (0, 1)],
+ 'text': 123,
+ 'ignore': False,
+ }]]
+ results = []
+ for gt, pred in zip(gts, preds):
+ results.append(evaluator.evaluate_image(gt, pred))
+ metrics = evaluator.combine_results(results)
+ print(metrics)
diff --git a/trainer/craft/model/craft.py b/trainer/craft/model/craft.py
new file mode 100644
index 000000000..f0da362a5
--- /dev/null
+++ b/trainer/craft/model/craft.py
@@ -0,0 +1,112 @@
+"""
+Copyright (c) 2019-present NAVER Corp.
+MIT License
+"""
+
+# -*- coding: utf-8 -*-
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from model.vgg16_bn import vgg16_bn, init_weights
+
+class double_conv(nn.Module):
+ def __init__(self, in_ch, mid_ch, out_ch):
+ super(double_conv, self).__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
+ nn.BatchNorm2d(mid_ch),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
+ nn.BatchNorm2d(out_ch),
+ nn.ReLU(inplace=True)
+ )
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class CRAFT(nn.Module):
+ def __init__(self, pretrained=True, freeze=False, amp=False):
+ super(CRAFT, self).__init__()
+
+ self.amp = amp
+
+ """ Base network """
+ self.basenet = vgg16_bn(pretrained, freeze)
+
+ """ U network """
+ self.upconv1 = double_conv(1024, 512, 256)
+ self.upconv2 = double_conv(512, 256, 128)
+ self.upconv3 = double_conv(256, 128, 64)
+ self.upconv4 = double_conv(128, 64, 32)
+
+ num_class = 2
+ self.conv_cls = nn.Sequential(
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
+ nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
+ nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
+ nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
+ nn.Conv2d(16, num_class, kernel_size=1),
+ )
+
+ init_weights(self.upconv1.modules())
+ init_weights(self.upconv2.modules())
+ init_weights(self.upconv3.modules())
+ init_weights(self.upconv4.modules())
+ init_weights(self.conv_cls.modules())
+
+ def forward(self, x):
+ """ Base network """
+ if self.amp:
+ with torch.cuda.amp.autocast():
+ sources = self.basenet(x)
+
+ """ U network """
+ y = torch.cat([sources[0], sources[1]], dim=1)
+ y = self.upconv1(y)
+
+ y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
+ y = torch.cat([y, sources[2]], dim=1)
+ y = self.upconv2(y)
+
+ y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
+ y = torch.cat([y, sources[3]], dim=1)
+ y = self.upconv3(y)
+
+ y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
+ y = torch.cat([y, sources[4]], dim=1)
+ feature = self.upconv4(y)
+
+ y = self.conv_cls(feature)
+
+ return y.permute(0,2,3,1), feature
+ else:
+
+ sources = self.basenet(x)
+
+ """ U network """
+ y = torch.cat([sources[0], sources[1]], dim=1)
+ y = self.upconv1(y)
+
+ y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
+ y = torch.cat([y, sources[2]], dim=1)
+ y = self.upconv2(y)
+
+ y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
+ y = torch.cat([y, sources[3]], dim=1)
+ y = self.upconv3(y)
+
+ y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
+ y = torch.cat([y, sources[4]], dim=1)
+ feature = self.upconv4(y)
+
+ y = self.conv_cls(feature)
+
+ return y.permute(0, 2, 3, 1), feature
+
+if __name__ == '__main__':
+ model = CRAFT(pretrained=True).cuda()
+ output, _ = model(torch.randn(1, 3, 768, 768).cuda())
+ print(output.shape)
\ No newline at end of file
diff --git a/trainer/craft/model/vgg16_bn.py b/trainer/craft/model/vgg16_bn.py
new file mode 100644
index 000000000..f3f21a79e
--- /dev/null
+++ b/trainer/craft/model/vgg16_bn.py
@@ -0,0 +1,73 @@
+from collections import namedtuple
+
+import torch
+import torch.nn as nn
+import torch.nn.init as init
+from torchvision import models
+from torchvision.models.vgg import model_urls
+
+def init_weights(modules):
+ for m in modules:
+ if isinstance(m, nn.Conv2d):
+ init.xavier_uniform_(m.weight.data)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ m.weight.data.normal_(0, 0.01)
+ m.bias.data.zero_()
+
+class vgg16_bn(torch.nn.Module):
+ def __init__(self, pretrained=True, freeze=True):
+ super(vgg16_bn, self).__init__()
+ model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
+ vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ for x in range(12): # conv2_2
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(12, 19): # conv3_3
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(19, 29): # conv4_3
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(29, 39): # conv5_3
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+
+ # fc6, fc7 without atrous conv
+ self.slice5 = torch.nn.Sequential(
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
+ nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
+ nn.Conv2d(1024, 1024, kernel_size=1)
+ )
+
+ if not pretrained:
+ init_weights(self.slice1.modules())
+ init_weights(self.slice2.modules())
+ init_weights(self.slice3.modules())
+ init_weights(self.slice4.modules())
+
+ init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7
+
+ if freeze:
+ for param in self.slice1.parameters(): # only first conv
+ param.requires_grad= False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu2_2 = h
+ h = self.slice2(h)
+ h_relu3_2 = h
+ h = self.slice3(h)
+ h_relu4_3 = h
+ h = self.slice4(h)
+ h_relu5_3 = h
+ h = self.slice5(h)
+ h_fc7 = h
+ vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
+ out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
+ return out
diff --git a/trainer/craft/requirements.txt b/trainer/craft/requirements.txt
new file mode 100644
index 000000000..af5f524cb
--- /dev/null
+++ b/trainer/craft/requirements.txt
@@ -0,0 +1,10 @@
+conda==4.10.3
+opencv-python==4.5.3.56
+Pillow==8.2.0
+Polygon3==3.0.9.1
+PyYAML==5.4.1
+scikit-image==0.17.2
+Shapely==1.8.0
+torch==1.9.0
+torchvision==0.10.0
+wandb==0.12.9
diff --git a/trainer/craft/scripts/run_cde.sh b/trainer/craft/scripts/run_cde.sh
new file mode 100644
index 000000000..5fd232184
--- /dev/null
+++ b/trainer/craft/scripts/run_cde.sh
@@ -0,0 +1,7 @@
+# sed -i -e 's/\r$//' scripts/run_cde.sh
+EXP_NAME=custom_data_release_test_3
+yaml_path="config/$EXP_NAME.yaml"
+cp config/custom_data_train.yaml $yaml_path
+#CUDA_VISIBLE_DEVICES=0,1 python3 train_distributed.py --yaml=$EXP_NAME --port=2468
+CUDA_VISIBLE_DEVICES=0 python3 train.py --yaml=$EXP_NAME --port=2468
+rm "config/$EXP_NAME.yaml"
\ No newline at end of file
diff --git a/trainer/craft/train.py b/trainer/craft/train.py
new file mode 100644
index 000000000..de441cf2e
--- /dev/null
+++ b/trainer/craft/train.py
@@ -0,0 +1,479 @@
+# -*- coding: utf-8 -*-
+import argparse
+import os
+import shutil
+import time
+import multiprocessing as mp
+import yaml
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import wandb
+
+from config.load_config import load_yaml, DotDict
+from data.dataset import SynthTextDataSet, CustomDataset
+from loss.mseloss import Maploss_v2, Maploss_v3
+from model.craft import CRAFT
+from eval import main_eval
+from metrics.eval_det_iou import DetectionIoUEvaluator
+from utils.util import copyStateDict, save_parser
+
+
+class Trainer(object):
+ def __init__(self, config, gpu, mode):
+
+ self.config = config
+ self.gpu = gpu
+ self.mode = mode
+ self.net_param = self.get_load_param(gpu)
+
+ def get_synth_loader(self):
+
+ dataset = SynthTextDataSet(
+ output_size=self.config.train.data.output_size,
+ data_dir=self.config.train.synth_data_dir,
+ saved_gt_dir=None,
+ mean=self.config.train.data.mean,
+ variance=self.config.train.data.variance,
+ gauss_init_size=self.config.train.data.gauss_init_size,
+ gauss_sigma=self.config.train.data.gauss_sigma,
+ enlarge_region=self.config.train.data.enlarge_region,
+ enlarge_affinity=self.config.train.data.enlarge_affinity,
+ aug=self.config.train.data.syn_aug,
+ vis_test_dir=self.config.vis_test_dir,
+ vis_opt=self.config.train.data.vis_opt,
+ sample=self.config.train.data.syn_sample,
+ )
+
+ syn_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=self.config.train.batch_size // self.config.train.synth_ratio,
+ shuffle=False,
+ num_workers=self.config.train.num_workers,
+ drop_last=True,
+ pin_memory=True,
+ )
+ return syn_loader
+
+ def get_custom_dataset(self):
+
+ custom_dataset = CustomDataset(
+ output_size=self.config.train.data.output_size,
+ data_dir=self.config.data_root_dir,
+ saved_gt_dir=None,
+ mean=self.config.train.data.mean,
+ variance=self.config.train.data.variance,
+ gauss_init_size=self.config.train.data.gauss_init_size,
+ gauss_sigma=self.config.train.data.gauss_sigma,
+ enlarge_region=self.config.train.data.enlarge_region,
+ enlarge_affinity=self.config.train.data.enlarge_affinity,
+ watershed_param=self.config.train.data.watershed,
+ aug=self.config.train.data.custom_aug,
+ vis_test_dir=self.config.vis_test_dir,
+ sample=self.config.train.data.custom_sample,
+ vis_opt=self.config.train.data.vis_opt,
+ pseudo_vis_opt=self.config.train.data.pseudo_vis_opt,
+ do_not_care_label=self.config.train.data.do_not_care_label,
+ )
+
+ return custom_dataset
+
+ def get_load_param(self, gpu):
+
+ if self.config.train.ckpt_path is not None:
+ map_location = "cuda:%d" % gpu
+ param = torch.load(self.config.train.ckpt_path, map_location=map_location)
+ else:
+ param = None
+
+ return param
+
+ def adjust_learning_rate(self, optimizer, gamma, step, lr):
+ lr = lr * (gamma ** step)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+ return param_group["lr"]
+
+ def get_loss(self):
+ if self.config.train.loss == 2:
+ criterion = Maploss_v2()
+ elif self.config.train.loss == 3:
+ criterion = Maploss_v3()
+ else:
+ raise Exception("Undefined loss")
+ return criterion
+
+ def iou_eval(self, dataset, train_step, buffer, model):
+ test_config = DotDict(self.config.test[dataset])
+
+ val_result_dir = os.path.join(
+ self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
+ )
+
+ evaluator = DetectionIoUEvaluator()
+
+ metrics = main_eval(
+ None,
+ self.config.train.backbone,
+ test_config,
+ evaluator,
+ val_result_dir,
+ buffer,
+ model,
+ self.mode,
+ )
+ if self.gpu == 0 and self.config.wandb_opt:
+ wandb.log(
+ {
+ "{} iou Recall".format(dataset): np.round(metrics["recall"], 3),
+ "{} iou Precision".format(dataset): np.round(
+ metrics["precision"], 3
+ ),
+ "{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3),
+ }
+ )
+
+ def train(self, buffer_dict):
+
+ torch.cuda.set_device(self.gpu)
+
+ # MODEL -------------------------------------------------------------------------------------------------------#
+ # SUPERVISION model
+ if self.config.mode == "weak_supervision":
+ if self.config.train.backbone == "vgg":
+ supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp)
+ else:
+ raise Exception("Undefined architecture")
+
+ supervision_device = self.gpu
+ if self.config.train.ckpt_path is not None:
+ supervision_param = self.get_load_param(supervision_device)
+ supervision_model.load_state_dict(
+ copyStateDict(supervision_param["craft"])
+ )
+ supervision_model = supervision_model.to(f"cuda:{supervision_device}")
+ print(f"Supervision model loading on : gpu {supervision_device}")
+ else:
+ supervision_model, supervision_device = None, None
+
+ # TRAIN model
+ if self.config.train.backbone == "vgg":
+ craft = CRAFT(pretrained=False, amp=self.config.train.amp)
+ else:
+ raise Exception("Undefined architecture")
+
+ if self.config.train.ckpt_path is not None:
+ craft.load_state_dict(copyStateDict(self.net_param["craft"]))
+
+ craft = craft.cuda()
+ craft = torch.nn.DataParallel(craft)
+
+ torch.backends.cudnn.benchmark = True
+
+ # DATASET -----------------------------------------------------------------------------------------------------#
+
+ if self.config.train.use_synthtext:
+ trn_syn_loader = self.get_synth_loader()
+ batch_syn = iter(trn_syn_loader)
+
+ if self.config.train.real_dataset == "custom":
+ trn_real_dataset = self.get_custom_dataset()
+ else:
+ raise Exception("Undefined dataset")
+
+ if self.config.mode == "weak_supervision":
+ trn_real_dataset.update_model(supervision_model)
+ trn_real_dataset.update_device(supervision_device)
+
+ trn_real_loader = torch.utils.data.DataLoader(
+ trn_real_dataset,
+ batch_size=self.config.train.batch_size,
+ shuffle=False,
+ num_workers=self.config.train.num_workers,
+ drop_last=False,
+ pin_memory=True,
+ )
+
+ # OPTIMIZER ---------------------------------------------------------------------------------------------------#
+ optimizer = optim.Adam(
+ craft.parameters(),
+ lr=self.config.train.lr,
+ weight_decay=self.config.train.weight_decay,
+ )
+
+ if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
+ optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
+ self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
+ self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
+
+ # LOSS --------------------------------------------------------------------------------------------------------#
+ # mixed precision
+ if self.config.train.amp:
+ scaler = torch.cuda.amp.GradScaler()
+
+ if (
+ self.config.train.ckpt_path is not None
+ and self.config.train.st_iter != 0
+ ):
+ scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
+ else:
+ scaler = None
+
+ criterion = self.get_loss()
+
+ # TRAIN -------------------------------------------------------------------------------------------------------#
+ train_step = self.config.train.st_iter
+ whole_training_step = self.config.train.end_iter
+ update_lr_rate_step = 0
+ training_lr = self.config.train.lr
+ loss_value = 0
+ batch_time = 0
+ start_time = time.time()
+
+ print(
+ "================================ Train start ================================"
+ )
+ while train_step < whole_training_step:
+ for (
+ index,
+ (
+ images,
+ region_scores,
+ affinity_scores,
+ confidence_masks,
+ ),
+ ) in enumerate(trn_real_loader):
+ craft.train()
+ if train_step > 0 and train_step % self.config.train.lr_decay == 0:
+ update_lr_rate_step += 1
+ training_lr = self.adjust_learning_rate(
+ optimizer,
+ self.config.train.gamma,
+ update_lr_rate_step,
+ self.config.train.lr,
+ )
+
+ images = images.cuda(non_blocking=True)
+ region_scores = region_scores.cuda(non_blocking=True)
+ affinity_scores = affinity_scores.cuda(non_blocking=True)
+ confidence_masks = confidence_masks.cuda(non_blocking=True)
+
+ if self.config.train.use_synthtext:
+ # Synth image load
+ syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
+ batch_syn
+ )
+ syn_image = syn_image.cuda(non_blocking=True)
+ syn_region_label = syn_region_label.cuda(non_blocking=True)
+ syn_affi_label = syn_affi_label.cuda(non_blocking=True)
+ syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True)
+
+ # concat syn & custom image
+ images = torch.cat((syn_image, images), 0)
+ region_image_label = torch.cat(
+ (syn_region_label, region_scores), 0
+ )
+ affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0)
+ confidence_mask_label = torch.cat(
+ (syn_confidence_mask, confidence_masks), 0
+ )
+ else:
+ region_image_label = region_scores
+ affinity_image_label = affinity_scores
+ confidence_mask_label = confidence_masks
+
+ if self.config.train.amp:
+ with torch.cuda.amp.autocast():
+
+ output, _ = craft(images)
+ out1 = output[:, :, :, 0]
+ out2 = output[:, :, :, 1]
+
+ loss = criterion(
+ region_image_label,
+ affinity_image_label,
+ out1,
+ out2,
+ confidence_mask_label,
+ self.config.train.neg_rto,
+ self.config.train.n_min_neg,
+ )
+
+ optimizer.zero_grad()
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+
+ else:
+ output, _ = craft(images)
+ out1 = output[:, :, :, 0]
+ out2 = output[:, :, :, 1]
+ loss = criterion(
+ region_image_label,
+ affinity_image_label,
+ out1,
+ out2,
+ confidence_mask_label,
+ self.config.train.neg_rto,
+ )
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ end_time = time.time()
+ loss_value += loss.item()
+ batch_time += end_time - start_time
+
+ if train_step > 0 and train_step % 5 == 0:
+ mean_loss = loss_value / 5
+ loss_value = 0
+ avg_batch_time = batch_time / 5
+ batch_time = 0
+
+ print(
+ "{}, training_step: {}|{}, learning rate: {:.8f}, "
+ "training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
+ time.strftime(
+ "%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
+ ),
+ train_step,
+ whole_training_step,
+ training_lr,
+ mean_loss,
+ avg_batch_time,
+ )
+ )
+
+ if self.config.wandb_opt:
+ wandb.log({"train_step": train_step, "mean_loss": mean_loss})
+
+ if (
+ train_step % self.config.train.eval_interval == 0
+ and train_step != 0
+ ):
+
+ craft.eval()
+
+ print("Saving state, index:", train_step)
+ save_param_dic = {
+ "iter": train_step,
+ "craft": craft.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ save_param_path = (
+ self.config.results_dir
+ + "/CRAFT_clr_"
+ + repr(train_step)
+ + ".pth"
+ )
+
+ if self.config.train.amp:
+ save_param_dic["scaler"] = scaler.state_dict()
+ save_param_path = (
+ self.config.results_dir
+ + "/CRAFT_clr_amp_"
+ + repr(train_step)
+ + ".pth"
+ )
+
+ torch.save(save_param_dic, save_param_path)
+
+ # validation
+ self.iou_eval(
+ "custom_data",
+ train_step,
+ buffer_dict["custom_data"],
+ craft,
+ )
+
+ train_step += 1
+ if train_step >= whole_training_step:
+ break
+
+ if self.config.mode == "weak_supervision":
+ state_dict = craft.module.state_dict()
+ supervision_model.load_state_dict(state_dict)
+ trn_real_dataset.update_model(supervision_model)
+
+ # save last model
+ save_param_dic = {
+ "iter": train_step,
+ "craft": craft.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ save_param_path = (
+ self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
+ )
+
+ if self.config.train.amp:
+ save_param_dic["scaler"] = scaler.state_dict()
+ save_param_path = (
+ self.config.results_dir
+ + "/CRAFT_clr_amp_"
+ + repr(train_step)
+ + ".pth"
+ )
+ torch.save(save_param_dic, save_param_path)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="CRAFT custom data train")
+ parser.add_argument(
+ "--yaml",
+ "--yaml_file_name",
+ default="custom_data_train",
+ type=str,
+ help="Load configuration",
+ )
+ parser.add_argument(
+ "--port", "--use ddp port", default="2346", type=str, help="Port number"
+ )
+
+ args = parser.parse_args()
+
+ # load configure
+ exp_name = args.yaml
+ config = load_yaml(args.yaml)
+
+ print("-" * 20 + " Options " + "-" * 20)
+ print(yaml.dump(config))
+ print("-" * 40)
+
+ # Make result_dir
+ res_dir = os.path.join(config["results_dir"], args.yaml)
+ config["results_dir"] = res_dir
+ if not os.path.exists(res_dir):
+ os.makedirs(res_dir)
+
+ # Duplicate yaml file to result_dir
+ shutil.copy(
+ "config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
+ )
+
+ if config["mode"] == "weak_supervision":
+ mode = "weak_supervision"
+ else:
+ mode = None
+
+
+ # Apply config to wandb
+ if config["wandb_opt"]:
+ wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
+ wandb.config.update(config)
+
+ config = DotDict(config)
+
+ # Start train
+ buffer_dict = {"custom_data":None}
+ trainer = Trainer(config, 0, mode)
+ trainer.train(buffer_dict)
+
+ if config["wandb_opt"]:
+ wandb.finish()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/trainer/craft/trainSynth.py b/trainer/craft/trainSynth.py
new file mode 100644
index 000000000..4d1d0dc72
--- /dev/null
+++ b/trainer/craft/trainSynth.py
@@ -0,0 +1,408 @@
+# -*- coding: utf-8 -*-
+import argparse
+import os
+import shutil
+import time
+import yaml
+import multiprocessing as mp
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import wandb
+
+from config.load_config import load_yaml, DotDict
+from data.dataset import SynthTextDataSet
+from loss.mseloss import Maploss_v2, Maploss_v3
+from model.craft import CRAFT
+from metrics.eval_det_iou import DetectionIoUEvaluator
+from eval import main_eval
+from utils.util import copyStateDict, save_parser
+
+
+class Trainer(object):
+ def __init__(self, config, gpu):
+
+ self.config = config
+ self.gpu = gpu
+ self.mode = None
+ self.trn_loader, self.trn_sampler = self.get_trn_loader()
+ self.net_param = self.get_load_param(gpu)
+
+ def get_trn_loader(self):
+
+ dataset = SynthTextDataSet(
+ output_size=self.config.train.data.output_size,
+ data_dir=self.config.data_dir.synthtext,
+ saved_gt_dir=None,
+ mean=self.config.train.data.mean,
+ variance=self.config.train.data.variance,
+ gauss_init_size=self.config.train.data.gauss_init_size,
+ gauss_sigma=self.config.train.data.gauss_sigma,
+ enlarge_region=self.config.train.data.enlarge_region,
+ enlarge_affinity=self.config.train.data.enlarge_affinity,
+ aug=self.config.train.data.syn_aug,
+ vis_test_dir=self.config.vis_test_dir,
+ vis_opt=self.config.train.data.vis_opt,
+ sample=self.config.train.data.syn_sample,
+ )
+
+ trn_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
+
+ trn_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=self.config.train.batch_size,
+ shuffle=False,
+ num_workers=self.config.train.num_workers,
+ sampler=trn_sampler,
+ drop_last=True,
+ pin_memory=True,
+ )
+ return trn_loader, trn_sampler
+
+ def get_load_param(self, gpu):
+ if self.config.train.ckpt_path is not None:
+ map_location = {"cuda:%d" % 0: "cuda:%d" % gpu}
+ param = torch.load(self.config.train.ckpt_path, map_location=map_location)
+ else:
+ param = None
+ return param
+
+ def adjust_learning_rate(self, optimizer, gamma, step, lr):
+ lr = lr * (gamma ** step)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+ return param_group["lr"]
+
+ def get_loss(self):
+ if self.config.train.loss == 2:
+ criterion = Maploss_v2()
+ elif self.config.train.loss == 3:
+ criterion = Maploss_v3()
+ else:
+ raise Exception("Undefined loss")
+ return criterion
+
+ def iou_eval(self, dataset, train_step, save_param_path, buffer, model):
+ test_config = DotDict(self.config.test[dataset])
+
+ val_result_dir = os.path.join(
+ self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
+ )
+
+ evaluator = DetectionIoUEvaluator()
+
+ metrics = main_eval(
+ save_param_path,
+ self.config.train.backbone,
+ test_config,
+ evaluator,
+ val_result_dir,
+ buffer,
+ model,
+ self.mode,
+ )
+ if self.gpu == 0 and self.config.wandb_opt:
+ wandb.log(
+ {
+ "{} IoU Recall".format(dataset): np.round(metrics["recall"], 3),
+ "{} IoU Precision".format(dataset): np.round(
+ metrics["precision"], 3
+ ),
+ "{} IoU F1-score".format(dataset): np.round(metrics["hmean"], 3),
+ }
+ )
+
+ def train(self, buffer_dict):
+ torch.cuda.set_device(self.gpu)
+
+ # DATASET -----------------------------------------------------------------------------------------------------#
+ trn_loader = self.trn_loader
+
+ # MODEL -------------------------------------------------------------------------------------------------------#
+ if self.config.train.backbone == "vgg":
+ craft = CRAFT(pretrained=True, amp=self.config.train.amp)
+ else:
+ raise Exception("Undefined architecture")
+
+ if self.config.train.ckpt_path is not None:
+ craft.load_state_dict(copyStateDict(self.net_param["craft"]))
+ craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
+ craft = craft.cuda()
+ craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu])
+
+ torch.backends.cudnn.benchmark = True
+
+ # OPTIMIZER----------------------------------------------------------------------------------------------------#
+
+ optimizer = optim.Adam(
+ craft.parameters(),
+ lr=self.config.train.lr,
+ weight_decay=self.config.train.weight_decay,
+ )
+
+ if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
+ optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
+ self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
+ self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
+
+ # LOSS --------------------------------------------------------------------------------------------------------#
+ # mixed precision
+ if self.config.train.amp:
+ scaler = torch.cuda.amp.GradScaler()
+
+ # load model
+ if (
+ self.config.train.ckpt_path is not None
+ and self.config.train.st_iter != 0
+ ):
+ scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
+ else:
+ scaler = None
+
+ criterion = self.get_loss()
+
+ # TRAIN -------------------------------------------------------------------------------------------------------#
+ train_step = self.config.train.st_iter
+ whole_training_step = self.config.train.end_iter
+ update_lr_rate_step = 0
+ training_lr = self.config.train.lr
+ loss_value = 0
+ batch_time = 0
+ epoch = 0
+ start_time = time.time()
+
+ while train_step < whole_training_step:
+ self.trn_sampler.set_epoch(train_step)
+ for (
+ index,
+ (image, region_image, affinity_image, confidence_mask,),
+ ) in enumerate(trn_loader):
+ craft.train()
+ if train_step > 0 and train_step % self.config.train.lr_decay == 0:
+ update_lr_rate_step += 1
+ training_lr = self.adjust_learning_rate(
+ optimizer,
+ self.config.train.gamma,
+ update_lr_rate_step,
+ self.config.train.lr,
+ )
+
+ images = image.cuda(non_blocking=True)
+ region_image_label = region_image.cuda(non_blocking=True)
+ affinity_image_label = affinity_image.cuda(non_blocking=True)
+ confidence_mask_label = confidence_mask.cuda(non_blocking=True)
+
+ if self.config.train.amp:
+ with torch.cuda.amp.autocast():
+
+ output, _ = craft(images)
+ out1 = output[:, :, :, 0]
+ out2 = output[:, :, :, 1]
+
+ loss = criterion(
+ region_image_label,
+ affinity_image_label,
+ out1,
+ out2,
+ confidence_mask_label,
+ self.config.train.neg_rto,
+ self.config.train.n_min_neg,
+ )
+
+ optimizer.zero_grad()
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+
+ else:
+ output, _ = craft(images)
+ out1 = output[:, :, :, 0]
+ out2 = output[:, :, :, 1]
+ loss = criterion(
+ region_image_label,
+ affinity_image_label,
+ out1,
+ out2,
+ confidence_mask_label,
+ self.config.train.neg_rto,
+ )
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ end_time = time.time()
+ loss_value += loss.item()
+ batch_time += end_time - start_time
+
+ if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
+ mean_loss = loss_value / 5
+ loss_value = 0
+ avg_batch_time = batch_time / 5
+ batch_time = 0
+
+ print(
+ "{}, training_step: {}|{}, learning rate: {:.8f}, "
+ "training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
+ time.strftime(
+ "%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
+ ),
+ train_step,
+ whole_training_step,
+ training_lr,
+ mean_loss,
+ avg_batch_time,
+ )
+ )
+ if self.gpu == 0 and self.config.wandb_opt:
+ wandb.log({"train_step": train_step, "mean_loss": mean_loss})
+
+ if (
+ train_step % self.config.train.eval_interval == 0
+ and train_step != 0
+ ):
+
+ # initialize all buffer value with zero
+ if self.gpu == 0:
+ for buffer in buffer_dict.values():
+ for i in range(len(buffer)):
+ buffer[i] = None
+
+ print("Saving state, index:", train_step)
+ save_param_dic = {
+ "iter": train_step,
+ "craft": craft.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ save_param_path = (
+ self.config.results_dir
+ + "/CRAFT_clr_"
+ + repr(train_step)
+ + ".pth"
+ )
+
+ if self.config.train.amp:
+ save_param_dic["scaler"] = scaler.state_dict()
+ save_param_path = (
+ self.config.results_dir
+ + "/CRAFT_clr_amp_"
+ + repr(train_step)
+ + ".pth"
+ )
+
+ torch.save(save_param_dic, save_param_path)
+
+ # validation
+ self.iou_eval(
+ "icdar2013",
+ train_step,
+ save_param_path,
+ buffer_dict["icdar2013"],
+ craft,
+ )
+
+ train_step += 1
+ if train_step >= whole_training_step:
+ break
+ epoch += 1
+
+ # save last model
+ if self.gpu == 0:
+ save_param_dic = {
+ "iter": train_step,
+ "craft": craft.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ save_param_path = (
+ self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
+ )
+
+ if self.config.train.amp:
+ save_param_dic["scaler"] = scaler.state_dict()
+ save_param_path = (
+ self.config.results_dir
+ + "/CRAFT_clr_amp_"
+ + repr(train_step)
+ + ".pth"
+ )
+ torch.save(save_param_dic, save_param_path)
+
+
+def main():
+ parser = argparse.ArgumentParser(description="CRAFT SynthText Train")
+ parser.add_argument(
+ "--yaml",
+ "--yaml_file_name",
+ default="syn_train",
+ type=str,
+ help="Load configuration",
+ )
+ parser.add_argument(
+ "--port", "--use ddp port", default="2646", type=str, help="Load configuration"
+ )
+
+ args = parser.parse_args()
+
+ # load configure
+ exp_name = args.yaml
+ config = load_yaml(args.yaml)
+
+ print("-" * 20 + " Options " + "-" * 20)
+ print(yaml.dump(config))
+ print("-" * 40)
+
+ # Make result_dir
+ res_dir = os.path.join(config["results_dir"], args.yaml)
+ config["results_dir"] = res_dir
+ if not os.path.exists(res_dir):
+ os.makedirs(res_dir)
+
+ # Duplicate yaml file to result_dir
+ shutil.copy(
+ "config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
+ )
+
+ ngpus_per_node = torch.cuda.device_count()
+ print(f"Total device num : {ngpus_per_node}")
+
+ manager = mp.Manager()
+ buffer1 = manager.list([None] * config["test"]["icdar2013"]["test_set_size"])
+ buffer_dict = {"icdar2013": buffer1}
+ torch.multiprocessing.spawn(
+ main_worker,
+ nprocs=ngpus_per_node,
+ args=(args.port, ngpus_per_node, config, buffer_dict, exp_name,),
+ )
+ print('flag5')
+
+
+def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name):
+
+ torch.distributed.init_process_group(
+ backend="nccl",
+ init_method="tcp://127.0.0.1:" + port,
+ world_size=ngpus_per_node,
+ rank=gpu,
+ )
+
+ # Apply config to wandb
+ if gpu == 0 and config["wandb_opt"]:
+ wandb.init(project="craft-stage1", entity="gmuffiness", name=exp_name)
+ wandb.config.update(config)
+
+ batch_size = int(config["train"]["batch_size"] / ngpus_per_node)
+ config["train"]["batch_size"] = batch_size
+ config = DotDict(config)
+
+ # Start train
+ trainer = Trainer(config, gpu)
+ trainer.train(buffer_dict)
+
+ if gpu == 0 and config["wandb_opt"]:
+ wandb.finish()
+ torch.distributed.destroy_process_group()
+
+if __name__ == "__main__":
+ main()
diff --git a/trainer/craft/train_distributed.py b/trainer/craft/train_distributed.py
new file mode 100644
index 000000000..8ab320c19
--- /dev/null
+++ b/trainer/craft/train_distributed.py
@@ -0,0 +1,523 @@
+# -*- coding: utf-8 -*-
+import argparse
+import os
+import shutil
+import time
+import multiprocessing as mp
+import yaml
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import wandb
+
+from config.load_config import load_yaml, DotDict
+from data.dataset import SynthTextDataSet, CustomDataset
+from loss.mseloss import Maploss_v2, Maploss_v3
+from model.craft import CRAFT
+from eval import main_eval
+from metrics.eval_det_iou import DetectionIoUEvaluator
+from utils.util import copyStateDict, save_parser
+
+
+class Trainer(object):
+ def __init__(self, config, gpu, mode):
+
+ self.config = config
+ self.gpu = gpu
+ self.mode = mode
+ self.net_param = self.get_load_param(gpu)
+
+ def get_synth_loader(self):
+
+ dataset = SynthTextDataSet(
+ output_size=self.config.train.data.output_size,
+ data_dir=self.config.train.synth_data_dir,
+ saved_gt_dir=None,
+ mean=self.config.train.data.mean,
+ variance=self.config.train.data.variance,
+ gauss_init_size=self.config.train.data.gauss_init_size,
+ gauss_sigma=self.config.train.data.gauss_sigma,
+ enlarge_region=self.config.train.data.enlarge_region,
+ enlarge_affinity=self.config.train.data.enlarge_affinity,
+ aug=self.config.train.data.syn_aug,
+ vis_test_dir=self.config.vis_test_dir,
+ vis_opt=self.config.train.data.vis_opt,
+ sample=self.config.train.data.syn_sample,
+ )
+
+ syn_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
+
+ syn_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=self.config.train.batch_size // self.config.train.synth_ratio,
+ shuffle=False,
+ num_workers=self.config.train.num_workers,
+ sampler=syn_sampler,
+ drop_last=True,
+ pin_memory=True,
+ )
+ return syn_loader
+
+ def get_custom_dataset(self):
+
+ custom_dataset = CustomDataset(
+ output_size=self.config.train.data.output_size,
+ data_dir=self.config.data_root_dir,
+ saved_gt_dir=None,
+ mean=self.config.train.data.mean,
+ variance=self.config.train.data.variance,
+ gauss_init_size=self.config.train.data.gauss_init_size,
+ gauss_sigma=self.config.train.data.gauss_sigma,
+ enlarge_region=self.config.train.data.enlarge_region,
+ enlarge_affinity=self.config.train.data.enlarge_affinity,
+ watershed_param=self.config.train.data.watershed,
+ aug=self.config.train.data.custom_aug,
+ vis_test_dir=self.config.vis_test_dir,
+ sample=self.config.train.data.custom_sample,
+ vis_opt=self.config.train.data.vis_opt,
+ pseudo_vis_opt=self.config.train.data.pseudo_vis_opt,
+ do_not_care_label=self.config.train.data.do_not_care_label,
+ )
+
+ return custom_dataset
+
+ def get_load_param(self, gpu):
+
+ if self.config.train.ckpt_path is not None:
+ map_location = "cuda:%d" % gpu
+ param = torch.load(self.config.train.ckpt_path, map_location=map_location)
+ else:
+ param = None
+
+ return param
+
+ def adjust_learning_rate(self, optimizer, gamma, step, lr):
+ lr = lr * (gamma ** step)
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = lr
+ return param_group["lr"]
+
+ def get_loss(self):
+ if self.config.train.loss == 2:
+ criterion = Maploss_v2()
+ elif self.config.train.loss == 3:
+ criterion = Maploss_v3()
+ else:
+ raise Exception("Undefined loss")
+ return criterion
+
+ def iou_eval(self, dataset, train_step, buffer, model):
+ test_config = DotDict(self.config.test[dataset])
+
+ val_result_dir = os.path.join(
+ self.config.results_dir, "{}/{}".format(dataset + "_iou", str(train_step))
+ )
+
+ evaluator = DetectionIoUEvaluator()
+
+ metrics = main_eval(
+ None,
+ self.config.train.backbone,
+ test_config,
+ evaluator,
+ val_result_dir,
+ buffer,
+ model,
+ self.mode,
+ )
+ if self.gpu == 0 and self.config.wandb_opt:
+ wandb.log(
+ {
+ "{} iou Recall".format(dataset): np.round(metrics["recall"], 3),
+ "{} iou Precision".format(dataset): np.round(
+ metrics["precision"], 3
+ ),
+ "{} iou F1-score".format(dataset): np.round(metrics["hmean"], 3),
+ }
+ )
+
+ def train(self, buffer_dict):
+
+ torch.cuda.set_device(self.gpu)
+ total_gpu_num = torch.cuda.device_count()
+
+ # MODEL -------------------------------------------------------------------------------------------------------#
+ # SUPERVISION model
+ if self.config.mode == "weak_supervision":
+ if self.config.train.backbone == "vgg":
+ supervision_model = CRAFT(pretrained=False, amp=self.config.train.amp)
+ else:
+ raise Exception("Undefined architecture")
+
+ # NOTE: only work on half GPU assign train / half GPU assign supervision setting
+ supervision_device = total_gpu_num // 2 + self.gpu
+ if self.config.train.ckpt_path is not None:
+ supervision_param = self.get_load_param(supervision_device)
+ supervision_model.load_state_dict(
+ copyStateDict(supervision_param["craft"])
+ )
+ supervision_model = supervision_model.to(f"cuda:{supervision_device}")
+ print(f"Supervision model loading on : gpu {supervision_device}")
+ else:
+ supervision_model, supervision_device = None, None
+
+ # TRAIN model
+ if self.config.train.backbone == "vgg":
+ craft = CRAFT(pretrained=False, amp=self.config.train.amp)
+ else:
+ raise Exception("Undefined architecture")
+
+ if self.config.train.ckpt_path is not None:
+ craft.load_state_dict(copyStateDict(self.net_param["craft"]))
+
+ craft = nn.SyncBatchNorm.convert_sync_batchnorm(craft)
+ craft = craft.cuda()
+ craft = torch.nn.parallel.DistributedDataParallel(craft, device_ids=[self.gpu])
+
+ torch.backends.cudnn.benchmark = True
+
+ # DATASET -----------------------------------------------------------------------------------------------------#
+
+ if self.config.train.use_synthtext:
+ trn_syn_loader = self.get_synth_loader()
+ batch_syn = iter(trn_syn_loader)
+
+ if self.config.train.real_dataset == "custom":
+ trn_real_dataset = self.get_custom_dataset()
+ else:
+ raise Exception("Undefined dataset")
+
+ if self.config.mode == "weak_supervision":
+ trn_real_dataset.update_model(supervision_model)
+ trn_real_dataset.update_device(supervision_device)
+
+ trn_real_sampler = torch.utils.data.distributed.DistributedSampler(
+ trn_real_dataset
+ )
+ trn_real_loader = torch.utils.data.DataLoader(
+ trn_real_dataset,
+ batch_size=self.config.train.batch_size,
+ shuffle=False,
+ num_workers=self.config.train.num_workers,
+ sampler=trn_real_sampler,
+ drop_last=False,
+ pin_memory=True,
+ )
+
+ # OPTIMIZER ---------------------------------------------------------------------------------------------------#
+ optimizer = optim.Adam(
+ craft.parameters(),
+ lr=self.config.train.lr,
+ weight_decay=self.config.train.weight_decay,
+ )
+
+ if self.config.train.ckpt_path is not None and self.config.train.st_iter != 0:
+ optimizer.load_state_dict(copyStateDict(self.net_param["optimizer"]))
+ self.config.train.st_iter = self.net_param["optimizer"]["state"][0]["step"]
+ self.config.train.lr = self.net_param["optimizer"]["param_groups"][0]["lr"]
+
+ # LOSS --------------------------------------------------------------------------------------------------------#
+ # mixed precision
+ if self.config.train.amp:
+ scaler = torch.cuda.amp.GradScaler()
+
+ if (
+ self.config.train.ckpt_path is not None
+ and self.config.train.st_iter != 0
+ ):
+ scaler.load_state_dict(copyStateDict(self.net_param["scaler"]))
+ else:
+ scaler = None
+
+ criterion = self.get_loss()
+
+ # TRAIN -------------------------------------------------------------------------------------------------------#
+ train_step = self.config.train.st_iter
+ whole_training_step = self.config.train.end_iter
+ update_lr_rate_step = 0
+ training_lr = self.config.train.lr
+ loss_value = 0
+ batch_time = 0
+ start_time = time.time()
+
+ print(
+ "================================ Train start ================================"
+ )
+ while train_step < whole_training_step:
+ trn_real_sampler.set_epoch(train_step)
+ for (
+ index,
+ (
+ images,
+ region_scores,
+ affinity_scores,
+ confidence_masks,
+ ),
+ ) in enumerate(trn_real_loader):
+ craft.train()
+ if train_step > 0 and train_step % self.config.train.lr_decay == 0:
+ update_lr_rate_step += 1
+ training_lr = self.adjust_learning_rate(
+ optimizer,
+ self.config.train.gamma,
+ update_lr_rate_step,
+ self.config.train.lr,
+ )
+
+ images = images.cuda(non_blocking=True)
+ region_scores = region_scores.cuda(non_blocking=True)
+ affinity_scores = affinity_scores.cuda(non_blocking=True)
+ confidence_masks = confidence_masks.cuda(non_blocking=True)
+
+ if self.config.train.use_synthtext:
+ # Synth image load
+ syn_image, syn_region_label, syn_affi_label, syn_confidence_mask = next(
+ batch_syn
+ )
+ syn_image = syn_image.cuda(non_blocking=True)
+ syn_region_label = syn_region_label.cuda(non_blocking=True)
+ syn_affi_label = syn_affi_label.cuda(non_blocking=True)
+ syn_confidence_mask = syn_confidence_mask.cuda(non_blocking=True)
+
+ # concat syn & custom image
+ images = torch.cat((syn_image, images), 0)
+ region_image_label = torch.cat(
+ (syn_region_label, region_scores), 0
+ )
+ affinity_image_label = torch.cat((syn_affi_label, affinity_scores), 0)
+ confidence_mask_label = torch.cat(
+ (syn_confidence_mask, confidence_masks), 0
+ )
+ else:
+ region_image_label = region_scores
+ affinity_image_label = affinity_scores
+ confidence_mask_label = confidence_masks
+
+ if self.config.train.amp:
+ with torch.cuda.amp.autocast():
+
+ output, _ = craft(images)
+ out1 = output[:, :, :, 0]
+ out2 = output[:, :, :, 1]
+
+ loss = criterion(
+ region_image_label,
+ affinity_image_label,
+ out1,
+ out2,
+ confidence_mask_label,
+ self.config.train.neg_rto,
+ self.config.train.n_min_neg,
+ )
+
+ optimizer.zero_grad()
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+
+ else:
+ output, _ = craft(images)
+ out1 = output[:, :, :, 0]
+ out2 = output[:, :, :, 1]
+ loss = criterion(
+ region_image_label,
+ affinity_image_label,
+ out1,
+ out2,
+ confidence_mask_label,
+ self.config.train.neg_rto,
+ )
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ end_time = time.time()
+ loss_value += loss.item()
+ batch_time += end_time - start_time
+
+ if train_step > 0 and train_step % 5 == 0 and self.gpu == 0:
+ mean_loss = loss_value / 5
+ loss_value = 0
+ avg_batch_time = batch_time / 5
+ batch_time = 0
+
+ print(
+ "{}, training_step: {}|{}, learning rate: {:.8f}, "
+ "training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
+ time.strftime(
+ "%Y-%m-%d:%H:%M:%S", time.localtime(time.time())
+ ),
+ train_step,
+ whole_training_step,
+ training_lr,
+ mean_loss,
+ avg_batch_time,
+ )
+ )
+
+ if self.gpu == 0 and self.config.wandb_opt:
+ wandb.log({"train_step": train_step, "mean_loss": mean_loss})
+
+ if (
+ train_step % self.config.train.eval_interval == 0
+ and train_step != 0
+ ):
+
+ craft.eval()
+ # initialize all buffer value with zero
+ if self.gpu == 0:
+ for buffer in buffer_dict.values():
+ for i in range(len(buffer)):
+ buffer[i] = None
+
+ print("Saving state, index:", train_step)
+ save_param_dic = {
+ "iter": train_step,
+ "craft": craft.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ save_param_path = (
+ self.config.results_dir
+ + "/CRAFT_clr_"
+ + repr(train_step)
+ + ".pth"
+ )
+
+ if self.config.train.amp:
+ save_param_dic["scaler"] = scaler.state_dict()
+ save_param_path = (
+ self.config.results_dir
+ + "/CRAFT_clr_amp_"
+ + repr(train_step)
+ + ".pth"
+ )
+
+ torch.save(save_param_dic, save_param_path)
+
+ # validation
+ self.iou_eval(
+ "custom_data",
+ train_step,
+ buffer_dict["custom_data"],
+ craft,
+ )
+
+ train_step += 1
+ if train_step >= whole_training_step:
+ break
+
+ if self.config.mode == "weak_supervision":
+ state_dict = craft.module.state_dict()
+ supervision_model.load_state_dict(state_dict)
+ trn_real_dataset.update_model(supervision_model)
+
+ # save last model
+ if self.gpu == 0:
+ save_param_dic = {
+ "iter": train_step,
+ "craft": craft.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ save_param_path = (
+ self.config.results_dir + "/CRAFT_clr_" + repr(train_step) + ".pth"
+ )
+
+ if self.config.train.amp:
+ save_param_dic["scaler"] = scaler.state_dict()
+ save_param_path = (
+ self.config.results_dir
+ + "/CRAFT_clr_amp_"
+ + repr(train_step)
+ + ".pth"
+ )
+ torch.save(save_param_dic, save_param_path)
+
+def main():
+ parser = argparse.ArgumentParser(description="CRAFT custom data train")
+ parser.add_argument(
+ "--yaml",
+ "--yaml_file_name",
+ default="custom_data_train",
+ type=str,
+ help="Load configuration",
+ )
+ parser.add_argument(
+ "--port", "--use ddp port", default="2346", type=str, help="Port number"
+ )
+
+ args = parser.parse_args()
+
+ # load configure
+ exp_name = args.yaml
+ config = load_yaml(args.yaml)
+
+ print("-" * 20 + " Options " + "-" * 20)
+ print(yaml.dump(config))
+ print("-" * 40)
+
+ # Make result_dir
+ res_dir = os.path.join(config["results_dir"], args.yaml)
+ config["results_dir"] = res_dir
+ if not os.path.exists(res_dir):
+ os.makedirs(res_dir)
+
+ # Duplicate yaml file to result_dir
+ shutil.copy(
+ "config/" + args.yaml + ".yaml", os.path.join(res_dir, args.yaml) + ".yaml"
+ )
+
+ if config["mode"] == "weak_supervision":
+ # NOTE: half GPU assign train / half GPU assign supervision setting
+ ngpus_per_node = torch.cuda.device_count() // 2
+ mode = "weak_supervision"
+ else:
+ ngpus_per_node = torch.cuda.device_count()
+ mode = None
+
+ print(f"Total process num : {ngpus_per_node}")
+
+ manager = mp.Manager()
+ buffer1 = manager.list([None] * config["test"]["custom_data"]["test_set_size"])
+
+ buffer_dict = {"custom_data": buffer1}
+ torch.multiprocessing.spawn(
+ main_worker,
+ nprocs=ngpus_per_node,
+ args=(args.port, ngpus_per_node, config, buffer_dict, exp_name, mode,),
+ )
+
+
+def main_worker(gpu, port, ngpus_per_node, config, buffer_dict, exp_name, mode):
+
+ torch.distributed.init_process_group(
+ backend="nccl",
+ init_method="tcp://127.0.0.1:" + port,
+ world_size=ngpus_per_node,
+ rank=gpu,
+ )
+
+ # Apply config to wandb
+ if gpu == 0 and config["wandb_opt"]:
+ wandb.init(project="craft-stage2", entity="user_name", name=exp_name)
+ wandb.config.update(config)
+
+ batch_size = int(config["train"]["batch_size"] / ngpus_per_node)
+ config["train"]["batch_size"] = batch_size
+ config = DotDict(config)
+
+ # Start train
+ trainer = Trainer(config, gpu, mode)
+ trainer.train(buffer_dict)
+
+ if gpu == 0:
+ if config["wandb_opt"]:
+ wandb.finish()
+
+ torch.distributed.barrier()
+ torch.distributed.destroy_process_group()
+
+if __name__ == "__main__":
+ main()
diff --git a/trainer/craft/utils/craft_utils.py b/trainer/craft/utils/craft_utils.py
new file mode 100644
index 000000000..f5d39df4e
--- /dev/null
+++ b/trainer/craft/utils/craft_utils.py
@@ -0,0 +1,345 @@
+
+# -*- coding: utf-8 -*-
+import os
+import torch
+import cv2
+import math
+import numpy as np
+from data import imgproc
+
+""" auxilary functions """
+# unwarp corodinates
+
+
+
+
+def warpCoord(Minv, pt):
+ out = np.matmul(Minv, (pt[0], pt[1], 1))
+ return np.array([out[0]/out[2], out[1]/out[2]])
+""" end of auxilary functions """
+
+def test():
+ print('pass')
+
+
+def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):
+ # prepare data
+ linkmap = linkmap.copy()
+ textmap = textmap.copy()
+ img_h, img_w = textmap.shape
+
+ """ labeling method """
+ ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
+ ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)
+
+ text_score_comb = np.clip(text_score + link_score, 0, 1)
+ nLabels, labels, stats, centroids = \
+ cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)
+
+ det = []
+ mapper = []
+ for k in range(1,nLabels):
+ # size filtering
+ size = stats[k, cv2.CC_STAT_AREA]
+ if size < 10: continue
+
+ # thresholding
+ if np.max(textmap[labels==k]) < text_threshold: continue
+
+ # make segmentation map
+ segmap = np.zeros(textmap.shape, dtype=np.uint8)
+ segmap[labels==k] = 255
+ segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area
+ x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
+ w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
+ niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
+ sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
+ # boundary check
+ if sx < 0 : sx = 0
+ if sy < 0 : sy = 0
+ if ex >= img_w: ex = img_w
+ if ey >= img_h: ey = img_h
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
+ segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel, iterations=1)
+ #kernel1 = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 5))
+ #segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel1, iterations=1)
+
+
+ # make box
+ np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
+ rectangle = cv2.minAreaRect(np_contours)
+ box = cv2.boxPoints(rectangle)
+
+ # align diamond-shape
+ w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
+ box_ratio = max(w, h) / (min(w, h) + 1e-5)
+ if abs(1 - box_ratio) <= 0.1:
+ l, r = min(np_contours[:,0]), max(np_contours[:,0])
+ t, b = min(np_contours[:,1]), max(np_contours[:,1])
+ box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)
+
+ # make clock-wise order
+ startidx = box.sum(axis=1).argmin()
+ box = np.roll(box, 4-startidx, 0)
+ box = np.array(box)
+
+ det.append(box)
+ mapper.append(k)
+
+ return det, labels, mapper
+
+def getPoly_core(boxes, labels, mapper, linkmap):
+ # configs
+ num_cp = 5
+ max_len_ratio = 0.7
+ expand_ratio = 1.45
+ max_r = 2.0
+ step_r = 0.2
+
+ polys = []
+ for k, box in enumerate(boxes):
+ # size filter for small instance
+ w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)
+ if w < 30 or h < 30:
+ polys.append(None); continue
+
+ # warp image
+ tar = np.float32([[0,0],[w,0],[w,h],[0,h]])
+ M = cv2.getPerspectiveTransform(box, tar)
+ word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)
+ try:
+ Minv = np.linalg.inv(M)
+ except:
+ polys.append(None); continue
+
+ # binarization for selected label
+ cur_label = mapper[k]
+ word_label[word_label != cur_label] = 0
+ word_label[word_label > 0] = 1
+
+ """ Polygon generation """
+ # find top/bottom contours
+ cp = []
+ max_len = -1
+ for i in range(w):
+ region = np.where(word_label[:,i] != 0)[0]
+ if len(region) < 2 : continue
+ cp.append((i, region[0], region[-1]))
+ length = region[-1] - region[0] + 1
+ if length > max_len: max_len = length
+
+ # pass if max_len is similar to h
+ if h * max_len_ratio < max_len:
+ polys.append(None); continue
+
+ # get pivot points with fixed length
+ tot_seg = num_cp * 2 + 1
+ seg_w = w / tot_seg # segment width
+ pp = [None] * num_cp # init pivot points
+ cp_section = [[0, 0]] * tot_seg
+ seg_height = [0] * num_cp
+ seg_num = 0
+ num_sec = 0
+ prev_h = -1
+ for i in range(0,len(cp)):
+ (x, sy, ey) = cp[i]
+ if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:
+ # average previous segment
+ if num_sec == 0: break
+ cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]
+ num_sec = 0
+
+ # reset variables
+ seg_num += 1
+ prev_h = -1
+
+ # accumulate center points
+ cy = (sy + ey) * 0.5
+ cur_h = ey - sy + 1
+ cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]
+ num_sec += 1
+
+ if seg_num % 2 == 0: continue # No polygon area
+
+ if prev_h < cur_h:
+ pp[int((seg_num - 1)/2)] = (x, cy)
+ seg_height[int((seg_num - 1)/2)] = cur_h
+ prev_h = cur_h
+
+ # processing last segment
+ if num_sec != 0:
+ cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]
+
+ # pass if num of pivots is not sufficient or segment widh is smaller than character height
+ if None in pp or seg_w < np.max(seg_height) * 0.25:
+ polys.append(None); continue
+
+ # calc median maximum of pivot points
+ half_char_h = np.median(seg_height) * expand_ratio / 2
+
+ # calc gradiant and apply to make horizontal pivots
+ new_pp = []
+ for i, (x, cy) in enumerate(pp):
+ dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]
+ dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]
+ if dx == 0: # gradient if zero
+ new_pp.append([x, cy - half_char_h, x, cy + half_char_h])
+ continue
+ rad = - math.atan2(dy, dx)
+ c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)
+ new_pp.append([x - s, cy - c, x + s, cy + c])
+
+ # get edge points to cover character heatmaps
+ isSppFound, isEppFound = False, False
+ grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])
+ grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])
+ for r in np.arange(0.5, max_r, step_r):
+ dx = 2 * half_char_h * r
+ if not isSppFound:
+ line_img = np.zeros(word_label.shape, dtype=np.uint8)
+ dy = grad_s * dx
+ p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])
+ cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
+ if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
+ spp = p
+ isSppFound = True
+ if not isEppFound:
+ line_img = np.zeros(word_label.shape, dtype=np.uint8)
+ dy = grad_e * dx
+ p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])
+ cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)
+ if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:
+ epp = p
+ isEppFound = True
+ if isSppFound and isEppFound:
+ break
+
+ # pass if boundary of polygon is not found
+ if not (isSppFound and isEppFound):
+ polys.append(None); continue
+
+ # make final polygon
+ poly = []
+ poly.append(warpCoord(Minv, (spp[0], spp[1])))
+ for p in new_pp:
+ poly.append(warpCoord(Minv, (p[0], p[1])))
+ poly.append(warpCoord(Minv, (epp[0], epp[1])))
+ poly.append(warpCoord(Minv, (epp[2], epp[3])))
+ for p in reversed(new_pp):
+ poly.append(warpCoord(Minv, (p[2], p[3])))
+ poly.append(warpCoord(Minv, (spp[2], spp[3])))
+
+ # add to final result
+ polys.append(np.array(poly))
+
+ return polys
+
+def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):
+ boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)
+
+ if poly:
+ polys = getPoly_core(boxes, labels, mapper, linkmap)
+ else:
+ polys = [None] * len(boxes)
+
+ return boxes, polys
+
+def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):
+ if len(polys) > 0:
+ polys = np.array(polys)
+ for k in range(len(polys)):
+ if polys[k] is not None:
+ polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)
+ return polys
+
+def save_outputs(image, region_scores, affinity_scores, text_threshold, link_threshold,
+ low_text, outoput_path, confidence_mask = None):
+ """save image, region_scores, and affinity_scores in a single image. region_scores and affinity_scores must be
+ cpu numpy arrays. You can convert GPU Tensors to CPU numpy arrays like this:
+ >>> array = tensor.cpu().data.numpy()
+ When saving outputs of the network during training, make sure you convert ALL tensors (image, region_score,
+ affinity_score) to numpy array first.
+ :param image: numpy array
+ :param region_scores: [] 2D numpy array with each element between 0~1.
+ :param affinity_scores: same as region_scores
+ :param text_threshold: 0 ~ 1. Closer to 0, characters with lower confidence will also be considered a word and be boxed
+ :param link_threshold: 0 ~ 1. Closer to 0, links with lower confidence will also be considered a word and be boxed
+ :param low_text: 0 ~ 1. Closer to 0, boxes will be more loosely drawn.
+ :param outoput_path:
+ :param confidence_mask:
+ :return:
+ """
+
+ assert region_scores.shape == affinity_scores.shape
+ assert len(image.shape) - 1 == len(region_scores.shape)
+
+ boxes, polys = getDetBoxes(region_scores, affinity_scores, text_threshold, link_threshold,
+ low_text, False)
+ boxes = np.array(boxes, np.int32) * 2
+ if len(boxes) > 0:
+ np.clip(boxes[:, :, 0], 0, image.shape[1])
+ np.clip(boxes[:, :, 1], 0, image.shape[0])
+ for box in boxes:
+ cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255))
+
+ target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
+ target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
+
+ if confidence_mask is not None:
+ confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
+ gt_scores = np.hstack([target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color])
+ confidence_mask_gray = np.hstack([np.zeros_like(confidence_mask_gray), confidence_mask_gray])
+ output = np.concatenate([gt_scores, confidence_mask_gray], axis=0)
+ output = np.hstack([image, output])
+
+ else:
+ gt_scores = np.concatenate([target_gaussian_heatmap_color, target_gaussian_affinity_heatmap_color], axis=0)
+ output = np.hstack([image, gt_scores])
+
+ cv2.imwrite(outoput_path, output)
+ return output
+
+
+def save_outputs_from_tensors(images, region_scores, affinity_scores, text_threshold, link_threshold,
+ low_text, output_dir, image_names, confidence_mask = None):
+
+ """takes images, region_scores, and affinity_scores as tensors (cab be GPU).
+ :param images: 4D tensor
+ :param region_scores: 3D tensor with values between 0 ~ 1
+ :param affinity_scores: 3D tensor with values between 0 ~ 1
+ :param text_threshold:
+ :param link_threshold:
+ :param low_text:
+ :param output_dir: direcotry to save the output images. Will be joined with base names of image_names
+ :param image_names: names of each image. Doesn't have to be the base name (image file names)
+ :param confidence_mask:
+ :return:
+ """
+ #import ipdb;ipdb.set_trace()
+ #images = images.cpu().permute(0, 2, 3, 1).contiguous().data.numpy()
+ if type(images) == torch.Tensor:
+ images = np.array(images)
+
+ region_scores = region_scores.cpu().data.numpy()
+ affinity_scores = affinity_scores.cpu().data.numpy()
+
+ batch_size = images.shape[0]
+ assert batch_size == region_scores.shape[0] and batch_size == affinity_scores.shape[0] and batch_size == len(image_names), \
+ "The first dimension (i.e. batch size) of images, region scores, and affinity scores must be equal"
+
+ output_images = []
+
+ for i in range(batch_size):
+ image = images[i]
+ region_score = region_scores[i]
+ affinity_score = affinity_scores[i]
+
+ image_name = os.path.basename(image_names[i])
+ outoput_path = os.path.join(output_dir,image_name)
+
+ output_image = save_outputs(image, region_score, affinity_score, text_threshold, link_threshold,
+ low_text, outoput_path, confidence_mask=confidence_mask)
+
+ output_images.append(output_image)
+
+ return output_images
\ No newline at end of file
diff --git a/trainer/craft/utils/inference_boxes.py b/trainer/craft/utils/inference_boxes.py
new file mode 100644
index 000000000..e334395bb
--- /dev/null
+++ b/trainer/craft/utils/inference_boxes.py
@@ -0,0 +1,361 @@
+import os
+import re
+import itertools
+
+import cv2
+import time
+import numpy as np
+import torch
+from torch.autograd import Variable
+
+from utils.craft_utils import getDetBoxes, adjustResultCoordinates
+from data import imgproc
+from data.dataset import SynthTextDataSet
+import math
+import xml.etree.ElementTree as elemTree
+
+
+#-------------------------------------------------------------------------------------------------------------------#
+def rotatePoint(xc, yc, xp, yp, theta):
+ xoff = xp - xc
+ yoff = yp - yc
+
+ cosTheta = math.cos(theta)
+ sinTheta = math.sin(theta)
+ pResx = cosTheta * xoff + sinTheta * yoff
+ pResy = - sinTheta * xoff + cosTheta * yoff
+ # pRes = (xc + pResx, yc + pResy)
+ return int(xc + pResx), int(yc + pResy)
+
+def addRotatedShape(cx, cy, w, h, angle):
+ p0x, p0y = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)
+ p1x, p1y = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)
+ p2x, p2y = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)
+ p3x, p3y = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)
+
+ points = [[p0x, p0y], [p1x, p1y], [p2x, p2y], [p3x, p3y]]
+
+ return points
+
+def xml_parsing(xml):
+ tree = elemTree.parse(xml)
+
+ annotations = [] # Initialize the list to store labels
+ iter_element = tree.iter(tag="object")
+
+ for element in iter_element:
+ annotation = {} # Initialize the dict to store labels
+
+ annotation['name'] = element.find("name").text # Save the name tag value
+
+ box_coords = element.iter(tag="robndbox")
+
+ for box_coord in box_coords:
+ cx = float(box_coord.find("cx").text)
+ cy = float(box_coord.find("cy").text)
+ w = float(box_coord.find("w").text)
+ h = float(box_coord.find("h").text)
+ angle = float(box_coord.find("angle").text)
+
+ convertcoodi = addRotatedShape(cx, cy, w, h, angle)
+
+ annotation['box_coodi'] = convertcoodi
+ annotations.append(annotation)
+
+ box_coords = element.iter(tag="bndbox")
+
+ for box_coord in box_coords:
+ xmin = int(box_coord.find("xmin").text)
+ ymin = int(box_coord.find("ymin").text)
+ xmax = int(box_coord.find("xmax").text)
+ ymax = int(box_coord.find("ymax").text)
+ # annotation['bndbox'] = [xmin,ymin,xmax,ymax]
+
+ annotation['box_coodi'] = [[xmin, ymin], [xmax, ymin], [xmax, ymax],
+ [xmin, ymax]]
+ annotations.append(annotation)
+
+
+
+
+ bounds = []
+ for i in range(len(annotations)):
+ box_info_dict = {"points": None, "text": None, "ignore": None}
+
+ box_info_dict["points"] = np.array(annotations[i]['box_coodi'])
+ if annotations[i]['name'] == "dnc":
+ box_info_dict["text"] = "###"
+ box_info_dict["ignore"] = True
+ else:
+ box_info_dict["text"] = annotations[i]['name']
+ box_info_dict["ignore"] = False
+
+ bounds.append(box_info_dict)
+
+
+
+ return bounds
+
+#-------------------------------------------------------------------------------------------------------------------#
+
+def load_prescription_gt(dataFolder):
+
+
+ total_img_path = []
+ total_imgs_bboxes = []
+ for (root, directories, files) in os.walk(dataFolder):
+ for file in files:
+ if '.jpg' in file:
+ img_path = os.path.join(root, file)
+ total_img_path.append(img_path)
+ if '.xml' in file:
+ gt_path = os.path.join(root, file)
+ total_imgs_bboxes.append(gt_path)
+
+
+ total_imgs_parsing_bboxes = []
+ for img_path, bbox in zip(sorted(total_img_path), sorted(total_imgs_bboxes)):
+ # check file
+
+ assert img_path.split(".jpg")[0] == bbox.split(".xml")[0]
+
+ result_label = xml_parsing(bbox)
+ total_imgs_parsing_bboxes.append(result_label)
+
+
+ return total_imgs_parsing_bboxes, sorted(total_img_path)
+
+
+# NOTE
+def load_prescription_cleval_gt(dataFolder):
+
+
+ total_img_path = []
+ total_gt_path = []
+ for (root, directories, files) in os.walk(dataFolder):
+ for file in files:
+ if '.jpg' in file:
+ img_path = os.path.join(root, file)
+ total_img_path.append(img_path)
+ if '_cl.txt' in file:
+ gt_path = os.path.join(root, file)
+ total_gt_path.append(gt_path)
+
+
+ total_imgs_parsing_bboxes = []
+ for img_path, gt_path in zip(sorted(total_img_path), sorted(total_gt_path)):
+ # check file
+
+ assert img_path.split(".jpg")[0] == gt_path.split('_label_cl.txt')[0]
+
+ lines = open(gt_path, encoding="utf-8").readlines()
+ word_bboxes = []
+
+ for line in lines:
+ box_info_dict = {"points": None, "text": None, "ignore": None}
+ box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
+
+ box_points = [int(box_info[i]) for i in range(8)]
+ box_info_dict["points"] = np.array(box_points)
+
+ word_bboxes.append(box_info_dict)
+ total_imgs_parsing_bboxes.append(word_bboxes)
+
+ return total_imgs_parsing_bboxes, sorted(total_img_path)
+
+
+def load_synthtext_gt(data_folder):
+
+ synth_dataset = SynthTextDataSet(
+ output_size=768, data_dir=data_folder, saved_gt_dir=data_folder, logging=False
+ )
+ img_names, img_bbox, img_words = synth_dataset.load_data(bbox="word")
+
+ total_img_path = []
+ total_imgs_bboxes = []
+ for index in range(len(img_bbox[:100])):
+ img_path = os.path.join(data_folder, img_names[index][0])
+ total_img_path.append(img_path)
+ try:
+ wordbox = img_bbox[index].transpose((2, 1, 0))
+ except:
+ wordbox = np.expand_dims(img_bbox[index], axis=0)
+ wordbox = wordbox.transpose((0, 2, 1))
+
+ words = [re.split(" \n|\n |\n| ", t.strip()) for t in img_words[index]]
+ words = list(itertools.chain(*words))
+ words = [t for t in words if len(t) > 0]
+
+ if len(words) != len(wordbox):
+ import ipdb
+
+ ipdb.set_trace()
+
+ single_img_bboxes = []
+ for j in range(len(words)):
+ box_info_dict = {"points": None, "text": None, "ignore": None}
+ box_info_dict["points"] = wordbox[j]
+ box_info_dict["text"] = words[j]
+ box_info_dict["ignore"] = False
+ single_img_bboxes.append(box_info_dict)
+
+ total_imgs_bboxes.append(single_img_bboxes)
+
+ return total_imgs_bboxes, total_img_path
+
+
+def load_icdar2015_gt(dataFolder, isTraing=False):
+ if isTraing:
+ img_folderName = "ch4_training_images"
+ gt_folderName = "ch4_training_localization_transcription_gt"
+ else:
+ img_folderName = "ch4_test_images"
+ gt_folderName = "ch4_test_localization_transcription_gt"
+
+ gt_folder_path = os.listdir(os.path.join(dataFolder, gt_folderName))
+ total_imgs_bboxes = []
+ total_img_path = []
+ for gt_path in gt_folder_path:
+ gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path)
+ img_path = (
+ gt_path.replace(gt_folderName, img_folderName)
+ .replace(".txt", ".jpg")
+ .replace("gt_", "")
+ )
+ image = cv2.imread(img_path)
+ lines = open(gt_path, encoding="utf-8").readlines()
+ single_img_bboxes = []
+ for line in lines:
+ box_info_dict = {"points": None, "text": None, "ignore": None}
+
+ box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
+ box_points = [int(box_info[j]) for j in range(8)]
+ word = box_info[8:]
+ word = ",".join(word)
+ box_points = np.array(box_points, np.int32).reshape(4, 2)
+ cv2.polylines(
+ image, [np.array(box_points).astype(np.int)], True, (0, 0, 255), 1
+ )
+ box_info_dict["points"] = box_points
+ box_info_dict["text"] = word
+ if word == "###":
+ box_info_dict["ignore"] = True
+ else:
+ box_info_dict["ignore"] = False
+
+ single_img_bboxes.append(box_info_dict)
+ total_imgs_bboxes.append(single_img_bboxes)
+ total_img_path.append(img_path)
+ return total_imgs_bboxes, total_img_path
+
+
+def load_icdar2013_gt(dataFolder, isTraing=False):
+
+ # choise test dataset
+ if isTraing:
+ img_folderName = "Challenge2_Test_Task12_Images"
+ gt_folderName = "Challenge2_Test_Task1_GT"
+ else:
+ img_folderName = "Challenge2_Test_Task12_Images"
+ gt_folderName = "Challenge2_Test_Task1_GT"
+
+ gt_folder_path = os.listdir(os.path.join(dataFolder, gt_folderName))
+
+ total_imgs_bboxes = []
+ total_img_path = []
+ for gt_path in gt_folder_path:
+ gt_path = os.path.join(os.path.join(dataFolder, gt_folderName), gt_path)
+ img_path = (
+ gt_path.replace(gt_folderName, img_folderName)
+ .replace(".txt", ".jpg")
+ .replace("gt_", "")
+ )
+ image = cv2.imread(img_path)
+ lines = open(gt_path, encoding="utf-8").readlines()
+ single_img_bboxes = []
+ for line in lines:
+ box_info_dict = {"points": None, "text": None, "ignore": None}
+
+ box_info = line.strip().encode("utf-8").decode("utf-8-sig").split(",")
+ box = [int(box_info[j]) for j in range(4)]
+ word = box_info[4:]
+ word = ",".join(word)
+ box = [
+ [box[0], box[1]],
+ [box[2], box[1]],
+ [box[2], box[3]],
+ [box[0], box[3]],
+ ]
+
+ box_info_dict["points"] = box
+ box_info_dict["text"] = word
+ if word == "###":
+ box_info_dict["ignore"] = True
+ else:
+ box_info_dict["ignore"] = False
+
+ single_img_bboxes.append(box_info_dict)
+
+ total_imgs_bboxes.append(single_img_bboxes)
+ total_img_path.append(img_path)
+
+ return total_imgs_bboxes, total_img_path
+
+
+def test_net(
+ net,
+ image,
+ text_threshold,
+ link_threshold,
+ low_text,
+ cuda,
+ poly,
+ canvas_size=1280,
+ mag_ratio=1.5,
+):
+ # resize
+
+ img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
+ image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
+ )
+ ratio_h = ratio_w = 1 / target_ratio
+
+ # preprocessing
+ x = imgproc.normalizeMeanVariance(img_resized)
+ x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]
+ x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]
+ if cuda:
+ x = x.cuda()
+
+ # forward pass
+ with torch.no_grad():
+ y, feature = net(x)
+
+ # make score and link map
+ score_text = y[0, :, :, 0].cpu().data.numpy().astype(np.float32)
+ score_link = y[0, :, :, 1].cpu().data.numpy().astype(np.float32)
+
+ # NOTE
+ score_text = score_text[: size_heatmap[0], : size_heatmap[1]]
+ score_link = score_link[: size_heatmap[0], : size_heatmap[1]]
+
+ # Post-processing
+ boxes, polys = getDetBoxes(
+ score_text, score_link, text_threshold, link_threshold, low_text, poly
+ )
+
+ # coordinate adjustment
+ boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
+ polys = adjustResultCoordinates(polys, ratio_w, ratio_h)
+ for k in range(len(polys)):
+ if polys[k] is None:
+ polys[k] = boxes[k]
+
+ # render results (optional)
+ score_text = score_text.copy()
+ render_score_text = imgproc.cvt2HeatmapImg(score_text)
+ render_score_link = imgproc.cvt2HeatmapImg(score_link)
+ render_img = [render_score_text, render_score_link]
+ # ret_score_text = imgproc.cvt2HeatmapImg(render_img)
+
+ return boxes, polys, render_img
diff --git a/trainer/craft/utils/util.py b/trainer/craft/utils/util.py
new file mode 100644
index 000000000..f6c862220
--- /dev/null
+++ b/trainer/craft/utils/util.py
@@ -0,0 +1,142 @@
+from collections import OrderedDict
+import os
+
+import cv2
+import numpy as np
+
+from data import imgproc
+from utils import craft_utils
+
+
+def copyStateDict(state_dict):
+ if list(state_dict.keys())[0].startswith("module"):
+ start_idx = 1
+ else:
+ start_idx = 0
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = ".".join(k.split(".")[start_idx:])
+ new_state_dict[name] = v
+ return new_state_dict
+
+
+def saveInput(
+ imagename, vis_dir, image, region_scores, affinity_scores, confidence_mask
+):
+ image = np.uint8(image.copy())
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+
+ boxes, polys = craft_utils.getDetBoxes(
+ region_scores, affinity_scores, 0.85, 0.2, 0.5, False
+ )
+
+ if image.shape[0] / region_scores.shape[0] >= 2:
+ boxes = np.array(boxes, np.int32) * 2
+ else:
+ boxes = np.array(boxes, np.int32)
+
+ if len(boxes) > 0:
+ np.clip(boxes[:, :, 0], 0, image.shape[1])
+ np.clip(boxes[:, :, 1], 0, image.shape[0])
+ for box in boxes:
+ cv2.polylines(image, [np.reshape(box, (-1, 1, 2))], True, (0, 0, 255))
+ target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
+ target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
+ confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
+
+ # overlay
+ height, width, channel = image.shape
+ overlay_region = cv2.resize(target_gaussian_heatmap_color, (width, height))
+ overlay_aff = cv2.resize(target_gaussian_affinity_heatmap_color, (width, height))
+ confidence_mask_gray = cv2.resize(
+ confidence_mask_gray, (width, height), interpolation=cv2.INTER_NEAREST
+ )
+ overlay_region = cv2.addWeighted(image, 0.4, overlay_region, 0.6, 5)
+ overlay_aff = cv2.addWeighted(image, 0.4, overlay_aff, 0.7, 6)
+
+ gt_scores = np.concatenate([overlay_region, overlay_aff], axis=1)
+
+ output = np.concatenate([gt_scores, confidence_mask_gray], axis=1)
+
+ output = np.hstack([image, output])
+
+ # synthtext
+ if type(imagename) is not str:
+ imagename = imagename[0].split("/")[-1][:-4]
+
+ outpath = vis_dir + f"/{imagename}_input.jpg"
+ if not os.path.exists(os.path.dirname(outpath)):
+ os.makedirs(os.path.dirname(outpath), exist_ok=True)
+ cv2.imwrite(outpath, output)
+ # print(f'Logging train input into {outpath}')
+
+
+def saveImage(
+ imagename,
+ vis_dir,
+ image,
+ bboxes,
+ affi_bboxes,
+ region_scores,
+ affinity_scores,
+ confidence_mask,
+):
+ output_image = np.uint8(image.copy())
+ output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR)
+ if len(bboxes) > 0:
+ for i in range(len(bboxes)):
+ _bboxes = np.int32(bboxes[i])
+ for j in range(_bboxes.shape[0]):
+ cv2.polylines(
+ output_image,
+ [np.reshape(_bboxes[j], (-1, 1, 2))],
+ True,
+ (0, 0, 255),
+ )
+
+ for i in range(len(affi_bboxes)):
+ cv2.polylines(
+ output_image,
+ [np.reshape(affi_bboxes[i].astype(np.int32), (-1, 1, 2))],
+ True,
+ (255, 0, 0),
+ )
+
+ target_gaussian_heatmap_color = imgproc.cvt2HeatmapImg(region_scores)
+ target_gaussian_affinity_heatmap_color = imgproc.cvt2HeatmapImg(affinity_scores)
+ confidence_mask_gray = imgproc.cvt2HeatmapImg(confidence_mask)
+
+ # overlay
+ height, width, channel = image.shape
+ overlay_region = cv2.resize(target_gaussian_heatmap_color, (width, height))
+ overlay_aff = cv2.resize(target_gaussian_affinity_heatmap_color, (width, height))
+
+ overlay_region = cv2.addWeighted(image.copy(), 0.4, overlay_region, 0.6, 5)
+ overlay_aff = cv2.addWeighted(image.copy(), 0.4, overlay_aff, 0.6, 5)
+
+ heat_map = np.concatenate([overlay_region, overlay_aff], axis=1)
+
+ # synthtext
+ if type(imagename) is not str:
+ imagename = imagename[0].split("/")[-1][:-4]
+
+ output = np.concatenate([output_image, heat_map, confidence_mask_gray], axis=1)
+ outpath = vis_dir + f"/{imagename}.jpg"
+ if not os.path.exists(os.path.dirname(outpath)):
+ os.makedirs(os.path.dirname(outpath), exist_ok=True)
+
+ cv2.imwrite(outpath, output)
+ # print(f'Logging original image into {outpath}')
+
+
+def save_parser(args):
+
+ """ final options """
+ with open(f"{args.results_dir}/opt.txt", "a", encoding="utf-8") as opt_file:
+ opt_log = "------------ Options -------------\n"
+ arg = vars(args)
+ for k, v in arg.items():
+ opt_log += f"{str(k)}: {str(v)}\n"
+ opt_log += "---------------------------------------\n"
+ print(opt_log)
+ opt_file.write(opt_log)