Shortcuts

grl.algorithms

QGPOCritic

class grl.algorithms.QGPOCritic(config)[source]
Overview:

Critic network for QGPO algorithm.

Interfaces:

__init__, forward

__init__(config)[source]
Overview:

Initialization of QGPO critic network.

Parameters:

config (EasyDict) – The configuration dict.

compute_double_q(action, state=None)[source]
Overview:

Return the output of two Q networks.

Parameters:
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output of the first Q network. q2 (Union[torch.Tensor, TensorDict]): The output of the second Q network.

Return type:

q1 (Union[torch.Tensor, TensorDict])

forward(action, state=None)[source]
Overview:

Return the output of QGPO critic.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

Return type:

Tensor

q_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]
Overview:

Calculate the Q loss.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

  • reward (torch.Tensor) – The input reward.

  • next_state (torch.Tensor) – The input next state.

  • done (torch.Tensor) – The input done.

  • fake_next_action (torch.Tensor) – The input fake next action.

  • discount_factor (float) – The discount factor.

Return type:

Tensor

QGPOPolicy

class grl.algorithms.QGPOPolicy(config)[source]
Overview:

QGPO policy network.

Interfaces:

__init__, forward, sample, behaviour_policy_sample, compute_q, behaviour_policy_loss, energy_guidance_loss, q_loss

__init__(config)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

behaviour_policy_loss(action, state)[source]
Overview:

Calculate the behaviour policy loss.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None)[source]
Overview:

Return the output of behaviour policy, which is the action conditioned on the state.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • solver_config (EasyDict) – The configuration for the ODE solver.

  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

compute_q(state, action)[source]
Overview:

Calculate the Q value.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • action (Union[torch.Tensor, TensorDict]) – The input action.

Returns:

The Q value.

Return type:

q (torch.Tensor)

energy_guidance_loss(state, fake_next_action)[source]
Overview:

Calculate the energy guidance loss of QGPO.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • fake_next_action (Union[torch.Tensor, TensorDict]) – The input fake next action.

Return type:

Tensor

forward(state)[source]
Overview:

Return the output of QGPO policy, which is the action conditioned on the state.

Parameters:

state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

q_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]
Overview:

Calculate the Q loss.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

  • reward (torch.Tensor) – The input reward.

  • next_state (torch.Tensor) – The input next state.

  • done (torch.Tensor) – The input done.

  • fake_next_action (torch.Tensor) – The input fake next action.

  • discount_factor (float) – The discount factor.

Return type:

Tensor

sample(state, batch_size=None, guidance_scale=tensor(1.), solver_config=None, t_span=None)[source]
Overview:

Return the output of QGPO policy, which is the action conditioned on the state.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • guidance_scale (Union[torch.Tensor, float]) – The guidance scale.

  • solver_config (EasyDict) – The configuration for the ODE solver.

  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

QGPOAlgorithm

class grl.algorithms.QGPOAlgorithm(config=None, simulator=None, dataset=None, model=None)[source]
Overview:

Q-guided policy optimization (QGPO) algorithm, which is an offline reinforcement learning algorithm that uses energy-based diffusion model for policy modeling.

Interfaces:

__init__, train, deploy

__init__(config=None, simulator=None, dataset=None, model=None)[source]
Overview:

Initialize the QGPO algorithm.

