Shortcuts

Source code for grl.rl_modules.value_network.value_network

from typing import Tuple, Union

import torch
import torch.nn as nn
from easydict import EasyDict
from tensordict import TensorDict

from grl.neural_network import get_module
from grl.neural_network.encoders import get_encoder


[docs]class VNetwork(nn.Module): """ Overview: Value network, which is used to approximate the value function. Interfaces: ``__init__``, ``forward`` """
[docs] def __init__(self, config: EasyDict): """ Overview: Initialization of value network. Arguments: config (:obj:`EasyDict`): The configuration dict. """ super().__init__() self.config = config self.model = torch.nn.ModuleDict() if hasattr(config, "state_encoder"): self.model["state_encoder"] = get_encoder(config.state_encoder.type)( **config.state_encoder.args ) else: self.model["state_encoder"] = torch.nn.Identity() if hasattr(config, "condition_encoder"): self.model["condition_encoder"] = get_encoder( config.condition_encoder.type )(**config.condition_encoder.args) else: self.model["condition_encoder"] = torch.nn.Identity() # TODO # specific backbone network self.model["backbone"] = get_module(config.backbone.type)( **config.backbone.args )
[docs] def forward( self, state: Union[torch.Tensor, TensorDict], condition: Union[torch.Tensor, TensorDict] = None, ) -> torch.Tensor: """ Overview: Return output of value networks. Arguments: state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. Returns: value (:obj:`Union[torch.Tensor, TensorDict]`): The output of value network. """ state_embedding = self.model["state_encoder"](state) if condition is not None: condition_encoder_embedding = self.model["condition_encoder"](condition) return self.model["backbone"](state_embedding, condition_encoder_embedding) else: return self.model["backbone"](state_embedding)
[docs]class DoubleVNetwork(nn.Module): """ Overview: Double value network, which has two value networks. Interfaces: ``__init__``, ``forward``, ``compute_double_v``, ``compute_mininum_v`` """
[docs] def __init__(self, config: EasyDict): super().__init__() self.model = torch.nn.ModuleDict() self.model["v1"] = VNetwork(config) self.model["v2"] = VNetwork(config)
[docs] def compute_double_v( self, state: Union[torch.Tensor, TensorDict], condition: Union[torch.Tensor, TensorDict], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: Return the output of two value networks. Arguments: state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. Returns: v1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first value network. v2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second value network. """ return self.model["v1"](state, condition), self.model["v2"](state, condition)
[docs] def compute_mininum_v( self, state: Union[torch.Tensor, TensorDict], condition: Union[torch.Tensor, TensorDict], ) -> torch.Tensor: """ Overview: Return the minimum output of two value networks. Arguments: state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. Returns: minimum_v (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of value network. """ return torch.min(*self.compute_double_v(state, condition=condition))
[docs] def forward( self, state: Union[torch.Tensor, TensorDict], condition: Union[torch.Tensor, TensorDict], ) -> torch.Tensor: """ Overview: Return the minimum output of two value networks. Arguments: state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. Returns: minimum_v (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of value network. """ return self.compute_mininum_v(state, condition=condition)