实例

1. 用数据集训练模型

import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Conv2D, Flatten, MaxPooling2D, Dropout
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 
# 创建CNN模型
model = Sequential()
model.add(Conv2D(10, (5, 5), activation="relu", input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(20, (5, 5), activation="relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(100, activation="relu"))
model.add(Dense(10, activation="softmax"))
model.compile(optimizer="rmsprop",
              loss=tf.keras.losses.categorical_crossentropy, metrics=['accuracy'])
 
# 加载mnist训练数据集,是关于手写数字的训练数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# normalize图片处理
normalized_x_train = tf.keras.utils.normalize(x_train)
normalized_x_test = tf.keras.utils.normalize(x_test)
# Hone hot标签处理
one_hot_y_train = tf.one_hot(y_train, 10)
one_hot_y_test = tf.one_hot(y_test, 10)
reshaped_x_train = normalized_x_train.reshape(-1, 28, 28, 1)
reshaped_x_test = normalized_x_test.reshape(-1, 28, 28, 1)
train_result = model.fit(reshaped_x_train, one_hot_y_train,
                         epochs=20, validation_data=(reshaped_x_test, one_hot_y_test))
# 保存模型
model.save('./model')
# 显示训练结果
plt.plot(train_result.history['accuracy'])
plt.plot(train_result.history['val_accuracy'])
plt.legend(["Accuracy", "ValidationAcc"])
plt.show()
 

2. 用训练好的模型进行判断

import cv2
import numpy as np
from tensorflow import keras
model = keras.models.load_model('model')
# 用实际图片测试训练结果
img = cv2.imread("test/2a.png")
# 图像处理
img_width = img.shape[1]
img_height = img.shape[0]
col_start = int((img_width-img_height)/2)
col_end = int(col_start+img_height)
cropped_img = img[:, col_start:col_end, :]
gray_img = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2GRAY)
(thresh, black_white) = cv2.threshold(gray_img,
                                      128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
black_white = cv2.bitwise_not(black_white)
black_white = cv2.resize(black_white, (28, 28))
black_white = black_white/255
black_white = black_white.reshape(-1, 28, 28, 1)
# 输入测试图片到分类模型中, 获得结果
prediction = model.predict(black_white)
print(np.argmax(prediction))
 

常见需求