import tushare
import codecs
#计算用
import numpy as np
import matplotlib.pyplot as plot
import matplotlib.dates as dates
import fileinput
import tkinter as tk
from tkinter import *
from tkinter import Button
from tkinter import Text
from tkinter import Entry
from tkinter import messagebox
import datetime #获取今日时间
#y=w1*x1+w2*x2+w3*x3+w4*x4+w5*x5 +b
#w1,w2,w3,w4,w5初始值为0.2 b初始值为0.01
w1=0.2
w2=0.2
w3=0.2
w4=0.2
w5=0.2
b =0.01
#learn_rate学习速度
learn_rate=0.00001
data=np.random.rand(10)
date=np.random.rand(10)
##GUI
def Test():
ok = 0
t=text.get()
if t=='':
messagebox.showinfo('提示','输入不能为空')
ok = 1
else:
ok = 0
for i in range(len(t)):
# 97 122 48 57
if (t[i]<'0' or t[i]>'9') and (t[i]<'a' and t[i]>'z'):
print(t[i])
ok = 1
print('ok',ok)
if ok == 1:
messagebox.showinfo('提示', '输入股票代码由数字或者字母')
if ok == 0 :
print(t)
#获取数据
success=get_data_online(t)
if success == 1:
get_data()
print(data)
#机器学习
pridicit=traing()
#显示数据
next_day = get_nextday_date()
res='DATE: '+ str(next_day)+' INDEX: '+t +' TOMORROW: '+str(pridicit)
messagebox.showinfo('预测结果:', res)
plot.title(res)
plot.plot_date(date, data, fmt='-', marker='o', c='g')
#条形图
pl=plot.bar(x=date,height=data,color='red')
plot.show()
#使用tushare获取大盘数据
def get_data_online(t):
#sh000001
#获得今日日期
s=get_today_date()
data = tushare.get_hist_data(t, start='2018-03-10',end=s)
file=open('data.csv','w')
#返回为空处理
try:
file.write(str(data['close']))
except :
print('Error: 输入股票代码不存在')
messagebox.showinfo('提示', '输入股票代码不存在')
return 0
file.close()
#去掉最后一行垃圾数据
count=len(open('data.csv','rU').readlines())
print('count',count)
f=fileinput.input('data.csv',inplace=True)
for index in f:
if f.filelineno()==count:
print('')
else:
print(index,end='')
#print(data['close'])
return 1
#输入窗口居中
def center_win(root, width, height):
swidth = root.winfo_screenwidth()
sheight = root.winfo_screenheight()
winsize = '%dx%d+%d+%d' % (width, height, (swidth - width)/2, (sheight - height)/2)
#print(size)
root.geometry(winsize)
return
def get_data():
global date
date=np.loadtxt('data.csv',delimiter=' ',
converters={0: dates.bytespdate2num("%Y-%m-%d")},usecols=(0),
skiprows=(1),unpack=True)
global data
data = np.loadtxt('data.csv',delimiter=' ',usecols=(1),unpack=True,dtype=str,skiprows=(1))
#print(date)
# plot.plot_date(date,data,fmt='-',marker='o',c='r')
# plot.show()
return
#获得今日日期
def get_today_date():
year = datetime.datetime.now().year
mounth=datetime.datetime.now().month
day = datetime.datetime.now().day
if mounth < 10:
mounth = str(0) + str(mounth)
if day < 10:
day = str(0) + str(day)
now_day = str(year) + '-' + str(mounth)+str(day)
return now_day
#获得下一日期日期
def get_nextday_date():
now_day = datetime.datetime.now()
year = datetime.datetime.now().year
mounth = datetime.datetime.now().month
day = datetime.datetime.now().day
#日期间隔一天
detaday = datetime.timedelta(days=1)
next_day = now_day + detaday
next_day = next_day.strftime('%Y-%m-%d')
return next_day
#界面显示
def show():
next_day = get_nextday_date()
res = '明日:' + str(next_day) + '股票:' + t + '预计明日股价:' + str(pridicit)
messagebox.showinfo('预测结果:', res)
plot.plot_date(date, data, fmt='-', marker='o', c='r')
# 条形图
pl = plot.bar(x=date, height=data, color='red')
plot.show()
return
#训练
def traing():
global w1,w2,w3,w4,w5,b
global data,date
global learn_rate
data=data
print(data[-0])
#x=float(data[1])+float(data[2])
#print(x)
#数组长度
l=len(data)-6
for i in range(l):
y=w1*float(data[-i-1])+w2*float(data[-i-2])\
+w3*float(data[-i-3])+w4*float(data[-i-4])+w5*float(data[-i-5])+b
print('y',y)
loss = (y-float(data[-i-6]))*(y-float(data[-i-6]))
print('loss:',loss)
#loss误差大于5,更新参数
if loss>0.001:
'''
loss = (x-wi*xi)**2
d(loss)/d(wi)=-2*(x-wi*x*)*xi
wi=wi-learn_rate(- d(loss)/d(wi))*xi=wi+learn_rate(- d(loss)/d(wi))*xi
'''
w1 = w1 + learn_rate * 2 * (float(data[-i - 6]) - y) * float(data[-i-1])
w2 = w2 + learn_rate * 2 * (float(data[-i - 6]) - y) * float(data[-i-2])
w3 = w3 + learn_rate * 2 * (float(data[-i - 6]) - y) * float(data[-i-3])
w4 = w4 + learn_rate * 2 * (float(data[-i - 6]) - y) * float(data[-i-4])
w5 = w5 + learn_rate * 2 * (float(data[-i - 6]) - y) * float(data[-i - 5])
b = b + learn_rate * 2 * (float(data[-i - 6]) - y)
print('w1',w1)
#调整learn_rate
#learn_rate=learn_rate/sqart(i+1)
learn_rate = learn_rate/((i+1)**0.5)
res=w1*float(data[4])+w2*float(data[3])\
+w3*float(data[2])+w4*float(data[1])+w5*float(data[0])+b
print('res',res)
return res
root=tk.Tk()
root.title('输入窗口')
center_win(root, 300, 150)
root.maxsize(600, 400)
root.minsize(250, 250)
la = tk.Label(root,text='请输入股票代码:预测股价20以内比较靠谱。例如:京东方 000725')
la.pack()
text=tk.Entry(root)
text.pack()
b1=Button(root,text='确定',command=Test)
b1.pack()
root.mainloop()
网友评论