# -*- coding: utf-8 -*-
"""
Created on %(date)s
@author: %(username)s
"""
"""
读取文件路径
方法:
1 利用cmd命令把所有目标文件路径写入文件, dir /b/s >filepath.txt,再删掉filepath.txt中包含的filepath.txt自己的路径
2 利用python的open和readlines完成
"""
target = open(r'C:\finance\filepath.txt')
filelist = target.readlines();
#print(filelist[0])
#print(type(filelist))
stockFilePath = filelist[455][:-1]
#这里文件路径都含有换行符,用[:-1]去掉换行符
stockFile = open(stockFilePath)
#print(stockFilePath)
stockData = stockFile.readline()
#第一行数据为列说明,略去不读
stockData = stockFile.readlines()
#print(stockData[0])
date_array = []
open_array = []
high_array = []
close_array = []
low_array = []
volume_array = []
amount_array = []
"""
下面定义一个取出数据的函数
"""
def dataInit(date_array=date_array, open_array=open_array, high_array=high_array, close_array=close_array,
low_array=low_array, volume_array=volume_array, amount_array=amount_array):
date_array = []
open_array = []
high_array = []
close_array = []
low_array = []
volume_array = []
amount_array = []
def dataGot(stockData=stockData, date_array=date_array, open_array=open_array, high_array=high_array, close_array=close_array,
low_array=low_array, volume_array=volume_array, amount_array=amount_array):
dataInit()
for data in stockData:
tmp = data[:-1].split(',')
date_array.append(tmp[0])
open_array.append(tmp[1])
high_array.append(tmp[2])
close_array.append(tmp[3])
low_array.append(tmp[4])
volume_array.append(tmp[5])
amount_array.append(tmp[6])
date_array.reverse()
open_array.reverse()
high_array.reverse()
close_array.reverse()
low_array.reverse()
volume_array.reverse()
amount_array.reverse()
dataGot()
#print(date_array[-1])
from collections import namedtuple,OrderedDict
from functools import reduce
class StockTradeDays(object):
def __init__(self, date_array=date_array, open_array=open_array, high_array=high_array, close_array=close_array,
low_array=low_array, volume_array=volume_array, amount_array=amount_array):
self.__date_array = date_array
self.__open_array = open_array
self.__high_array = high_array
self.__close_array = close_array
self.__low_array = low_array
self.__volume_array = volume_array
self.__amount_array = amount_array
self.__change_array = self.__init_change()
self.stock_dict = self._init_stock_dict()
def __init_change(self):
price_float_array =[float(price_str) for price_str in self.__close_array]
pp_array = [(p1,p2) for p1, p2 in zip(price_float_array[:-1], price_float_array[1:])]
change_array = list(map(lambda pp: reduce(lambda a, b:round((b - a) / a, 3), pp), pp_array))
change_array.insert(0,0)
return change_array
def _init_stock_dict(self):
stock_namedtuple = namedtuple('stock',('date','open','high','close','low','volume','amount','change'))
stock_dict = OrderedDict((date,stock_namedtuple(date,openprice,high,close,low,volume,amount,change))
for date,openprice,high,close,low,volume,amount,change in
zip(self.__date_array, self.__open_array, self.__high_array, self.__close_array, self.__low_array, self.__volume_array, self.__amount_array,
self.__change_array))
return stock_dict
def __str__(self):
return str(self.stock_dict)
__repr__ = __str__
def __iter__(self):
for key in self.stock_dict:
yield self.stock_dict[key]
def __getitem__(self, ind):
date_key = self.__date_array[ind]
return self.stock_dict[date_key]
def __len__(self):
return len(self.stock_dict)
stock1 = StockTradeDays()
#for ind, day in enumerate(stock1):
# if ind < 10:
# print(day)
# else:
# break
import six
from abc import ABCMeta, abstractmethod
class TradeStrategyBase(six.with_metaclass(ABCMeta, object)):
"""
交易策略抽象基类
"""
@abstractmethod
def buy_strategy(self, *args, **kwargs):
pass
@abstractmethod
def sell_strategy(self, *args, **kwargs):
pass
class TradeStrategy1(TradeStrategyBase):
s_keep_stock_threshold = 20
def __init__(self):
self.keep_stock_day = 0
self.__buy_change_threshold = 0.07
def buy_strategy(self, trade_ind, trade_day, trade_days):
if self.keep_stock_day == 0 and \
trade_day.change > self.__buy_change_threshold:
self.keep_stock_day +=1
elif self.keep_stock_day >0:
self.keep_stock_day += 1
def sell_strategy(self, trade_ind, trade_day, trade_days):
if self.keep_stock_day >=\
TradeStrategy1.s_keep_stock_threshold:
self.keep_stock_day = 0
@property
def buy_change_threshold(self):
return self.__buy_change_threshold
@buy_change_threshold.setter
def buy_change_threshold(self, buy_change_threshold):
if not isinstance(buy_change_threshold, float):
raise TypeError('buy_change_threshold must be float')
self.__buy_change_threshold = round(buy_change_threshold, 2)
class TradeLoopBack(object):
def __init__(self, trade_days, trade_strategy):
self.trade_days = trade_days
self.trade_strategy = trade_strategy
self.profit_array = []
def execute_trade(self):
for ind, day in enumerate(self.trade_days):
if self.trade_strategy.keep_stock_day > 0:
self.profit_array.append(day.change)
if hasattr(self.trade_strategy, 'buy_strategy'):
self.trade_strategy.buy_strategy(ind, day, self.trade_days)
if hasattr(self.trade_strategy, 'sell_strategy'):
self.trade_strategy.sell_strategy(ind, day, self.trade_days)
trade_loop_back = TradeLoopBack(stock1, TradeStrategy1())
trade_loop_back.execute_trade()
print ('回测策略1 总盈亏为:{}%'.format(
reduce(lambda a, b:a+b, trade_loop_back.profit_array) * 100))
import numpy as np
import matplotlib.pyplot as plt
plt.plot(np.array(trade_loop_back.profit_array).cumsum())
网友评论