强化学习之stable_baseline3详细说明和各项功能的使用

本文基于官方文档的基础上,把其中的重要部分整合和翻译,并整理成容易理解的顺序。其中蕴含有大量使用案例,方便大家理解和查看。

官方文档:https://stable-baselines3.readthedocs.io/en/master/

参考资料:https://zhuanlan.zhihu.com/p/406517851

前言

接触过强化学习的同学想必都已经用过OpenAI的Gym了,Gym给我们提供多种多样的强化学习环境,同时也可以让我们方便地创建自己的环境,是试验强化学习算法的绝佳场所。现在有了试验场地,那么我们当然想要一个趁手的实验工具来帮助我们快速实现各种强化学习算法啦。固然,我们可以在了解各种算法的基本原理后,自己尝试用各种深度学习框架来实现算法(例如pytorch,tensorflow等)。

例如DQN的pytorch实现

import gym
from gym import wrappers

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np

from IPython.display import clear_output
from matplotlib import pyplot as plt

import random
from timeit import default_timer as timer
from datetime import timedelta
import math
from utils.wrappers import make_atari, wrap_deepmind, wrap_pytorch

from utils.hyperparameters import Config
from agents.BaseAgent import BaseAgent


config = Config()

config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#epsilon variables
config.epsilon_start = 1.0
config.epsilon_final = 0.01
config.epsilon_decay = 30000
config.epsilon_by_frame = lambda frame_idx: config.epsilon_final + (config.epsilon_start - config.epsilon_final) * math.exp(-1. * frame_idx / config.epsilon_decay)

#misc agent variables
config.GAMMA=0.99
config.LR=1e-4

#memory
config.TARGET_NET_UPDATE_FREQ = 1000
config.EXP_REPLAY_SIZE = 100000
config.BATCH_SIZE = 32

#Learning control variables
config.LEARN_START = 10000
config.MAX_FRAMES=1000000

class ExperienceReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []

    def push(self, transition):
        self.memory.append(transition)
        if len(self.memory) > self.capacity:
            del self.memory[0]

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


class DQN(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(DQN, self).__init__()

        self.input_shape = input_shape
        self.num_actions = num_actions

        self.conv1 = nn.Conv2d(self.input_shape[0], 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        self.fc1 = nn.Linear(self.feature_size(), 512)
        self.fc2 = nn.Linear(512, self.num_actions)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x

    def feature_size(self):
        return self.conv3(self.conv2(self.conv1(torch.zeros(1, *self.input_shape)))).view(1, -1).size(1)


class Model(BaseAgent):
    def __init__(self, static_policy=False, env=None, config=None):
        super(Model, self).__init__()
        self.device = config.device

        self.gamma = config.GAMMA
        self.lr = config.LR
        self.target_net_update_freq = config.TARGET_NET_UPDATE_FREQ
        self.experience_replay_size = config.EXP_REPLAY_SIZE
        self.batch_size = config.BATCH_SIZE
        self.learn_start = config.LEARN_START

        self.static_policy = static_policy
        self.num_feats = env.observation_space.shape
        self.num_actions = env.action_space.n
        self.env = env

        self.declare_networks()

        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

        # move to correct device
        self.model = self.model.to(self.device)
        self.target_model.to(self.device)

        if self.static_policy:
            self.model.eval()
            self.target_model.eval()
        else:
            self.model.train()
            self.target_model.train()

        self.update_count = 0

        self.declare_memory()

    def declare_networks(self):
        self.model = DQN(self.num_feats, self.num_actions)
        self.target_model = DQN(self.num_feats, self.num_actions)

    def declare_memory(self):
        self.memory = ExperienceReplayMemory(self.experience_replay_size)

    def append_to_replay(self, s, a, r, s_):
        self.memory.push((s, a, r, s_))

    def prep_minibatch(self):
        # random transition batch is taken from experience replay memory
        transitions = self.memory.sample(self.batch_size)

        batch_state, batch_action, batch_reward, batch_next_state = zip(*transitions)

        shape = (-1,) + self.num_feats

        # (32,1,84,84)
        batch_state = torch.tensor(batch_state, device=self.device, dtype=torch.float).view(shape)
        # (32,1)
        batch_action = torch.tensor(batch_action, device=self.device, dtype=torch.long).squeeze().view(-1, 1)
        # (32,1)
        batch_reward = torch.tensor(batch_reward, device=self.device, dtype=torch.float).squeeze().view(-1, 1)
        # map()会根据提供的函数对指定序列做映射,这里检查下一个状态是否为空,shape为(32)
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch_next_state)), device=self.device,
                                      dtype=torch.uint8)
        try:  # sometimes all next states are false
            # 检查状态中的数是否存在,shape为(32,1,84,84)
            non_final_next_states = torch.tensor([s for s in batch_next_state if s is not None], device=self.device,
                                                 dtype=torch.float).view(shape)
            empty_next_state_values = False
        except:
            non_final_next_states = None
            empty_next_state_values = True

        return batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values

    def compute_loss(self, batch_vars):
        # 提取六组数据
        batch_state, batch_action, batch_reward, non_final_next_states, non_final_mask, empty_next_state_values = batch_vars

        # estimate
        current_q_values = self.model(batch_state).gather(1, batch_action)

        # target
        with torch.no_grad():
            max_next_q_values = torch.zeros(self.batch_size, device=self.device, dtype=torch.float).unsqueeze(dim=1)
            if not empty_next_state_values:
                max_next_action = self.get_max_next_state_action(non_final_next_states)
                max_next_q_values[non_final_mask] = self.target_model(non_final_next_states).gather(1, max_next_action)
            expected_q_values = batch_reward + (self.gamma * max_next_q_values)

        diff = (expected_q_values - current_q_values)
        loss = self.huber(diff)
        loss = loss.mean()

        return loss

    def update(self, s, a, r, s_, frame=0):
        if self.static_policy:
            return None

        self.append_to_replay(s, a, r, s_)

        if frame < self.learn_start:
            return None

        batch_vars = self.prep_minibatch()

        loss = self.compute_loss(batch_vars)

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        for param in self.model.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()

        self.update_target_model()
        self.save_loss(loss.item())
        self.save_sigma_param_magnitudes()

    def get_action(self, s, eps=0.1):
        with torch.no_grad():
            if np.random.random() >= eps or self.static_policy:
                X = torch.tensor([s], device=self.device, dtype=torch.float)
                a = self.model(X).max(1)[1].view(1, 1)
                return a.item()
            else:
                return np.random.randint(0, self.num_actions)

    def update_target_model(self):
        self.update_count += 1
        self.update_count = self.update_count % self.target_net_update_freq
        if self.update_count == 0:
            self.target_model.load_state_dict(self.model.state_dict())

    def get_max_next_state_action(self, next_states):
        return self.target_model(next_states).max(dim=1)[1].view(-1, 1)

    def huber(self, x):
        cond = (x.abs() < 1.0).to(torch.float)
        return 0.5 * x.pow(2) * cond + (x.abs() - 0.5) * (1 - cond)


def plot(frame_idx, rewards, losses, sigma, elapsed_time):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('frame %s. reward: %s. time: %s' % (frame_idx, np.mean(rewards[-10:]), elapsed_time))
    plt.plot(rewards)
    if losses:
        plt.subplot(132)
        plt.title('loss')
        plt.plot(losses)
    if sigma:
        plt.subplot(133)
        plt.title('noisy param magnitude')
        plt.plot(sigma)
    plt.show()


start = timer()

env_id = "PongNoFrameskip-v4"
env = make_atari(env_id)
env = wrap_deepmind(env, frame_stack=False)
env = wrap_pytorch(env)
model = Model(env=env, config=config)

episode_reward = 0

observation = env.reset()
for frame_idx in range(1, config.MAX_FRAMES + 1):
    epsilon = config.epsilon_by_frame(frame_idx)

    action = model.get_action(observation, epsilon)
    prev_observation = observation
    observation, reward, done, _ = env.step(action)
    observation = None if done else observation

    model.update(prev_observation, action, reward, observation, frame_idx)
    episode_reward += reward

    if done:
        print("Finish a episode")
        observation = env.reset()
        model.save_reward(episode_reward)
        episode_reward = 0

        if np.mean(model.rewards[-10:]) > 19:
            print("mean_reward>19....Plot...")
            plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag,
                 timedelta(seconds=int(timer() - start)))
            break

    if frame_idx % 10000 == 0:
        plot(frame_idx, model.rewards, model.losses, model.sigma_parameter_mag, timedelta(seconds=int(timer() - start)))
        print("Plot..................")

model.save_w()
env.close()

这段程序摘自周博磊强化学习课程的DQN算法。可以看到,不包含各种自己定义的库文件,这个主文件的代码已经接近300行。整理和归纳封装这么一大段代码的工作量可想而知,这样编程代码的可复用性并不高。

