SAC¶
Soft Actor Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.
SAC is the successor of Soft Q-Learning SQL and incorporates the double Q-learning trick from TD3. A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.
Available Policies
alias of |
|
Policy class (with both actor and critic) for SAC. |
Notes¶
Original paper: https://arxiv.org/abs/1801.01290
OpenAI Spinning Guide for SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
Original Implementation: https://github.com/haarnoja/sac
Blog post on using SAC with real robots: https://bair.berkeley.edu/blog/2018/12/14/sac/
Note
In our implementation, we use an entropy coefficient (as in OpenAI Spinning or Facebook Horizon), which is the equivalent to the inverse of reward scale in the original SAC paper. The main reason is that it avoids having too high errors when updating the Q functions.
Note
The default policies for SAC differ a bit from others MlpPolicy: it uses ReLU instead of tanh activation, to match the original paper
Can I use?¶
Recurrent policies: ❌
Multi processing: ❌
Gym spaces:
Space |
Action |
Observation |
---|---|---|
Discrete |
❌ |
✔️ |
Box |
✔️ |
✔️ |
MultiDiscrete |
❌ |
✔️ |
MultiBinary |
❌ |
✔️ |
Example¶
import gym
import numpy as np
from stable_baselines3 import SAC
env = gym.make("Pendulum-v0")
model = SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("sac_pendulum")
del model # remove to demonstrate saving and loading
model = SAC.load("sac_pendulum")
obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()
Results¶
PyBullet Environments¶
Results on the PyBullet benchmark (1M steps) using 3 seeds. The complete learning curves are available in the associated issue #48.
Note
Hyperparameters from the gSDE paper were used (as they are tuned for PyBullet envs).
Gaussian means that the unstructured Gaussian noise is used for exploration, gSDE (generalized State-Dependent Exploration) is used otherwise.
Environments |
SAC |
SAC |
TD3 |
---|---|---|---|
Gaussian |
gSDE |
Gaussian |
|
HalfCheetah |
2757 +/- 53 |
2984 +/- 202 |
2774 +/- 35 |
Ant |
3146 +/- 35 |
3102 +/- 37 |
3305 +/- 43 |
Hopper |
2422 +/- 168 |
2262 +/- 1 |
2429 +/- 126 |
Walker2D |
2184 +/- 54 |
2136 +/- 67 |
2063 +/- 185 |
How to replicate the results?¶
Clone the rl-zoo repo:
git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/
Run the benchmark (replace $ENV_ID
by the envs mentioned above):
python train.py --algo sac --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results:
python scripts/all_plots.py -a sac -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/sac_results
python scripts/plot_from_file.py -i logs/sac_results.pkl -latex -l SAC
Parameters¶
-
class
stable_baselines3.sac.
SAC
(policy, env, learning_rate=0.0003, buffer_size=1000000, learning_starts=100, batch_size=256, tau=0.005, gamma=0.99, train_freq=1, gradient_steps=1, action_noise=None, optimize_memory_usage=False, ent_coef='auto', target_update_interval=1, target_entropy='auto', use_sde=False, sde_sample_freq=- 1, use_sde_at_warmup=False, tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=None, device='auto', _init_setup_model=True)[source]¶ Soft Actor-Critic (SAC) Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor, This implementation borrows code from original implementation (https://github.com/haarnoja/sac) from OpenAI Spinning Up (https://github.com/openai/spinningup), from the softlearning repo (https://github.com/rail-berkeley/softlearning/) and from Stable Baselines (https://github.com/hill-a/stable-baselines) Paper: https://arxiv.org/abs/1801.01290 Introduction to SAC: https://spinningup.openai.com/en/latest/algorithms/sac.html
Note: we use double q target and not value target as discussed in https://github.com/hill-a/stable-baselines/issues/270
- Parameters
policy (
Union
[str
,Type
[SACPolicy
]]) – The policy model to use (MlpPolicy, CnnPolicy, …)env (
Union
[Env
,VecEnv
,str
]) – The environment to learn from (if registered in Gym, can be str)learning_rate (
Union
[float
,Callable
[[float
],float
]]) – learning rate for adam optimizer, the same learning rate will be used for all networks (Q-Values, Actor and Value function) it can be a function of the current progress remaining (from 1 to 0)buffer_size (
int
) – size of the replay bufferlearning_starts (
int
) – how many steps of the model to collect transitions for before learning startsbatch_size (
int
) – Minibatch size for each gradient updatetau (
float
) – the soft update coefficient (“Polyak update”, between 0 and 1)gamma (
float
) – the discount factortrain_freq (
Union
[int
,Tuple
[int
,str
]]) – Update the model everytrain_freq
steps. Alternatively pass a tuple of frequency and unit like(5, "step")
or(2, "episode")
.gradient_steps (
int
) – How many gradient steps to do after each rollout (seetrain_freq
) Set to-1
means to do as many gradient steps as steps done in the environment during the rollout.action_noise (
Optional
[ActionNoise
]) – the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type.optimize_memory_usage (
bool
) – Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195ent_coef (
Union
[str
,float
]) – Entropy regularization coefficient. (Equivalent to inverse of reward scale in the original SAC paper.) Controlling exploration/exploitation trade-off. Set it to ‘auto’ to learn it automatically (and ‘auto_0.1’ for using 0.1 as initial value)target_update_interval (
int
) – update the target network everytarget_network_update_freq
gradient steps.target_entropy (
Union
[str
,float
]) – target entropy when learningent_coef
(ent_coef = 'auto'
)use_sde (
bool
) – Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False)sde_sample_freq (
int
) – Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout)use_sde_at_warmup (
bool
) – Whether to use gSDE instead of uniform sampling during the warm up phase (before learning starts)create_eval_env (
bool
) – Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment)policy_kwargs (
Optional
[Dict
[str
,Any
]]) – additional arguments to be passed to the policy on creationverbose (
int
) – the verbosity level: 0 no output, 1 info, 2 debugseed (
Optional
[int
]) – Seed for the pseudo random generatorsdevice (
Union
[device
,str
]) – Device (cpu, cuda, …) on which the code should be run. Setting it to auto, the code will be run on the GPU if possible._init_setup_model (
bool
) – Whether or not to build the network at the creation of the instance
-
collect_rollouts
(env, callback, train_freq, replay_buffer, action_noise=None, learning_starts=0, log_interval=None)¶ Collect experiences and store them into a
ReplayBuffer
.- Parameters
env (
VecEnv
) – The training environmentcallback (
BaseCallback
) – Callback that will be called at each step (and at the beginning and end of the rollout)train_freq (
TrainFreq
) – How much experience to collect by doing rollouts of current policy. EitherTrainFreq(<n>, TrainFrequencyUnit.STEP)
orTrainFreq(<n>, TrainFrequencyUnit.EPISODE)
with<n>
being an integer greater than 0.action_noise (
Optional
[ActionNoise
]) – Action noise that will be used for exploration Required for deterministic policy (e.g. TD3). This can also be used in addition to the stochastic policy for SAC.learning_starts (
int
) – Number of steps before learning for the warm-up phase.replay_buffer (
ReplayBuffer
) –log_interval (
Optional
[int
]) – Log data everylog_interval
episodes
- Return type
RolloutReturn
- Returns
-
get_env
()¶ Returns the current environment (can be None if not defined).
- Return type
Optional
[VecEnv
]- Returns
The current environment
-
get_parameters
()¶ Return the parameters of the agent. This includes parameters from different networks, e.g. critics (value functions) and policies (pi functions).
- Return type
Dict
[str
,Dict
]- Returns
Mapping of from names of the objects to PyTorch state-dicts.
-
get_vec_normalize_env
()¶ Return the
VecNormalize
wrapper of the training env if it exists.- Return type
Optional
[VecNormalize
]- Returns
The
VecNormalize
env.
-
learn
(total_timesteps, callback=None, log_interval=4, eval_env=None, eval_freq=- 1, n_eval_episodes=5, tb_log_name='SAC', eval_log_path=None, reset_num_timesteps=True)[source]¶ Return a trained model.
- Parameters
total_timesteps (
int
) – The total number of samples (env steps) to train oncallback (
Union
[None
,Callable
,List
[BaseCallback
],BaseCallback
]) – callback(s) called at every step with state of the algorithm.log_interval (
int
) – The number of timesteps before logging.tb_log_name (
str
) – the name of the run for TensorBoard loggingeval_env (
Union
[Env
,VecEnv
,None
]) – Environment that will be used to evaluate the agenteval_freq (
int
) – Evaluate the agent everyeval_freq
timesteps (this may vary a little)n_eval_episodes (
int
) – Number of episode to evaluate the agenteval_log_path (
Optional
[str
]) – Path to a folder where the evaluations will be savedreset_num_timesteps (
bool
) – whether or not to reset the current timestep number (used in logging)
- Return type
- Returns
the trained model
-
classmethod
load
(path, env=None, device='auto', custom_objects=None, **kwargs)¶ Load the model from a zip-file
- Parameters
path (
Union
[str
,Path
,BufferedIOBase
]) – path to the file (or a file-like) where to load the agent fromenv (
Union
[Env
,VecEnv
,None
]) – the new environment to run the loaded model on (can be None if you only need prediction from a trained model) has priority over any saved environmentdevice (
Union
[device
,str
]) – Device on which the code should run.custom_objects (
Optional
[Dict
[str
,Any
]]) – Dictionary of objects to replace upon loading. If a variable is present in this dictionary as a key, it will not be deserialized and the corresponding item will be used instead. Similar to custom_objects inkeras.models.load_model
. Useful when you have an object in file that can not be deserialized.kwargs – extra arguments to change the model when loading
- Return type
-
load_replay_buffer
(path)¶ Load a replay buffer from a pickle file.
- Parameters
path (
Union
[str
,Path
,BufferedIOBase
]) – Path to the pickled replay buffer.- Return type
None
-
predict
(observation, state=None, mask=None, deterministic=False)¶ Get the model’s action(s) from an observation
- Parameters
observation (
ndarray
) – the input observationstate (
Optional
[ndarray
]) – The last states (can be None, used in recurrent policies)mask (
Optional
[ndarray
]) – The last masks (can be None, used in recurrent policies)deterministic (
bool
) – Whether or not to return deterministic actions.
- Return type
Tuple
[ndarray
,Optional
[ndarray
]]- Returns
the model’s action and the next state (used in recurrent policies)
-
save
(path, exclude=None, include=None)¶ Save all the attributes of the object and the model parameters in a zip-file.
- Parameters
path (
Union
[str
,Path
,BufferedIOBase
]) – path to the file where the rl agent should be savedexclude (
Optional
[Iterable
[str
]]) – name of parameters that should be excluded in addition to the default onesinclude (
Optional
[Iterable
[str
]]) – name of parameters that might be excluded but should be included anyway
- Return type
None
-
save_replay_buffer
(path)¶ Save the replay buffer as a pickle file.
- Parameters
path (
Union
[str
,Path
,BufferedIOBase
]) – Path to the file where the replay buffer should be saved. if path is a str or pathlib.Path, the path is automatically created if necessary.- Return type
None
-
set_env
(env)¶ Checks the validity of the environment, and if it is coherent, set it as the current environment. Furthermore wrap any non vectorized env into a vectorized checked parameters: - observation_space - action_space
- Parameters
env (
Union
[Env
,VecEnv
]) – The environment for learning a policy- Return type
None
-
set_parameters
(load_path_or_dict, exact_match=True, device='auto')¶ Load parameters from a given zip-file or a nested dictionary containing parameters for different modules (see
get_parameters
).- Parameters
load_path_or_iter – Location of the saved data (path or file-like, see
save
), or a nested dictionary containing nn.Module parameters used by the policy. The dictionary maps object names to a state-dictionary returned bytorch.nn.Module.state_dict()
.exact_match (
bool
) – If True, the given parameters should include parameters for each module and each of their parameters, otherwise raises an Exception. If set to False, this can be used to update only specific parameters.device (
Union
[device
,str
]) – Device on which the code should run.
- Return type
None
-
set_random_seed
(seed=None)¶ Set the seed of the pseudo-random generators (python, numpy, pytorch, gym, action_space)
- Parameters
seed (
Optional
[int
]) –- Return type
None
SAC Policies¶
-
stable_baselines3.sac.
MlpPolicy
¶ alias of
stable_baselines3.sac.policies.SACPolicy
-
class
stable_baselines3.sac.policies.
SACPolicy
(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, use_sde=False, log_std_init=-3, sde_net_arch=None, use_expln=False, clip_mean=2.0, features_extractor_class=<class 'stable_baselines3.common.torch_layers.FlattenExtractor'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=True)[source] Policy class (with both actor and critic) for SAC.
- Parameters
observation_space (
Space
) – Observation spaceaction_space (
Space
) – Action spacelr_schedule (
Callable
[[float
],float
]) – Learning rate schedule (could be constant)net_arch (
Union
[List
[int
],Dict
[str
,List
[int
]],None
]) – The specification of the policy and value networks.activation_fn (
Type
[Module
]) – Activation functionuse_sde (
bool
) – Whether to use State Dependent Exploration or notlog_std_init (
float
) – Initial value for the log standard deviationsde_net_arch (
Optional
[List
[int
]]) – Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features.use_expln (
bool
) – Useexpln()
function instead ofexp()
when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice,exp()
is usually enough.clip_mean (
float
) – Clip the mean output when using gSDE to avoid numerical instability.features_extractor_class (
Type
[BaseFeaturesExtractor
]) – Features extractor to use.features_extractor_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments to pass to the features extractor.normalize_images (
bool
) – Whether to normalize images or not, dividing by 255.0 (True by default)optimizer_class (
Type
[Optimizer
]) – The optimizer to use,th.optim.Adam
by defaultoptimizer_kwargs (
Optional
[Dict
[str
,Any
]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizern_critics (
int
) – Number of critic networks to create.share_features_extractor (
bool
) – Whether to share or not the features extractor between the actor and the critic (this saves computation time)
-
forward
(obs, deterministic=False)[source] Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type
Tensor
-
reset_noise
(batch_size=1)[source] Sample new weights for the exploration matrix, when using gSDE.
- Parameters
batch_size (
int
) –- Return type
None
-
class
stable_baselines3.sac.
CnnPolicy
(observation_space, action_space, lr_schedule, net_arch=None, activation_fn=<class 'torch.nn.modules.activation.ReLU'>, use_sde=False, log_std_init=-3, sde_net_arch=None, use_expln=False, clip_mean=2.0, features_extractor_class=<class 'stable_baselines3.common.torch_layers.NatureCNN'>, features_extractor_kwargs=None, normalize_images=True, optimizer_class=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, n_critics=2, share_features_extractor=True)[source]¶ Policy class (with both actor and critic) for SAC.
- Parameters
observation_space (
Space
) – Observation spaceaction_space (
Space
) – Action spacelr_schedule (
Callable
[[float
],float
]) – Learning rate schedule (could be constant)net_arch (
Union
[List
[int
],Dict
[str
,List
[int
]],None
]) – The specification of the policy and value networks.activation_fn (
Type
[Module
]) – Activation functionuse_sde (
bool
) – Whether to use State Dependent Exploration or notlog_std_init (
float
) – Initial value for the log standard deviationsde_net_arch (
Optional
[List
[int
]]) – Network architecture for extracting features when using gSDE. If None, the latent features from the policy will be used. Pass an empty list to use the states as features.use_expln (
bool
) – Useexpln()
function instead ofexp()
when using gSDE to ensure a positive standard deviation (cf paper). It allows to keep variance above zero and prevent it from growing too fast. In practice,exp()
is usually enough.clip_mean (
float
) – Clip the mean output when using gSDE to avoid numerical instability.features_extractor_class (
Type
[BaseFeaturesExtractor
]) – Features extractor to use.normalize_images (
bool
) – Whether to normalize images or not, dividing by 255.0 (True by default)optimizer_class (
Type
[Optimizer
]) – The optimizer to use,th.optim.Adam
by defaultoptimizer_kwargs (
Optional
[Dict
[str
,Any
]]) – Additional keyword arguments, excluding the learning rate, to pass to the optimizern_critics (
int
) – Number of critic networks to create.share_features_extractor (
bool
) – Whether to share or not the features extractor between the actor and the critic (this saves computation time)