Source code for grl.rl_modules.value_network.q_network
from typing import Tuple, Union
import torch
import torch.nn as nn
from easydict import EasyDict
from tensordict import TensorDict
from grl.neural_network import get_module
from grl.neural_network.encoders import get_encoder
[docs]class QNetwork(nn.Module):
"""
Overview:
Q network, which is used to approximate the Q value.
Interfaces:
``__init__``, ``forward``
"""
[docs] def __init__(self, config: EasyDict):
"""
Overview:
Initialization of Q network.
Arguments:
config (:obj:`EasyDict`): The configuration dict.
"""
super().__init__()
self.config = config
self.model = torch.nn.ModuleDict()
if hasattr(config, "action_encoder"):
self.model["action_encoder"] = get_encoder(config.action_encoder.type)(
**config.action_encoder.args
)
else:
self.model["action_encoder"] = torch.nn.Identity()
if hasattr(config, "state_encoder"):
self.model["state_encoder"] = get_encoder(config.state_encoder.type)(
**config.state_encoder.args
)
else:
self.model["state_encoder"] = torch.nn.Identity()
# TODO
# specific backbone network
self.model["backbone"] = get_module(config.backbone.type)(
**config.backbone.args
)
[docs] def forward(
self,
action: Union[torch.Tensor, TensorDict],
state: Union[torch.Tensor, TensorDict],
) -> torch.Tensor:
"""
Overview:
Return output of Q networks.
Arguments:
action (:obj:`Union[torch.Tensor, TensorDict]`): The input action.
state (:obj:`Union[torch.Tensor, TensorDict]`): The input state.
Returns:
q (:obj:`Union[torch.Tensor, TensorDict]`): The output of Q network.
"""
action_embedding = self.model["action_encoder"](action)
state_embedding = self.model["state_encoder"](state)
return self.model["backbone"](action_embedding, state_embedding)
[docs]class DoubleQNetwork(nn.Module):
"""
Overview:
Double Q network, which has two Q networks.
Interfaces:
``__init__``, ``forward``, ``compute_double_q``, ``compute_mininum_q``
"""
[docs] def __init__(self, config: EasyDict):
super().__init__()
self.model = torch.nn.ModuleDict()
self.model["q1"] = QNetwork(config)
self.model["q2"] = QNetwork(config)
[docs] def compute_double_q(
self,
action: Union[torch.Tensor, TensorDict],
state: Union[torch.Tensor, TensorDict],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Return the output of two Q networks.
Arguments:
action (:obj:`Union[torch.Tensor, TensorDict]`): The input action.
state (:obj:`Union[torch.Tensor, TensorDict]`): The input state.
Returns:
q1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first Q network.
q2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second Q network.
"""
return self.model["q1"](action, state), self.model["q2"](action, state)
[docs] def compute_mininum_q(
self,
action: Union[torch.Tensor, TensorDict],
state: Union[torch.Tensor, TensorDict],
) -> torch.Tensor:
"""
Overview:
Return the minimum output of two Q networks.
Arguments:
action (:obj:`Union[torch.Tensor, TensorDict]`): The input action.
state (:obj:`Union[torch.Tensor, TensorDict]`): The input state.
Returns:
minimum_q (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of Q network.
"""
return torch.min(*self.compute_double_q(action, state))
[docs] def forward(
self,
action: Union[torch.Tensor, TensorDict],
state: Union[torch.Tensor, TensorDict],
) -> torch.Tensor:
"""
Overview:
Return the minimum output of two Q networks.
Arguments:
action (:obj:`Union[torch.Tensor, TensorDict]`): The input action.
state (:obj:`Union[torch.Tensor, TensorDict]`): The input state.
Returns:
minimum_q (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of Q network.
"""
return self.compute_mininum_q(action, state)