Parameters:
  • config (EasyDict) – The configuration , which must contain the following keys: train (EasyDict): The training configuration. deploy (EasyDict): The deployment configuration.

  • simulator (object) – The environment simulator.

  • dataset (QGPODataset) – The dataset.

  • model (Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.

Interface:

__init__, train, deploy

deploy(config=None)[source]
Overview:

Deploy the model using the given configuration.

Parameters:

config (EasyDict) – The deployment configuration.

Return type:

QGPOAgent

train(config=None)[source]
Overview:

Train the model using the given configuration. A weight-and-bias run will be created automatically when this function is called.

Parameters:

config (EasyDict) – The training configuration.

SRPOCritic

class grl.algorithms.SRPOCritic(config)[source]
Overview:

The critic network used in SRPO algorithm.

Interfaces:

__init__, v_loss, ``q_loss

__init__(config)[source]
Overview:

Initialize the critic network.

Parameters:

config (EasyDict) – The configuration.

forward(action, state=None)[source]
Overview:

Return the output of critic.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

Return type:

Tensor

SRPOPolicy

class grl.algorithms.SRPOPolicy(config)[source]
Overview:

The SRPO policy network.

Interfaces:

__init__, forward, sample, behaviour_policy_loss, srpo_actor_loss

__init__(config)[source]
Overview:

Initialize the SRPO policy network.

Parameters:

config (EasyDict) – The configuration.

behaviour_policy_loss(action, state)[source]
Overview:

Calculate the behaviour policy loss.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

forward(state)[source]
Overview:

Return the output of SRPO policy, which is the action conditioned on the state.

Parameters:

state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

sample(state, batch_size=None, solver_config=None, t_span=None)[source]
Overview:

Return the output of SRPO policy, which is the action conditioned on the state.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • solver_config (EasyDict) – The configuration for the ODE solver.

  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

srpo_actor_loss(state)[source]
Overview:

Calculate the Q loss.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

  • reward (torch.Tensor) – The input reward.

  • next_state (torch.Tensor) – The input next state.

  • done (torch.Tensor) – The input done.

  • fake_next_action (torch.Tensor) – The input fake next action.

  • discount_factor (float) – The discount factor.

Return type:

Tensor

SRPOAlgorithm

class grl.algorithms.SRPOAlgorithm(config=None, simulator=None, dataset=None, model=None)[source]
__init__(config=None, simulator=None, dataset=None, model=None)[source]
Overview:

Initialize the SRPO algorithm.

Parameters:
  • config (EasyDict) – The configuration , which must contain the following keys: train (EasyDict): The training configuration. deploy (EasyDict): The deployment configuration.

  • simulator (object) – The environment simulator.

  • dataset (Dataset) – The dataset.

  • model (Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.

Interface:

__init__, train, deploy

deploy(config=None)[source]
Overview:

Deploy the model using the given configuration.

Parameters:

config (EasyDict) – The deployment configuration.

Return type:

SRPOAgent

train(config=None)[source]
Overview:

Train the model using the given configuration. A weight-and-bias run will be created automatically when this function is called.

Parameters:

config (EasyDict) – The training configuration.

GMPOCritic

class grl.algorithms.GMPOCritic(config)[source]
Overview:

Critic network for GMPO algorithm.

Interfaces:

__init__, forward

__init__(config)[source]
Overview:

Initialization of GMPO critic network.

Parameters:

config (EasyDict) – The configuration dict.

compute_double_q(action, state=None)[source]
Overview:

Return the output of two Q networks.

Parameters:
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output of the first Q network. q2 (Union[torch.Tensor, TensorDict]): The output of the second Q network.

Return type:

q1 (Union[torch.Tensor, TensorDict])

forward(action, state=None)[source]
Overview:

Return the output of GMPO critic.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

Return type:

Tensor

GMPOPolicy

class grl.algorithms.GMPOPolicy(config)[source]
Overview:

GMPO policy network for GMPO algorithm, which includes the base model (optinal), the guided model and the critic.

Interfaces:

__init__, forward, sample, compute_q, behaviour_policy_loss, policy_optimization_loss_by_advantage_weighted_regression, policy_optimization_loss_by_advantage_weighted_regression_softmax

__init__(config)[source]
Overview:

Initialize the GMPO policy network.

Parameters:

config (EasyDict) – The configuration dict.

behaviour_policy_loss(action, state, maximum_likelihood=False)[source]
Overview:

Calculate the behaviour policy loss.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]
Overview:

Return the output of behaviour policy, which is the action conditioned on the state.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • solver_config (EasyDict) – The configuration for the ODE solver.

  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

  • with_grad (bool) – Whether to calculate the gradient.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

compute_q(state, action)[source]
Overview:

Calculate the Q value.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • action (Union[torch.Tensor, TensorDict]) – The input action.

Returns:

The Q value.

Return type:

q (torch.Tensor)

forward(state)[source]
Overview:

Return the output of GMPO policy, which is the action conditioned on the state.

Parameters:

state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

policy_optimization_loss_by_advantage_weighted_regression(action, state, maximum_likelihood=False, beta=1.0, weight_clamp=100.0)[source]
Overview:

Calculate the behaviour policy loss.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

policy_optimization_loss_by_advantage_weighted_regression_softmax(state, fake_action, maximum_likelihood=False, beta=1.0)[source]
Overview:

Calculate the behaviour policy loss.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]
Overview:

