美文网首页
PARL QuickStart

PARL QuickStart

作者: vickeex | 来源:发表于2019-09-29 12:25 被阅读0次

    github addr: https://github.com/PaddlePaddle/PARL.git
    readthedoc addr: https://parl.readthedocs.io/en/latest/index.html

    pre intro: fluid program

    PARL基于fluid(PaddleFluid)实现,并在program概念上进行逻辑分装。
    to be in details

    PARL introduction

    Model

    构建网络(模型),并实现forward方法(前向传播)。一个简单的两层网络模型实现如下。

    class CartpoleModel(parl.Model):
        def __init__(self, act_dim):
            act_dim = act_dim
            hid1_size = act_dim * 10
    
            self.fc1 = layers.fc(size=hid1_size, act='tanh')
            self.fc2 = layers.fc(size=act_dim, act='softmax')
    
        def forward(self, obs):
            out = self.fc1(obs)
            out = self.fc2(out)
            return out
    
    Algorithm

    更新传给它的模型参数。Algorithm需要实现以下方法:

    • __init__ : 定义继承自parl.model的模型,或定义learning_rate, reward_decay, action_demension等超参数.
    • learn : 定义loss函数,并根据loss和数据更新模型参数.
    • predict : 根据当前环境状态预测一个动作.
    • sample : 基于predict方法,生成带噪声的动作,用于某些场景下的动作探索.

    parl的algorithms中已实现了多类不同的算法,如PolicyGradient,A2C,A3C,IMPALA等,如下调用即可。

        model = CartpoleModel(act_dim=ACT_DIM)
        alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
    
    Agent

    将Algorithm作为初始参数传给agent,用于和环境交互并生成训练数据。在Agent类中还需实现以下方法:

    • build_program:定义fluid的program,一般两个实例,分别用于预测和训练.
    • learn:预处理中间数据,并提供给training program.
    • predict:将当前环境状态提供给prediction program并返回执行动作.
    • sample:根据当前状态进行探索(prob决定动作被选取的概率).
    class CartpoleAgent(parl.Agent):
        def __init__(self, algorithm, obs_dim, act_dim):
            self.obs_dim = obs_dim
            self.act_dim = act_dim
            super(CartpoleAgent, self).__init__(algorithm)
    
        def build_program(self):
            self.pred_program = fluid.Program()
            self.learn_program = fluid.Program()
    
            with fluid.program_guard(self.pred_program):
                obs = layers.data(
                    name='obs', shape=[self.obs_dim], dtype='float32')
                self.act_prob = self.alg.predict(obs)
    
            with fluid.program_guard(self.learn_program):
                obs = layers.data(
                    name='obs', shape=[self.obs_dim], dtype='float32')
                act = layers.data(name='act', shape=[1], dtype='int64')
                reward = layers.data(name='reward', shape=[], dtype='float32')
                self.cost = self.alg.learn(obs, act, reward)
    
        def sample(self, obs):
            obs = np.expand_dims(obs, axis=0)
            act_prob = self.fluid_executor.run(
                self.pred_program,
                feed={'obs': obs.astype('float32')},
                fetch_list=[self.act_prob])[0]
            act_prob = np.squeeze(act_prob, axis=0)
            act = np.random.choice(range(self.act_dim), p=act_prob)
            return act
    
        def predict(self, obs):
            obs = np.expand_dims(obs, axis=0)
            act_prob = self.fluid_executor.run(
                self.pred_program,
                feed={'obs': obs.astype('float32')},
                fetch_list=[self.act_prob])[0]
            act_prob = np.squeeze(act_prob, axis=0)
            act = np.argmax(act_prob)
            return act
    
        def learn(self, obs, act, reward):
            act = np.expand_dims(act, axis=-1)
            feed = {
                'obs': obs.astype('float32'),
                'act': act.astype('int64'),
                'reward': reward.astype('float32')
            }
            cost = self.fluid_executor.run(
                self.learn_program, feed=feed, fetch_list=[self.cost])[0]
            return cost
    
    
    main loop
        env = gym.make("CartPole-v0")
        model = CartpoleModel(act_dim=ACT_DIM)
        alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
        agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM)
        
        for i in range(1000):
            obs_list, action_list, reward_list = run_episode(env, agent)
            batch_obs, batch_action, batch_reward = ......
            agent.learn(batch_obs, batch_action, batch_reward)
    
    

    相关文章

      网友评论

          本文标题:PARL QuickStart

          本文链接:https://www.haomeiwen.com/subject/lltluctx.html