美文网首页
如何训练AI玩飞机大战游戏(创号版)

如何训练AI玩飞机大战游戏(创号版)

作者: f5065e26181b | 来源:发表于2019-02-18 20:00 被阅读17次

    虽然没有谷歌强大的集群和DeepMind变态的算法的团队,但基于深度强化学习(Deep Q Network DQN )的自制小游戏AI效果同样很赞。先上效果图:


    500000次训练效果.gif

    下面分四个部分,具体给大家介绍。

    /1/背景介绍
    2013年DeepMind团队发表论文“Playing Atari with Deep Reinforcement Learning”,用Q-Network模型成功让AI玩起了Atari系列游戏。并于2015年在《Nature》上发表了一篇升级版,“Human-level control through deep reinforcement learning”,自此,在这类游戏领域,人已经无法超过机器了。AI玩游戏的姿势是这样的:


    像素乒乓.gif

    后来的故事大家都很熟悉了,AlphaGo击败世界冠军,星际争霸2职业选手也被打败,连大家接触较多的王者荣耀也不能幸免。


    星际2.gif

    /2/深度强化学习模型
    看完了轻松的部分,下面简单介绍一下模型。DQN是DRL的一种算法,它将卷积神经网络(CNN)和Q-Learning结合起来。
    Q-learning是强化学习的一种,原理图如下:


    强化学习原理.png

    也就是Agent在观察得到当前的状态state和回报reward的基础上,选取输出一个动作action,进而影响环境,使环境状态和回报都产生变化。通过不断循环让Agent学习如何在环境中获得更高的回报。
    卷积神经网络CNN是图像处理领域非常经典的神经网络模型,在本模型中,输入是原始图像数据,输出为每个动作action对应的评估值。
    因此DQN总体结构是这样的:


    DQN原理图.png
    图比较简单,但原理很清晰,是将Agent中的模型用CNN来代替,环境的State为游戏界面截图,输出为AI的动作,在飞机大战中就是飞机向左、向右还是不动。回报reward具体为,在一次循环中没有被击中为0.1,被击中为-1,击中敌机为1。图中回放记忆单元、当前网络和目标网络都是为了将CNN这种需要大量样本的监督学习融合在强化学习模型中的手段。篇幅限制这里只是概述性的介绍,后期会专门讲。

    /3/模型实现
    3.1程序的总体结构
    程序主函数在PlaneDQN.py中,与DQN模型相关的函数在BrainDQN_Nature.py中,游戏模型在game文件夹中,训练过程保存的训练值在saved_networks文件夹中。


    程序结构.png

    3.2主函数搭建
    大家注意看while循环里的结构,其实非常明确:

    • getaction()为在当前的Q值下选取动作
    • framestep()为运行环境,并输出观测值
    • process()为对图像数据进行处理的函数
    • setPerception()根据图像和回报,对网络进行训练
    def playPlane():
       # Step 1: 初始化DQN
       actions = 3
       brain = BrainDQN(actions)
       # Step 2: 初始化游戏
       plane = game.GameState()
       # Step 3: 玩游戏
       # Step 3.1: 获取初始动作
       action0 = np.array([1,0,0])  # [1,0,0]do nothing,[0,1,0]left,[0,0,1]right
       observation0, reward0, terminal = plane.frame_step(action0)
       observation0 = cv2.cvtColor(cv2.resize(observation0, (80, 80)), cv2.COLOR_BGR2GRAY)
       ret, observation0 = cv2.threshold(observation0,1,255,cv2.THRESH_BINARY)
       brain.setInitState(observation0)
       # Step 3.2: 开始游戏
       while 1!= 0:
          action = brain.getAction()
          nextObservation,reward,terminal = plane.frame_step(action)
          nextObservation = preprocess(nextObservation)
          brain.setPerception(nextObservation,action,reward,terminal)
    

    3.3 游戏类GameState和framestep
    通过pygame实现游戏界面的搭建,分别建立子弹类、玩家类、敌机类和游戏类,结构代码所示。

    class Bullet(pygame.sprite.Sprite):
        def __init__(self, bullet_img, init_pos):
        def move(self):
    # 我方飞机类
    class Player(pygame.sprite.Sprite):
        def __init__(self, plane_img, player_rect, init_pos):
        def shoot(self, bullet_img):           def moveLeft(self):        
        def moveRight(self):
    # 敌方飞机类
    class Enemy(pygame.sprite.Sprite):
        def __init__(self, enemy_img, enemy_down_imgs, init_pos):
        def move(self):
    
    class GameState:
        def __init__(self):       
        def frame_step(self, input_actions):       
            if input_actions[0] == 1 or input_actions[1]== 1 or input_actions[2]== 1:  # 检查输入正常
                if input_actions[0] == 0 and input_actions[1] == 1 and input_actions[2] == 0:
                    self.player.moveLeft()
                elif input_actions[0] == 0 and input_actions[1] == 0 and input_actions[2] == 1:
                    self.player.moveRight()
                else:
                    pass
            else:
                raise ValueError('Multiple input actions!')
            image_data = pygame.surfarray.array3d(pygame.display.get_surface())
            pygame.display.update()
            clock = pygame.time.Clock()
            clock.tick(30)
            return image_data, reward, terminal
    

    其中GameState中的framestep()函数,是整个DQN运行一次使环境发生变化的基础函数,该函数运行一次,会根据inputaction进行动作实施,接着会在该时段对界面上的元素进行移动,并判断是否撞击。最后通过get_surface获取界面图像,最后返回环境的image_data,reward和游戏是否停止的terminal。本文游戏效果图为:


    游戏界面.png

    为提高模型收敛速度,在实际运行时将背景图片去掉。
    3.4 DQN模型类
    该部分为DQN模型的核心,主要有根据参数建立CNN网络的createQNetwork(),进行模型训练的trainQNetwork(),进行动作选择的getAction()。

    class BrainDQN:
       def __init__(self,actions):   
       def createQNetwork(self):
          return stateInput,QValue,W_conv1,b_conv1,W_conv2,b_conv2,W_conv3,b_conv3,W_fc1,b_fc1,W_fc2,b_fc2
       def copyTargetQNetwork(self):
          self.session.run(self.copyTargetQNetworkOperation)
       def createTrainingMethod(self):
       def trainQNetwork(self):   
       def getAction(self):
          return action
       def setInitState(self,observation):
          self.currentState = np.stack((observation, observation, observation, observation), axis = 2)
       def weight_variable(self,shape):
          return tf.Variable(initial)
       def bias_variable(self,shape):
          return tf.Variable(initial)
       def conv2d(self,x, W, stride):
          return tf.nn.conv2d(x, W, strides = [1, stride, stride, 1], padding = "SAME")
       def max_pool_2x2(self,x):
          return tf.nn.max_pool(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padd
    

    3.5图像处理
    图像预处理调用cv2库函数,对图像进行大小和灰度处理。

    def preprocess(observation):
       observation = cv2.cvtColor(cv2.resize(observation, (80, 80)), cv2.COLOR_BGR2GRAY)#灰度转化
       ret, observation = cv2.threshold(observation,1,255,cv2.THRESH_BINARY)
        return np.reshape(observation,(80,80,1))
    

    /4/环境搭建

    • 系统:Ubuntu16.04、win10
    • Python3.5
    • pygame 1.9.4
    • TensorFlow1.11(GPU版)
    • OpenCV-Python

    关注公众号“1024程序开发者社区”回复“AI飞机”,获取代码,包含训练500000次的结果。本程序对硬件要求不高,显存2GB以上就可运行。


    关注我.png

    1024程序开发者社区的交流群已经建立,许多小伙伴已经加入其中,感谢大家的支持。大家可以在群里就技术问题进行交流,还没有加入的小伙伴可以扫描下方“社区物业”二维码,让管理员帮忙拉进群,期待大家的加入。


    加入我.png

    相关文章

      网友评论

          本文标题:如何训练AI玩飞机大战游戏(创号版)

          本文链接:https://www.haomeiwen.com/subject/sqpzeqtx.html