Shortcuts

Source code for grl.agents.srpo

from typing import Dict, Union

import numpy as np
import torch
from easydict import EasyDict

from grl.agents import obs_transform, action_transform


[docs]class SRPOAgent: """ Overview: The SRPO agent. Interface: ``__init__``, ``action`` """
[docs] def __init__( self, config: EasyDict, model: Union[torch.nn.Module, torch.nn.ModuleDict], ): """ Overview: Initialize the agent. Arguments: config (:obj:`EasyDict`): The configuration. model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. """ self.config = config self.device = config.device self.model = model.to(self.device)
[docs] def act( self, obs: Union[np.ndarray, torch.Tensor, Dict], return_as_torch_tensor: bool = False, ) -> Union[np.ndarray, torch.Tensor, Dict]: """ Overview: Given an observation, return an action. Arguments: obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation. return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor. Returns: action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. """ obs = obs_transform(obs, self.device) with torch.no_grad(): # --------------------------------------- # Customized inference code ↓ # --------------------------------------- obs = obs.unsqueeze(0) action = ( self.model["SRPOPolicy"].policy(obs).squeeze(0).detach().cpu().numpy() ) # --------------------------------------- # Customized inference code ↑ # --------------------------------------- action = action_transform(action, return_as_torch_tensor) return action