網(wǎng)站虛擬主持百度域名收錄
?
本人主頁:機器學(xué)習(xí)司貓白
ok,話不多說,我們進入正題吧
項目概述
本案例使用經(jīng)典的MNIST手寫數(shù)字數(shù)據(jù)集,通過Keras構(gòu)建全連接神經(jīng)網(wǎng)絡(luò),實現(xiàn)0-9數(shù)字的分類識別。文章將包含:
- 關(guān)鍵概念圖解
- 完整實現(xiàn)代碼
- 訓(xùn)練過程可視化
- 模型效果深度分析
環(huán)境準備
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.metrics import confusion_matrix
import seaborn as sns
三、數(shù)據(jù)加載與探索
3.1 加載數(shù)據(jù)集
# 加載內(nèi)置MNIST數(shù)據(jù)集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()print("訓(xùn)練集形狀:", x_train.shape)
print("測試集形狀:", x_test.shape)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 5s 0us/step
訓(xùn)練集形狀: (60000, 28, 28)
測試集形狀: (10000, 28, 28)
可以看到圖像是28*28,每張圖像一共有784個像素點。
3.2 數(shù)據(jù)可視化
plt.figure(figsize=(10,5))
for i in range(15):plt.subplot(3,5,i+1)plt.imshow(x_train[i], cmap='gray')plt.title(f"Label: {y_train[i]}")plt.axis('off')
plt.tight_layout()
plt.savefig('mnist_samples.png', dpi=300)
plt.show()
四、數(shù)據(jù)預(yù)處理
4.1 數(shù)據(jù)歸一化
# 將像素值縮放到0-1范圍
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255# 將圖像展平為784維向量
x_train = x_train.reshape(-1,