CNN和DNN
使用 CNN(卷积神经网络) 和 DNN(全连接神经网络) 训练 MNIST 数据集的完整代码,包括:
✅ 数据加载 & 预处理 ✅ CNN 模型定义 & 训练 ✅ DNN 模型定义 & 训练 ✅ 模型评估 & 预测 ✅ 可视化预测结果
📌 代码结构
-
CNN 模型(用于图像分类,适用于 MNIST)
-
DNN 模型(全连接网络,适用于基本分类任务)
-
对比 CNN 和 DNN 的性能
-
展示预测效果
📌 代码实现
import numpy as np
import os
import keras
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
# 设置 Keras 后端
os.environ[“KERAS_BACKEND”] = “tensorflow”
# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# 归一化数据
x_train = x_train.astype(“float32”) / 255
x_test = x_test.astype(“float32”) / 255
# 调整形状以适应 CNN 结构
x_train_cnn = np.expand_dims(x_train, -1)
x_test_cnn = np.expand_dims(x_test, -1)
# 调整形状以适应 DNN 结构(展平 28x28 为 784)
x_train_dnn = x_train.reshape(-1, 28 * 28)
x_test_dnn = x_test.reshape(-1, 28 * 28)
# 设定类别数
num_classes = 10
# 将标签转换为 one-hot 编码(仅 DNN 需要)
y_train_onehot = keras.utils.to_categorical(y_train, num_classes)
y_test_onehot = keras.utils.to_categorical(y_test, num_classes)
# ---------------------------------------
# 📌 CNN 模型定义
# ---------------------------------------
def build_cnn_model():
model = keras.Sequential([
keras.layers.Conv2D(64, kernel_size=(3, 3), activation=“relu”, input_shape=(28, 28, 1)),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation=“relu”),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Conv2D(128, kernel_size=(3, 3), activation=“relu”),
keras.layers.Conv2D(128, kernel_size=(3, 3), activation=“relu”),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dropout(0.5),
keras.layers.Dense(num_classes, activation=“softmax”),
])
return model
# ---------------------------------------
# 📌 DNN 模型定义
# ---------------------------------------
def build_dnn_model():
model = keras.Sequential([
keras.layers.Dense(512, activation=“relu”, input_shape=(28 * 28,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(512, activation=“relu”),
keras.layers.Dropout(0.2),
keras.layers.Dense(num_classes, activation=“softmax”),
])
return model
# ---------------------------------------
# 📌 训练 CNN
# ---------------------------------------
cnn_model = build_cnn_model()
cnn_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
metrics=[“accuracy”],
)
cnn_model.fit(
x_train_cnn, y_train, batch_size=128, epochs=5, validation_split=0.15
)
# ---------------------------------------
# 📌 训练 DNN
# ---------------------------------------
dnn_model = build_dnn_model()
dnn_model.compile(
loss=“categorical_crossentropy”,
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
metrics=[“accuracy”],
)
dnn_model.fit(
x_train_dnn, y_train_onehot, batch_size=128, epochs=5, validation_split=0.15
)
# ---------------------------------------
# 📌 评估 CNN 和 DNN
# ---------------------------------------
cnn_score = cnn_model.evaluate(x_test_cnn, y_test, verbose=0)
dnn_score = dnn_model.evaluate(x_test_dnn, y_test_onehot, verbose=0)
print(f"CNN 测试集准确率: {cnn_score[1]:.4f}“)
print(f"DNN 测试集准确率: {dnn_score[1]:.4f}”)
# ---------------------------------------
# 📌 CNN 和 DNN 进行预测
# ---------------------------------------
cnn_predictions = np.argmax(cnn_model.predict(x_test_cnn), axis=1)
dnn_predictions = np.argmax(dnn_model.predict(x_test_dnn), axis=1)
# ---------------------------------------
# 📌 可视化 CNN 和 DNN 预测结果
# ---------------------------------------
def plot_predictions(preds, model_name):
num_images = 10
plt.figure(figsize=(10, 5))
for i in range(num_images):
plt.subplot(2, 5, i + 1)
plt.imshow(x_test[i], cmap=“gray”) # 显示灰度图
plt.title(f"{model_name} Pred: {preds[i]}\nTrue: {y_test[i]}“, fontsize=10)
plt.axis(“off”)
plt.tight_layout()
plt.show()
print(”\n🎯 CNN 预测结果:“)
plot_predictions(cnn_predictions, “CNN”)
print(”\n🎯 DNN 预测结果:“)
plot_predictions(dnn_predictions, “DNN”)
# ---------------------------------------
# 📌 计算 CNN 和 DNN 预测错误的示例
# ---------------------------------------
def plot_wrong_predictions(preds, model_name):
wrong_preds = np.where(preds != y_test)[0]
num_errors = min(10, len(wrong_preds))
plt.figure(figsize=(10, 5))
for i in range(num_errors):
idx = wrong_preds[i]
plt.subplot(2, 5, i + 1)
plt.imshow(x_test[idx], cmap=“gray”)
plt.title(f"Pred: {preds[idx]}\nTrue: {y_test[idx]}”, fontsize=10, color=“red”)
plt.axis(“off”)
plt.tight_layout()
plt.show()
print(“\n🚨 CNN 预测错误示例:”)
plot_wrong_predictions(cnn_predictions, “CNN”)
print(“\n🚨 DNN 预测错误示例:”)
plot_wrong_predictions(dnn_predictions, “DNN”)
📌 代码解析
1️⃣ 数据预处理
-
CNN 需要 (28, 28, 1) 格式输入,因此使用
np.expand_dims()
-
DNN 需要 展平为 (784,) 形状,因此使用
reshape(-1, 28*28)
2️⃣ 模型定义
-
CNN 模型
-
Conv2D
卷积层 -
MaxPooling2D
池化层 -
GlobalAveragePooling2D
-
Dropout
防止过拟合 -
Dense
输出层(Softmax 分类)
-
-
DNN 模型
-
Dense(512, activation="relu")
两层 -
Dropout(0.2)
防止过拟合 -
Dense(10, activation="softmax")
输出层
-
3️⃣ 模型训练
-
CNN 直接用
y_train
(SparseCategoricalCrossentropy 适用于整数标签) -
DNN 需要
y_train_onehot
(CategoricalCrossentropy 适用于 one-hot)
4️⃣ 预测 & 可视化
-
对比 CNN 和 DNN 预测效果
-
展示预测错误的示例
📌 运行结果
✅ CNN 测试集准确率 ≈ 98.7% ✅ DNN 测试集准确率 ≈ 97.4% ✅ CNN 预测效果更佳,DNN 易混淆相似数字