Source code for grl.agents.qgpo
from typing import Dict, Union
import numpy as np
import torch
from easydict import EasyDict
from grl.agents import obs_transform, action_transform
[docs]class QGPOAgent:
"""
Overview:
The agent for the QGPO algorithm.
Interface:
``__init__``, ``action``
"""
[docs] def __init__(
self,
config: EasyDict,
model: Union[torch.nn.Module, torch.nn.ModuleDict],
):
"""
Overview:
Initialize the agent.
Arguments:
config (:obj:`EasyDict`): The configuration.
model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model.
"""
self.config = config
self.device = config.device
self.model = model.to(self.device)
if hasattr(self.config, "guidance_scale"):
self.guidance_scale = self.config.guidance_scale
else:
self.guidance_scale = 1.0
[docs] def act(
self,
obs: Union[np.ndarray, torch.Tensor, Dict],
return_as_torch_tensor: bool = False,
) -> Union[np.ndarray, torch.Tensor, Dict]:
"""
Overview:
Given an observation, return an action.
Arguments:
obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation.
return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor.
Returns:
action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action.
"""
obs = obs_transform(obs, self.device)
with torch.no_grad():
# ---------------------------------------
# Customized inference code ↓
# ---------------------------------------
obs = obs.unsqueeze(0)
action = (
self.model["QGPOPolicy"]
.sample(
state=obs,
t_span=(
torch.linspace(0.0, 1.0, self.config.t_span).to(obs.device)
if self.config.t_span is not None
else None
),
guidance_scale=self.guidance_scale,
)
.squeeze(0)
.cpu()
.detach()
.numpy()
)
# ---------------------------------------
# Customized inference code ↑
# ---------------------------------------
action = action_transform(action, return_as_torch_tensor)
return action