因此为了提高方便广大强化学习爱好者去调用各种流行的强化学习算法,stable-baseline应运而生,而stable-baseline经过改进,催生了基于Pytorch的stable baseline3。作为最著名的强化学习算法库之一,它经常和gym搭配,被广泛应用于各种强化学习训练中。

简介

stable-baseline3是一个非常受欢迎的深度强化学习工具包,能够快速完成强化学习算法的搭建和评估,提供预训练的智能体,包括保存和录制视频等等,是一个功能非常强大的库。

详情可以查看官网:https://stable-baselines3.readthedocs.io/en/master/

快速搭建环境

首先确保gym成功安装,然后执行以下代码:

import gym

env_name = "CartPole-v0"
env = gym.make(env_name)          # 导入环境

episodes = 10
for episode in range(1, episodes + 1):
    state = env.reset()           
    done = False
    score = 0

    while not done:
        env.render()                           # 渲染环境
        action = env.action_space.sample()     # 随机采样动作
        n_state, reward, done, info = env.step(action)    # 和环境交互,得到下一个状态,奖励等信息
        score += reward                        # 计算分数
    print("Episode : {}, Score : {}".format(episode, score))

env.close()     # 关闭窗口

这段这个CartPole相当于强化学习中的Hello World,能够非常便捷地实验我们的算法。这行这段代码,可以看到一个倒立摆在胡乱操作。

Stable_baseline3控制环境(快速使用)

下面就让我们用Stable_baseline3来训练我们的强化学习模型,使得我们可以很好地控制这个环境。首先确保我们已经安装了stable_baseline3这个库,然后执行以下代码:

from stable_baselines3 import DQN
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
import gym

env_name = "CartPole-v0"
env = gym.make(env_name)
# 把环境向量化,如果有多个环境写成列表传入DummyVecEnv中,可以用一个线程来执行多个环境,提高训练效率
env = DummyVecEnv([lambda : env])
# 定义一个DQN模型,设置其中的各个参数
model = DQN(
    "MlpPolicy",                                # MlpPolicy定义策略网络为MLP网络
    env=env, 
    learning_rate=5e-4,
    batch_size=128,
    buffer_size=50000,
    learning_starts=0,
    target_update_interval=250,
    policy_kwargs={"net_arch" : [256, 256]},     # 这里代表隐藏层为2层256个节点数的网络
    verbose=1,                                   # verbose=1代表打印训练信息,如果是0为不打印,2为打印调试信息
    tensorboard_log="./tensorboard/CartPole-v0/"  # 训练数据保存目录,可以用tensorboard查看
)
# 开始训练
model.learn(total_timesteps=1e5)
# 策略评估,可以看到倒立摆在平稳运行了
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10, render=true)
#env.close()
print("mean_reward:",mean_reward,"std_reward:",std_reward)
# 保存模型到相应的目录
model.save("./model/CartPole.pkl")

通过命令行打开tensorboard:

tensorboard --logdir .\tensorboard\CartPole-v0\

如果想要读取已有的模型来控制环境,执行代码如下:

from stable_baselines3 import DQN
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
import gym

env = gym.make(env_name)
# 导入模型
model = DQN.load("./model/CartPole.pkl")

state = env.reset()
done = False 
score = 0
while not done:
    # 预测动作
    action, _ = model.predict(observation=state)
    # 与环境互动
    state, reward, done, info = env.step(action=action)
    score += reward
    env.render()
env.close()
print("score=",score)

各种算法的参数使用

详细文档参加https://stable-baselines3.readthedocs.io/en/master/modules/base.html

以PPO算法为例,其他同理

