Critic Modules¶
Critic modules follow the Critic part of the Actor-Critic architecture. In NF's case, we follow a SAC base with NCP Networks where both variants estimate Q-values using two target networks.
The layout of the modules are identical but their underlying functionality differs to handle their respective use cases.
The only differences are the required parameters for the predict and target_predict methods.
Critic modules are a wrapper over the top of PyTorch functionality and are made up of the following components:
| Attribute | Description | PyTorch Item |
|---|---|---|
network1 |
The first Critic network. | torch.nn.Module |
network2 |
The second Critic network. | torch.nn.Module |
target1 |
The first Critic's target network. | torch.nn.Module |
target2 |
The second Critic's target network. | torch.nn.Module |
optim1 |
The first Critic network's optimizer. | torch.optim.Optimizer |
optim1 |
The second Critic network's optimizer. | torch.optim.Optimizer |
Discrete¶
For discrete action spaces, we use the CriticModuleDiscrete class.
This accepts the following parameters:
| Parameter | Description | Default |
|---|---|---|
state_dim |
The dimension of the state space. | - |
n_neurons |
The number of decision/hidden neurons. | - |
action_dim |
The dimension of the action space. | - |
optim |
The PyTorch optimizer. | torch.optim.Adam |
lr |
The optimizer learning rate. | 0.0003 |
tau |
The soft target network update factor. | 0.0005 |
device |
The device to perform computations on. E.g., cpu or cuda:0. |
None |
Continuous¶
For continuous action spaces, we use the CriticModule class.
The parameters are the same as the CriticModuleDiscrete class.
Target Updates¶
To update the target networks we use the update_targets method:
| Python | |
|---|---|
1 | |
Updating Gradients¶
To update the network gradients, we use the gradient_step method:
| Python | |
|---|---|
1 | |
Prediction¶
To make a prediction with the Critic networks, we use the predict method:
| Python | |
|---|---|
1 2 3 4 5 | |
Target Prediction¶
To make a prediction with the target networks, we use the target_predict method:
| Python | |
|---|---|
1 2 3 4 5 | |
This gives us the smallest next Q-Value prediction between the two target networks (torch.min(q_values1, q_values2)).
Next, we'll look at the entropy modules! 🚀