grl.algorithms¶
QGPOCritic¶
- class grl.algorithms.QGPOCritic(config)[source]¶
- Overview:
Critic network for QGPO algorithm.
- Interfaces:
__init__
,forward
- __init__(config)[source]¶
- Overview:
Initialization of QGPO critic network.
- Parameters:
config (
EasyDict
) – The configuration dict.
- compute_double_q(action, state=None)[source]¶
- Overview:
Return the output of two Q networks.
- Parameters:
action (
Union[torch.Tensor, TensorDict]
) – The input action.state (
Union[torch.Tensor, TensorDict]
) – The input state.
- Returns:
The output of the first Q network. q2 (
Union[torch.Tensor, TensorDict]
): The output of the second Q network.- Return type:
q1 (
Union[torch.Tensor, TensorDict]
)
- forward(action, state=None)[source]¶
- Overview:
Return the output of QGPO critic.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.
- Return type:
Tensor
- q_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]¶
- Overview:
Calculate the Q loss.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.reward (
torch.Tensor
) – The input reward.next_state (
torch.Tensor
) – The input next state.done (
torch.Tensor
) – The input done.fake_next_action (
torch.Tensor
) – The input fake next action.discount_factor (
float
) – The discount factor.
- Return type:
Tensor
QGPOPolicy¶
- class grl.algorithms.QGPOPolicy(config)[source]¶
- Overview:
QGPO policy network.
- Interfaces:
__init__
,forward
,sample
,behaviour_policy_sample
,compute_q
,behaviour_policy_loss
,energy_guidance_loss
,q_loss
- __init__(config)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- behaviour_policy_loss(action, state)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.
- behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None)[source]¶
- Overview:
Return the output of behaviour policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.solver_config (
EasyDict
) – The configuration for the ODE solver.t_span (
torch.Tensor
) – The time span for the ODE solver or SDE solver.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
- compute_q(state, action)[source]¶
- Overview:
Calculate the Q value.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.action (
Union[torch.Tensor, TensorDict]
) – The input action.
- Returns:
The Q value.
- Return type:
q (
torch.Tensor
)
- energy_guidance_loss(state, fake_next_action)[source]¶
- Overview:
Calculate the energy guidance loss of QGPO.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.fake_next_action (
Union[torch.Tensor, TensorDict]
) – The input fake next action.
- Return type:
Tensor
- forward(state)[source]¶
- Overview:
Return the output of QGPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
- q_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]¶
- Overview:
Calculate the Q loss.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.reward (
torch.Tensor
) – The input reward.next_state (
torch.Tensor
) – The input next state.done (
torch.Tensor
) – The input done.fake_next_action (
torch.Tensor
) – The input fake next action.discount_factor (
float
) – The discount factor.
- Return type:
Tensor
- sample(state, batch_size=None, guidance_scale=tensor(1.), solver_config=None, t_span=None)[source]¶
- Overview:
Return the output of QGPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.guidance_scale (
Union[torch.Tensor, float]
) – The guidance scale.solver_config (
EasyDict
) – The configuration for the ODE solver.t_span (
torch.Tensor
) – The time span for the ODE solver or SDE solver.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
QGPOAlgorithm¶
- class grl.algorithms.QGPOAlgorithm(config=None, simulator=None, dataset=None, model=None)[source]¶
- Overview:
Q-guided policy optimization (QGPO) algorithm, which is an offline reinforcement learning algorithm that uses energy-based diffusion model for policy modeling.
- Interfaces:
__init__
,train
,deploy
- __init__(config=None, simulator=None, dataset=None, model=None)[source]¶
- Overview:
Initialize the QGPO algorithm.
- Parameters:
config (
EasyDict
) – The configuration , which must contain the following keys: train (EasyDict
): The training configuration. deploy (EasyDict
): The deployment configuration.simulator (
object
) – The environment simulator.dataset (
QGPODataset
) – The dataset.model (
Union[torch.nn.Module, torch.nn.ModuleDict]
) – The model.
- Interface:
__init__
,train
,deploy
SRPOCritic¶
SRPOPolicy¶
- class grl.algorithms.SRPOPolicy(config)[source]¶
- Overview:
The SRPO policy network.
- Interfaces:
__init__
,forward
,sample
,behaviour_policy_loss
,srpo_actor_loss
- __init__(config)[source]¶
- Overview:
Initialize the SRPO policy network.
- Parameters:
config (
EasyDict
) – The configuration.
- behaviour_policy_loss(action, state)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.
- forward(state)[source]¶
- Overview:
Return the output of SRPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
- sample(state, batch_size=None, solver_config=None, t_span=None)[source]¶
- Overview:
Return the output of SRPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.solver_config (
EasyDict
) – The configuration for the ODE solver.t_span (
torch.Tensor
) – The time span for the ODE solver or SDE solver.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
- srpo_actor_loss(state)[source]¶
- Overview:
Calculate the Q loss.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.reward (
torch.Tensor
) – The input reward.next_state (
torch.Tensor
) – The input next state.done (
torch.Tensor
) – The input done.fake_next_action (
torch.Tensor
) – The input fake next action.discount_factor (
float
) – The discount factor.
- Return type:
Tensor
SRPOAlgorithm¶
- class grl.algorithms.SRPOAlgorithm(config=None, simulator=None, dataset=None, model=None)[source]¶
- __init__(config=None, simulator=None, dataset=None, model=None)[source]¶
- Overview:
Initialize the SRPO algorithm.
- Parameters:
config (
EasyDict
) – The configuration , which must contain the following keys: train (EasyDict
): The training configuration. deploy (EasyDict
): The deployment configuration.simulator (
object
) – The environment simulator.dataset (
Dataset
) – The dataset.model (
Union[torch.nn.Module, torch.nn.ModuleDict]
) – The model.
- Interface:
__init__
,train
,deploy
GMPOCritic¶
- class grl.algorithms.GMPOCritic(config)[source]¶
- Overview:
Critic network for GMPO algorithm.
- Interfaces:
__init__
,forward
- __init__(config)[source]¶
- Overview:
Initialization of GMPO critic network.
- Parameters:
config (
EasyDict
) – The configuration dict.
- compute_double_q(action, state=None)[source]¶
- Overview:
Return the output of two Q networks.
- Parameters:
action (
Union[torch.Tensor, TensorDict]
) – The input action.state (
Union[torch.Tensor, TensorDict]
) – The input state.
- Returns:
The output of the first Q network. q2 (
Union[torch.Tensor, TensorDict]
): The output of the second Q network.- Return type:
q1 (
Union[torch.Tensor, TensorDict]
)
GMPOPolicy¶
- class grl.algorithms.GMPOPolicy(config)[source]¶
- Overview:
GMPO policy network for GMPO algorithm, which includes the base model (optinal), the guided model and the critic.
- Interfaces:
__init__
,forward
,sample
,compute_q
,behaviour_policy_loss
,policy_optimization_loss_by_advantage_weighted_regression
,policy_optimization_loss_by_advantage_weighted_regression_softmax
- __init__(config)[source]¶
- Overview:
Initialize the GMPO policy network.
- Parameters:
config (
EasyDict
) – The configuration dict.
- behaviour_policy_loss(action, state, maximum_likelihood=False)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.
- behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
- Overview:
Return the output of behaviour policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.solver_config (
EasyDict
) – The configuration for the ODE solver.t_span (
torch.Tensor
) – The time span for the ODE solver or SDE solver.with_grad (
bool
) – Whether to calculate the gradient.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
- compute_q(state, action)[source]¶
- Overview:
Calculate the Q value.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.action (
Union[torch.Tensor, TensorDict]
) – The input action.
- Returns:
The Q value.
- Return type:
q (
torch.Tensor
)
- forward(state)[source]¶
- Overview:
Return the output of GMPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
- policy_optimization_loss_by_advantage_weighted_regression(action, state, maximum_likelihood=False, beta=1.0, weight_clamp=100.0)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.
- policy_optimization_loss_by_advantage_weighted_regression_softmax(state, fake_action, maximum_likelihood=False, beta=1.0)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.
- sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
- Overview:
Return the output of GMPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.solver_config (
EasyDict
) – The configuration for the ODE solver.t_span (
torch.Tensor
) – The time span for the ODE solver or SDE solver.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
GMPOAlgorithm¶
- class grl.algorithms.GMPOAlgorithm(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
- Overview:
The Generative Model Policy Optimization(GMPO) algorithm.
- Interfaces:
__init__
,train
,deploy
- __init__(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
- Overview:
Initialize the GMPO && GPG algorithm.
- Parameters:
config (
EasyDict
) – The configuration , which must contain the following keys: train (EasyDict
): The training configuration. deploy (EasyDict
): The deployment configuration.simulator (
object
) – The environment simulator.dataset (
GPDataset
) – The dataset.model (
Union[torch.nn.Module, torch.nn.ModuleDict]
) – The model.
- Interface:
__init__
,train
,deploy
GMPGCritic¶
- class grl.algorithms.GMPGCritic(config)[source]¶
- Overview:
Critic network.
- Interfaces:
__init__
,forward
- __init__(config)[source]¶
- Overview:
Initialization of GPO critic network.
- Parameters:
config (
EasyDict
) – The configuration dict.
- compute_double_q(action, state=None)[source]¶
- Overview:
Return the output of two Q networks.
- Parameters:
action (
Union[torch.Tensor, TensorDict]
) – The input action.state (
Union[torch.Tensor, TensorDict]
) – The input state.
- Returns:
The output of the first Q network. q2 (
Union[torch.Tensor, TensorDict]
): The output of the second Q network.- Return type:
q1 (
Union[torch.Tensor, TensorDict]
)
- forward(action, state=None)[source]¶
- Overview:
Return the output of GPO critic.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.
- Return type:
Tensor
- in_support_ql_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]¶
- Overview:
Calculate the Q loss.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.reward (
torch.Tensor
) – The input reward.next_state (
torch.Tensor
) – The input next state.done (
torch.Tensor
) – The input done.fake_next_action (
torch.Tensor
) – The input fake next action.discount_factor (
float
) – The discount factor.
- Return type:
Tensor
GMPGPolicy¶
- class grl.algorithms.GMPGPolicy(config)[source]¶
- __init__(config)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- behaviour_policy_loss(action, state, maximum_likelihood=False)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor
) – The input action.state (
torch.Tensor
) – The input state.
- behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
- Overview:
Return the output of behaviour policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.solver_config (
EasyDict
) – The configuration for the ODE solver.t_span (
torch.Tensor
) – The time span for the ODE solver or SDE solver.with_grad (
bool
) – Whether to calculate the gradient.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
- compute_q(state, action)[source]¶
- Overview:
Calculate the Q value.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.action (
Union[torch.Tensor, TensorDict]
) – The input action.
- Returns:
The Q value.
- Return type:
q (
torch.Tensor
)
- forward(state)[source]¶
- Overview:
Return the output of GPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
- sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
- Overview:
Return the output of GPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]
) – The batch size.solver_config (
EasyDict
) – The configuration for the ODE solver.t_span (
torch.Tensor
) – The time span for the ODE solver or SDE solver.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict]
)
GMPGAlgorithm¶
- class grl.algorithms.GMPGAlgorithm(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
- Overview:
The Generative Model Policy Gradient(GMPG) algorithm.
- Interfaces:
__init__
,train
,deploy
- __init__(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
- Overview:
Initialize algorithm.
- Parameters:
config (
EasyDict
) – The configuration , which must contain the following keys: train (EasyDict
): The training configuration. deploy (EasyDict
): The deployment configuration.simulator (
object
) – The environment simulator.dataset (
GPDataset
) – The dataset.model (
Union[torch.nn.Module, torch.nn.ModuleDict]
) – The model.
- Interface:
__init__
,train
,deploy