美文网首页
x.fit()、x.transform()、x.fit_tran

x.fit()、x.transform()、x.fit_tran

作者: 马尔代夫Maldives | 来源:发表于2022-12-15 13:49 被阅读0次

    在sklearn中,有很多算法(比如PCA、StandarScaler等,本文中算法用‘x’代替)都有x.fit()x.transform()x.fit_transform()三个函数,这三者究竟有什么关系?

    Note:

    x.transform()必须在x.fit()或x.fit_transform()后使用,否则报错,说明前者是依赖于后二者的。
    (实际上后二者主要是从数据中提取需要的“参数”,而x.transform()则是利用这些参数对原数据或其他数据做相应的处理)。

    使用方法1:

    第1步:x.fit(trainData):(无返回值),主要就是根据trainData计算相应的参数,比如对于StandarScaler方法,fit会计算trainData的均值、方差等,并保存在x中,留做后用注意:x.fit(trainData)只是根据trainData计算需要的参数,并不会对trainData做任何处理);
    第2步:result1 = x.transform(trainData):(有返回值),根据第1步得到的参数,对源数据trainData进行处理,这一步才是根据参数对数据进行处理;
    第3步:result2 = x.transform(newData):(有返回值),根据第1步计算得到的参数,对新数据newData进行处理
    (第2、3步本质上一样,这么做就能保证newData与trainData是在相同的标准下进行的数据转换。)

    使用方法2:

    第1步:result1 = x.fit_transform(trainData):(有返回值),将前面的x.fit()和x.transform()合并成了一步,但还是会先根据trainData计算参数并保存在x中,然后再根据参数对trianData进行处理
    第2步:result2 = x.transform(newData):(有返回值),根据得到的参数对newData进行处理,得到result2;

    注意:
    result = x.fit_transform(trainData)之后再执行result = x.fit_transform(newData),虽然能得到结果,但是这么做trainData和newData是在各自的参数(标准)下进行的处理,两者没有任何联系,如果trainData和newData需要在相同的标准下进行处理,则这么写是错误的,应该避免。

    示例:

    公共部分:

    # 导入包
    from sklearn.preprocessing import StandardScaler
    # 定义数据
    trainData = np.array([[-7,2,3],[4,5,6],[0,1,-8]])
    newData = np.array([[-3,0,1],[-10,2,-7],[5,-9,8]])
    

    使用方法1:

    scaler = StandardScaler()  # 生成一个标准化对象scaler
    scaler.fit(trainData)  # 利用scaler的fit()函数计算trainData的均值、方差(相当于拿trainData做标准)
    y1 = scaler.transform(trainData)  # 将标准应用于trainData
    y2 = scaler.transform(newData)  # 将标准应用于newData
    print(y1)
    print(y2)
    [[-1.31982404 -0.39223227  0.44307902]
     [ 1.09985336  1.37281295  0.94154292]
     [ 0.21997067 -0.98058068 -1.38462194]]
    [[-0.43994135 -1.56892908  0.11076976]
     [-1.97973605 -0.39223227 -1.21846731]
     [ 1.31982404 -6.86406473  1.27385218]]
    
    # 如果newData基于自己进行标准化,则结果y22与前面的y2是有一些差异的
    scaler.fit(newData)
    y22 = scaler.transform(newData)
    print(y22)
    [[-0.05439283  0.48771311  0.05439283]
     [-1.19664225  0.90575292 -1.25103507]
     [ 1.25103507 -1.39346603  1.19664225]]
    

    使用方法2:

    # 直接用fit_transform()
    scaler = StandardScaler()
    y1 = scaler.fit_transform(trainData)
    y2 = scaler.transform(newData)
    print(y1)
    print(y2)
    [[-1.31982404 -0.39223227  0.44307902]
     [ 1.09985336  1.37281295  0.94154292]
     [ 0.21997067 -0.98058068 -1.38462194]]
    [[-0.43994135 -1.56892908  0.11076976]
     [-1.97973605 -0.39223227 -1.21846731]
     [ 1.31982404 -6.86406473  1.27385218]]
    # 这里的y1和y2和方法1得到的y1和y2完全一样
    
    # 直接对newData应用fit_transform(),这就相当于利用newData自己的标准,跟前面没啥关系,结果y22与这里的y2有差异,但和前面的y22相等。
    y22 = scaler.fit_transform(newData)
    print(y22)
    [[-0.05439283  0.48771311  0.05439283]
     [-1.19664225  0.90575292 -1.25103507]
     [ 1.25103507 -1.39346603  1.19664225]]
    

    相关文章

      网友评论

          本文标题:x.fit()、x.transform()、x.fit_tran

          本文链接:https://www.haomeiwen.com/subject/pvlrqdtx.html