先创建一个数据样本集合
import numpy as np
import pandas as pd
from pandas import DataFrame,Series
import torch
e = np.random.randn(100)
#print(data)
x = np.random.uniform(1,10,size=(100,1)) #从均匀分布中100个取值
数据清洗连接操作
X = DataFrame(data=x)
print(X)
E = DataFrame(data=e)
print(E)
Y = 2*X+E+3
#Y = Y.rename(columns={0:'1'})
Y.columns = ['1']
print(Y)
Dt = pd.concat([X,Y.reindex(X.index)],axis=1)#将样本数据
print(Dt)
Dt.to_csv('LinR.csv',index = False,header = False) #去掉头索引和列索引写入LinR.csv文件
0
0 7.959296
1 2.914118
2 6.531930
3 6.033350
4 3.329694
.. ...
95 6.689540
96 9.719845
97 1.911316
98 6.540799
99 2.833886
[100 rows x 1 columns]
0
0 -0.988615
1 -0.298217
2 1.846047
3 1.982957
4 0.644001
.. ...
95 -1.317616
96 0.844359
97 0.475709
98 -0.182511
99 -0.075437
[100 rows x 1 columns]
1
0 17.929977
1 8.530018
2 17.909907
3 17.049656
4 10.303389
.. ...
95 15.061463
96 23.284048
97 7.298340
98 15.899087
99 8.592336
[100 rows x 1 columns]
0 1
0 7.959296 17.929977
1 2.914118 8.530018
2 6.531930 17.909907
3 6.033350 17.049656
4 3.329694 10.303389
.. ... ...
95 6.689540 15.061463
96 9.719845 23.284048
97 1.911316 7.298340
98 6.540799 15.899087
99 2.833886 8.592336
[100 rows x 2 columns]
import matplotlib.pyplot as plt
sample = np.genfromtxt('LinR.csv',delimiter=',')
print(sample)
x = sample[:,0] #提取出自变量
y = sample[:,1] #提取出函数值
[[ 7.95929583 17.92997689]
[ 2.91411788 8.5300183 ]
[ 6.53192986 17.90990674]
[ 6.0333498 17.04965623]
[ 3.32969394 10.30338909]
[ 9.78091003 24.11104012]
[ 5.50414492 15.01162983]
[ 2.3905083 7.98267035]
[ 9.10073718 20.63593365]
[ 8.6771435 20.00840521]
[ 8.59297259 20.47870044]
[ 3.71421162 13.64842939]
[ 9.53717825 20.65189706]
[ 1.65102104 5.72905434]
[ 2.25873778 6.97617708]
[ 4.22231735 11.55259733]
[ 1.53631731 4.94287203]
[ 5.34000791 13.16707094]
[ 6.17545718 15.06864934]
[ 4.90415263 12.41321868]
[ 7.75327441 18.6794576 ]
[ 5.69725718 12.84726204]
[ 3.56909863 10.5966961 ]
[ 9.94681425 21.33072322]
[ 3.59799413 9.56658453]
[ 3.87835658 11.11208729]
[ 5.03228019 13.68579256]
[ 9.48569476 22.77879716]
[ 4.66228634 12.75729249]
[ 2.4908625 7.26331882]
[ 4.81153515 12.35458447]
[ 9.58406536 22.38109883]
[ 4.28121897 10.74117052]
[ 2.45652558 8.71965016]
[ 1.9736664 6.4809288 ]
[ 9.51051954 22.72786535]
[ 6.75959417 17.81727331]
[ 5.61578003 14.22442185]
[ 3.85178666 9.75542375]
[ 5.68632445 13.52993683]
[ 5.4966711 15.11321743]
[ 1.11686002 6.05160856]
[ 8.82731607 19.98729226]
[ 3.23499699 8.89761767]
[ 3.22340228 9.82438038]
[ 3.79773535 9.75677403]
[ 7.91327594 19.74140776]
[ 6.96384595 16.55981935]
[ 8.04811251 18.00044451]
[ 5.95047664 15.91780519]
[ 6.86179131 14.57196115]
[ 9.92403963 24.85022378]
[ 2.28427921 5.8765021 ]
[ 3.65702265 9.91875919]
[ 2.94110273 10.2254981 ]
[ 1.44233169 6.68000409]
[ 8.11509328 19.30801919]
[ 9.61113455 22.19692088]
[ 5.31715886 10.11119983]
[ 4.18841633 10.38081049]
[ 3.38696086 10.89271924]
[ 6.79149049 16.38415715]
[ 9.43502226 22.16533282]
[ 2.39967181 8.2302037 ]
[ 7.36239661 17.6441651 ]
[ 2.94883598 9.58722116]
[ 8.07898524 18.88577106]
[ 7.27030879 17.48551352]
[ 3.82638905 10.42645614]
[ 2.21195155 6.25717108]
[ 2.62373052 7.75337578]
[ 4.2888883 13.05445549]
[ 4.79372014 13.47932057]
[ 7.27504625 18.03400179]
[ 7.6804564 18.91495491]
[ 4.96708467 11.14166529]
[ 1.3403768 3.22305645]
[ 3.31317609 10.54117355]
[ 8.83562974 21.6072844 ]
[ 5.2485316 15.68878013]
[ 3.09095248 8.00604601]
[ 1.40126741 5.61514191]
[ 5.97243366 16.6769609 ]
[ 9.27273317 21.88579855]
[ 1.10719167 7.34636051]
[ 7.64405604 16.02157126]
[ 1.46871259 7.17774005]
[ 8.04512646 19.50343429]
[ 6.44703966 14.47562085]
[ 2.53241481 6.62603648]
[ 7.88971862 19.72029653]
[ 5.25505582 13.99824232]
[ 6.42160172 17.67089747]
[ 4.71902614 10.58448458]
[ 3.69982573 11.40370471]
[ 6.68953965 15.06146343]
[ 9.7198446 23.28404806]
[ 1.91131556 7.29834036]
[ 6.54079876 15.89908666]
[ 2.83388631 8.59233584]]
# 做出y = x*2+e+3的散点图
plt.scatter(x, y, marker = 'o',color = 'red', s = 40 )
plt.show()
output_5_0.png
#损失函数
def loss_function(data,b,w):
Total_Error = 0
for i in range(len(data)):
x = data[i][0]
y = data[i][1]
Total_Error += ((w*x+b)-y)**2
return Total_Error/float(len(data))
# 梯度下降步骤
def gradient_step(b_current,w_current,data,learning_rate):
#初始化梯度
b_grad = 0
w_grad = 0
N = float(len(data))
for i in range(len(data)):
x = data[i][0]
y = data[i][1]
b_grad += (2/N)*(w_current*x+b_current-y)
w_grad += x*(2/N)*(w_current*x+b_current-y)
new_b = b_current - learning_rate*b_grad
#print(new_b)
#w_current -= learning_rate*w_grad
new_w = b_current - learning_rate*w_grad
#print(new_w)
return [new_b,new_w]
#梯度下降迭代
def gradient_descent_iter(initial_b,initial_w,iteration_num,data,learning_rate):
b_current = initial_b
w_current = initial_w
for i in range(iteration_num):
b_current,w_current = gradient_step(b_current,w_current,np.array(data),learning_rate)
return [b_current,w_current]
#运行函数
def run_fun():
initial_b = 0
initial_w = 0
iteration_num = 20000
learning_rate = 0.0001
data = np.genfromtxt('LinR.csv',delimiter=',')
print('Starting gradient decent at b = {0},w={1},loss_value = {2}'.format(initial_b,initial_w,loss_function(data,initial_b,initial_w)))
print('Running...')
[b,w] = gradient_descent_iter(initial_b,initial_w,iteration_num,data,learning_rate)
print('After {0} iterations b = {1},w = {2},loss_value = {3}'.format(iteration_num,b,w,loss_function(data,b,w)))
#if __name__ == '__main__'():
run_fun()
Starting gradient decent at b = 0,w=0,loss_value = 217.6852494453724
Running...
After 20000 iterations b = 2.1616063195346547,w = 2.1614204809921613,loss_value = 1.4195063550286242
可以看到迭代出的b值与w值与预先设定的2和3很接近损失函数值也很小了
import matplotlib.pyplot as plt
sample = np.genfromtxt('LinR.csv',delimiter=',')
print(sample)
x_1 = sample[:,0]
y_1 = sample[:,1]
[[ 7.95929583 17.92997689]
[ 2.91411788 8.5300183 ]
[ 6.53192986 17.90990674]
[ 6.0333498 17.04965623]
[ 3.32969394 10.30338909]
[ 9.78091003 24.11104012]
[ 5.50414492 15.01162983]
[ 2.3905083 7.98267035]
[ 9.10073718 20.63593365]
[ 8.6771435 20.00840521]
[ 8.59297259 20.47870044]
[ 3.71421162 13.64842939]
[ 9.53717825 20.65189706]
[ 1.65102104 5.72905434]
[ 2.25873778 6.97617708]
[ 4.22231735 11.55259733]
[ 1.53631731 4.94287203]
[ 5.34000791 13.16707094]
[ 6.17545718 15.06864934]
[ 4.90415263 12.41321868]
[ 7.75327441 18.6794576 ]
[ 5.69725718 12.84726204]
[ 3.56909863 10.5966961 ]
[ 9.94681425 21.33072322]
[ 3.59799413 9.56658453]
[ 3.87835658 11.11208729]
[ 5.03228019 13.68579256]
[ 9.48569476 22.77879716]
[ 4.66228634 12.75729249]
[ 2.4908625 7.26331882]
[ 4.81153515 12.35458447]
[ 9.58406536 22.38109883]
[ 4.28121897 10.74117052]
[ 2.45652558 8.71965016]
[ 1.9736664 6.4809288 ]
[ 9.51051954 22.72786535]
[ 6.75959417 17.81727331]
[ 5.61578003 14.22442185]
[ 3.85178666 9.75542375]
[ 5.68632445 13.52993683]
[ 5.4966711 15.11321743]
[ 1.11686002 6.05160856]
[ 8.82731607 19.98729226]
[ 3.23499699 8.89761767]
[ 3.22340228 9.82438038]
[ 3.79773535 9.75677403]
[ 7.91327594 19.74140776]
[ 6.96384595 16.55981935]
[ 8.04811251 18.00044451]
[ 5.95047664 15.91780519]
[ 6.86179131 14.57196115]
[ 9.92403963 24.85022378]
[ 2.28427921 5.8765021 ]
[ 3.65702265 9.91875919]
[ 2.94110273 10.2254981 ]
[ 1.44233169 6.68000409]
[ 8.11509328 19.30801919]
[ 9.61113455 22.19692088]
[ 5.31715886 10.11119983]
[ 4.18841633 10.38081049]
[ 3.38696086 10.89271924]
[ 6.79149049 16.38415715]
[ 9.43502226 22.16533282]
[ 2.39967181 8.2302037 ]
[ 7.36239661 17.6441651 ]
[ 2.94883598 9.58722116]
[ 8.07898524 18.88577106]
[ 7.27030879 17.48551352]
[ 3.82638905 10.42645614]
[ 2.21195155 6.25717108]
[ 2.62373052 7.75337578]
[ 4.2888883 13.05445549]
[ 4.79372014 13.47932057]
[ 7.27504625 18.03400179]
[ 7.6804564 18.91495491]
[ 4.96708467 11.14166529]
[ 1.3403768 3.22305645]
[ 3.31317609 10.54117355]
[ 8.83562974 21.6072844 ]
[ 5.2485316 15.68878013]
[ 3.09095248 8.00604601]
[ 1.40126741 5.61514191]
[ 5.97243366 16.6769609 ]
[ 9.27273317 21.88579855]
[ 1.10719167 7.34636051]
[ 7.64405604 16.02157126]
[ 1.46871259 7.17774005]
[ 8.04512646 19.50343429]
[ 6.44703966 14.47562085]
[ 2.53241481 6.62603648]
[ 7.88971862 19.72029653]
[ 5.25505582 13.99824232]
[ 6.42160172 17.67089747]
[ 4.71902614 10.58448458]
[ 3.69982573 11.40370471]
[ 6.68953965 15.06146343]
[ 9.7198446 23.28404806]
[ 1.91131556 7.29834036]
[ 6.54079876 15.89908666]
[ 2.83388631 8.59233584]]
#将拟合曲线与样本做比较
import matplotlib.pyplot as plt
import numpy as np
plt.scatter(x_1, y_1, marker = 'o',color = 'red', s = 40,label = 'sampelSC' )
#plt.show()
x_2 = np.arange(0,10)
w = 2.1614204809921613
b = u2.1616063195346547
y_2 = w*x+b
plt.plot(x_2,y_2,label='Fitting Curve')
plt.show()
output_11_0.png
可以看到曲线拟合的效果还是很不错的
网友评论