KNN + GridSearchCV 手写数字数据集识别

14次阅读
没有评论

本项目使用 scikit-learn 自带的手写数字数据集(digits),通过 KNN(K-Nearest Neighbors)分类器完成识别任务,并使用网格搜索 + 交叉验证自动寻找最优参数。

1. 项目目标

  • 使用 KNN 对手写数字进行分类(类别 0-9)
  • 使用 GridSearchCV 自动调参,提升模型效果
  • 输出测试集准确率和分类报告,评估模型表现

2. 项目结构

  • knn_mnist_experiment.py:主实验脚本(数据加载、训练、调参、评估)
  • requirements.txt:项目依赖

3. 环境与依赖

requirements.txt 内容:

  • scikit-learn>=1.3
  • numpy>=1.24

推荐步骤:

conda activate conda_environment
pip install -r requirements.txt

4. 运行方式

在项目根目录执行:

python knn_mnist_experiment.py

你将看到以下几类输出:

  • 最优参数(best_params_
  • 交叉验证最佳准确率(best_score_
  • 测试集准确率
  • 每个类别的 precision / recall / f1-score(分类报告)

5. 代码逐段讲解

5.1 导入库

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report

作用说明:

  • datasets:加载内置数据集(这里用 load_digits
  • train_test_split:把数据拆成训练集和测试集
  • GridSearchCV:在参数网格上做交叉验证,自动选最优参数
  • KNeighborsClassifier:KNN 分类器
  • accuracy_scoreclassification_report:模型评估指标

5.2 加载数据集

digits = datasets.load_digits()
X = digits.data
y = digits.target

这里用的是 digits 数据集(常被称作“简化版 MNIST 风格”数据):

  • 样本数:1797
  • 每张图像是 8×8 灰度图
  • 展平后每个样本是 64 维特征(即 X
  • 标签 y 取值范围是 0-9

5.3 划分训练集与测试集

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42
)

参数解释:

  • test_size=0.2:20% 做测试,80% 做训练
  • random_state=42:固定随机种子,保证每次划分一致,可复现

5.4 定义 KNN 模型

knn = KNeighborsClassifier()

先实例化一个基础 KNN,不手动指定核心参数,让后续网格搜索去选择。

5.5 定义参数搜索空间

param_grid = {
    'n_neighbors': list(range(1, 11)),
    'p': [1, 2, 3]
}

含义:

  • n_neighbors:近邻个数 $k$,搜索范围 1-10
  • p:Minkowski 距离参数

Minkowski 距离公式:

d(\mathbf{x}, \mathbf{z}) = \left(\sum_{i=1}^{n} |x_i - z_i|^p\right)^{1/p}

常见情况:

  • p=1:曼哈顿距离
  • p=2:欧氏距离
  • p=3:更高阶的 Minkowski 距离

5.6 配置 GridSearchCV

grid = GridSearchCV(
    knn,
    param_grid,
    cv=3,
    scoring='accuracy'
)

解释:

  • cv=3:3 折交叉验证
  • scoring='accuracy':以准确率作为调参目标
  • 一共会评估 10 x 3 = 30 组参数
  • 每组参数做 3 折验证,总计拟合 90 次

5.7 在训练集上搜索最优参数

grid.fit(X_train, y_train)
print("最优参数:", grid.best_params_)
print("交叉验证最佳准确率:", grid.best_score_)

GridSearchCV 会在训练集内部完成交叉验证,最终返回:

  • best_params_:表现最好的参数组合
  • best_score_:该组合在交叉验证中的平均准确率

5.8 使用最优模型在测试集评估

best_model = grid.best_estimator_
y_pred = best_model.predict(X_test)

best_estimator_ 是已经按最优参数重新拟合后的模型,可直接用于测试集预测。

5.9 输出评估结果

print("测试集准确率:", accuracy_score(y_test, y_pred))
print("分类报告:")
print(classification_report(y_test, y_pred))

指标含义:

  • accuracy:整体预测正确比例
  • precision:预测为某类时有多少是真的
  • recall:某类真实样本有多少被找回
  • f1-score:precision 与 recall 的综合指标
  • support:该类在测试集中的样本数量

6. 实验结果如何解读

  • 若测试准确率接近交叉验证准确率,说明泛化较稳定
  • 若测试准确率明显低于交叉验证结果,可能存在划分偶然性或轻微过拟合
  • 分类报告中某些类别分数低,通常说明这些数字在特征空间中更容易混淆

7. 可以继续优化的方向

  • 增大参数搜索范围(例如 n_neighbors 到 20 或 30)
  • 增加交叉验证折数(如 cv=5)提升评估稳定性
  • 在训练前加入特征缩放(StandardScaler)并比较效果
  • 尝试加权 KNN(weights='distance'
  • 对比其他模型(SVM、RandomForest、LogisticRegression)

8. 小结

这份脚本展示了一个标准的机器学习实验流程:

  1. 准备数据
  2. 划分数据
  3. 定义模型与参数空间
  4. 交叉验证调参
  5. 测试集评估

对于入门实验来说,这个结构清晰、可复现,也方便后续扩展成更完整的实验框架。

正文完
 0
评论(没有评论)