tensorflow——深度学习实现手写体识别

'''
使用手写体识别模型进行预测
'''
import tensorflow as tf
import pylab as plt

# 读取一张图片
img = tf.keras.preprocessing.image.load_img('./test_img/6.png')
img = img.resize((28, 28))
# 图片灰度化
img = img.convert('L')
# 图片转成张量
img = tf.keras.preprocessing.image.img_to_array(img).reshape((784))

# 图片归一化
img = img / 255.0

# 显示图片
plt.imshow(img.reshape((28, 28)), cmap='gray')
plt.show()

x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])

# 定义模型
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y_pred = tf.nn.softmax(tf.matmul(x, W) + b)

# 损失函数:交叉熵
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_pred), axis=1))
# 优化器:随机梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 加载模型
    saver = tf.train.Saver()
    saver.restore(sess, './model/mnist')

    # 识别图片
    pred = sess.run(y_pred, feed_dict={x: [img]})
    # 显示预测结果
    # print(pred)
    print('\t\t\t\t\t识别结果:', tf.argmax(pred, 1).eval())


最后编辑于:2024/05/23作者: 牛逼PHP

发表评论