Shortcuts

grl.rl_modules

GymEnvSimulator

class grl.rl_modules.GymEnvSimulator(env_id)[source]
Overview:

A simple gym environment simulator in GenerativeRL. This simulator is used to collect episodes and steps using a given policy in a gym environment. It runs in single process and is suitable for small-scale experiments.

Interfaces:

__init__, collect_episodes, collect_steps, evaluate

__init__(env_id)[source]
Overview:

Initialize the GymEnvSimulator according to the given configuration.

Parameters:

env_id (str) – The id of the gym environment to simulate.

collect_episodes(policy, num_episodes=None, num_steps=None)[source]
Overview:

Collect several episodes using the given policy. The environment will be reset at the beginning of each episode. No history will be stored in this method. The collected information of steps will be returned as a list of dictionaries.

Parameters:
  • policy (Union[Callable, torch.nn.Module]) – The policy to collect episodes.

  • num_episodes (int) – The number of episodes to collect.

  • num_steps (int) – The number of steps to collect.

Return type:

List[Dict]

collect_steps(policy, num_episodes=None, num_steps=None, random_policy=False)[source]
Overview:

Collect several steps using the given policy. The environment will not be reset until the end of the episode. Last observation will be stored in this method. The collected information of steps will be returned as a list of dictionaries.

Parameters:
  • policy (Union[Callable, torch.nn.Module]) – The policy to collect steps.

  • num_episodes (int) – The number of episodes to collect.

  • num_steps (int) – The number of steps to collect.

  • random_policy (bool) – Whether to use a random policy.

Return type:

List[Dict]

evaluate(policy, num_episodes=None, render_args=None)[source]
Return type:

List[Dict]

Overview:

Evaluate the given policy using the environment. The environment will be reset at the beginning of each episode. No history will be stored in this method. The evaluation resultswill be returned as a list of dictionaries.

OneShotValueFunction

class grl.rl_modules.OneShotValueFunction(config)[source]
Overview:

Value network for one-shot cases, which means that no Bellman backup is needed for training.

Interfaces:

__init__, forward

__init__(config)[source]
Overview:

Initialization of one-shot value network.

Parameters:

config (EasyDict) – The configuration dict.

compute_double_v(state, condition=None)[source]
Overview:

Return the output of two value networks.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

Returns:

The output of the first value network. v2 (Union[torch.Tensor, TensorDict]): The output of the second value network.

Return type:

v1 (Union[torch.Tensor, TensorDict])

forward(state, condition=None)[source]
Overview:

Return the output of one-shot value network.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

Return type:

Tensor

v_loss(state, value, condition=None)[source]
Overview:

Calculate the v loss.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • value (Union[torch.Tensor, TensorDict]) – The input value.

  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

Returns:

The v loss.

Return type:

v_loss (torch.Tensor)

VNetwork

class grl.rl_modules.VNetwork(config)[source]
Overview:

Value network, which is used to approximate the value function.

Interfaces:

__init__, forward

__init__(config)[source]
Overview:

Initialization of value network.

Parameters:

config (EasyDict) – The configuration dict.

forward(state, condition=None)[source]
Overview:

Return output of value networks.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

Returns:

The output of value network.

Return type:

value (Union[torch.Tensor, TensorDict])

DoubleVNetwork

class grl.rl_modules.DoubleVNetwork(config)[source]
Overview:

Double value network, which has two value networks.

Interfaces:

__init__, forward, compute_double_v, compute_mininum_v

__init__(config)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

compute_double_v(state, condition)[source]
Overview:

Return the output of two value networks.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

Returns:

The output of the first value network. v2 (Union[torch.Tensor, TensorDict]): The output of the second value network.

Return type:

v1 (Union[torch.Tensor, TensorDict])

compute_mininum_v(state, condition)[source]
Overview:

Return the minimum output of two value networks.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

Returns:

The minimum output of value network.

Return type:

minimum_v (Union[torch.Tensor, TensorDict])

forward(state, condition)[source]
Overview:

Return the minimum output of two value networks.

Parameters:
  • state (Union[torch.Tensor, TensorDict]) – The input state.

  • condition (Union[torch.Tensor, TensorDict]) – The input condition.

Returns:

The minimum output of value network.

Return type:

minimum_v (Union[torch.Tensor, TensorDict])

QNetwork

class grl.rl_modules.QNetwork(config)[source]
Overview:

Q network, which is used to approximate the Q value.

Interfaces:

__init__, forward

__init__(config)[source]
Overview:

Initialization of Q network.

Parameters:

config (EasyDict) – The configuration dict.

forward(action, state)[source]
Overview:

Return output of Q networks.

Parameters:
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output of Q network.

Return type:

q (Union[torch.Tensor, TensorDict])

DoubleQNetwork

class grl.rl_modules.DoubleQNetwork(config)[source]
Overview:

Double Q network, which has two Q networks.

Interfaces:

__init__, forward, compute_double_q, compute_mininum_q

__init__(config)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

compute_double_q(action, state)[source]
Overview:

Return the output of two Q networks.

Parameters:
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output of the first Q network. q2 (Union[torch.Tensor, TensorDict]): The output of the second Q network.

Return type:

q1 (Union[torch.Tensor, TensorDict])

compute_mininum_q(action, state)[source]
Overview:

Return the minimum output of two Q networks.

Parameters:
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The minimum output of Q network.

Return type:

minimum_q (Union[torch.Tensor, TensorDict])

forward(action, state)[source]
Overview:

Return the minimum output of two Q networks.

Parameters:
  • action (Union[torch.Tensor, TensorDict]) – The input action.

  • state (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The minimum output of Q network.

Return type:

minimum_q (Union[torch.Tensor, TensorDict])