OpenVINO手写数字识别
小o
更新于 3年前
模型介绍
之前没有注意到,最近在OpenVINO2020R04版本的模型库中发现了它有个手写数字识别的模型,支持 or . 格式的数字识别与小数点识别。相关的模型为:handwritten-score-recognition-0003
该模型是基于LSTM双向神经网络训练,基于CTC损失,
输入格式为:[NCHW]= [1x1x32x64]
输出格式为:[WxBxL]=[16x1x13]
其中13表示"0123456789._#",#表示空白、_表示非数字的字符
对输出格式的解码方式支持CTC贪心与Beam搜索,演示程序使用CTC贪心解码,这种方式最简单,我喜欢!
代码实现
代码基于OPenVINO-Python SDK实现,首先需要说明一下,OpenVINO python SDK中主要的类是IECore,首先创建IECore实例对象,然后完成下面的流程操作:创建实例,加载模型
log.info("Creating Inference Engine")
ie = IECore()
net = ie.read_network(model=model_xml, weight***odel_bin)
获取输入与输出层名称
log.info("Preparing input blob*****r>input_it = iter(net.input_info)
input_blob = next(input_it)
print(input_blob)
output_it = iter(net.output****r>out_blob = next(output_it)
#Read and pre-process input image***r>print(net.input_info[input_blob].input_data.shape)
n, c, h, w = net.input_info[input_blob].input_data.shape
加载网络为可执行网络,
#Loading model to the plugin
exec_net = ie.load_network(network=net, device_name="CPU")
读取输入图像,并处理为 or ., 格式,代码实现如下:
ocr = cv.imread("D:/images/zsxq/ocr1.png")
cv.imshow("input", ocr)
gray = cv.cvtColor(ocr, cv.COLOR_BGR2GRAY)
binary = cv.adaptiveThreshold(gray, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY_INV, 25, 10)
cv.imshow("binary", binary)
contours, hireachy = cv.findContour***inary, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
for cnt in range(len(contours)):
area = cv.contourArea(contours[cnt])
if area < 10:
cv.drawContour***inary, contours, cnt, (0), -1, 8)
cv.imshow("remove noise", binary)
#获取每个分数
temp = np.copy(binary)
se = cv.getStructuringElement(cv.MORPH_RECT, (45, 5))
temp = cv.dilate(temp, se)
contours, hireachy = cv.findContours(temp, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
for cnt in range(len(contours)):
x, y, iw, ih = cv.boundingRect(contours[cnt])
roi = gray[y:y + ih, x:x + iw]
image = cv.resize(roi, (w, h))
运行测试图像:
输出结果:
针对每个字符识别,推理与CTC解析输出结果:
img_blob = np.expand_dims(image, 0)
#Start sync inference
log.info("Starting inference in synchronou***ode")
inf_start1 = time.time()
res = exec_net.infer(inputs={input_blob: [img_blob]})
inf_end1 = time.time() - inf_start1
print("inference time(ms) : %.3f" % (inf_end1 * 1000))
res = res[out_blob]
#CTC greedy decode from here
print(res.shape)
#解析输出text
ocrstr = ""
prev_pad = False;
for i in range(res.shape[0]):
ctc = res[i] # 1x13
ctc = np.squeeze(ctc, 0)
index, prob = ctc_soft_max(ctc)
if digit_nums[index] == '#':
prev_pad = True
else:
if len(ocrstr) == 0 or prev_pad or (len(ocrstr) > 0 and digit_nums[index] != ocrstr[-1]):
prev_pad = False
ocrstr += digit_nums[index]
cv.putText(ocr, ocrstr, (x, y-5), cv.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2, 8)
cv.rectangle(ocr, (x, y), (x+iw, y+ih), (0, 255, 0), 2, 8, 0)
最终的运行截图如下:
0个评论