美文网首页
10 ML locally weighted linear re

10 ML locally weighted linear re

作者: peimin | 来源:发表于2016-06-07 17:25 被阅读0次
    from numpy import *
    import matplotlib.pyplot as plt
    
    def loadDataSet(fileName):      #general function to parse tab -delimited floats
        numFeat = len(open(fileName).readline().split('\t')) - 1 #get number of fields 
        dataMat = []
        labelMat = []
    
        fr = open(fileName)
        for line in fr.readlines():
            lineArr =[]
            curLine = line.strip().split('\t')
            for i in range(numFeat):
                lineArr.append(float(curLine[i]))
            dataMat.append(lineArr)
            labelMat.append(float(curLine[-1])) # last is label
        return dataMat,labelMat
    
    def lwlr(testPoint, xArr, yArr,k=1.0):
        xMat = mat(xArr)
        yMat = mat(yArr).T
    
        m = shape(xMat)[0]
        weights = mat(eye((m))) 
    
        for j in range(m):                      #next 2 lines create weights matrix
            diffMat = testPoint - xMat[j, :]     #
            weights[j,j] = exp(diffMat*diffMat.T/(-2.0*k**2))
    
        xTx = xMat.T * (weights * xMat)
        if linalg.det(xTx) == 0.0:
            print "This matrix is singular, cannot do inverse"
            return
        ws = xTx.I * (xMat.T * (weights * yMat))
        return testPoint * ws
    
    def lwlrTest(testArr,xArr,yArr,k=1.0):  #loops over all the data points and applies lwlr to each one
        m = shape(testArr)[0] # row
        yHat = zeros(m)
        for i in range(m):
            yHat[i] = lwlr(testArr[i], xArr, yArr,k)
        return yHat
    
    
    xArr, yArr = loadDataSet('ex0.txt')
    #k = 0.003
    #k = 0.01
    k = 1.0
    yHat = lwlrTest(xArr, xArr, yArr, k)
    
    xMat = mat(xArr)
    yMat = mat(yArr)
    
    srtInd = xMat[:, 1].argsort(0) # first need sort
    xSort  = xMat[srtInd][:, 0, :]
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    ax.plot(xSort[:, 1], yHat[srtInd]) # draw y line
    ax.scatter(xMat[:, 1].flatten().A[0], yMat.T.flatten().A[0], s=2, c='red') # scatter plot
    
    plt.show()
    

    1.0

    1.0.png

    0.01

    0.01.png

    0.003

    3.png

    相关文章

      网友评论

          本文标题:10 ML locally weighted linear re

          本文链接:https://www.haomeiwen.com/subject/amnddttx.html