在sklearn中,有很多算法(比如PCA、StandarScaler等,本文中算法用‘x’代替)都有x.fit()、x.transform()、x.fit_transform()三个函数,这三者究竟有什么关系?
Note:
x.transform()必须在x.fit()或x.fit_transform()后使用,否则报错,说明前者是依赖于后二者的。
(实际上后二者主要是从数据中提取需要的“参数”,而x.transform()则是利用这些参数对原数据或其他数据做相应的处理)。
使用方法1:
第1步:x.fit(trainData):(无返回值),主要就是根据trainData计算相应的参数,比如对于StandarScaler方法,fit会计算trainData的均值、方差等,并保存在x中,留做后用(注意:x.fit(trainData)只是根据trainData计算需要的参数,并不会对trainData做任何处理);
第2步:result1 = x.transform(trainData):(有返回值),根据第1步得到的参数,对源数据trainData进行处理,这一步才是根据参数对数据进行处理;
第3步:result2 = x.transform(newData):(有返回值),根据第1步计算得到的参数,对新数据newData进行处理。
(第2、3步本质上一样,这么做就能保证newData与trainData是在相同的标准下进行的数据转换。)
使用方法2:
第1步:result1 = x.fit_transform(trainData):(有返回值),将前面的x.fit()和x.transform()合并成了一步,但还是会先根据trainData计算参数并保存在x中,然后再根据参数对trianData进行处理;
第2步:result2 = x.transform(newData):(有返回值),根据得到的参数对newData进行处理,得到result2;
注意:
result = x.fit_transform(trainData)之后再执行result = x.fit_transform(newData),虽然能得到结果,但是这么做trainData和newData是在各自的参数(标准)下进行的处理,两者没有任何联系,如果trainData和newData需要在相同的标准下进行处理,则这么写是错误的,应该避免。
示例:
公共部分:
# 导入包
from sklearn.preprocessing import StandardScaler
# 定义数据
trainData = np.array([[-7,2,3],[4,5,6],[0,1,-8]])
newData = np.array([[-3,0,1],[-10,2,-7],[5,-9,8]])
使用方法1:
scaler = StandardScaler() # 生成一个标准化对象scaler
scaler.fit(trainData) # 利用scaler的fit()函数计算trainData的均值、方差(相当于拿trainData做标准)
y1 = scaler.transform(trainData) # 将标准应用于trainData
y2 = scaler.transform(newData) # 将标准应用于newData
print(y1)
print(y2)
[[-1.31982404 -0.39223227 0.44307902]
[ 1.09985336 1.37281295 0.94154292]
[ 0.21997067 -0.98058068 -1.38462194]]
[[-0.43994135 -1.56892908 0.11076976]
[-1.97973605 -0.39223227 -1.21846731]
[ 1.31982404 -6.86406473 1.27385218]]
# 如果newData基于自己进行标准化,则结果y22与前面的y2是有一些差异的
scaler.fit(newData)
y22 = scaler.transform(newData)
print(y22)
[[-0.05439283 0.48771311 0.05439283]
[-1.19664225 0.90575292 -1.25103507]
[ 1.25103507 -1.39346603 1.19664225]]
使用方法2:
# 直接用fit_transform()
scaler = StandardScaler()
y1 = scaler.fit_transform(trainData)
y2 = scaler.transform(newData)
print(y1)
print(y2)
[[-1.31982404 -0.39223227 0.44307902]
[ 1.09985336 1.37281295 0.94154292]
[ 0.21997067 -0.98058068 -1.38462194]]
[[-0.43994135 -1.56892908 0.11076976]
[-1.97973605 -0.39223227 -1.21846731]
[ 1.31982404 -6.86406473 1.27385218]]
# 这里的y1和y2和方法1得到的y1和y2完全一样
# 直接对newData应用fit_transform(),这就相当于利用newData自己的标准,跟前面没啥关系,结果y22与这里的y2有差异,但和前面的y22相等。
y22 = scaler.fit_transform(newData)
print(y22)
[[-0.05439283 0.48771311 0.05439283]
[-1.19664225 0.90575292 -1.25103507]
[ 1.25103507 -1.39346603 1.19664225]]
网友评论