Shortcuts

Source code for ding.rl_utils.sampler

import torch
import treetensor.torch as ttorch
from torch.distributions import Normal, Independent


[docs]class ArgmaxSampler: ''' Overview: Argmax sampler, return the index of the maximum value '''
[docs] def __call__(self, logit: torch.Tensor) -> torch.Tensor: ''' Overview: Return the index of the maximum value Arguments: - logit (:obj:`torch.Tensor`): The input tensor Returns: - action (:obj:`torch.Tensor`): The index of the maximum value ''' return logit.argmax(dim=-1)
[docs]class MultinomialSampler: ''' Overview: Multinomial sampler, return the index of the sampled value '''
[docs] def __call__(self, logit: torch.Tensor) -> torch.Tensor: ''' Overview: Return the index of the sampled value Arguments: - logit (:obj:`torch.Tensor`): The input tensor Returns: - action (:obj:`torch.Tensor`): The index of the sampled value ''' dist = torch.distributions.Categorical(logits=logit) return dist.sample()
[docs]class MuSampler: ''' Overview: Mu sampler, return the mu of the input tensor '''
[docs] def __call__(self, logit: ttorch.Tensor) -> torch.Tensor: ''' Overview: Return the mu of the input tensor Arguments: - logit (:obj:`ttorch.Tensor`): The input tensor Returns: - action (:obj:`torch.Tensor`): The mu of the input tensor ''' return logit.mu
[docs]class ReparameterizationSampler: ''' Overview: Reparameterization sampler, return the reparameterized value of the input tensor '''
[docs] def __call__(self, logit: ttorch.Tensor) -> torch.Tensor: ''' Overview: Return the reparameterized value of the input tensor Arguments: - logit (:obj:`ttorch.Tensor`): The input tensor Returns: - action (:obj:`torch.Tensor`): The reparameterized value of the input tensor ''' dist = Normal(logit.mu, logit.sigma) dist = Independent(dist, 1) return dist.rsample()
[docs]class HybridStochasticSampler: ''' Overview: Hybrid stochastic sampler, return the sampled action type and the reparameterized action args '''
[docs] def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor: ''' Overview: Return the sampled action type and the reparameterized action args Arguments: - logit (:obj:`ttorch.Tensor`): The input tensor Returns: - action (:obj:`ttorch.Tensor`): The sampled action type and the reparameterized action args ''' dist = torch.distributions.Categorical(logits=logit.action_type) action_type = dist.sample() dist = Normal(logit.action_args.mu, logit.action_args.sigma) dist = Independent(dist, 1) action_args = dist.rsample() return ttorch.as_tensor({ 'action_type': action_type, 'action_args': action_args, })
[docs]class HybridDeterminsticSampler: ''' Overview: Hybrid deterministic sampler, return the argmax action type and the mu action args '''
[docs] def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor: ''' Overview: Return the argmax action type and the mu action args Arguments: - logit (:obj:`ttorch.Tensor`): The input tensor Returns: - action (:obj:`ttorch.Tensor`): The argmax action type and the mu action args ''' action_type = logit.action_type.argmax(dim=-1) action_args = logit.action_args.mu return ttorch.as_tensor({ 'action_type': action_type, 'action_args': action_args, })