-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinferenceModel.py
More file actions
51 lines (33 loc) · 1.48 KB
/
inferenceModel.py
File metadata and controls
51 lines (33 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import cv2
import typing
import numpy as np
from mltu.inferenceModel import OnnxInferenceModel
from mltu.utils.text_utils import ctc_decoder, get_cer
class ImageToWordModel(OnnxInferenceModel):
def __init__(self, char_list: typing.Union[str, list], *args, **kwargs):
super().__init__(*args, **kwargs)
self.char_list = char_list
def predict(self, image: np.ndarray):
image = cv2.resize(image, self.input_shape[:2][::-1])
image_pred = np.expand_dims(image, axis=0).astype(np.float32)
preds = self.model.run(None, {self.input_name: image_pred})[0]
text = ctc_decoder(preds, self.char_list)[0]
return text
if __name__ == "__main__":
import pandas as pd
from tqdm import tqdm
from mltu.configs import BaseModelConfigs
configs = BaseModelConfigs.load("Models/1_image_to_word/202211270035/configs.yaml")
model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab)
df = pd.read_csv("Models/1_image_to_word/202211270035/val.csv").dropna().values.tolist()
accum_cer = []
for image_path, label in tqdm(df[:20]):
image = cv2.imread(image_path)
try:
prediction_text = model.predict(image)
cer = get_cer(prediction_text, label)
print(f"Image: {image_path}, Label: {label}, Prediction: {prediction_text}, CER: {cer}")
except:
continue
accum_cer.append(cer)
print(f"Average CER: {np.average(accum_cer)}")