grl.rl_modules¶
GymEnvSimulator¶
- class grl.rl_modules.GymEnvSimulator(env_id)[source]¶
- Overview:
A simple gym environment simulator in GenerativeRL. This simulator is used to collect episodes and steps using a given policy in a gym environment. It runs in single process and is suitable for small-scale experiments.
- Interfaces:
__init__
,collect_episodes
,collect_steps
,evaluate
- __init__(env_id)[source]¶
- Overview:
Initialize the GymEnvSimulator according to the given configuration.
- Parameters:
env_id (
str
) – The id of the gym environment to simulate.
- collect_episodes(policy, num_episodes=None, num_steps=None)[source]¶
- Overview:
Collect several episodes using the given policy. The environment will be reset at the beginning of each episode. No history will be stored in this method. The collected information of steps will be returned as a list of dictionaries.
- Parameters:
policy (
Union[Callable, torch.nn.Module]
) – The policy to collect episodes.num_episodes (
int
) – The number of episodes to collect.num_steps (
int
) – The number of steps to collect.
- Return type:
List
[Dict
]
- collect_steps(policy, num_episodes=None, num_steps=None, random_policy=False)[source]¶
- Overview:
Collect several steps using the given policy. The environment will not be reset until the end of the episode. Last observation will be stored in this method. The collected information of steps will be returned as a list of dictionaries.
- Parameters:
policy (
Union[Callable, torch.nn.Module]
) – The policy to collect steps.num_episodes (
int
) – The number of episodes to collect.num_steps (
int
) – The number of steps to collect.random_policy (
bool
) – Whether to use a random policy.
- Return type:
List
[Dict
]
- evaluate(policy, num_episodes=None, render_args=None)[source]¶
- Return type:
List
[Dict
]
- Overview:
Evaluate the given policy using the environment. The environment will be reset at the beginning of each episode. No history will be stored in this method. The evaluation resultswill be returned as a list of dictionaries.
OneShotValueFunction¶
- class grl.rl_modules.OneShotValueFunction(config)[source]¶
- Overview:
Value network for one-shot cases, which means that no Bellman backup is needed for training.
- Interfaces:
__init__
,forward
- __init__(config)[source]¶
- Overview:
Initialization of one-shot value network.
- Parameters:
config (
EasyDict
) – The configuration dict.
- compute_double_v(state, condition=None)[source]¶
- Overview:
Return the output of two value networks.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.condition (
Union[torch.Tensor, TensorDict]
) – The input condition.
- Returns:
The output of the first value network. v2 (
Union[torch.Tensor, TensorDict]
): The output of the second value network.- Return type:
v1 (
Union[torch.Tensor, TensorDict]
)
- forward(state, condition=None)[source]¶
- Overview:
Return the output of one-shot value network.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.condition (
Union[torch.Tensor, TensorDict]
) – The input condition.
- Return type:
Tensor
- v_loss(state, value, condition=None)[source]¶
- Overview:
Calculate the v loss.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.value (
Union[torch.Tensor, TensorDict]
) – The input value.condition (
Union[torch.Tensor, TensorDict]
) – The input condition.
- Returns:
The v loss.
- Return type:
v_loss (
torch.Tensor
)
VNetwork¶
- class grl.rl_modules.VNetwork(config)[source]¶
- Overview:
Value network, which is used to approximate the value function.
- Interfaces:
__init__
,forward
- __init__(config)[source]¶
- Overview:
Initialization of value network.
- Parameters:
config (
EasyDict
) – The configuration dict.
- forward(state, condition=None)[source]¶
- Overview:
Return output of value networks.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.condition (
Union[torch.Tensor, TensorDict]
) – The input condition.
- Returns:
The output of value network.
- Return type:
value (
Union[torch.Tensor, TensorDict]
)
DoubleVNetwork¶
- class grl.rl_modules.DoubleVNetwork(config)[source]¶
- Overview:
Double value network, which has two value networks.
- Interfaces:
__init__
,forward
,compute_double_v
,compute_mininum_v
- __init__(config)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- compute_double_v(state, condition)[source]¶
- Overview:
Return the output of two value networks.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.condition (
Union[torch.Tensor, TensorDict]
) – The input condition.
- Returns:
The output of the first value network. v2 (
Union[torch.Tensor, TensorDict]
): The output of the second value network.- Return type:
v1 (
Union[torch.Tensor, TensorDict]
)
- compute_mininum_v(state, condition)[source]¶
- Overview:
Return the minimum output of two value networks.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.condition (
Union[torch.Tensor, TensorDict]
) – The input condition.
- Returns:
The minimum output of value network.
- Return type:
minimum_v (
Union[torch.Tensor, TensorDict]
)
- forward(state, condition)[source]¶
- Overview:
Return the minimum output of two value networks.
- Parameters:
state (
Union[torch.Tensor, TensorDict]
) – The input state.condition (
Union[torch.Tensor, TensorDict]
) – The input condition.
- Returns:
The minimum output of value network.
- Return type:
minimum_v (
Union[torch.Tensor, TensorDict]
)
QNetwork¶
DoubleQNetwork¶
- class grl.rl_modules.DoubleQNetwork(config)[source]¶
- Overview:
Double Q network, which has two Q networks.
- Interfaces:
__init__
,forward
,compute_double_q
,compute_mininum_q
- __init__(config)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- compute_double_q(action, state)[source]¶
- Overview:
Return the output of two Q networks.
- Parameters:
action (
Union[torch.Tensor, TensorDict]
) – The input action.state (
Union[torch.Tensor, TensorDict]
) – The input state.
- Returns:
The output of the first Q network. q2 (
Union[torch.Tensor, TensorDict]
): The output of the second Q network.- Return type:
q1 (
Union[torch.Tensor, TensorDict]
)
- compute_mininum_q(action, state)[source]¶
- Overview:
Return the minimum output of two Q networks.
- Parameters:
action (
Union[torch.Tensor, TensorDict]
) – The input action.state (
Union[torch.Tensor, TensorDict]
) – The input state.
- Returns:
The minimum output of Q network.
- Return type:
minimum_q (
Union[torch.Tensor, TensorDict]
)
- forward(action, state)[source]¶
- Overview:
Return the minimum output of two Q networks.
- Parameters:
action (
Union[torch.Tensor, TensorDict]
) – The input action.state (
Union[torch.Tensor, TensorDict]
) – The input state.
- Returns:
The minimum output of Q network.
- Return type:
minimum_q (
Union[torch.Tensor, TensorDict]
)