资讯专栏INFORMATION COLUMN

机器学习(八)-基于KNN分类算法的手写识别系统

lily_wang / 816人阅读

摘要:项目介绍基于近邻分类器的手写识别系统这里构造的系统只能识别数字到。将图像格式化处理为一个向量。

1 项目介绍

基于k-近邻分类器(KNN)的手写识别系统, 这里构造的系统只能识别数字0到9。

数据集和项目源代码

难点: 图形信息如何处理?

图像转换为文本格式

2 准备数据

将图像转换为测试向量

训练集:

目录trainingDigits

大约2000个例子

每个数字大约有200个样本;

测试集

目录testDigits

大约900个测试数据。

将图像格式化处理为一个向量。我们将把一个32×32的二进制图像矩阵转换为1×1024的向量, 如下图所示,

import numpy as np
def img2vector(filename):
    """
    # 将图像数据转换为(1,1024)向量
    :param filename: 
    :return: (1,1024)向量
    """
    # 生成一个1*1024且值全为0的向量;
    returnVect = np.zeros((1, 1024))
    # 读取要转换的信息;
    file = open(filename)
    # 依次填充
    # 读取每一行数据;
    for i in range(32):
        lineStr = file.readline()
        # 读取每一列数据;
        for j in range(32):
            returnVect[0, 32 * i + j] = int(lineStr[j])
    return returnVect
3 实施 KNN 算法

对未知类别属性的数据集中的每个点依次执行以下操作, 与上一个案例代码相同:
(1) 计算已知类别数据集中的点与当前点之间的距离;
(2) 按照距离递增次序排序;
(3) 选取与当前点距离最小的k个点;
(4) 确定前k个点所在类别的出现频率;
(5) 返回前k个点出现频率最高的类别作为当前点的预测分类。

def classify(inX, dataSet, labels, k):
    """
    :param inX: 要预测的数据
    :param dataSet: 我们要传入的已知数据集
    :param labels:  我们要传入的标签
    :param k: KNN里的k, 也就是说我们要选几个近邻
    :return: 排序的结果
    """
    dataSetSize = dataSet.shape[0]  # (6,2) 6
    # tile会重复inX, 把他重复成(datasetsize, 1)型的矩阵
    # print(inX)
    # (x1 - y1), (x2- y2)
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
    # 平方
    sqDiffMat = diffMat ** 2
    # 相加, axis=1 行相加
    sqDistance = sqDiffMat.sum(axis=1)
    # 开根号
    distances = sqDistance ** 0.5
    # print(distances)
    # 排序 输出的是序列号index,并不是值
    sortedDistIndicies = distances.argsort()
    # print(sortedDistIndicies)

    classCount = {}
    for i in range(k):
        voteLabel = labels[sortedDistIndicies[i]]
        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
        # print(classCount)
    sortedClassCount = sorted(classCount.items(), key=lambda d: float(d[1]), reverse=True)
    return sortedClassCount[0]

4 测试算法

使用 k-近邻算法识别手写数字

测试集里面的信息;

def handWritingClassTest(k):
    """
    # 测试手写数字识别错误率的代码
    :param k:
    :return:
    """
    hwLabels = []
    import os
    # 读取所有的训练集文件;
    trainingFileList = os.listdir("data/knn-digits/trainingDigits")
    # 获取训练集个数;
    m = len(trainingFileList)
    # 生成m行1024列全为0的矩阵;
    trainingMat = np.zeros((m, 1024))
    # 填充训练集矩阵;
    for i in range(m):
        fileNameStr = trainingFileList[i]    # fileNameStr: 0_0.txt
        fileStr = fileNameStr.split(".")[0]  # fileStr: 0_0
        classNumStr = int(fileStr.split("_")[0])    # (数字分类的结果)classNumStr: 0
        # 填写真实的数字结果;
        hwLabels.append(classNumStr)
        # 图形的数据: (1,1024)向量
        trainingMat[i, :] = img2vector("data/knn-digits/trainingDigits/%s" % fileNameStr)

    # 填充测试集矩阵;
    testFileList = os.listdir("data/knn-digits/testDigits")
    # 默认错误率为0;
    errorCount = 0.0
    # 测试集的总数;
    mTest = len(testFileList)
    # 填充测试集矩阵;
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split(".")[0]
        classNumStr = int(fileStr.split("_")[0])
        vectorTest = img2vector("data/knn-digits/testDigits/%s" % fileNameStr)

        # 判断预测结果与真实结果是否一致?
        result = classify(vectorTest, trainingMat, hwLabels, k)

        if result != classNumStr:
            # 如果不一致,则统计出来, 计算错误率;
            errorCount += 1.0
            print("[预测失误]:分类结果是:%d, 真实结果是:%d" % (result, classNumStr))
    print("错误总数:%d" % errorCount)
    print("错误率:%f" % (errorCount / mTest))
    print("模型准确率:%f" %(1-errorCount / mTest))
    return errorCount


print(handWritingClassTest(2))

效果展示

5 KNN算法手写识别的缺点

算法的执行效率并不高。

每个测试向量做2000次距离计算,每个距离计算包括了1024个维度浮点运算,总计要执行900次;

需要为测试向量准备2MB的存储空间

有没有更好的方法?

k决策树就是k-近邻算法的优化版,可以节省大量的计算开销。

文章版权归作者所有,未经允许请勿转载,若此文章存在违规行为,您可以联系管理员删除。

转载请注明本文地址:https://www.ucloud.cn/yun/19981.html

相关文章

  • 机器学习()-基于KNN分类算法手写识别系统

    摘要:项目介绍基于近邻分类器的手写识别系统这里构造的系统只能识别数字到。将图像格式化处理为一个向量。 1 项目介绍 基于k-近邻分类器(KNN)的手写识别系统, 这里构造的系统只能识别数字0到9。 数据集和项目源代码 难点: 图形信息如何处理? 图像转换为文本格式 2 准备数据 将图像转换为测试向量 训练集: 目录trainingDigits 大约2000个例子 每个数字大约有200个...

    Warren 评论0 收藏0
  • Python数据挖掘与机器学习技术入门实战

    摘要:在本次课程中,着重讲解的是传统的机器学习技术及各种算法。回归对连续型数据进行预测趋势预测等除了分类之外,数据挖掘技术和机器学习技术还有一个非常经典的场景回归。 摘要: 什么是数据挖掘?什么是机器学习?又如何进行Python数据预处理?本文将带领大家一同了解数据挖掘和机器学习技术,通过淘宝商品案例进行数据预处理实战,通过鸢尾花案例介绍各种分类算法。 课程主讲简介:韦玮,企业家,资深IT领...

    ephererid 评论0 收藏0

发表评论

0条评论

lily_wang

|高级讲师

TA的文章

阅读更多
最新活动
阅读需要支付1元查看
<