| from typing import Optional, Tuple, List |
| import torch |
| import torch.nn as nn |
| import treetensor.torch as ttorch |
|
|
|
|
| class PPOFModel(nn.Module): |
| mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] |
|
|
| def __init__( |
| self, |
| obs_shape: Tuple[int], |
| action_shape: int, |
| encoder_hidden_size_list: List = [128, 128, 64], |
| actor_head_hidden_size: int = 64, |
| actor_head_layer_num: int = 1, |
| critic_head_hidden_size: int = 64, |
| critic_head_layer_num: int = 1, |
| activation: Optional[nn.Module] = nn.ReLU(), |
| ) -> None: |
| super(PPOFModel, self).__init__() |
| self.obs_shape, self.action_shape = obs_shape, action_shape |
|
|
| |
| layers = [] |
| input_size = obs_shape[0] |
| kernel_size_list = [8, 4, 3] |
| stride_list = [4, 2, 1] |
| for i in range(len(encoder_hidden_size_list)): |
| output_size = encoder_hidden_size_list[i] |
| layers.append(nn.Conv2d(input_size, output_size, kernel_size_list[i], stride_list[i])) |
| layers.append(activation) |
| input_size = output_size |
| layers.append(nn.Flatten()) |
| self.encoder = nn.Sequential(*layers) |
|
|
| flatten_size = input_size = self.get_flatten_size() |
| |
| layers = [] |
| for i in range(critic_head_layer_num): |
| layers.append(nn.Linear(input_size, critic_head_hidden_size)) |
| layers.append(activation) |
| input_size = critic_head_hidden_size |
| layers.append(nn.Linear(critic_head_hidden_size, 1)) |
| self.critic = nn.Sequential(*layers) |
| |
| layers = [] |
| input_size = flatten_size |
| for i in range(actor_head_layer_num): |
| layers.append(nn.Linear(input_size, actor_head_hidden_size)) |
| layers.append(activation) |
| input_size = actor_head_hidden_size |
| self.actor = nn.Sequential(*layers) |
| self.mu = nn.Linear(actor_head_hidden_size, action_shape) |
| self.log_sigma = nn.Parameter(torch.zeros(1, action_shape)) |
|
|
| |
| self.init_weights() |
|
|
| def init_weights(self) -> None: |
| |
| raise NotImplementedError |
|
|
| def get_flatten_size(self) -> int: |
| test_data = torch.randn(1, *self.obs_shape) |
| with torch.no_grad(): |
| output = self.encoder(test_data) |
| return output.shape[1] |
|
|
| def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor: |
| assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) |
| return getattr(self, mode)(inputs) |
|
|
| def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor: |
| x = self.encoder(x) |
| x = self.actor(x) |
| mu = self.mu(x) |
| log_sigma = self.log_sigma + torch.zeros_like(mu) |
| sigma = torch.exp(log_sigma) |
| return ttorch.as_tensor({'mu': mu, 'sigma': sigma}) |
|
|
| def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: |
| x = self.encoder(x) |
| value = self.critic(x) |
| return value |
|
|
| def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: |
| x = self.encoder(x) |
| value = self.critic(x) |
| x = self.actor(x) |
| mu = self.mu(x) |
| log_sigma = self.log_sigma + torch.zeros_like(mu) |
| sigma = torch.exp(log_sigma) |
| return ttorch.as_tensor({'logit': {'mu': mu, 'sigma': sigma}, 'value': value}) |
|
|
|
|
| def test_ppof_model() -> None: |
| model = PPOFModel((4, 84, 84), 5) |
| print(model) |
| data = torch.randn(3, 4, 84, 84) |
| output = model(data, mode='compute_critic') |
| assert output.shape == (3, 1) |
| output = model(data, mode='compute_actor') |
| assert output.mu.shape == (3, 5) |
| assert output.sigma.shape == (3, 5) |
| output = model(data, mode='compute_actor_critic') |
| assert output.value.shape == (3, 1) |
| assert output.logit.mu.shape == (3, 5) |
| assert output.logit.sigma.shape == (3, 5) |
| print('End...') |
|
|
|
|
| if __name__ == "__main__": |
| test_ppof_model() |
|
|