美文网首页程序员
利用TDD思想学写机器学习代码

利用TDD思想学写机器学习代码

作者: 心水 | 来源:发表于2018-06-28 21:03 被阅读142次

    学习一门新技术的最好方法就是找一个需求然后用新技术实现一把。最近想试试Python的机器学习库,就想到了下面这个需求:

    根据我从家里的出发时间和历史数据,预估我到达公司的时间。

    要一下子实现这个需求并不容易,TDD教育我们要从最简单的测试用例开始,小步渐进地实现最终的需求。

    最简单的测试是什么呢?我想应该是在没有任何历史数据积累的情况下,预估不出到达公司的时间,Python代码如下:

    class EstArrivalTimeTest(unittest.TestCase):
    
        def no_data_no_result(self):
            eta = EstArrivalTime()
            self.assertEqual(None, eta.of(
                {
                    "departedHour": 7,
                    "departedMin": 10 
                }
            ))
    

    其实影响到公司时间的因素很多,比如今天是礼拜几、是下雨天还是晴天、我是坐公交车还是共享单车去地铁站等等,TDD教育我们先完成再完美,所以输入参数我只选择了从家里出发的时间。运行测试,失败,写出下面的代码让测试通过:

    class EstArrivalTime:
        
        def of(self, input):
            return None
    

    下一个最简单的测试是什么呢?我想到了有一条历史数据的情况,比如昨天我7:10分从家里出发,8:30到公司,如果今天我也是7:10分出发,那么预计到达公司的时间也是8:30,就是需要80分钟的时间才能到公司,测试代码如下:

        def one_data_simple_result(self):
            eta = EstArrivalTime()
            eta.learnFrom([
                {
                 "departedHour": 7, 
                 "departedMin": 10, 
                 "estMinutes": 80
                }
            ])
            self.assertEqual(80, eta.of({
                "departedHour": 7,
                "departedMinute": 10}))
    

    返回值类型是估计到达公司所需要的分钟数,怎么才能让这个测试通过了,可以先用最简单的线性回归模型,代码如下:

    from sklearn.linear_model import LinearRegression
    import numpy as np
    
    class EstArrivalTime:
        X = []
        Y = []
        model = None
    
        def learnFrom(self, data):
            
            for record in data:
                self.X.append([
                    record['departedHour'], 
                    record['departedMinute']])
    
                self.Y.append(record['estMinutes'])
                
                self.model = LinearRegression()
                self.model.fit(self.X, self.Y)
        
        def of(self, input):
            if (self.model):
                y_fit = self.model.predict([[
                    input['departedHour'], 
                    input['departMinute']]])
                return y_fit[0]
            return None
    

    其中X存储的是我从家里出发的所有历史时间(特征数据),Y存储的是相对应的我到达公司所需的所有历史分钟数(标注数据)。运行一下测试,幸运通过!精通机器学习注定是一个长期学习和积累的过程,希望这是一个良好的开端。

    相关文章

      网友评论

        本文标题:利用TDD思想学写机器学习代码

        本文链接:https://www.haomeiwen.com/subject/sbdnyftx.html