|
|
""" |
|
|
PPO implementation based on CleanRL's PPO implementation. |
|
|
https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_lstm.py |
|
|
""" |
|
|
import os |
|
|
import time |
|
|
import multiprocessing |
|
|
import gymnasium as gym |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.distributions.categorical import Categorical |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from gymnasium.wrappers import NormalizeReward |
|
|
import minigrid |
|
|
from minigrid.wrappers import ImgObsWrapper |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
|
|
|
from models.ctm_rl import ContinuousThoughtMachineRL |
|
|
from models.lstm_rl import LSTMBaseline |
|
|
from utils.housekeeping import set_seed |
|
|
from tasks.rl.envs import MaskVelocityWrapper |
|
|
from tasks.rl.utils import combine_tracking_data |
|
|
from tasks.rl.plotting import make_rl_gif |
|
|
from tasks.image_classification.plotting import plot_neural_dynamics |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Train CTM with RL.") |
|
|
|
|
|
|
|
|
parser.add_argument('--model_type', type=str, default="ctm", choices=['ctm', 'lstm'], help='Sequence length for parity task.') |
|
|
parser.add_argument('--d_model', type=int, default=128, help='Dimension of the model.') |
|
|
parser.add_argument('--d_input', type=int, default=64, help='Dimension of the input projection.') |
|
|
parser.add_argument('--synapse_depth', type=int, default=1, help='Depth of U-NET model for synapse. 1=linear.') |
|
|
parser.add_argument('--n_synch_out', type=int, default=16, help='Number of neurons for output sync.') |
|
|
parser.add_argument('--neuron_select_type', type=str, default='random', choices=['first-last', 'random', 'random-pairing'], help='Protocol for selecting neuron subset.') |
|
|
parser.add_argument('--iterations', type=int, default=1, help='Number of internal ticks.') |
|
|
parser.add_argument('--memory_length', type=int, default=5, help='Length of pre-activation history for NLMs.') |
|
|
parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep NLMs.') |
|
|
parser.add_argument('--memory_hidden_dims', type=int, default=2, help='Hidden dimensions for deep NLMs.') |
|
|
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.') |
|
|
parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs.') |
|
|
parser.add_argument('--continuous_state_trace', action=argparse.BooleanOptionalAction, default=True, help='Flag to carry over state trace between environment steps.') |
|
|
|
|
|
|
|
|
parser.add_argument('--env_id', type=str, default="Acrobot-v1", help='Environment ID.') |
|
|
parser.add_argument('--mask_velocity', action=argparse.BooleanOptionalAction, default=True, help='Mask the velocity components of the observation.') |
|
|
parser.add_argument('--max_environment_steps', type=int, default=500, help='The maximum number of environment steps.') |
|
|
|
|
|
|
|
|
parser.add_argument('--num_steps', type=int, default=100, help='The number of environment steps to run in each environment per policy rollout.') |
|
|
parser.add_argument('--total_timesteps', type=int, default=1_000_000, help='The combined total of all environment steps (across all batches).') |
|
|
parser.add_argument('--num_envs', type=int, default=8, help='The number of parallel game environments.') |
|
|
parser.add_argument('--anneal_lr', action=argparse.BooleanOptionalAction, default=True, help='Use learning rate annealing.') |
|
|
parser.add_argument('--discount_gamma', type=float, default=0.99, help='The discount factor gamma.') |
|
|
parser.add_argument('--gae_lambda', type=float, default=0.95, help='The lambda for the Generalized Advantage Estimation (GAE).') |
|
|
parser.add_argument('--num_minibatches', type=int, default=4, help='The number of mini-batches.') |
|
|
parser.add_argument('--update_epochs', type=int, default=1, help='The number of epochs to update the policy.') |
|
|
parser.add_argument('--norm_adv', action=argparse.BooleanOptionalAction, default=True, help='Toggle advantages normalization.') |
|
|
parser.add_argument('--clip_coef', type=float, default=0.1, help='The surrogate clipping coefficient.') |
|
|
parser.add_argument('--clip_vloss', action=argparse.BooleanOptionalAction, default=False, help='Use clipped loss for the value function (as per the PPO paper).') |
|
|
parser.add_argument('--ent_coef', type=float, default=0.1, help='Entropy coefficient.') |
|
|
parser.add_argument('--vf_coef', type=float, default=0.25, help='Value function coefficient.') |
|
|
parser.add_argument('--max_grad_norm', type=float, default=0.5, help='The maximum norm for gradient clipping.') |
|
|
parser.add_argument('--target_kl', type=float, default=None, help='Target KL divergence threshold.') |
|
|
parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate.') |
|
|
parser.add_argument('--num_validation_envs', type=int, default=1, help='Number of environments to evaluate with during training.') |
|
|
|
|
|
|
|
|
parser.add_argument('--log_dir', type=str, default='logs/rl/acrobot', help='Directory for logging.') |
|
|
parser.add_argument('--tb_log_dir', type=str, default='logs/runs', help='Directory for tensorboard logging.') |
|
|
parser.add_argument('--run_name', type=str, default='default_run', help='Name of the run for logging and tracking.') |
|
|
parser.add_argument('--save_every', type=int, default=100, help='Save checkpoint frequency.') |
|
|
parser.add_argument('--seed', type=int, default=0, help='Random seed.') |
|
|
parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=True, help='Reload checkpoint from log_dir?') |
|
|
parser.add_argument('--track_every', type=int, default=1000, help='Track metrics frequency.') |
|
|
parser.add_argument('--device', type=int, nargs='+', default=[-1], help='GPU(s) or -1 for CPU.') |
|
|
|
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
def make_env_classic_control(env_id, max_environment_steps, mask_velocity=True, render_mode=None): |
|
|
def thunk(): |
|
|
env = gym.make(env_id, render_mode=render_mode) |
|
|
if mask_velocity: env = MaskVelocityWrapper(env) |
|
|
env = NormalizeReward(env, gamma=0.99, epsilon=1e-8) |
|
|
env = gym.wrappers.TimeLimit(env, max_episode_steps=max_environment_steps) |
|
|
env = gym.wrappers.RecordEpisodeStatistics(env) |
|
|
return env |
|
|
return thunk |
|
|
|
|
|
def make_env_minigrid(env_id, max_environment_steps): |
|
|
def thunk(): |
|
|
env = gym.make(env_id, max_steps=max_environment_steps, render_mode="rgb_array") |
|
|
env = ImgObsWrapper(env) |
|
|
env = NormalizeReward(env, gamma=0.99, epsilon=1e-8) |
|
|
env = gym.wrappers.TimeLimit(env, max_episode_steps=max_environment_steps) |
|
|
env = gym.wrappers.RecordEpisodeStatistics(env) |
|
|
return env |
|
|
return thunk |
|
|
|
|
|
def layer_init(layer, std=np.sqrt(2), bias_const=0.0): |
|
|
torch.nn.init.orthogonal_(layer.weight, std) |
|
|
torch.nn.init.constant_(layer.bias, bias_const) |
|
|
return layer |
|
|
|
|
|
class Agent(nn.Module): |
|
|
def __init__(self, size_action_space, args, device): |
|
|
super().__init__() |
|
|
|
|
|
self.continious_state_trace = args.continuous_state_trace |
|
|
self.device = device |
|
|
self.model_type = args.model_type |
|
|
|
|
|
if "MiniGrid" in args.env_id: |
|
|
backbone_type='navigation-backbone' |
|
|
else: |
|
|
backbone_type='classic-control-backbone' |
|
|
|
|
|
if args.model_type == "ctm": |
|
|
self.recurrent_model = ContinuousThoughtMachineRL( |
|
|
iterations=args.iterations, |
|
|
d_model=args.d_model, |
|
|
d_input=args.d_input, |
|
|
n_synch_out=args.n_synch_out, |
|
|
synapse_depth=args.synapse_depth, |
|
|
memory_length=args.memory_length, |
|
|
deep_nlms=args.deep_memory, |
|
|
memory_hidden_dims=args.memory_hidden_dims, |
|
|
do_layernorm_nlm=args.do_normalisation, |
|
|
backbone_type=backbone_type, |
|
|
prediction_reshaper=[-1], |
|
|
dropout=args.dropout, |
|
|
neuron_select_type=args.neuron_select_type, |
|
|
) |
|
|
actor_input_dim = critic_input_dim = self.recurrent_model.synch_representation_size_out |
|
|
else: |
|
|
self.recurrent_model = LSTMBaseline( |
|
|
iterations=args.iterations, |
|
|
d_model=args.d_model, |
|
|
d_input=args.d_input, |
|
|
backbone_type=backbone_type, |
|
|
) |
|
|
actor_input_dim = critic_input_dim = args.d_model |
|
|
|
|
|
self.actor = nn.Sequential( |
|
|
layer_init(nn.Linear(actor_input_dim, 64), std=1), |
|
|
nn.ReLU(), |
|
|
layer_init(nn.Linear(64, 64), std=1), |
|
|
nn.ReLU(), |
|
|
layer_init(nn.Linear(64, size_action_space), std=1) |
|
|
) |
|
|
|
|
|
self.critic = nn.Sequential( |
|
|
layer_init(nn.Linear(critic_input_dim, 64), std=1), |
|
|
nn.ReLU(), |
|
|
layer_init(nn.Linear(64, 64), std=1), |
|
|
nn.ReLU(), |
|
|
layer_init(nn.Linear(64, 1), std=1) |
|
|
) |
|
|
|
|
|
def get_initial_state(self, num_envs): |
|
|
if self.model_type == "ctm": |
|
|
return self.get_initial_ctm_state(num_envs) |
|
|
elif self.model_type == "lstm": |
|
|
return self.get_initial_lstm_state(num_envs) |
|
|
else: |
|
|
raise ValueError("Model type not supported.") |
|
|
|
|
|
def get_initial_ctm_state(self, num_envs): |
|
|
initial_state_trace = torch.repeat_interleave(self.recurrent_model.start_trace.unsqueeze(0), num_envs, 0) |
|
|
initial_activated_state_trace = torch.repeat_interleave(self.recurrent_model.start_activated_trace.unsqueeze(0), num_envs, 0) |
|
|
return initial_state_trace, initial_activated_state_trace |
|
|
|
|
|
def get_initial_lstm_state(self, num_envs): |
|
|
initial_hidden_state = torch.repeat_interleave(self.recurrent_model.start_hidden_state.unsqueeze(0), num_envs, 0) |
|
|
initial_cell_state = torch.repeat_interleave(self.recurrent_model.start_cell_state.unsqueeze(0), num_envs, 0) |
|
|
return initial_hidden_state, initial_cell_state |
|
|
|
|
|
def _get_hidden_states(self, state, done, num_envs): |
|
|
if self.model_type == "ctm": |
|
|
return self._get_ctm_hidden_states(state, done, num_envs) |
|
|
elif self.model_type == "lstm": |
|
|
return self._get_lstm_hidden_states(state, done, num_envs) |
|
|
else: |
|
|
raise ValueError("Model type not supported.") |
|
|
|
|
|
|
|
|
def _get_lstm_hidden_states(self, lstm_state, done, num_envs): |
|
|
initial_hidden_state, initial_cell_state = self.get_initial_lstm_state(num_envs) |
|
|
|
|
|
masked_previous_hidden_state = (1.0 - done).view(-1, 1) * lstm_state[0] |
|
|
masked_previous_cell_state_state = (1.0 - done).view(-1, 1) * lstm_state[1] |
|
|
masked_initial_hidden_state = done.view(-1, 1) * initial_hidden_state |
|
|
masked_initial_cell_state = done.view(-1, 1) * initial_cell_state |
|
|
return (masked_previous_hidden_state + masked_initial_hidden_state), (masked_previous_cell_state_state + masked_initial_cell_state) |
|
|
|
|
|
|
|
|
def _get_ctm_hidden_states(self, ctm_state, done, num_envs): |
|
|
initial_state_trace, initial_activated_state_trace = self.get_initial_ctm_state(num_envs) |
|
|
if self.continious_state_trace: |
|
|
masked_previous_state_trace = (1.0 - done).view(-1, 1, 1) * ctm_state[0] |
|
|
masked_previous_activated_state_trace = (1.0 - done).view(-1, 1, 1) * ctm_state[1] |
|
|
masked_initial_state_trace = done.view(-1, 1, 1) * initial_state_trace |
|
|
masked_initial_activated_state_trace = done.view(-1, 1, 1) * initial_activated_state_trace |
|
|
return (masked_previous_state_trace + masked_initial_state_trace), (masked_previous_activated_state_trace + masked_initial_activated_state_trace) |
|
|
else: |
|
|
return (initial_state_trace, initial_activated_state_trace) |
|
|
|
|
|
def get_states(self, x, ctm_state, done, track=False): |
|
|
num_envs = ctm_state[0].shape[0] |
|
|
|
|
|
if len(x.shape) == 4: |
|
|
_, C, H, W = x.shape |
|
|
xs = x.reshape((-1, num_envs, C, H, W)) |
|
|
elif len(x.shape) == 2: |
|
|
_, C = x.shape |
|
|
xs = x.reshape((-1, num_envs, C)) |
|
|
else: |
|
|
raise ValueError("Input shape not supported.") |
|
|
|
|
|
done = done.reshape((-1, num_envs)) |
|
|
new_hidden = [] |
|
|
for x, d in zip(xs, done): |
|
|
if not track: |
|
|
synchronisation, ctm_state = self.recurrent_model(x, self._get_hidden_states(ctm_state, d, num_envs)) |
|
|
tracking_data = None |
|
|
new_hidden += [synchronisation] |
|
|
else: |
|
|
synchronisation, ctm_state, pre_activations, post_activations = self.recurrent_model(x, self._get_hidden_states(ctm_state, d, num_envs), track=True) |
|
|
tracking_data = { |
|
|
'pre_activations': pre_activations, |
|
|
'post_activations': post_activations, |
|
|
'synchronisation': synchronisation.detach().cpu().numpy(), |
|
|
} |
|
|
new_hidden += [synchronisation] |
|
|
|
|
|
return torch.cat(new_hidden), ctm_state, tracking_data |
|
|
|
|
|
def get_value(self, x, ctm_state, done): |
|
|
hidden, _, _ = self.get_states(x, ctm_state, done) |
|
|
return self.critic(hidden) |
|
|
|
|
|
def get_action_and_value(self, x, ctm_state, done, action=None, track=False): |
|
|
hidden, ctm_state, tracking_data = self.get_states(x, ctm_state, done, track=track) |
|
|
action_logits = self.actor(hidden) |
|
|
action_probs = Categorical(logits=action_logits) |
|
|
|
|
|
if action is None: |
|
|
action = action_probs.sample() |
|
|
|
|
|
value = self.critic(hidden) |
|
|
|
|
|
return action, action_probs.log_prob(action), action_probs.entropy(), value, ctm_state, tracking_data, action_logits, action_probs.probs |
|
|
|
|
|
def save_model(agent, optimizer, global_step, training_iteration, episode_rewards_tracking, episode_lengths_tracking, global_steps_tracking, args, save_path): |
|
|
torch.save({ |
|
|
'model_state_dict': agent.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'global_step': global_step, |
|
|
'training_iteration': training_iteration, |
|
|
'episode_rewards_tracking': episode_rewards_tracking, |
|
|
'episode_lengths_tracking': episode_lengths_tracking, |
|
|
'global_steps_tracking': global_steps_tracking, |
|
|
'args': args |
|
|
}, save_path) |
|
|
|
|
|
def load_model(agent, optimizer, checkpoint_path, device): |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
|
|
agent.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
if optimizer is not None and 'optimizer_state_dict' in checkpoint: |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
|
|
global_step = checkpoint.get('global_step', 0) |
|
|
training_iteration = checkpoint.get('training_iteration', 0) |
|
|
episode_rewards_tracking = checkpoint.get('episode_rewards_tracking', []) |
|
|
episode_lengths_tracking = checkpoint.get('episode_lengths_tracking', []) |
|
|
global_steps_tracking = checkpoint.get('global_steps_tracking', []) |
|
|
model_args = checkpoint.get('args', None) |
|
|
|
|
|
print(f"Loaded checkpoint from {checkpoint_path} at iteration {training_iteration} and step {global_step}") |
|
|
return global_step, training_iteration, episode_rewards_tracking, episode_lengths_tracking, global_steps_tracking, model_args |
|
|
|
|
|
|
|
|
def plot_activations(agent, device, args): |
|
|
agent.eval() |
|
|
with torch.no_grad(): |
|
|
for idx in range(args.num_validation_envs): |
|
|
if args.env_id in ("CartPole-v1", "Acrobot-v1"): |
|
|
eval_env = make_env_classic_control(args.env_id, args.max_environment_steps, mask_velocity=args.mask_velocity, render_mode="rgb_array")() |
|
|
elif "MiniGrid" in args.env_id: |
|
|
eval_env = make_env_minigrid(args.env_id, args.max_environment_steps)() |
|
|
else: |
|
|
raise NotImplementedError(f"Environment {args.env_id} not supported.") |
|
|
|
|
|
eval_next_obs, _ = eval_env.reset() |
|
|
episode_reward = 0 |
|
|
eval_next_done = False |
|
|
eval_state = agent.get_initial_state(1) |
|
|
tracking_data_by_world_step = [] |
|
|
for _ in range(args.max_environment_steps): |
|
|
with torch.no_grad(): |
|
|
action, _, _, value, eval_state, tracking_data, action_logits, action_probs = agent.get_action_and_value( |
|
|
torch.Tensor(eval_next_obs).to(device).unsqueeze(0), |
|
|
eval_state, |
|
|
torch.Tensor([eval_next_done]).to(device), |
|
|
track=True |
|
|
) |
|
|
eval_next_obs, reward, termination, truncation, _ = eval_env.step(action.cpu().numpy()[0]) |
|
|
eval_next_done = termination or truncation |
|
|
|
|
|
tracking_data['actions'] = np.tile(action.detach().cpu().numpy(), (args.iterations)) |
|
|
tracking_data['values'] = np.tile(value.squeeze(-1).detach().cpu().numpy(), (args.iterations)) |
|
|
tracking_data['action_logits'] = np.tile(action_logits.detach().cpu().numpy(), (args.iterations, 1)) |
|
|
tracking_data['action_probs'] = np.tile(action_probs.detach().cpu().numpy(), (args.iterations, 1)) |
|
|
tracking_data['rewards'] = np.tile(np.array(reward), (args.iterations)) |
|
|
tracking_data['inputs'] = np.tile(np.array(eval_env.render()), (args.iterations, 1, 1, 1)) |
|
|
|
|
|
tracking_data_by_world_step.append(tracking_data) |
|
|
|
|
|
episode_reward += reward |
|
|
|
|
|
if eval_next_done: |
|
|
break |
|
|
|
|
|
eval_env.close() |
|
|
|
|
|
combined_tracking_data = combine_tracking_data(tracking_data_by_world_step) |
|
|
|
|
|
n_to_plot = 80 if combined_tracking_data['post_activations'].shape[-1] < 100 else 100 |
|
|
plot_neural_dynamics(combined_tracking_data['post_activations'], n_to_plot, args.log_dir, axis_snap=True) |
|
|
|
|
|
process = multiprocessing.Process( |
|
|
target=make_rl_gif, |
|
|
args=( |
|
|
combined_tracking_data['action_logits'], |
|
|
combined_tracking_data['action_probs'], |
|
|
combined_tracking_data['actions'], |
|
|
combined_tracking_data['values'], |
|
|
combined_tracking_data['rewards'], |
|
|
combined_tracking_data['pre_activations'], |
|
|
combined_tracking_data['post_activations'], |
|
|
combined_tracking_data['inputs'], |
|
|
f"{args.log_dir}/eval_output_val_{idx}.gif" |
|
|
) |
|
|
) |
|
|
process.start() |
|
|
agent.train() |
|
|
|
|
|
def initialize_args(): |
|
|
args = parse_args() |
|
|
args = initialise_dynamic_args(args) |
|
|
return args |
|
|
|
|
|
def initialise_dynamic_args(args): |
|
|
args.batch_size = int(args.num_envs * args.num_steps) |
|
|
args.minibatch_size = int(args.batch_size // args.num_minibatches) |
|
|
args.num_training_iterations = args.total_timesteps // args.batch_size |
|
|
return args |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
args = initialize_args() |
|
|
|
|
|
set_seed(args.seed) |
|
|
os.makedirs(args.log_dir, exist_ok=True) |
|
|
|
|
|
print(f"Tensorboard logging to {args.tb_log_dir}/{args.run_name}") |
|
|
|
|
|
writer = SummaryWriter(f"{args.tb_log_dir}/{args.run_name}") |
|
|
writer.add_text( |
|
|
"hyperparameters", |
|
|
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), |
|
|
) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
if args.env_id in ("CartPole-v1", "Acrobot-v1"): |
|
|
envs = gym.vector.SyncVectorEnv([make_env_classic_control(args.env_id, args.max_environment_steps, args.mask_velocity) for _ in range(args.num_envs)]) |
|
|
elif "MiniGrid" in args.env_id: |
|
|
envs = gym.vector.SyncVectorEnv([make_env_minigrid(args.env_id, args.max_environment_steps) for _ in range(args.num_envs)]) |
|
|
|
|
|
agent = Agent(envs.single_action_space.n, args, device).to(device) |
|
|
plot_activations(agent, device, args) |
|
|
print(f'Total params: {sum(p.numel() for p in agent.parameters())}') |
|
|
optimizer = optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5) |
|
|
|
|
|
checkpoint_path = f"{args.log_dir}/checkpoint.pt" |
|
|
if os.path.exists(checkpoint_path) and args.reload: |
|
|
global_step, training_iteration, episode_rewards_tracking, episode_lengths_tracking, global_steps_tracking, _ = load_model(agent, optimizer, checkpoint_path, device) |
|
|
else: |
|
|
global_step, training_iteration, episode_rewards_tracking, episode_lengths_tracking, global_steps_tracking = 0, 0, [], [], [] |
|
|
|
|
|
|
|
|
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) |
|
|
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) |
|
|
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) |
|
|
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) |
|
|
dones = torch.zeros((args.num_steps, args.num_envs)).to(device) |
|
|
values = torch.zeros((args.num_steps, args.num_envs)).to(device) |
|
|
|
|
|
start_time = time.time() |
|
|
next_obs, _ = envs.reset(seed=args.seed) |
|
|
next_obs = torch.Tensor(next_obs).to(device) |
|
|
next_done = torch.zeros(args.num_envs).to(device) |
|
|
next_state = agent.get_initial_state(args.num_envs) |
|
|
|
|
|
progress_bar = tqdm(range(training_iteration, args.num_training_iterations + 1), total=args.num_training_iterations + 1, initial=training_iteration, desc="Training", dynamic_ncols=True,) |
|
|
|
|
|
for training_iteration in progress_bar: |
|
|
initial_state = (next_state[0].clone(), next_state[1].clone()) |
|
|
|
|
|
if args.anneal_lr: |
|
|
frac = 1.0 - (training_iteration - 1.0) / args.num_training_iterations |
|
|
lrnow = frac * args.lr |
|
|
optimizer.param_groups[0]["lr"] = lrnow |
|
|
|
|
|
for step in range(0, args.num_steps): |
|
|
next_obs = torch.Tensor(next_obs).to(device) |
|
|
global_step += args.num_envs |
|
|
obs[step] = next_obs |
|
|
dones[step] = next_done |
|
|
|
|
|
with torch.no_grad(): |
|
|
action, logprob, _, value, next_state, _, _, _ = agent.get_action_and_value(next_obs, next_state, next_done) |
|
|
values[step] = value.flatten() |
|
|
actions[step] = action |
|
|
logprobs[step] = logprob |
|
|
next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) |
|
|
next_done = np.logical_or(terminations, truncations) |
|
|
rewards[step] = torch.tensor(reward).to(device).view(-1) |
|
|
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(next_done).to(device) |
|
|
|
|
|
if "final_info" in infos: |
|
|
for info in infos["final_info"]: |
|
|
if info and "episode" in info[0]: |
|
|
progress_bar.set_postfix({ |
|
|
"step": global_step, |
|
|
"train. iter": training_iteration, |
|
|
"return": round(info[0]["episode"]["r"], 2), |
|
|
"length": info[0]["episode"]["l"] |
|
|
}) |
|
|
writer.add_scalar("charts/episodic_return", info[0]["episode"]["r"], global_step) |
|
|
writer.add_scalar("charts/episodic_length", info[0]["episode"]["l"], global_step) |
|
|
episode_rewards_tracking.append(info[0]["episode"]["r"]) |
|
|
episode_lengths_tracking.append(info[0]["episode"]["l"]) |
|
|
global_steps_tracking.append(global_step) |
|
|
|
|
|
elif "episode" in infos: |
|
|
if infos["episode"]: |
|
|
episode_rewards = infos["episode"]["r"] |
|
|
episode_lengths = infos["episode"]["l"] |
|
|
completed_episodes = infos["episode"]["_r"] |
|
|
for env_idx in range(len(completed_episodes)): |
|
|
if completed_episodes[env_idx]: |
|
|
progress_bar.set_postfix({ |
|
|
"step": global_step, |
|
|
"train. iter": training_iteration, |
|
|
"return": round(episode_rewards[env_idx], 2), |
|
|
"length": episode_lengths[env_idx] |
|
|
}) |
|
|
writer.add_scalar("charts/episodic_return", episode_rewards[env_idx], global_step) |
|
|
writer.add_scalar("charts/episodic_length", episode_lengths[env_idx], global_step) |
|
|
episode_rewards_tracking.append(episode_rewards[env_idx]) |
|
|
episode_lengths_tracking.append(episode_lengths[env_idx]) |
|
|
global_steps_tracking.append(global_step) |
|
|
|
|
|
with torch.no_grad(): |
|
|
next_value = agent.get_value( |
|
|
next_obs, |
|
|
next_state, |
|
|
next_done, |
|
|
).reshape(1, -1) |
|
|
advantages = torch.zeros_like(rewards).to(device) |
|
|
lastgaelam = 0 |
|
|
for t in reversed(range(args.num_steps)): |
|
|
if t == args.num_steps - 1: |
|
|
nextnonterminal = 1.0 - next_done |
|
|
nextvalues = next_value |
|
|
else: |
|
|
nextnonterminal = 1.0 - dones[t + 1] |
|
|
nextvalues = values[t + 1] |
|
|
delta = rewards[t] + args.discount_gamma * nextvalues * nextnonterminal - values[t] |
|
|
advantages[t] = lastgaelam = delta + args.discount_gamma * args.gae_lambda * nextnonterminal * lastgaelam |
|
|
returns = advantages + values |
|
|
|
|
|
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) |
|
|
b_logprobs = logprobs.reshape(-1) |
|
|
b_actions = actions.reshape((-1,) + envs.single_action_space.shape) |
|
|
b_dones = dones.reshape(-1) |
|
|
b_advantages = advantages.reshape(-1) |
|
|
b_returns = returns.reshape(-1) |
|
|
b_values = values.reshape(-1) |
|
|
|
|
|
assert args.num_envs % args.num_minibatches == 0 |
|
|
envsperbatch = args.num_envs // args.num_minibatches |
|
|
envinds = np.arange(args.num_envs) |
|
|
flatinds = np.arange(args.batch_size).reshape(args.num_steps, args.num_envs) |
|
|
clipfracs = [] |
|
|
for epoch in range(args.update_epochs): |
|
|
for start in range(0, args.num_envs, envsperbatch): |
|
|
end = start + envsperbatch |
|
|
mbenvinds = envinds[start:end] |
|
|
mb_inds = flatinds[:, mbenvinds].ravel() |
|
|
|
|
|
if args.model_type == "ctm": |
|
|
selected_hidden_state = (initial_state[0][mbenvinds,:,:], initial_state[1][mbenvinds,:,:]) |
|
|
elif args.model_type == "lstm": |
|
|
selected_hidden_state = (initial_state[0][mbenvinds,:], initial_state[1][mbenvinds,:]) |
|
|
|
|
|
|
|
|
_, newlogprob, entropy, newvalue, _, _, _, _ = agent.get_action_and_value( |
|
|
b_obs[mb_inds], |
|
|
selected_hidden_state, |
|
|
b_dones[mb_inds], |
|
|
b_actions.long()[mb_inds], |
|
|
) |
|
|
|
|
|
logratio = newlogprob - b_logprobs[mb_inds] |
|
|
ratio = logratio.exp() |
|
|
|
|
|
with torch.no_grad(): |
|
|
old_approx_kl = (-logratio).mean() |
|
|
approx_kl = ((ratio - 1) - logratio).mean() |
|
|
clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()] |
|
|
|
|
|
mb_advantages = b_advantages[mb_inds] |
|
|
if args.norm_adv: |
|
|
mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) |
|
|
|
|
|
|
|
|
pg_loss1 = -mb_advantages * ratio |
|
|
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef) |
|
|
pg_loss = torch.max(pg_loss1, pg_loss2).mean() |
|
|
|
|
|
|
|
|
newvalue = newvalue.view(-1) |
|
|
if args.clip_vloss: |
|
|
v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2 |
|
|
v_clipped = b_values[mb_inds] + torch.clamp( |
|
|
newvalue - b_values[mb_inds], |
|
|
-args.clip_coef, |
|
|
args.clip_coef, |
|
|
) |
|
|
v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2 |
|
|
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) |
|
|
v_loss = 0.5 * v_loss_max.mean() |
|
|
else: |
|
|
v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean() |
|
|
|
|
|
|
|
|
entropy_loss = entropy.mean() |
|
|
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
total_norm = 0 |
|
|
for name, param in agent.named_parameters(): |
|
|
if param.grad is None: |
|
|
print(f"Warning: Gradient for {name} is None!") |
|
|
else: |
|
|
param_norm = param.grad.data.norm(2).item() |
|
|
total_norm += param_norm ** 2 |
|
|
writer.add_scalar(f"grad_norms/{name}", param_norm, global_step) |
|
|
total_norm = total_norm ** 0.5 |
|
|
writer.add_scalar("grad_norms/total", total_norm, global_step) |
|
|
|
|
|
|
|
|
nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm) |
|
|
optimizer.step() |
|
|
|
|
|
if args.target_kl is not None and approx_kl > args.target_kl: |
|
|
break |
|
|
|
|
|
if training_iteration % args.track_every == 0 or training_iteration == 1: |
|
|
plot_activations(agent, device, args) |
|
|
|
|
|
if training_iteration % args.save_every == 0 or training_iteration == 1 or global_step == args.total_timesteps-1: |
|
|
save_model(agent, optimizer, global_step, training_iteration, episode_rewards_tracking, episode_lengths_tracking, global_steps_tracking, args, f"{args.log_dir}/checkpoint.pt") |
|
|
|
|
|
y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() |
|
|
var_y = np.var(y_true) |
|
|
explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y |
|
|
|
|
|
writer.add_scalar("charts/lr", optimizer.param_groups[0]["lr"], global_step) |
|
|
writer.add_scalar("losses/value_loss", v_loss.item(), global_step) |
|
|
writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) |
|
|
writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) |
|
|
writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step) |
|
|
writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) |
|
|
writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step) |
|
|
writer.add_scalar("losses/explained_variance", explained_var, global_step) |
|
|
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) |
|
|
|
|
|
envs.close() |
|
|
writer.close() |