美文网首页
什么是强化学习:三、冰面求生(二)

什么是强化学习:三、冰面求生(二)

作者: 圣_狒司机 | 来源:发表于2023-05-03 23:17 被阅读0次

    一、一点小小改进

    def chose_action(self,ai=False):
            pre_s_space = self.s + self.action_space
            pre_s_space = [pre_s for pre_s in pre_s_space if not np.all((self.s_chain ==pre_s),axis=1).any()]
            if pre_s_space:self.action_space = pre_s_space - self.s
            a = random.choice(self.action_space)
            return a
    

    返回

    playing 0 times!
    526 times success!
    

    状态不重复的话提速一倍,把环境弄复杂一点也是一样:

    playing 20000 times!
    playing 40000 times!
    playing 60000 times!
    70613 times success!
    

    这种改进其实没什么大用,因为它是针对这种游戏专门的设计的,你的小人工智能还没有掌握普遍的学习智慧。

    二、Q_learning

    现在我们加入一点称得上人工智能的东西 Q_learning !
    公式:


    Q_learning

    写成自加模式:
    Q(s,a) += \alpha\cdot (rewards + \gamma \cdot max(Q(s^,,a))-Q(s,a) )

    只用下面这个公式!
    原理见 【强化学习】Q-Learning算法详解
    把环境的 step函数改成:

    def step(self,a=[0,0],ai=False):
            s = self.agent.s
            s_ = np.array(s) + np.array(a)
            if 0<=s_[0]<=3 and 0<=s_[1]<=11:
                self.agent.s = s_
                r = self.env[s_[0],s_[1]]
            else:
                s_ = s
                r = -1
            self.agent.post_step(s,a,r,s_)
            return s_,r
    

    人工智能体的改动:

    class Agent:
        def __init__(self):
            self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
            self.s = np.array([0,0])
            self.s_chain = np.expand_dims(self.s,0)
            self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)
            self.Q = np.zeros((4,12,4))
            self.epsilon = 0.2
    
        def chose_action(self):
            if self.epsilon < random.random():
                a = random.choice(self.action_space)
            else:
                a = self.action_space[np.argmax(self.Q[self.s[0],self.s[1]])]
            return a
    
        def post_step(self,s,a,r,s_):
            self.s_chain = np.vstack([self.s_chain,s])
            self.sar_chain = np.vstack([self.sar_chain,np.hstack([s,a,r])])
            a_number = np.where((agent.action_space == a).all(axis=1))
            if r == -1:
                self.Q[s_[0],s_[1]] = -1
            update = 0.1*(0.9*self.Q[s_[0],s_[1],].max() - self.Q[s[0],s[1],a_number])
            self.Q[s[0],s[1],a_number] += update
    
        def reset(self):
            self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
            self.s = np.array([0,0])
            self.s_chain = np.expand_dims(self.s,0)
            self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)
    

    三、拉出来遛遛

    把环境弄复杂一点点,经过30万轮碰壁,他学会了找到出口:

    327737 times success!
    
    1.png

    他学到的东西叫做Q表,算是他的知识库,我们打开看看:

    plt.imshow((agent.Q[:,:,:-1]*255).astype(np.uint8))
    
    Q表

    可以看到越接近出口的地方越明亮,极暗的方块是他觉得不安全的位置。

    四、文末放上全部代码:

    import numpy as np
    from matplotlib import pyplot as plt
    import random
    from itertools import count
    
    class Env:
        def __init__(self):
            self.action_space = []
            self.agent = None
            self.env = np.zeros((4,12))
            self.env[-1,1:-1] = -1
            self.env[:2,3] = -1
            self.env[1:-1,8] = -1
            self.env[-1,-1] = 1
            self.env_show = self.env.copy()
    
        def step(self,a=[0,0],ai=False):
            s = self.agent.s
            s_ = np.array(s) + np.array(a)
            if 0<=s_[0]<=3 and 0<=s_[1]<=11:
                self.agent.s = s_
                r = self.env[s_[0],s_[1]]
            else:
                s_ = s
                r = -1
            self.agent.post_step(s,a,r,s_)
            return s_,r
    
        def play(self):
            env.reset()
            for t in count(1):
                a = agent.chose_action()
                if a is not None:
                    s,r = env.step(a)
                    if r in [-1,1]:
                        break
                else:
                    r = None
                    break
            return t,r
        
        def play_until_success(self):
            for t in count(1):
                _,r = self.play()
                if r:
                    if t%20000 == 0:
                        print(f"playing {t} times!")
                    if r == 1:
                        print(f"{t} times success!")
                        self.render()
                        break
                else:break
    
        def render(self):
            for i,j in self.agent.s_chain:
                self.env_show[i,j] = 0.5
            plt.imshow(self.env_show)
            plt.show()
    
        def reset(self):
            self.agent.reset()
            self.env_show = self.env.copy()
    
        def register(self,agent):
            self.agent = agent
    
    class Agent:
        def __init__(self):
            self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
            self.s = np.array([0,0])
            self.s_chain = np.expand_dims(self.s,0)
            self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)
            self.Q = np.zeros((4,12,4))
            self.epsilon = 0.2
    
        def chose_action(self):
            if self.epsilon < random.random():
                a = random.choice(self.action_space)
            else:
                a = self.action_space[np.argmax(self.Q[self.s[0],self.s[1]])]
            return a
    
        def post_step(self,s,a,r,s_):
            self.s_chain = np.vstack([self.s_chain,s])
            self.sar_chain = np.vstack([self.sar_chain,np.hstack([s,a,r])])
            a_number = np.where((agent.action_space == a).all(axis=1))
            if r == -1:
                self.Q[s_[0],s_[1]] = -1
            update = 0.1*(0.9*self.Q[s_[0],s_[1],].max() - self.Q[s[0],s[1],a_number])
            self.Q[s[0],s[1],a_number] += update
    
        def reset(self):
            self.action_space = np.array([[0,1],[-1,0],[0,-1],[1,0]])
            self.s = np.array([0,0])
            self.s_chain = np.expand_dims(self.s,0)
            self.sar_chain = np.expand_dims(np.hstack([np.array([0,0]),[0,0],0]),0)
    
    env = Env()
    agent = Agent()
    env.register(agent)
    # env.render()
    env.play_until_success()
    

    相关文章

      网友评论

          本文标题:什么是强化学习:三、冰面求生(二)

          本文链接:https://www.haomeiwen.com/subject/umpojdtx.html