光学字符识别(或光学字符阅读器,又名 OCR)是一种在过去二十年中用于识别和数字化图像中出现的字母和数字字符的技术。在行业中,这项技术可以帮助我们避免人工手动输入数据。
本文中,我们将了解如何将深度学习应用于OCR技术,以及对手写字符进行分类所需的步骤:
· 准备用于训练 OCR 模型的 0-9 和 A-Z 字母数据集。
· 加载数据集。
· 在数据集上乐成训练 Keras 和 TensorFlow 模型。
· 绘制训练结果并可视化验证数据。
· 推测某些图像中存在的文本。
准备数据集
我们利用以下两个数据集来训练我们的 Keras 和 TensorFlow 模型。
· 0–9: MNIST
· A-Z: Kaggle
MNIST数据集
该数据集由 NIST 的特殊数据库3和特殊数据库1构建而成,此中包含手写数字的二进制图像。
It is built into popular deep learning frameworks, including Keras, TensorFlow, PyTorch, etc.The MNIST dataset will allow us to recognize the digits 0–9.Each of these digits is contained in a 28 x 28 grayscale image.
Kaggle数据集
This dataset takes the capital letters A–Z from NIST Special Database 19.Kaggle also rescales them from 28 x 28 grayscale pixels to the same format as our MNIST data.
加载数据集
由于我们有两个独立的数据集,首先我们必须加载两个数据集并将它们组合成一个数据集。
加载 Kaggle 字母数据集
def load_az_dataset(dataset_path): # initialize the list of data and labels data = [] labels = [] # loop over the rows of the A-Z handwritten digit dataset for row in open(dataset_path): # parse the label and image from the row row = row.split(",") label = int(row[0]) image = np.array([int(x) for x in row[1:]], dtype="uint8") # images are represented as single channel (grayscale) images # that are 28x28=784 pixels -- we need to take this flattened # 784-d list of numbers and reshape them into a 28x28 matrix image = image.reshape((28, 28)) # update the list of data and labels data.append(image) labels.append(label) # convert the data and labels to NumPy arrays data = np.array(data, dtype="float32") labels = np.array(labels, dtype="int") # return a 2-tuple of the A-Z data and labels return (data, labels)加载 MNIST 数字数据集
def load_zero_nine_dataset(): # load the MNIST dataset and stack the training data and testing # data together (we'll create our own training and testing splits # later in the project) ((trainData, trainLabels), (testData, testLabels)) = mnist.load_data() data = np.vstack([trainData, testData]) labels = np.hstack([trainLabels, testLabels]) # return a 2-tuple of the MNIST data and labels return (data, labels)归并数据集
...# load all datasets(azData, azLabels) = load_az_dataset(args["az"])(digitsData, digitsLabels) = load_zero_nine_dataset()# the MNIST dataset occupies the labels 0-9, so let's add 10 to every A-Z label to ensure the A-Z characters are not incorrectly labeled as digitsazLabels += 10# stack the A-Z data and labels with the MNIST digits data and labelsdata = np.vstack([azData, digitsData])labels = np.hstack([azLabels, digitsLabels])...在数据集上训练模型
本文利用 Keras、TensorFlow 和 ResNet 架构来训练模型。
model = ResNet.build(32, 32, 1, len(le.classes_), (3, 3, 3), (64, 64, 128, 256), reg=0.0005)...H = model.fit( aug.flow(trainX, trainY, batch_size=BS), validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS, epochs=EPOCHS, class_weight=classWeight, verbose=1)...利用以下下令训练该模型需要大约 30-45 分钟。
python train_model.py --az dataset/a_z_handwritten_data.csv --model trained_ocr.model[INFO] loading datasets...[INFO] compiling model...[INFO] training network...Epoch 1/50 34/437 [=>……………………….] — ETA: 7:40 — loss: 2.5050 — accuracy: 0.2989...可视化结果
我们将绘制一个可视化,以确保它正常工作。
推测
我们利用以下代码进行推测
python prediction.py — model trained_ocr.model — image images/hello_world.png[INFO] H - 92.48%[INFO] W - 54.50%[INFO] E - 94.93%[INFO] L - 97.58%[INFO] 2 - 65.73%[INFO] L - 96.56%[INFO] R - 97.31%[INFO] 0 - 37.92%[INFO] L - 97.13%[INFO] D - 97.83%
完整的源代码可以在这里看到:https://github.com/housecricket/how-to-train-OCR-with-Keras-and-TensorFlow
文件树布局如下所示
├── __init__.py├── dataset│ └── a_z_handwritten_data.csv├── images│ ├── hello_world.png│ └── vietnamxinchao.png├── models│ ├── __init__.py│ └── resnet.py├── prediction.py├── requirements.txt├── train_model.py├── trained_ocr.model└── utils.py总结
在本文中,我们利用Keras、TensorFlow和Python来训练OCR模型,是不是很简朴~作为一个深度学习的入门算法,快来试试吧~ |