-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy pathprepare_dataset.py
More file actions
90 lines (80 loc) · 3.1 KB
/
prepare_dataset.py
File metadata and controls
90 lines (80 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import random
import click
import cv2
from loguru import logger
import numpy as np
import json
from pathlib import Path
from tqdm import tqdm
@logger.catch
def crop_minAreaRect(img, xc, yc, w, h, a):
box = cv2.boxPoints(((xc, yc), (w, h), -a))
w, h = int(w), int(h)
box = np.int0(box)
src_pts = box.astype("float32")
dst_pts = np.array([[0, h - 1],
[0, 0],
[w - 1, 0],
[w - 1, h - 1]], dtype="float32")
M = cv2.getPerspectiveTransform(src_pts, dst_pts)
warped = cv2.warpPerspective(img, M, (w, h))
return warped
@click.command()
@click.option('-a', '--annotations-path',
type=click.Path(exists=True,
file_okay=False,
readable=True,
path_type=Path),
default=Path('dataset_info/'))
@click.option('-s', '--save_path',
type=click.Path(file_okay=False,
writable=True,
path_type=Path),
default=Path('cropped/'))
@click.option('--no-split', is_flag=True)
@click.option('--reduce', type=float)
@logger.catch
def main(annotations_path: Path, save_path: Path, no_split: bool, reduce: float):
annotations: list[Path]
outputs: list[Path]
if no_split:
annotations = [annotations_path / 'imgur5k_annotations.json']
outputs = [save_path / 'whole']
else:
annotations = [
annotations_path / 'imgur5k_annotations_train.json',
annotations_path / 'imgur5k_annotations_val.json',
annotations_path / 'imgur5k_annotations_test.json',
]
outputs = [
save_path / 'train',
save_path / 'val',
save_path / 'test'
]
annotations_path.mkdir(parents=True, exist_ok=True)
for annotation_path, output_path in tqdm(zip(annotations, outputs)):
words = {}
output_path.mkdir(parents=True, exist_ok=True)
annotation = json.load(annotation_path.open('r'))
annotations = list(annotation['index_to_ann_map'].items())
if reduce is not None:
random.shuffle(annotations)
annotations = annotations[:int(len(annotations) * reduce)]
for index_id, ann_ids in tqdm(annotations, leave=False):
img_info = annotation['index_id'][index_id]
img = cv2.imread(img_info['image_path'])
if img is None:
continue
for ann_id in ann_ids:
info = annotation['ann_id'][ann_id]
info['word'] = str(info['word'])
if len(info['word']) == 0:
continue
words[ann_id] = info['word']
if (output_path / f'{ann_id}.png').exists():
continue
img_cropped = crop_minAreaRect(img, *info['bounding_box'])
cv2.imwrite(str(output_path / f'{ann_id}.png'), img_cropped)
json.dump(words, (output_path / 'words.json').open('w'))
if __name__ == '__main__':
main() # pylint: disable=no-value-for-parameter