通用参数

  • policy( Type[ BasePolicy]) – 策略对象
  • env ( Union[ Env, VecEnv, str, None]) – 要学习的环境
  • policy_base ( Type[ BasePolicy]) – 此方法使用的基本策略
  • learning_rate ( Union[ float, Callable[[ float], float]]) – 优化器的学习率,它可以是当前剩余进度的函数(从 1 到 0)
  • policy_kwargs ( Optional[ Dict[ str, Any]]) – 在创建时传递给策略的附加参数,默认为None
  • tensorboard_log ( Optional[ str]) – tensorboard 的日志位置(如果没有,则不记录),默认为None
  • verbose ( int) – 详细程度:0 无,1 训练信息,2 调试,默认为0
  • device ( Union[ device, str]) – 代码应在其上运行的设备。默认情况下,它将尝试使用与 Cuda 兼容的设备,如果不可能,则回退到 cpu。默认为auto
  • support_multi_env ( bool) – 算法是否支持多环境训练(如在 A2C 中),默认为False
  • create_eval_env ( bool) – 是否创建第二个环境,用于定期评估代理。(仅在为环境传递字符串时可用),默认为False
  • monitor_wrapper ( bool) – 创建环境时,是否将其包装在 Monitor 包装器中。默认为True
  • seed( Optional[ int]) – 伪随机生成器的种子,默认为None
  • use_sde ( bool) – 是否使用广义状态相关探索 (gSDE) 而不是动作噪声探索(默认值:False)
  • sde_sample_freq ( int) – 使用 gSDE 时每 n 步采样一个新的噪声矩阵 默认值:-1(仅在推出开始时采样)
  • supported_action_spaces ( Optional[ Tuple[ Space, ...]]) – 算法支持的动作空间。默认为None
  • PPO(其他的算法同理)

    快速案例:

    import gym
    
    from stable_baselines3 import PPO
    from stable_baselines3.common.env_util import make_vec_env
    
    # Parallel environments
    env = make_vec_env("CartPole-v1", n_envs=4)
    
    model = PPO("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=25000)
    model.save("ppo_cartpole")
    
    del model # remove to demonstrate saving and loading
    
    model = PPO.load("ppo_cartpole")
    
    obs = env.reset()
    while True:
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        env.render()
    

    PPO初始化函数的参数设置

    必须传入的参数(第一,第二个参数):

  • policy:选择网络类型,可选MlpPolicy,CnnPolicy,MultiInputPolicy。

  • env:Gym中的环境。

  • 可选参数:

  • learning_rate – 学习率,默认为0.0003
  • n_steps ( int) – 每次更新环境运行的步骤数,默认为2048
  • batch_size ( int) – batch的大小,默认为64
  • n_epochs ( int) – 优化损失的 epoch 数,默认为10
  • gamma ( float) – 折扣系数,默认为0.99
  • gae_lambda ( float) – 为广义优势估计器权衡偏差与方差的因子,默认为0.95
  • clip_range ( Union[ float, Callable[[ float], float]]) – 剪辑参数,它可以是当前剩余进度的函数(从 1 到 0)。默认为0.2
  • clip_range_vf ( Union[ None, float, Callable[[ float], float]]) – 值函数的clip参数,它可以是当前剩余进度的函数(从 1 到 0)。这是特定于 OpenAI 实现的参数。如果 None 被传递(默认),则不会对 value 函数进行clip。重要提示:此clip取决于奖励缩放。默认为None
  • ent_coef ( float) – 损失计算的熵系数,默认为0
  • vf_coef ( float) – 损失计算的价值函数系数,默认为0.5
  • max_grad_norm ( float) – 梯度裁剪的最大值,默认为0.5
  • use_sde ( bool) – 是否使用广义状态相关探索 (gSDE) 而不是动作噪声探索,默认为False
  • sde_sample_freq ( int) – 使用 gSDE 时每 n 步采样一个新的噪声矩阵 (仅在推出开始时采样),默认为-1
  • target_kl ( Optional[ float]) – 限制更新之间的 KL 差异,因为裁剪不足以防止大更新,请参阅问题 #213(参见https://github.com/hill-a/stable-baselines/issues/213)默认情况下,kl 散度没有限制,为None
  • tensorboard_log ( Optional[ str]) – tensorboard 的日志文件夹(如果没有,则不记录),默认为None
  • create_eval_env ( bool) – 是否创建第二个环境,用于定期评估代理。(仅在为环境传递字符串时可用),默认为False
  • policy_kwargs ( Optional[ Dict[ str, Any]]) – 在创建时传递给策略的附加参数,默认为None
  • verbose( int) – 详细级别:0 无输出,1 信息,2 调试,默认为0
  • seed( Optional[ int]) – 伪随机生成器的种子,默认为None
  • device ( Union[ device, str]) – 训练使用的设备,默认为auto,
  • _init_setup_model ( bool) – 是否在创建实例时构建网络,默认为True
  • learn函数(用于训练模型):

    learn(total_timesteps,callback=None, log_interval=1,eval_env=None,eval_freq=- 1, n_eval_episodes=5,tb_log_name=‘PPO’, eval_log_path=None,reset_num_timesteps=True)

  • total_timesteps ( int) – 要训练的环境步数。
  • callback( Union[ None, Callable, List[ BaseCallback], BaseCallback]) – 在每一步调用的回调,可以用CheckpointCallback来创建一个存档点和规定存档间隔。
  • log_interval ( int) – 记录一次信息的时间步数。
  • tb_log_name ( str) – TensorBoard 日志运行的名称
  • eval_env ( Union[ Env, VecEnv, None]) – 用于评估智能体的环境
  • eval_freq ( int) – 每隔eval_freq的步数评估一下智能体
  • n_eval_episodes ( int) – 每隔n_eval_episodes个episode评估一次智能体
  • eval_log_path ( Optional[ str]) – 保存评估的文件夹路径
  • reset_num_timesteps ( bool) – 是否重置当前时间步数(用于日志记录)
  • save函数(保存模型):

    save(path,exclude=None,include=None)

  • path ( Union[ str, Path, BufferedIOBase]) – 保存文件的路径
  • exclude ( Optional[ Iterable[ str]]) – 除了默认参数之外,还应排除的参数名称
  • include ( Optional[ Iterable[ str]]) – 可能被排除但无论如何都应该包含的参数的名称
  • load函数(导入模型):

    load(path, env=None,device=‘auto’,custom_objects=None, print_system_info=False,force_reset=True,**kwargs)

  • path ( Union[ str, Path, BufferedIOBase]) – 加载模型文件的路径
  • env ( Union[ Env, VecEnv, None]) – 运行加载模型的新环境(如果只需要预测,则可以是 None ),优先于任何已保存的环境
  • device ( Union[ device, str]) – 代码应在其上运行的设备。
  • custom_objects ( Optional[ Dict[ str, Any]]) – 加载时要替换的对象字典。如果此字典中存在一个变量作为键,则不会对其进行反序列化,而是使用相应的项。类似于 keras.models.load_model. 当文件中有无法反序列化的对象时很有用。
  • print_system_info ( bool) – 是否从保存的模型和当前系统信息中打印系统信息(有助于调试加载问题)
  • force_reset ( bool) – 在训练之前强制调用reset()以避免意外行为。见https://github.com/DLR-RM/stable-baselines3/issues/597
  • kwargs – 加载时更改模型的额外参数
  • 创建对应的网络模型的函数stable_baselines3.common.policies.ActorCriticPolicy

    返回值可以直接输入PPO初始化的第一个参数中,参数如下:

    必要参数:

  • observation_space ( Space) – 观察空间
  • action_space ( Space) – 动作空间
  • lr_schedule ( Callable[[ float], float]) – 学习率计划(可以是常数)
  • 可选参数:

  • net_arch ( Optional[ List[ Union[ int, Dict[ str, List[ int]]]]]) – 策略和价值网络的规范。
  • activation_fn ( Type[ Module]) – 激活函数
  • ortho_init ( bool) – 是否使用正交初始化
  • use_sde ( bool) – 是否使用状态相关探索
  • log_std_init ( float) – 对数标准差的初始值
  • full_std ( bool) – 使用 gSDE 时是否对 std 使用 (n_features x n_actions) 参数而不是仅使用 (n_features,)
  • sde_net_arch ( Optional[ List[ int]]) – 使用 gSDE 时提取特征的网络架构。如果没有,将使用策略中的潜在特征。传递一个空列表以将状态用作特征。
  • use_expln ( bool) – 使用expln()函数而不是exp()确保正标准偏差(参见论文)。它允许将方差保持在零以上并防止其增长过快。在实践中,exp()通常就足够了。
  • squash_output ( bool) – 是否使用 tanh 函数压缩输出,这允许在使用 gSDE 时确保边界。
  • features_extractor_class ( Type[ BaseFeaturesExtractor]) – 要使用的特征提取器。
  • features_extractor_kwargs ( Optional[ Dict[ str, Any]]) – 传递给特征提取器的关键字参数。
  • normalize_images ( bool) – 是否对图像进行归一化,除以 255.0(默认为 True)
  • optimizer_class ( Type[ Optimizer]) – th.optim.Adam默认使用的优化器
  • optimizer_kwargs ( Optional[ Dict[ str, Any]]) – 附加关键字参数,不包括学习率,传递给优化器
  • stable_baselines3.common.policies.ActorCriticCnnPolicy和stable_baselines3.common.policies.MultiInputActorCriticPolicy

    和上面同理,参考https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html

    DQN

    快速案例:

    import gym
    
    from stable_baselines3 import DQN
    
    env = gym.make("CartPole-v0")
    
    model = DQN("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=10000, log_interval=4)
    model.save("dqn_cartpole")
    
    del model # remove to demonstrate saving and loading
    
    model = DQN.load("dqn_cartpole")
    
    obs = env.reset()
    while True:
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(action)
        env.render()
        if done:
          obs = env.reset()
    

    详细参数参考https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html

    TD3

    快速案例:

    import gym
    import numpy as np
    
    from stable_baselines3 import TD3
    from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
    
    env = gym.make("Pendulum-v1")
    
    # The noise objects for TD3
    n_actions = env.action_space.shape[-1]
    action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
    
    model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1)
    model.learn(total_timesteps=10000, log_interval=10)
    model.save("td3_pendulum")
    env = model.get_env()
    
    del model # remove to demonstrate saving and loading
    
    model = TD3.load("td3_pendulum")
    
    obs = env.reset()
    while True:
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        env.render()
    

    详细参数参考https://stable-baselines3.readthedocs.io/en/master/modules/td3.html

    SAC

    快速案例:

    import gym
    import numpy as np
    
    from stable_baselines3 import SAC
    
    env = gym.make("Pendulum-v1")
    
    model = SAC("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=10000, log_interval=4)
    model.save("sac_pendulum")
    
    del model # remove to demonstrate saving and loading
    
    model = SAC.load("sac_pendulum")
    
    obs = env.reset()
    while True:
        action, _states = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(action)
        env.render()
        if done:
          obs = env.reset()
    

    详细参数参考https://stable-baselines3.readthedocs.io/en/master/modules/sac.html

    各项功能的使用案例

    训练、保存、读取

    import gym
    
    from stable_baselines3 import DQN
    from stable_baselines3.common.evaluation import evaluate_policy
    
    
    # Create environment
    env = gym.make('LunarLander-v2')
    
    # Instantiate the agent
    model = DQN('MlpPolicy', env, verbose=1)
    # Train the agent
    model.learn(total_timesteps=int(2e5))
    # Save the agent
    model.save("dqn_lunar")
    del model  # delete trained model to demonstrate loading
    
    # Load the trained agent
    # NOTE: if you have loading issue, you can pass `print_system_info=True`
    # to compare the system on which the model was trained vs the current one
    # model = DQN.load("dqn_lunar", env=env, print_system_info=True)
    model = DQN.load("dqn_lunar", env=env)
    
    # Evaluate the agent
    # NOTE: If you use wrappers with your environment that modify rewards,
    #       this will be reflected here. To evaluate with original rewards,
    #       wrap environment in a "Monitor" wrapper before other wrappers.
    mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
    
    # Enjoy trained agent
    obs = env.reset()
    for i in range(1000):
        action, _states = model.predict(obs, deterministic=True)
        obs, rewards, dones, info = env.step(action)
        env.render()
    

    高级的保存、加载

    save函数不会保存replay buffer,因此这里提供了save_replay_buffer()和load_replay_buffer()来保存和读取。

    from stable_baselines3 import SAC
    from stable_baselines3.common.evaluation import evaluate_policy
    from stable_baselines3.sac.policies import MlpPolicy
    
    # Create the model, the training environment
    # and the test environment (for evaluation)
    model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1,
                learning_rate=1e-3, create_eval_env=True)
    
    # Evaluate the model every 1000 steps on 5 test episodes
    # and save the evaluation to the "logs/" folder
    model.learn(6000, eval_freq=1000, n_eval_episodes=5, eval_log_path="./logs/")
    
    # save the model
    model.save("sac_pendulum")
    
    # the saved model does not contain the replay buffer
    loaded_model = SAC.load("sac_pendulum")
    print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
    
    # now save the replay buffer too
    model.save_replay_buffer("sac_replay_buffer")
    
    # load it into the loaded_model
    loaded_model.load_replay_buffer("sac_replay_buffer")
    
    # now the loaded replay is not empty anymore
    print(f"The loaded_model has {loaded_model.replay_buffer.size()} transitions in its buffer")
    
    # Save the policy independently from the model
    # Note: if you don't save the complete model with `model.save()`
    # you cannot continue training afterward
    policy = model.policy
    policy.save("sac_policy_pendulum")
    
    # Retrieve the environment
    env = model.get_env()
    
    # Evaluate the policy
    mean_reward, std_reward = evaluate_policy(policy, env, n_eval_episodes=10, deterministic=True)
    
    print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
    
    # Load the policy independently from the model
    saved_policy = MlpPolicy.load("sac_policy_pendulum")
    
    # Evaluate the loaded policy
    mean_reward, std_reward = evaluate_policy(saved_policy, env, n_eval_episodes=10, deterministic=True)
    
    print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
    

    开多个环境同时训练

    import gym
    import numpy as np
    
    from stable_baselines3 import PPO
    from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
    from stable_baselines3.common.env_util import make_vec_env
    from stable_baselines3.common.utils import set_random_seed
    
    def make_env(env_id, rank, seed=0):
        """
        Utility function for multiprocessed env.
    
        :param env_id: (str) the environment ID
        :param num_env: (int) the number of environments you wish to have in subprocesses
        :param seed: (int) the inital seed for RNG
        :param rank: (int) index of the subprocess
        """
        def _init():
            env = gym.make(env_id)
            env.seed(seed + rank)
            return env
        set_random_seed(seed)
        return _init
    
    if __name__ == '__main__':
        env_id = "CartPole-v1"
        num_cpu = 4  # Number of processes to use
        # Create the vectorized environment
        env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
    
        # Stable Baselines provides you with make_vec_env() helper
        # which does exactly the previous steps for you.
        # You can choose between `DummyVecEnv` (usually faster) and `SubprocVecEnv`
        # env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
    
        model = PPO('MlpPolicy', env, verbose=1)
        model.learn(total_timesteps=25_000)
    
        obs = env.reset()
        for _ in range(1000):
            action, _states = model.predict(obs)
            obs, rewards, dones, info = env.step(action)
            env.render()
    

    off-policy算法的多处理

    import gym
    
    from stable_baselines3 import SAC
    from stable_baselines3.common.env_util import make_vec_env
    
    env = make_vec_env("Pendulum-v0", n_envs=4, seed=0)
    
    # We collect 4 transitions per call to `ènv.step()`
    # and performs 2 gradient steps per call to `ènv.step()`
    # if gradient_steps=-1, then we would do 4 gradients steps per call to `ènv.step()`
    model = SAC('MlpPolicy', env, train_freq=1, gradient_steps=2, verbose=1)
    model.learn(total_timesteps=10_000)
    

    使用具有字典状态输入的环境

    例如图片向量与其他的单个输入,那么状态可以设置为一个字典,字典包含的值有图像观察和矢量观察。

    from stable_baselines3 import PPO
    from stable_baselines3.common.envs import SimpleMultiObsEnv
    
    
    # Stable Baselines provides SimpleMultiObsEnv as an example environment with Dict observations
    env = SimpleMultiObsEnv(random_start=False)
    
    model = PPO("MultiInputPolicy", env, verbose=1)
    model.learn(total_timesteps=100_000)
    

    定义callback函数来监控训练

    import os
    
    import gym
    import numpy as np
    import matplotlib.pyplot as plt
    
    from stable_baselines3 import TD3
    from stable_baselines3.common import results_plotter
    from stable_baselines3.common.monitor import Monitor
    from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results
    from stable_baselines3.common.noise import NormalActionNoise
    from stable_baselines3.common.callbacks import BaseCallback
    
    
    class SaveOnBestTrainingRewardCallback(BaseCallback):
        """
        Callback for saving a model (the check is done every ``check_freq`` steps)
        based on the training reward (in practice, we recommend using ``EvalCallback``).
    
        :param check_freq:
        :param log_dir: Path to the folder where the model will be saved.
          It must contains the file created by the ``Monitor`` wrapper.
        :param verbose: Verbosity level.
        """
        def __init__(self, check_freq: int, log_dir: str, verbose: int = 1):
            super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
            self.check_freq = check_freq
            self.log_dir = log_dir
            self.save_path = os.path.join(log_dir, 'best_model')
            self.best_mean_reward = -np.inf
    
        def _init_callback(self) -> None:
            # Create folder if needed
            if self.save_path is not None:
                os.makedirs(self.save_path, exist_ok=True)
    
        def _on_step(self) -> bool:
            if self.n_calls % self.check_freq == 0:
    
              # Retrieve training reward
              x, y = ts2xy(load_results(self.log_dir), 'timesteps')
              if len(x) > 0:
                  # Mean training reward over the last 100 episodes
                  mean_reward = np.mean(y[-100:])
                  if self.verbose > 0:
                    print(f"Num timesteps: {self.num_timesteps}")
                    print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")
    
                  # New best model, you could save the agent here
                  if mean_reward > self.best_mean_reward:
                      self.best_mean_reward = mean_reward
                      # Example for saving best model
                      if self.verbose > 0:
                        print(f"Saving new best model to {self.save_path}")
                      self.model.save(self.save_path)
    
            return True
    
    # Create log dir
    log_dir = "tmp/"
    os.makedirs(log_dir, exist_ok=True)
    
    # Create and wrap the environment
    env = gym.make('LunarLanderContinuous-v2')
    env = Monitor(env, log_dir)
    
    # Add some action noise for exploration
    n_actions = env.action_space.shape[-1]
    action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
    # Because we use parameter noise, we should use a MlpPolicy with layer normalization
    model = TD3('MlpPolicy', env, action_noise=action_noise, verbose=0)
    # Create the callback: check every 1000 steps
    callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
    # Train the agent
    timesteps = 1e5
    model.learn(total_timesteps=int(timesteps), callback=callback)
    
    plot_results([log_dir], timesteps, results_plotter.X_TIMESTEPS, "TD3 LunarLander")
    plt.show()
    

    图像作为输入的环境

    from stable_baselines3.common.env_util import make_atari_env
    from stable_baselines3.common.vec_env import VecFrameStack
    from stable_baselines3 import A2C
    
    # There already exists an environment generator
    # that will make and wrap atari environments correctly.
    # Here we are also multi-worker training (n_envs=4 => 4 environments)
    env = make_atari_env('PongNoFrameskip-v4', n_envs=4, seed=0)
    # 注意这里叠加了4帧作为一个输入
    env = VecFrameStack(env, n_stack=4)
    # 使用了cnn网络
    model = A2C('CnnPolicy', env, verbose=1)
    model.learn(total_timesteps=25_000)
    
    obs = env.reset()
    while True:
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        env.render()
    

    规范输入

    import os
    import gym
    import pybullet_envs
    
    from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
    from stable_baselines3 import PPO
    
    env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
    # Automatically normalize the input features and reward
    env = VecNormalize(env, norm_obs=True, norm_reward=True,
                       clip_obs=10.)
    
    model = PPO('MlpPolicy', env)
    model.learn(total_timesteps=2000)
    
    # Don't forget to save the VecNormalize statistics when saving the agent
    log_dir = "/tmp/"
    model.save(log_dir + "ppo_halfcheetah")
    stats_path = os.path.join(log_dir, "vec_normalize.pkl")
    env.save(stats_path)
    
    # To demonstrate loading
    del model, env
    
    # Load the saved statistics
    env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
    env = VecNormalize.load(stats_path, env)
    #  do not update them at test time
    env.training = False      
    # reward normalization is not needed at test time
    env.norm_reward = False
    
    # Load the agent
    model = PPO.load(log_dir + "ppo_halfcheetah", env=env)
    

    事后经验回放(HER)

    import gym
    import highway_env
    import numpy as np
    
    from stable_baselines3 import HerReplayBuffer, SAC, DDPG, TD3
    from stable_baselines3.common.noise import NormalActionNoise
    
    env = gym.make("parking-v0")
    
    # Create 4 artificial transitions per real transition
    n_sampled_goal = 4
    
    # SAC hyperparams:
    model = SAC(
        "MultiInputPolicy",
        env,
        replay_buffer_class=HerReplayBuffer,
        replay_buffer_kwargs=dict(
          n_sampled_goal=n_sampled_goal,
          goal_selection_strategy="future",
          # IMPORTANT: because the env is not wrapped with a TimeLimit wrapper
          # we have to manually specify the max number of steps per episode
          max_episode_length=100,
          online_sampling=True,
        ),
        verbose=1,
        buffer_size=int(1e6),
        learning_rate=1e-3,
        gamma=0.95,
        batch_size=256,
        policy_kwargs=dict(net_arch=[256, 256, 256]),
    )
    
    model.learn(int(2e5))
    model.save("her_sac_highway")
    
    # Load saved model
    # Because it needs access to `env.compute_reward()`
    # HER must be loaded with the env
    model = SAC.load("her_sac_highway", env=env)
    
    obs = env.reset()
    
    # Evaluate the agent
    episode_reward = 0
    for _ in range(100):
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, info = env.step(action)
        env.render()
        episode_reward += reward
        if done or info.get("is_success", False):
            print("Reward:", episode_reward, "Success?", info.get("is_success", False))
            episode_reward = 0.0
            obs = env.reset()
    

    设定可变学习率

    from typing import Callable
    
    from stable_baselines3 import PPO
    
    
    def linear_schedule(initial_value: float) -> Callable[[float], float]:
        def func(progress_remaining: float) -> float:
            return progress_remaining * initial_value
        return func
    
    # Initial learning rate of 0.001
    model = PPO("MlpPolicy", "CartPole-v1", learning_rate=linear_schedule(0.001), verbose=1)
    model.learn(total_timesteps=20_000)
    # By default, `reset_num_timesteps` is True, in which case the learning rate schedule resets.
    # progress_remaining = 1.0 - (num_timesteps / total_timesteps)
    model.learn(total_timesteps=10_000, reset_num_timesteps=True)
    

    访问和修改模型参数

    from typing import Dict
    
    import gym
    import numpy as np
    import torch as th
    
    from stable_baselines3 import A2C
    from stable_baselines3.common.evaluation import evaluate_policy
    
    
    def mutate(params: Dict[str, th.Tensor]) -> Dict[str, th.Tensor]:
        """Mutate parameters by adding normal noise to them"""
        return dict((name, param + th.randn_like(param)) for name, param in params.items())
    
    
    # Create policy with a small network
    model = A2C(
        "MlpPolicy",
        "CartPole-v1",
        ent_coef=0.0,
        policy_kwargs={"net_arch": [32]},
        seed=0,
        learning_rate=0.05,
    )
    
    # Use traditional actor-critic policy gradient updates to
    # find good initial parameters
    model.learn(total_timesteps=10_000)
    
    # Include only variables with "policy", "action" (policy) or "shared_net" (shared layers)
    # in their name: only these ones affect the action.
    # NOTE: you can retrieve those parameters using model.get_parameters() too
    mean_params = dict(
        (key, value)
        for key, value in model.policy.state_dict().items()
        if ("policy" in key or "shared_net" in key or "action" in key)
    )
    
    # population size of 50 invdiduals
    pop_size = 50
    # Keep top 10%
    n_elite = pop_size // 10
    # Retrieve the environment
    env = model.get_env()
    
    for iteration in range(10):
        # Create population of candidates and evaluate them
        population = []
        for population_i in range(pop_size):
            candidate = mutate(mean_params)
            # Load new policy parameters to agent.
            # Tell function that it should only update parameters
            # we give it (policy parameters)
            model.policy.load_state_dict(candidate, strict=False)
            # Evaluate the candidate
            fitness, _ = evaluate_policy(model, env)
            population.append((candidate, fitness))
        # Take top 10% and use average over their parameters as next mean parameter
        top_candidates = sorted(population, key=lambda x: x[1], reverse=True)[:n_elite]
        mean_params = dict(
            (
                name,
                th.stack([candidate[0][name] for candidate in top_candidates]).mean(dim=0),
            )
            for name in mean_params.keys()
        )
        mean_fitness = sum(top_candidate[1] for top_candidate in top_candidates) / n_elite
        print(f"Iteration {iteration + 1:<3} Mean top fitness: {mean_fitness:.2f}")
        print(f"Best fitness: {top_candidates[0][1]:.2f}")
    

    SB3 和 ProcgenEnv

    Procgen这样的一些环境已经产生了矢量化环境(参见问题 #314中的讨论)。为了将它与 SB3 一起使用,您必须将其包装在一个VecMonitor包装器中,这也将允许跟踪代理进度。

    from procgen import ProcgenEnv
    
    from stable_baselines3 import PPO
    from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
    
    # ProcgenEnv is already vectorized
    venv = ProcgenEnv(num_envs=2, env_name='starpilot')
    
    # To use only part of the observation:
    # venv = VecExtractDictObs(venv, "rgb")
    
    # Wrap with a VecMonitor to collect stats and avoid errors
    venv = VecMonitor(venv=venv)
    
    model = PPO("MultiInputPolicy", venv, verbose=1)
    model.learn(10_000)
    

    录制视频

    电脑需要安装ffmpeg或avconv

    import gym
    from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
    
    env_id = 'CartPole-v1'
    video_folder = 'logs/videos/'
    video_length = 100
    
    env = DummyVecEnv([lambda: gym.make(env_id)])
    
    obs = env.reset()
    
    # Record the video starting at the first step
    env = VecVideoRecorder(env, video_folder,
                           record_video_trigger=lambda x: x == 0, video_length=video_length,
                           name_prefix=f"random-agent-{env_id}")
    
    env.reset()
    for _ in range(video_length + 1):
      action = [env.action_space.sample()]
      obs, _, _, _ = env.step(action)
    # Save the video
    env.close()
    

    制作GIF

    不适用于雅达利游戏

    import imageio
    import numpy as np
    
    from stable_baselines3 import A2C
    
    model = A2C("MlpPolicy", "LunarLander-v2").learn(100_000)
    
    images = []
    obs = model.env.reset()
    img = model.env.render(mode='rgb_array')
    for i in range(350):
        images.append(img)
        action, _ = model.predict(obs)
        obs, _, _ ,_ = model.env.step(action)
        img = model.env.render(mode='rgb_array')
    
    imageio.mimsave('lander_a2c.gif', [np.array(img) for i, img in enumerate(images) if i%2 == 0], fps=29)
    

    使用自定义环境

    自定义环境需要继承gym.Env类,然后重新其中的方法,配置一定的参数即可,格式如下:

    import gym
    from gym import spaces
    
    class CustomEnv(gym.Env):
        """Custom Environment that follows gym interface"""
        metadata = {'render.modes': ['human']}
    
        def __init__(self, arg1, arg2, ...):
            super(CustomEnv, self).__init__()
            # Define action and observation space
            # They must be gym.spaces objects
            # Example when using discrete actions:
            self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
            # Example for using image as input (channel-first; channel-last also works):
            self.observation_space = spaces.Box(low=0, high=255,
                                                shape=(N_CHANNELS, HEIGHT, WIDTH), dtype=np.uint8)
    
        def step(self, action):
            ...
            return observation, reward, done, info
        def reset(self):
            ...
            return observation  # reward, done, info can't be included
        def render(self, mode='human'):
            ...
        def close (self):
            ...
    

    检查环境是否符合gym接口:

    from stable_baselines3.common.env_checker import check_env
    
    env = CustomEnv(arg1, ...)
    # It will check your custom environment and output additional warnings if needed
    check_env(env)
    

    当需要使用环境时:

    # Instantiate the env
    env = CustomEnv(arg1, ...)
    # Define and Train the agent
    model = A2C('CnnPolicy', env).learn(total_timesteps=1000)
    

    小案例:

    import numpy as np
    import gym
    from gym import spaces
    
    
    class GoLeftEnv(gym.Env):
      """
      Custom Environment that follows gym interface.
      This is a simple env where the agent must learn to go always left. 
      """
      # Because of google colab, we cannot implement the GUI ('human' render mode)
      metadata = {'render.modes': ['console']}
      # Define constants for clearer code
      LEFT = 0
      RIGHT = 1
    
      def __init__(self, grid_size=10):
        super(GoLeftEnv, self).__init__()
    
        # Size of the 1D-grid
        self.grid_size = grid_size
        # Initialize the agent at the right of the grid
        self.agent_pos = grid_size - 1
    
        # Define action and observation space
        # They must be gym.spaces objects
        # Example when using discrete actions, we have two: left and right
        n_actions = 2
        self.action_space = spaces.Discrete(n_actions)
        # The observation will be the coordinate of the agent
        # this can be described both by Discrete and Box space
        self.observation_space = spaces.Box(low=0, high=self.grid_size,
                                            shape=(1,), dtype=np.float32)
    
      def reset(self):
        """
        Important: the observation must be a numpy array
        :return: (np.array) 
        """
        # Initialize the agent at the right of the grid
        self.agent_pos = self.grid_size - 1
        # here we convert to float32 to make it more general (in case we want to use continuous actions)
        return np.array([self.agent_pos]).astype(np.float32)
    
      def step(self, action):
        if action == self.LEFT:
          self.agent_pos -= 1
        elif action == self.RIGHT:
          self.agent_pos += 1
        else:
          raise ValueError("Received invalid action={} which is not part of the action space".format(action))
    
        # Account for the boundaries of the grid
        self.agent_pos = np.clip(self.agent_pos, 0, self.grid_size)
    
        # Are we at the left of the grid?
        done = bool(self.agent_pos == 0)
    
        # Null reward everywhere except when reaching the goal (left of the grid)
        reward = 1 if self.agent_pos == 0 else 0
    
        # Optionally we can pass additional info, we are not using that for now
        info = {}
    
        return np.array([self.agent_pos]).astype(np.float32), reward, done, info
    
      def render(self, mode='console'):
        if mode != 'console':
          raise NotImplementedError()
        # agent is represented as a cross, rest as a dot
        print("." * self.agent_pos, end="")
        print("x", end="")
        print("." * (self.grid_size - self.agent_pos))
    
      def close(self):
        pass
        
    

    我们也可以把环境注册在gym中,像gym内置环境一样调用它,注册格式为:

    from gym.envs.registration import register
    # Example for the CartPole environment
    register(
        # unique identifier for the env `name-version`
        id="CartPole-v1",
        # path to the class for creating the env
        # Note: entry_point also accept a class as input (and not only a string)
        entry_point="gym.envs.classic_control:CartPoleEnv",
        # Max number of steps per episode, using a `TimeLimitWrapper`
        max_episode_steps=500,
    )
    

    自定义策略网络

    自定义策略网络架构的一种方法是在创建模型时使用参数传递policy_kwargs参数:

    import gym
    import torch as th
    
    from stable_baselines3 import PPO
    
    # Custom actor (pi) and value function (vf) networks
    # of two layers of size 32 each with Relu activation function
    policy_kwargs = dict(activation_fn=th.nn.ReLU,
                         net_arch=[dict(pi=[32, 32], vf=[32, 32])])
    # Create the agent
    model = PPO("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
    # Retrieve the environment
    env = model.get_env()
    # Train the agent
    model.learn(total_timesteps=100000)
    # Save the agent
    model.save("ppo_cartpole")
    
    del model
    # the policy_kwargs are automatically loaded
    model = PPO.load("ppo_cartpole", env=env)
    

    上面定义了actor-critic网络,都是两层32的网络,激活函数为ReLu。

    如果我们想单独定义特征提取的部分(例如CNN),我们可以定义一个类,继承BaseFeaturesExtractor:

    import gym
    import torch as th
    import torch.nn as nn
    
    from stable_baselines3 import PPO
    from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
    
    
    class CustomCNN(BaseFeaturesExtractor):
        """
        :param observation_space: (gym.Space)
        :param features_dim: (int) Number of features extracted.
            This corresponds to the number of unit for the last layer.
        """
    
        def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256):
            super(CustomCNN, self).__init__(observation_space, features_dim)
            # We assume CxHxW images (channels first)
            # Re-ordering will be done by pre-preprocessing or wrapper
            n_input_channels = observation_space.shape[0]
            self.cnn = nn.Sequential(
                nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
                nn.ReLU(),
                nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
                nn.ReLU(),
                nn.Flatten(),
            )
    
            # Compute shape by doing one forward pass
            with th.no_grad():
                n_flatten = self.cnn(
                    th.as_tensor(observation_space.sample()[None]).float()
                ).shape[1]
    
            self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
    
        def forward(self, observations: th.Tensor) -> th.Tensor:
            return self.linear(self.cnn(observations))
    
    policy_kwargs = dict(
        features_extractor_class=CustomCNN,
        features_extractor_kwargs=dict(features_dim=128),
    )
    model = PPO("CnnPolicy", "BreakoutNoFrameskip-v4", policy_kwargs=policy_kwargs, verbose=1)
    model.learn(1000)
    

    默认情况actor和critic共享特征提取部分共用一个网络,但也可以通过在policy_kwargs设定share_features_extractor=False来更改。

    多输入和字典观察

    多输入指的是既有图像输入,又有向量输入,我们把它们处理后全部展开拼接到一起。

    import gym
    import torch as th
    from torch import nn
    
    from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
    
    class CustomCombinedExtractor(BaseFeaturesExtractor):
        def __init__(self, observation_space: gym.spaces.Dict):
            # We do not know features-dim here before going over all the items,
            # so put something dummy for now. PyTorch requires calling
            # nn.Module.__init__ before adding modules
            super(CustomCombinedExtractor, self).__init__(observation_space, features_dim=1)
    
            extractors = {}
    
            total_concat_size = 0
            # We need to know size of the output of this extractor,
            # so go over all the spaces and compute output feature sizes
            for key, subspace in observation_space.spaces.items():
                if key == "image":
                    # We will just downsample one channel of the image by 4x4 and flatten.
                    # Assume the image is single-channel (subspace.shape[0] == 0)
                    extractors[key] = nn.Sequential(nn.MaxPool2d(4), nn.Flatten())
                    total_concat_size += subspace.shape[1] // 4 * subspace.shape[2] // 4
                elif key == "vector":
                    # Run through a simple MLP
                    extractors[key] = nn.Linear(subspace.shape[0], 16)
                    total_concat_size += 16
    
            self.extractors = nn.ModuleDict(extractors)
    
            # Update the features dim manually
            self._features_dim = total_concat_size
    
        def forward(self, observations) -> th.Tensor:
            encoded_tensor_list = []
    
            # self.extractors contain nn.Modules that do all the processing.
            for key, extractor in self.extractors.items():
                encoded_tensor_list.append(extractor(observations[key]))
            # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension.
            return th.cat(encoded_tensor_list, dim=1)
    

    给Actor和Critic订制网络结构

    共享网络:

    策略网络和价值网络之间我们可以指定它们共享多少层的网络,我们只需要在policy_kwargs字典中构建一对键值对即可。

              obs
               |
             <128>
               |
             <128>
       /               \
    action            value
    

    此网络可以这样定义:net_arch=[128, 128]

              obs
               |
             <128>
       /               \
    action             <256>
                         |
                       <256>
                         |
                       value
    

    此网络可以这样定义:net_arch=[128, dict(vf=[256, 256])]

              obs
               |
             <128>
       /               \
     <16>             <256>
       |                |
    action            value
    

    此网络可以这样定义:[128, dict(vf=[256], pi=[16])]

    如果想要对网络进行更精确的定义,需要重新定义一个类:

    from typing import Callable, Dict, List, Optional, Tuple, Type, Union
    
    import gym
    import torch as th
    from torch import nn
    
    from stable_baselines3 import PPO
    from stable_baselines3.common.policies import ActorCriticPolicy
    
    
    class CustomNetwork(nn.Module):
        """
        Custom network for policy and value function.
        It receives as input the features extracted by the feature extractor.
    
        :param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
        :param last_layer_dim_pi: (int) number of units for the last layer of the policy network
        :param last_layer_dim_vf: (int) number of units for the last layer of the value network
        """
    
        def __init__(
            self,
            feature_dim: int,
            last_layer_dim_pi: int = 64,
            last_layer_dim_vf: int = 64,
        ):
            super(CustomNetwork, self).__init__()
    
            # IMPORTANT:
            # Save output dimensions, used to create the distributions
            self.latent_dim_pi = last_layer_dim_pi
            self.latent_dim_vf = last_layer_dim_vf
    
            # Policy network
            self.policy_net = nn.Sequential(
                nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
            )
            # Value network
            self.value_net = nn.Sequential(
                nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
            )
    
        def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
            """
            :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
                If all layers are shared, then ``latent_policy == latent_value``
            """
            return self.policy_net(features), self.value_net(features)
    
        def forward_actor(self, features: th.Tensor) -> th.Tensor:
            return self.policy_net(features)
    
        def forward_critic(self, features: th.Tensor) -> th.Tensor:
            return self.value_net(features)
    
    
    class CustomActorCriticPolicy(ActorCriticPolicy):
        def __init__(
            self,
            observation_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            lr_schedule: Callable[[float], float],
            net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
            activation_fn: Type[nn.Module] = nn.Tanh,
            *args,
            **kwargs,
        ):
    
            super(CustomActorCriticPolicy, self).__init__(
                observation_space,
                action_space,
                lr_schedule,
                net_arch,
                activation_fn,
                # Pass remaining arguments to base class
                *args,
                **kwargs,
            )
            # Disable orthogonal initialization
            self.ortho_init = False
    
        def _build_mlp_extractor(self) -> None:
            self.mlp_extractor = CustomNetwork(self.features_dim)
    
    
    model = PPO(CustomActorCriticPolicy, "CartPole-v1", verbose=1)
    model.learn(5000)
    

    Callback

    Callback是一组在训练过程给定状态中调用的函数。具有自动保存,监控,模型操作,进度条等功能。

    自定义Callback的格式:

    from stable_baselines3.common.callbacks import BaseCallback
    
    
    class CustomCallback(BaseCallback):
        """
        A custom callback that derives from ``BaseCallback``.
    
        :param verbose: (int) Verbosity level 0: not output 1: info 2: debug
        """
        def __init__(self, verbose=0):
            super(CustomCallback, self).__init__(verbose)
            # Those variables will be accessible in the callback
            # (they are defined in the base class)
            # The RL model
            # self.model = None  # type: BaseAlgorithm
            # An alias for self.model.get_env(), the environment used for training
            # self.training_env = None  # type: Union[gym.Env, VecEnv, None]
            # Number of time the callback was called
            # self.n_calls = 0  # type: int
            # self.num_timesteps = 0  # type: int
            # local and global variables
            # self.locals = None  # type: Dict[str, Any]
            # self.globals = None  # type: Dict[str, Any]
            # The logger object, used to report things in the terminal
            # self.logger = None  # stable_baselines3.common.logger
            # # Sometimes, for event callback, it is useful
            # # to have access to the parent object
            # self.parent = None  # type: Optional[BaseCallback]
    
        def _on_training_start(self) -> None:
            """
            This method is called before the first rollout starts.
            """
            pass
    
        def _on_rollout_start(self) -> None:
            """
            A rollout is the collection of environment interaction
            using the current policy.
            This event is triggered before collecting new samples.
            """
            pass
    
        def _on_step(self) -> bool:
            """
            This method will be called by the model after each call to `env.step()`.
    
            For child callback (of an `EventCallback`), this will be called
            when the event is triggered.
    
            :return: (bool) If the callback returns False, training is aborted early.
            """
            return True
    
        def _on_rollout_end(self) -> None:
            """
            This event is triggered before updating the policy.
            """
            pass
    
        def _on_training_end(self) -> None:
            """
            This event is triggered before exiting the `learn()` method.
            """
            pass
    

    事件Callback:

    SB3提供了第二种类型的BaseCallback,名为EventCallback。当触发事件时,将调用子回调。例子:

    class EventCallback(BaseCallback):
        """
        Base class for triggering callback on event.
    
        :param callback: (Optional[BaseCallback]) Callback that will be called
            when an event is triggered.
        :param verbose: (int)
        """
        def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
            super(EventCallback, self).__init__(verbose=verbose)
            self.callback = callback
            # Give access to the parent
            if callback is not None:
                self.callback.parent = self
        ...
    
        def _on_event(self) -> bool:
            if self.callback is not None:
                return self.callback()
            return True
    

    from stable_baselines3 import SAC
    from stable_baselines3.common.callbacks import CheckpointCallback
    # 这里设定每隔1000个step在logs文件夹下保存一次模型,名称前缀为rl_model
    checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/',
                                             name_prefix='rl_model')
    
    model = SAC('MlpPolicy', 'Pendulum-v1')
    model.learn(2000, callback=checkpoint_callback)
    

    EvalCallback

    这个函数用于评估智能体的表现,使用的是另外一个测试的环境。可以在best_model_save_path保存最佳模型,并在log_pash指定的情况下以evaluations.npz保存评估结果。

    import gym
    
    from stable_baselines3 import SAC
    from stable_baselines3.common.callbacks import EvalCallback
    
    # Separate evaluation env
    eval_env = gym.make('Pendulum-v1')
    # Use deterministic actions for evaluation
    eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/',
                                 log_path='./logs/', eval_freq=500,
                                 deterministic=True, render=False)
    
    model = SAC('MlpPolicy', 'Pendulum-v1')
    model.learn(5000, callback=eval_callback)
    

    Callback列表

    同时使用多个Callback:

    import gym
    
    from stable_baselines3 import SAC
    from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
    
    checkpoint_callback = CheckpointCallback(save_freq=1000, save_path='./logs/')
    # Separate evaluation env
    eval_env = gym.make('Pendulum-v1')
    eval_callback = EvalCallback(eval_env, best_model_save_path='./logs/best_model',
                                 log_path='./logs/results', eval_freq=500)
    # Create the callback list
    callback = CallbackList([checkpoint_callback, eval_callback])
    
    model = SAC('MlpPolicy', 'Pendulum-v1')
    # Equivalent to:
    # model.learn(5000, callback=[checkpoint_callback, eval_callback])
    model.learn(5000, callback=callback)
    

    StopTrainingOnRewardThreshold

    当平均的episode奖励达到一个阈值时,停止训练。必须与EvalCallback一起使用:

    import gym
    
    from stable_baselines3 import SAC
    from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
    
    # Separate evaluation env
    eval_env = gym.make('Pendulum-v1')
    # Stop training when the model reaches the reward threshold
    callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-200, verbose=1)
    eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, verbose=1)
    
    model = SAC('MlpPolicy', 'Pendulum-v1', verbose=1)
    # Almost infinite number of timesteps, but the training will stop
    # early as soon as the reward threshold is reached
    model.learn(int(1e10), callback=eval_callback)
    

    EveryNTimesteps

    事件Callback的一种,规定了每隔step就固定调用一次指定的Callback。

    import gym
    
    from stable_baselines3 import PPO
    from stable_baselines3.common.callbacks import CheckpointCallback, EveryNTimesteps
    
    # this is equivalent to defining CheckpointCallback(save_freq=500)
    # checkpoint_callback will be triggered every 500 steps
    checkpoint_on_event = CheckpointCallback(save_freq=1, save_path='./logs/')
    event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
    
    model = PPO('MlpPolicy', 'Pendulum-v1', verbose=1)
    
    model.learn(int(2e4), callback=event_callback)
    

    StopTrainingOnMaxEpisodes

    当达到指定做大的episode的时候,自动停止训练。对于同时开启多个训练环境时,这个指定数要乘以环境数。

    from stable_baselines3 import A2C
    from stable_baselines3.common.callbacks import StopTrainingOnMaxEpisodes
    
    # Stops training when the model reaches the maximum number of episodes
    callback_max_episodes = StopTrainingOnMaxEpisodes(max_episodes=5, verbose=1)
    
    model = A2C('MlpPolicy', 'Pendulum-v1', verbose=1)
    # Almost infinite number of timesteps, but the training will stop
    # early as soon as the max number of episodes is reached
    model.learn(int(1e10), callback=callback_max_episodes)
    

    使用Tensorboard查看训练结果

    最简使用

    from stable_baselines3 import A2C
    
    model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
    model.learn(total_timesteps=10_000)
    

    查看Tensorboard的操作为在控制台输入命令:

    tensorboard --logdir ./a2c_cartpole_tensorboard/
    

    然后打开对应的网址就能查看。

    添加自己想要的值

    import numpy as np
    
    from stable_baselines3 import SAC
    from stable_baselines3.common.callbacks import BaseCallback
    
    model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
    
    
    class TensorboardCallback(BaseCallback):
        """
        Custom callback for plotting additional values in tensorboard.
        """
    
        def __init__(self, verbose=0):
            super(TensorboardCallback, self).__init__(verbose)
    
        def _on_step(self) -> bool:
            # Log scalar value (here a random variable)
            value = np.random.random()
            self.logger.record('random_value', value)
            return True
    
    
    model.learn(50000, callback=TensorboardCallback())
    

    记录图像

    使用这个功能前需要安装pillow

    from stable_baselines3 import SAC
    from stable_baselines3.common.callbacks import BaseCallback
    from stable_baselines3.common.logger import Image
    
    model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
    
    
    class ImageRecorderCallback(BaseCallback):
        def __init__(self, verbose=0):
            super(ImageRecorderCallback, self).__init__(verbose)
    
        def _on_step(self):
            image = self.training_env.render(mode="rgb_array")
            # "HWC" specify the dataformat of the image, here channel last
            # (H for height, W for width, C for channel)
            # See https://pytorch.org/docs/stable/tensorboard.html
            # for supported formats
            self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
            return True
    
    
    model.learn(50000, callback=ImageRecorderCallback())
    

    绘图

    Tensorboard支持定期使用matplotlib绘图。使用前提是matplotlib需要安装,否则会不作记录并出现警告。

    import numpy as np
    import matplotlib.pyplot as plt
    
    from stable_baselines3 import SAC
    from stable_baselines3.common.callbacks import BaseCallback
    from stable_baselines3.common.logger import Figure
    
    model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
    
    
    class FigureRecorderCallback(BaseCallback):
        def __init__(self, verbose=0):
            super(FigureRecorderCallback, self).__init__(verbose)
    
        def _on_step(self):
            # Plot values (here a random variable)
            figure = plt.figure()
            figure.add_subplot().plot(np.random.random(3))
            # Close the figure after logging it
            self.logger.record("trajectory/figure", Figure(figure, close=True), exclude=("stdout", "log", "json", "csv"))
            plt.close()
            return True
    
    
    model.learn(50000, callback=FigureRecorderCallback())
    

    记录视频

    使用这个功能前需要安装moviepy,否则会略过记录并警告。

    from typing import Any, Dict
    
    import gym
    import torch as th
    
    from stable_baselines3 import A2C
    from stable_baselines3.common.callbacks import BaseCallback
    from stable_baselines3.common.evaluation import evaluate_policy
    from stable_baselines3.common.logger import Video
    
    
    class VideoRecorderCallback(BaseCallback):
        def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):
            """
            Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard
    
            :param eval_env: A gym environment from which the trajectory is recorded
            :param render_freq: Render the agent's trajectory every eval_freq call of the callback.
            :param n_eval_episodes: Number of episodes to render
            :param deterministic: Whether to use deterministic or stochastic policy
            """
            super().__init__()
            self._eval_env = eval_env
            self._render_freq = render_freq
            self._n_eval_episodes = n_eval_episodes
            self._deterministic = deterministic
    
        def _on_step(self) -> bool:
            if self.n_calls % self._render_freq == 0:
                screens = []
    
                def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
                    """
                    Renders the environment in its current state, recording the screen in the captured `screens` list
    
                    :param _locals: A dictionary containing all local variables of the callback's scope
                    :param _globals: A dictionary containing all global variables of the callback's scope
                    """
                    screen = self._eval_env.render(mode="rgb_array")
                    # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
                    screens.append(screen.transpose(2, 0, 1))
    
                evaluate_policy(
                    self.model,
                    self._eval_env,
                    callback=grab_screens,
                    n_eval_episodes=self._n_eval_episodes,
                    deterministic=self._deterministic,
                )
                self.logger.record(
                    "trajectory/video",
                    Video(th.ByteTensor([screens]), fps=40),
                    exclude=("stdout", "log", "json", "csv"),
                )
            return True
    
    
    model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1)
    video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000)
    model.learn(total_timesteps=int(5e4), callback=video_recorder)
    

    直接记录任意数据

    可以直接访问底层的SummaryWriter:

    from stable_baselines3 import SAC
    from stable_baselines3.common.callbacks import BaseCallback
    from stable_baselines3.common.logger import TensorBoardOutputFormat
    
    
    
    model = SAC("MlpPolicy", "Pendulum-v1", tensorboard_log="/tmp/sac/", verbose=1)
    
    
    class SummaryWriterCallback(BaseCallback):
    
        def _on_training_start(self):
            self._log_freq = 1000  # log every 1000 calls
    
            output_formats = self.logger.Logger.CURRENT.output_formats
            # Save reference to tensorboard formatter object
            # note: the failure case (not formatter found) is not handled here, should be done with try/except.
            self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))
    
        def _on_step(self) -> bool:
            if self.n_calls % self._log_freq == 0:
                self.tb_formatter.writer.add_text("direct_access", "this is a value", self.num_timesteps)
                self.tb_formatter.writer.flush()
    
    
    model.learn(50000, callback=SummaryWriterCallback())
    

    模仿学习

    Imitation库在SB3基础上实现了模仿学习算法,其中包含四个算法:

  • Behavioral Cloning
  • DAgger with synthetic examples
  • Adversarial Inverse Reinforcement Learning (AIRL)
  • Generative Adversarial Imitation Learning (GAIL)
  • 安装:

    pip install imitation
    

    简单使用:

    # Train PPO agent on cartpole and collect expert demonstrations
    python -m imitation.scripts.expert_demos with fast cartpole log_dir=quickstart
    
    # Train GAIL from demonstrations
    python -m imitation.scripts.train_adversarial with fast gail cartpole rollout_path=quickstart/rollouts/final.pkl
    
    # Train AIRL from demonstrations
    python -m imitation.scripts.train_adversarial with fast airl cartpole rollout_path=quickstart/rollouts/final.pkl
    

    案例:

    """Trains BC, GAIL and AIRL models on saved CartPole-v1 demonstrations."""
    
    import pathlib
    import pickle
    import tempfile
    
    import seals  # noqa: F401
    import stable_baselines3 as sb3
    
    from imitation.algorithms import bc
    from imitation.algorithms.adversarial import airl, gail
    from imitation.data import rollout
    from imitation.rewards import reward_nets
    from imitation.util import logger, util
    
    # Load pickled test demonstrations.
    with open("tests/testdata/expert_models/cartpole_0/rollouts/final.pkl", "rb") as f:
        # This is a list of `imitation.data.types.Trajectory`, where
        # every instance contains observations and actions for a single expert
        # demonstration.
        trajectories = pickle.load(f)
    
    # Convert List[types.Trajectory] to an instance of `imitation.data.types.Transitions`.
    # This is a more general dataclass containing unordered
    # (observation, actions, next_observation) transitions.
    transitions = rollout.flatten_trajectories(trajectories)
    
    venv = util.make_vec_env("seals/CartPole-v0", n_envs=2)
    
    tempdir = tempfile.TemporaryDirectory(prefix="quickstart")
    tempdir_path = pathlib.Path(tempdir.name)
    print(f"All Tensorboards and logging are being written inside {tempdir_path}/.")
    
    # Train BC on expert data.
    # BC also accepts as `demonstrations` any PyTorch-style DataLoader that iterates over
    # dictionaries containing observations and actions.
    bc_logger = logger.configure(tempdir_path / "BC/")
    bc_trainer = bc.BC(
        observation_space=venv.observation_space,
        action_space=venv.action_space,
        demonstrations=transitions,
        custom_logger=bc_logger,
    )
    bc_trainer.train(n_epochs=1)
    
    # Train GAIL on expert data.
    # GAIL, and AIRL also accept as `demonstrations` any Pytorch-style DataLoader that
    # iterates over dictionaries containing observations, actions, and next_observations.
    gail_logger = logger.configure(tempdir_path / "GAIL/")
    gail_reward_net = reward_nets.BasicRewardNet(
        observation_space=venv.observation_space,
        action_space=venv.action_space,
    )
    gail_trainer = gail.GAIL(
        venv=venv,
        demonstrations=transitions,
        demo_batch_size=32,
        gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024),
        reward_net=gail_reward_net,
        custom_logger=gail_logger,
    )
    gail_trainer.train(total_timesteps=2048)
    
    # Train AIRL on expert data.
    airl_logger = logger.configure(tempdir_path / "AIRL/")
    airl_reward_net = reward_nets.BasicShapedRewardNet(
        observation_space=venv.observation_space,
        action_space=venv.action_space,
    )
    airl_trainer = airl.AIRL(
        venv=venv,
        demonstrations=transitions,
        demo_batch_size=32,
        gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024),
        reward_net=airl_reward_net,
        custom_logger=airl_logger,
    )
    airl_trainer.train(total_timesteps=2048)
    

    以onnx的格式导出模型

    onnx的模型可用于ml-agents的使用。

    以下示例适用于连续动作,如果是离散动作,需要进行处理。

    例子1

    from stable_baselines3 import PPO
    import torch
    
    class OnnxablePolicy(torch.nn.Module):
      def __init__(self, extractor, action_net, value_net):
          super(OnnxablePolicy, self).__init__()
          self.extractor = extractor
          self.action_net = action_net
          self.value_net = value_net
    
      def forward(self, observation):
          # NOTE: You may have to process (normalize) observation in the correct
          #       way before using this. See `common.preprocessing.preprocess_obs`
          action_hidden, value_hidden = self.extractor(observation)
          return self.action_net(action_hidden), self.value_net(value_hidden)
    
    # Example: model = PPO("MlpPolicy", "Pendulum-v1")
    model = PPO.load("PathToTrainedModel.zip")
    model.policy.to("cpu")
    onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net)
    
    dummy_input = torch.randn(1, observation_size)
    torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9)
    
    ##### Load and test with onnx
    
    import onnx
    import onnxruntime as ort
    import numpy as np
    
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    
    observation = np.zeros((1, observation_size)).astype(np.float32)
    ort_sess = ort.InferenceSession(onnx_path)
    action, value = ort_sess.run(None, {'input.1': observation})
    

    例子2

    from stable_baselines3 import SAC
    import torch
    
    class OnnxablePolicy(torch.nn.Module):
      def __init__(self,  actor):
          super(OnnxablePolicy, self).__init__()
    
          # Removing the flatten layer because it can't be onnxed
          self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)
    
      def forward(self, observation):
          # NOTE: You may have to process (normalize) observation in the correct
          #       way before using this. See `common.preprocessing.preprocess_obs`
          return self.actor(observation)
    
    model = SAC.load("PathToTrainedModel.zip")
    onnxable_model = OnnxablePolicy(model.policy.actor)
    
    dummy_input = torch.randn(1, observation_size)
    onnxable_model.policy.to("cpu")
    torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9)
    

    后记

    1. 类似的强化学习算法库有很多,例如 天授,rllib,ElegantRL,PARL,几种框架各有千秋,可自己选用最顺手的进行训练。
    2. 各项函数的具体作用和全部参数,请查看文档,这里只描述了其普遍应用。
    3. 对于新出的论文,sb3也没有相应的实现,需要我们在阅读源码的基础上,继承其中的一些base类,按照sb3的框架来尝试实现。

    来源:微笑小星

    物联沃分享整理
    物联沃-IOTWORD物联网 » 强化学习之stable_baseline3详细说明和各项功能的使用

    发表评论