Return the output of GMPO policy, which is the action conditioned on the state.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • solver_config (EasyDict) – The configuration for the ODE solver.

  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

GMPOAlgorithm

class grl.algorithms.GMPOAlgorithm(config=None, simulator=None, dataset=None, model=None, seed=None)[source]
Overview:

The Generative Model Policy Optimization(GMPO) algorithm.

Interfaces:

__init__, train, deploy

__init__(config=None, simulator=None, dataset=None, model=None, seed=None)[source]
Overview:

Initialize the GMPO && GPG algorithm.

Parameters:
  • config (EasyDict) – The configuration , which must contain the following keys: train (EasyDict): The training configuration. deploy (EasyDict): The deployment configuration.

  • simulator (object) – The environment simulator.

  • dataset (GPDataset) – The dataset.

  • model (Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.

Interface:

__init__, train, deploy

train(config=None, seed=None)[source]
Overview:

Train the model using the given configuration. A weight-and-bias run will be created automatically when this function is called.

Parameters:
  • config (EasyDict) – The training configuration.

  • seed (int) – The random seed.

GMPGCritic

class grl.algorithms.GMPGCritic(config)[source]
Overview:

Critic network.

Interfaces:

__init__, forward

__init__(config)[source]
Overview:

Initialization of GPO critic network.

Parameters:

config (EasyDict) – The configuration dict.

compute_double_q(action, state=None)[source]
Overview:

Return the output of two Q networks.

Parameters:
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output of the first Q network. q2 (Union[torch.Tensor, TensorDict]): The output of the second Q network.

Return type:

q1 (Union[torch.Tensor, TensorDict])

forward(action, state=None)[source]
Overview:

Return the output of GPO critic.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

Return type:

Tensor

in_support_ql_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]
Overview:

Calculate the Q loss.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

  • reward (torch.Tensor) – The input reward.

  • next_state (torch.Tensor) – The input next state.

  • done (torch.Tensor) – The input done.

  • fake_next_action (torch.Tensor) – The input fake next action.

  • discount_factor (float) – The discount factor.

Return type:

Tensor

GMPGPolicy

class grl.algorithms.GMPGPolicy(config)[source]
__init__(config)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

behaviour_policy_loss(action, state, maximum_likelihood=False)[source]
Overview:

Calculate the behaviour policy loss.

Parameters:
  • action (torch.Tensor) – The input action.

  • state (torch.Tensor) – The input state.

behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]
Overview:

Return the output of behaviour policy, which is the action conditioned on the state.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • solver_config (EasyDict) – The configuration for the ODE solver.

  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

  • with_grad (bool) – Whether to calculate the gradient.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

compute_q(state, action)[source]
Overview:

Calculate the Q value.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • action (Union[torch.Tensor, TensorDict]) – The input action.

Returns:

The Q value.

Return type:

q (torch.Tensor)

forward(state)[source]
Overview:

Return the output of GPO policy, which is the action conditioned on the state.

Parameters:

state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]
Overview:

Return the output of GPO policy, which is the action conditioned on the state.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • batch_size (Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.

  • solver_config (EasyDict) – The configuration for the ODE solver.

  • t_span (torch.Tensor) – The time span for the ODE solver or SDE solver.

Returns:

The output action.

Return type:

action (Union[torch.Tensor, TensorDict])

GMPGAlgorithm

class grl.algorithms.GMPGAlgorithm(config=None, simulator=None, dataset=None, model=None, seed=None)[source]
Overview:

The Generative Model Policy Gradient(GMPG) algorithm.

Interfaces:

__init__, train, deploy

__init__(config=None, simulator=None, dataset=None, model=None, seed=None)[source]
Overview:

Initialize algorithm.

Parameters:
  • config (EasyDict) – The configuration , which must contain the following keys: train (EasyDict): The training configuration. deploy (EasyDict): The deployment configuration.

  • simulator (object) – The environment simulator.

  • dataset (GPDataset) – The dataset.

  • model (Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.

Interface:

__init__, train, deploy

train(config=None, seed=None)[source]
Overview:

Train the model using the given configuration. A weight-and-bias run will be created automatically when this function is called.

Parameters:
  • config (EasyDict) – The training configuration.

  • seed (int) – The random seed.