Source code for grl.rl_modules.value_network.value_network
from typing import Tuple, Union
import torch
import torch.nn as nn
from easydict import EasyDict
from tensordict import TensorDict
from grl.neural_network import get_module
from grl.neural_network.encoders import get_encoder
[docs]class VNetwork(nn.Module):
"""
Overview:
Value network, which is used to approximate the value function.
Interfaces:
``__init__``, ``forward``
"""
[docs] def __init__(self, config: EasyDict):
"""
Overview:
Initialization of value network.
Arguments:
config (:obj:`EasyDict`): The configuration dict.
"""
super().__init__()
self.config = config
self.model = torch.nn.ModuleDict()
if hasattr(config, "state_encoder"):
self.model["state_encoder"] = get_encoder(config.state_encoder.type)(
**config.state_encoder.args
)
else:
self.model["state_encoder"] = torch.nn.Identity()
if hasattr(config, "condition_encoder"):
self.model["condition_encoder"] = get_encoder(
config.condition_encoder.type
)(**config.condition_encoder.args)
else:
self.model["condition_encoder"] = torch.nn.Identity()
# TODO
# specific backbone network
self.model["backbone"] = get_module(config.backbone.type)(
**config.backbone.args
)
[docs] def forward(
self,
state: Union[torch.Tensor, TensorDict],
condition: Union[torch.Tensor, TensorDict] = None,
) -> torch.Tensor:
"""
Overview:
Return output of value networks.
Arguments:
state (:obj:`Union[torch.Tensor, TensorDict]`): The input state.
condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition.
Returns:
value (:obj:`Union[torch.Tensor, TensorDict]`): The output of value network.
"""
state_embedding = self.model["state_encoder"](state)
if condition is not None:
condition_encoder_embedding = self.model["condition_encoder"](condition)
return self.model["backbone"](state_embedding, condition_encoder_embedding)
else:
return self.model["backbone"](state_embedding)
[docs]class DoubleVNetwork(nn.Module):
"""
Overview:
Double value network, which has two value networks.
Interfaces:
``__init__``, ``forward``, ``compute_double_v``, ``compute_mininum_v``
"""
[docs] def __init__(self, config: EasyDict):
super().__init__()
self.model = torch.nn.ModuleDict()
self.model["v1"] = VNetwork(config)
self.model["v2"] = VNetwork(config)
[docs] def compute_double_v(
self,
state: Union[torch.Tensor, TensorDict],
condition: Union[torch.Tensor, TensorDict],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Overview:
Return the output of two value networks.
Arguments:
state (:obj:`Union[torch.Tensor, TensorDict]`): The input state.
condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition.
Returns:
v1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first value network.
v2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second value network.
"""
return self.model["v1"](state, condition), self.model["v2"](state, condition)
[docs] def compute_mininum_v(
self,
state: Union[torch.Tensor, TensorDict],
condition: Union[torch.Tensor, TensorDict],
) -> torch.Tensor:
"""
Overview:
Return the minimum output of two value networks.
Arguments:
state (:obj:`Union[torch.Tensor, TensorDict]`): The input state.
condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition.
Returns:
minimum_v (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of value network.
"""
return torch.min(*self.compute_double_v(state, condition=condition))
[docs] def forward(
self,
state: Union[torch.Tensor, TensorDict],
condition: Union[torch.Tensor, TensorDict],
) -> torch.Tensor:
"""
Overview:
Return the minimum output of two value networks.
Arguments:
state (:obj:`Union[torch.Tensor, TensorDict]`): The input state.
condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition.
Returns:
minimum_v (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of value network.
"""
return self.compute_mininum_v(state, condition=condition)