import  numpy as np
import operator
import matplotlib
import matplotlib.pyplot as plt
import os


def createDataSet():
    group = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels

def classify0(inX, dataSet, labels, k):
    # kNN算法简单流程

    dataSetSize = dataSet.shape[0]
    diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def file2matrix(filename):
    # 将txt文件转化为需要的数据格式

    fr = open(filename)
    arrayOLines = fr.readlines()
    numberOfLines = len(arrayOLines)
    returnMat = np.zeros((numberOfLines, 3))
    classLabelVectors = []
    index = 0
    for line in arrayOLines:
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index, :] = listFromLine[0:3]
        classLabelVectors.append(int(listFromLine[-1]))
        index += 1
    return returnMat, classLabelVectors


"""
data, labels = file2matrix("datingTestSet.txt")
# print(data)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(data[:, 0], data[:, 1], 15*np.array(labels), 15*np.array(labels))
plt.show()
"""


def autoNorm(dataSet):
    """
    归一化特征值
    :param dataSet: 训练集数据
    :return:
    """
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normDataSet = np.zeros(np.shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - np.tile(minVals, (m, 1))
    normDataSet = normDataSet / np.tile(ranges, (m, 1))
    return normDataSet, ranges, minVals


def datingClassTest():
    """
    测试分类效果,即the error rate
    """
    hoRatio = 0.1
    datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
    normMat, ranges, minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :],
                                     datingLabels[numTestVecs:m], 3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult,
                                                                             datingLabels[i]))
        if classifierResult != datingLabels[i]:
            errorCount += 1.0
    print("the total error rate is: %f" % (errorCount/float(numTestVecs)))


def classifyPerson():
    """
    构建完整可用系统,即约会网站预测函数
    """
    resultList = ['not at all', 'in small doses', 'in large doses']
    percentTats = float(input("percentage of time spent playing video games?  "))
    ffMiles = float(input("frequent flier miles earned per year?  "))
    iceCream = float(input("liters of ice cream consumed per year?  "))
    datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
    normMat, ranges, minVals = autoNorm(datingDataMat)
    inArr = np.array([ffMiles, percentTats, iceCream])
    classifierResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)
    print("You will probably like this person:", resultList[classifierResult - 1])


def img2vector(filename):
    """
    手写识别将图像转换为测试向量
    :param filename:
    :return:
    """
    returnVect = np.zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32*i+j] = int(lineStr[j])
    return returnVect

testVector = img2vector('digits/testDigits/0_13.txt')


def handwritingClassTest():
    """
    手写数字识别系统的测试代码
    """
    hwLabels = []
    trainingFileList = os.listdir('digits/trainingDigits')
    m = len(trainingFileList)
    trainingMat = np.zeros((m, 1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr)

    testFileList = os.listdir('digits/testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileNameStr.split('_')[0])
        vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult,
                                                                             classNumStr))
        if classifierResult != classNumStr:
            errorCount += 1.0
    print("\nthe total number of errors is: %d" % errorCount)
    print("\nthe total error rate is: %f" % (errorCount/float(mTest)))


handwritingClassTest()

优质内容筛选与推荐>>
1、Java关于接口的那点事儿
2、2018黑帽SEO优化排名技术方法大总结分类目录文章标签友情链接联系我们
3、羊皮书APP(Android版)开发系列(七)Android沉浸通知栏
4、数据脱敏——什么是数据脱敏
5、论Java中的内存分配


长按二维码向我转账

受苹果公司新规定影响,微信 iOS 版的赞赏功能被关闭,可通过二维码转账支持公众号。

    阅读
    好看
    已推荐到看一看
    你的朋友可以在“发现”-“看一看”看到你认为好看的文章。
    已取消,“好看”想法已同步删除
    已推荐到看一看 和朋友分享想法
    最多200字,当前共 发送

    已发送

    朋友将在看一看看到

    确定
    分享你的想法...
    取消

    分享想法到看一看

    确定
    最多200字,当前共

    发送中

    网络异常,请稍后重试

    微信扫一扫
    关注该公众号