Source code for grl.datasets.qgpo
#############################################################
# This QGPOD4RLDataset is a modification implementation from https://github.com/ChenDRAG/CEP-energy-guided-diffusion
#############################################################
from abc import abstractmethod
from typing import List
import os
import gym
import numpy as np
import torch
from tensordict import TensorDict
from torchrl.data import LazyTensorStorage, LazyMemmapStorage
from grl.utils.log import log
[docs]class QGPODataset(torch.utils.data.Dataset):
"""
Overview:
Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \
which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
is sampled from the action support generated by the behaviour policy.
Interface:
``__init__``, ``__getitem__``, ``__len__``.
"""
def __getitem__(self, index):
"""
Overview:
Get data by index
Arguments:
index (:obj:`int`): Index of data
Returns:
data (:obj:`dict`): Data dict
.. note::
The data dict contains the following keys:
s (:obj:`torch.Tensor`): State
a (:obj:`torch.Tensor`): Action
r (:obj:`torch.Tensor`): Reward
s_ (:obj:`torch.Tensor`): Next state
d (:obj:`torch.Tensor`): Is finished
fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \
(fake action is sampled from the action support generated by the behaviour policy)
fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \
(fake action is sampled from the action support generated by the behaviour policy)
"""
data = {
"s": self.states[index % self.len],
"a": self.actions[index % self.len],
"r": self.rewards[index % self.len],
"s_": self.next_states[index % self.len],
"d": self.is_finished[index % self.len],
"fake_a": (
self.fake_actions[index % self.len]
if hasattr(self, "fake_actions")
else 0.0
), # self.fake_actions <D, 16, A>
"fake_a_": (
self.fake_next_actions[index % self.len]
if hasattr(self, "fake_next_actions")
else 0.0
), # self.fake_next_actions <D, 16, A>
}
return data
def __len__(self):
return self.len
def load_fake_actions(self, fake_actions, fake_next_actions):
self.fake_actions = fake_actions
self.fake_next_actions = fake_next_actions
@abstractmethod
def return_range(self, dataset, max_episode_steps):
raise NotImplementedError
class QGPOTensorDictDataset(torch.utils.data.Dataset):
"""
Overview:
Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \
which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
is sampled from the action support generated by the behaviour policy.
Interface:
``__init__``, ``__getitem__``, ``__len__``.
"""
def __init__(self):
"""
Overview:
Initialization method of QGPOD4RLDataset class
"""
pass
def __getitem__(self, index):
"""
Overview:
Get data by index
Arguments:
index (:obj:`int`): Index of data
Returns:
data (:obj:`dict`): Data dict
.. note::
The data dict contains the following keys:
s (:obj:`torch.Tensor`): State
a (:obj:`torch.Tensor`): Action
r (:obj:`torch.Tensor`): Reward
s_ (:obj:`torch.Tensor`): Next state
d (:obj:`torch.Tensor`): Is finished
fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \
(fake action is sampled from the action support generated by the behaviour policy)
fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \
(fake action is sampled from the action support generated by the behaviour policy)
"""
data = self.storage.get(index=index)
return data
def __len__(self):
return self.len
def load_fake_actions(self, fake_actions, fake_next_actions):
self.fake_actions = fake_actions
self.fake_next_actions = fake_next_actions
self.storage.set(
range(self.len),
TensorDict(
{
"s": self.states,
"a": self.actions,
"r": self.rewards,
"s_": self.next_states,
"d": self.is_finished,
"fake_a": self.fake_actions,
"fake_a_": self.fake_next_actions,
},
batch_size=[self.len],
),
)
@abstractmethod
def return_range(self, dataset, max_episode_steps):
raise NotImplementedError
[docs]class QGPOD4RLDataset(QGPODataset):
"""
Overview:
Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \
which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
is sampled from the action support generated by the behaviour policy.
Interface:
``__init__``, ``__getitem__``, ``__len__``.
"""
[docs] def __init__(
self,
env_id: str,
):
"""
Overview:
Initialization method of QGPOD4RLDataset class
Arguments:
env_id (:obj:`str`): The environment id
"""
super().__init__()
import d4rl
data = d4rl.qlearning_dataset(gym.make(env_id))
self.states = torch.from_numpy(data["observations"]).float()
self.actions = torch.from_numpy(data["actions"]).float()
self.next_states = torch.from_numpy(data["next_observations"]).float()
reward = torch.from_numpy(data["rewards"]).view(-1, 1).float()
self.is_finished = torch.from_numpy(data["terminals"]).view(-1, 1).float()
reward_tune = "iql_antmaze" if "antmaze" in env_id else "iql_locomotion"
if reward_tune == "normalize":
reward = (reward - reward.mean()) / reward.std()
elif reward_tune == "iql_antmaze":
reward = reward - 1.0
elif reward_tune == "iql_locomotion":
min_ret, max_ret = QGPOD4RLDataset.return_range(data, 1000)
reward /= max_ret - min_ret
reward *= 1000
elif reward_tune == "cql_antmaze":
reward = (reward - 0.5) * 4.0
elif reward_tune == "antmaze":
reward = (reward - 0.25) * 2.0
self.rewards = reward
self.len = self.states.shape[0]
log.info(f"{self.len} data loaded in QGPOD4RLDataset")
def return_range(dataset, max_episode_steps):
returns, lengths = [], []
ep_ret, ep_len = 0.0, 0
for r, d in zip(dataset["rewards"], dataset["terminals"]):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0.0, 0
# returns.append(ep_ret) # incomplete trajectory
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset["rewards"])
return min(returns), max(returns)
class QGPOOnlineDataset(QGPODataset):
"""
Overview:
Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \
which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
is sampled from the action support generated by the behaviour policy.
Interface:
``__init__``, ``__getitem__``, ``__len__``.
"""
def __init__(
self,
fake_action_shape: int = None,
data: List = None,
):
"""
Overview:
Initialization method of QGPOD4RLDataset class
Arguments:
data (:obj:`List`): The data list
"""
super().__init__()
self.fake_action_shape = fake_action_shape
if data is not None:
self.states = torch.from_numpy(data["observations"]).float()
self.actions = torch.from_numpy(data["actions"]).float()
self.next_states = torch.from_numpy(data["next_observations"]).float()
reward = torch.from_numpy(data["rewards"]).view(-1, 1).float()
self.is_finished = torch.from_numpy(data["terminals"]).view(-1, 1).float()
self.rewards = reward
# self.fake_actions = torch.zeros_like(self.actions.unsqueeze(1).expand(-1, fake_action_shape, -1))
# self.fake_next_actions = torch.zeros_like(self.actions.unsqueeze(1).expand(-1, fake_action_shape, -1))
self.len = self.states.shape[0]
else:
self.states = torch.tensor([])
self.actions = torch.tensor([])
self.next_states = torch.tensor([])
self.is_finished = torch.tensor([])
self.rewards = torch.tensor([])
# self.fake_actions = torch.tensor([])
# self.fake_next_actions = torch.tensor([])
self.len = 0
log.debug(f"{self.len} data loaded in QGPOOnlineDataset")
def drop_data(self, drop_ratio: float, random: bool = True):
# drop the data from the dataset
drop_num = int(self.len * drop_ratio)
# randomly drop the data if random is True
if random:
drop_indices = torch.randperm(self.len)[:drop_num]
else:
drop_indices = torch.arange(drop_num)
keep_mask = torch.ones(self.len, dtype=torch.bool)
keep_mask[drop_indices] = False
self.states = self.states[keep_mask]
self.actions = self.actions[keep_mask]
self.next_states = self.next_states[keep_mask]
self.is_finished = self.is_finished[keep_mask]
self.rewards = self.rewards[keep_mask]
# self.fake_actions = self.fake_actions[keep_mask]
# self.fake_next_actions = self.fake_next_actions[keep_mask]
self.len = self.states.shape[0]
log.debug(f"{drop_num} data dropped in QGPOOnlineDataset")
def load_data(self, data: List):
# concatenate the data into the dataset
# collate the data by sorting the keys
keys = ["obs", "action", "done", "next_obs", "reward"]
collated_data = {
k: torch.tensor(np.stack([item[k] for item in data]))
for i, k in enumerate(keys)
}
self.states = torch.cat([self.states, collated_data["obs"].float()], dim=0)
self.actions = torch.cat([self.actions, collated_data["action"].float()], dim=0)
self.next_states = torch.cat(
[self.next_states, collated_data["next_obs"].float()], dim=0
)
reward = collated_data["reward"].view(-1, 1).float()
self.is_finished = torch.cat(
[self.is_finished, collated_data["done"].view(-1, 1).float()], dim=0
)
self.rewards = torch.cat([self.rewards, reward], dim=0)
# self.fake_actions = torch.cat([self.fake_actions, torch.zeros_like(collated_data['action'].unsqueeze(1).expand(-1, self.fake_action_shape, -1))], dim=0)
# self.fake_next_actions = torch.cat([self.fake_next_actions, torch.zeros_like(collated_data['action'].unsqueeze(1).expand(-1, self.fake_action_shape, -1))], dim=0)
self.len = self.states.shape[0]
log.debug(f"{self.len} data loaded in QGPOOnlineDataset")
def return_range(dataset, max_episode_steps):
returns, lengths = [], []
ep_ret, ep_len = 0.0, 0
for r, d in zip(dataset["rewards"], dataset["terminals"]):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0.0, 0
# returns.append(ep_ret) # incomplete trajectory
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset["rewards"])
return min(returns), max(returns)
class QGPOD4RLOnlineDataset(QGPODataset):
"""
Overview:
Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \
which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
is sampled from the action support generated by the behaviour policy.
Interface:
``__init__``, ``__getitem__``, ``__len__``.
"""
def __init__(
self,
env_id: str,
fake_action_shape: int = None,
):
"""
Overview:
Initialization method of QGPOD4RLDataset class
Arguments:
data (:obj:`List`): The data list
"""
super().__init__()
self.fake_action_shape = fake_action_shape
import d4rl
data = d4rl.qlearning_dataset(gym.make(env_id))
self.states = torch.from_numpy(data["observations"]).float()
self.actions = torch.from_numpy(data["actions"]).float()
self.next_states = torch.from_numpy(data["next_observations"]).float()
reward = torch.from_numpy(data["rewards"]).view(-1, 1).float()
self.is_finished = torch.from_numpy(data["terminals"]).view(-1, 1).float()
reward_tune = "iql_antmaze" if "antmaze" in env_id else "iql_locomotion"
if reward_tune == "normalize":
reward = (reward - reward.mean()) / reward.std()
elif reward_tune == "iql_antmaze":
reward = reward - 1.0
elif reward_tune == "iql_locomotion":
min_ret, max_ret = QGPOD4RLDataset.return_range(data, 1000)
reward /= max_ret - min_ret
reward *= 1000
elif reward_tune == "cql_antmaze":
reward = (reward - 0.5) * 4.0
elif reward_tune == "antmaze":
reward = (reward - 0.25) * 2.0
self.rewards = reward
self.len = self.states.shape[0]
log.debug(f"{self.len} data loaded in QGPOD4RLOnlineDataset")
def drop_data(self, drop_ratio: float, random: bool = True):
# drop the data from the dataset
drop_num = int(self.len * drop_ratio)
# randomly drop the data if random is True
if random:
drop_indices = torch.randperm(self.len)[:drop_num]
else:
drop_indices = torch.arange(drop_num)
keep_mask = torch.ones(self.len, dtype=torch.bool)
keep_mask[drop_indices] = False
self.states = self.states[keep_mask]
self.actions = self.actions[keep_mask]
self.next_states = self.next_states[keep_mask]
self.is_finished = self.is_finished[keep_mask]
self.rewards = self.rewards[keep_mask]
# self.fake_actions = self.fake_actions[keep_mask]
# self.fake_next_actions = self.fake_next_actions[keep_mask]
self.len = self.states.shape[0]
log.debug(f"{drop_num} data dropped in QGPOOnlineDataset")
def load_data(self, data: List):
# concatenate the data into the dataset
# collate the data by sorting the keys
keys = ["obs", "action", "done", "next_obs", "reward"]
collated_data = {
k: torch.tensor(np.stack([item[k] for item in data]))
for i, k in enumerate(keys)
}
self.states = torch.cat([self.states, collated_data["obs"].float()], dim=0)
self.actions = torch.cat([self.actions, collated_data["action"].float()], dim=0)
self.next_states = torch.cat(
[self.next_states, collated_data["next_obs"].float()], dim=0
)
reward = collated_data["reward"].view(-1, 1).float()
self.is_finished = torch.cat(
[self.is_finished, collated_data["done"].view(-1, 1).float()], dim=0
)
self.rewards = torch.cat([self.rewards, reward], dim=0)
# self.fake_actions = torch.cat([self.fake_actions, torch.zeros_like(collated_data['action'].unsqueeze(1).expand(-1, self.fake_action_shape, -1))], dim=0)
# self.fake_next_actions = torch.cat([self.fake_next_actions, torch.zeros_like(collated_data['action'].unsqueeze(1).expand(-1, self.fake_action_shape, -1))], dim=0)
self.len = self.states.shape[0]
log.debug(f"{self.len} data loaded in QGPOOnlineDataset")
def return_range(dataset, max_episode_steps):
returns, lengths = [], []
ep_ret, ep_len = 0.0, 0
for r, d in zip(dataset["rewards"], dataset["terminals"]):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0.0, 0
# returns.append(ep_ret) # incomplete trajectory
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset["rewards"])
return min(returns), max(returns)
class QGPOCustomizedDataset(QGPODataset):
"""
Overview:
Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \
which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
is sampled from the action support generated by the behaviour policy.
Interface:
``__init__``, ``__getitem__``, ``__len__``.
"""
def __init__(
self,
env_id: str = None,
numpy_data_path: str = None,
):
"""
Overview:
Initialization method of QGPOCustomizedDataset class
Arguments:
env_id (:obj:`str`): The environment id
numpy_data_path (:obj:`str`): The path to the numpy data
"""
super().__init__()
data = np.load(numpy_data_path)
self.states = torch.from_numpy(data["obs"]).float()
self.actions = torch.from_numpy(data["action"]).float()
self.next_states = torch.from_numpy(data["next_obs"]).float()
reward = torch.from_numpy(data["reward"]).view(-1, 1).float()
self.is_finished = torch.from_numpy(data["done"]).view(-1, 1).float()
self.rewards = reward
self.len = self.states.shape[0]
log.info(f"{self.len} data loaded in QGPOCustomizedDataset")
class QGPOD4RLTensorDictDataset(QGPOTensorDictDataset):
"""
Overview:
Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \
which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
is sampled from the action support generated by the behaviour policy.
Interface:
``__init__``, ``__getitem__``, ``__len__``.
"""
def __init__(
self,
env_id: str,
action_augment_num: int = 16,
):
"""
Overview:
Initialization method of QGPOD4RLDataset class
Arguments:
env_id (:obj:`str`): The environment id
"""
super().__init__()
import d4rl
data = d4rl.qlearning_dataset(gym.make(env_id))
self.states = torch.from_numpy(data["observations"]).float()
self.actions = torch.from_numpy(data["actions"]).float()
self.next_states = torch.from_numpy(data["next_observations"]).float()
reward = torch.from_numpy(data["rewards"]).view(-1, 1).float()
self.is_finished = torch.from_numpy(data["terminals"]).view(-1, 1).float()
reward_tune = "iql_antmaze" if "antmaze" in env_id else "iql_locomotion"
if reward_tune == "normalize":
reward = (reward - reward.mean()) / reward.std()
elif reward_tune == "iql_antmaze":
reward = reward - 1.0
elif reward_tune == "iql_locomotion":
min_ret, max_ret = QGPOD4RLDataset.return_range(data, 1000)
reward /= max_ret - min_ret
reward *= 1000
elif reward_tune == "cql_antmaze":
reward = (reward - 0.5) * 4.0
elif reward_tune == "antmaze":
reward = (reward - 0.25) * 2.0
self.rewards = reward
self.len = self.states.shape[0]
log.info(f"{self.len} data loaded in QGPOD4RLDataset")
self.storage = LazyTensorStorage(max_size=self.len)
self.storage.set(
range(self.len),
TensorDict(
{
"s": self.states,
"a": self.actions,
"r": self.rewards,
"s_": self.next_states,
"d": self.is_finished,
"fake_a": torch.zeros_like(self.actions)
.unsqueeze(1)
.repeat_interleave(action_augment_num, dim=1),
"fake_a_": torch.zeros_like(self.actions)
.unsqueeze(1)
.repeat_interleave(action_augment_num, dim=1),
},
batch_size=[self.len],
),
)
def return_range(dataset, max_episode_steps):
returns, lengths = [], []
ep_ret, ep_len = 0.0, 0
for r, d in zip(dataset["rewards"], dataset["terminals"]):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0.0, 0
# returns.append(ep_ret) # incomplete trajectory
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset["rewards"])
return min(returns), max(returns)
class QGPOCustomizedTensorDictDataset(QGPOTensorDictDataset):
"""
Overview:
Dataset for QGPO algorithm. The training of QGPO algorithm is based on contrastive energy prediction, \
which needs true action and fake action. The true action is sampled from the dataset, and the fake action \
is sampled from the action support generated by the behaviour policy.
Interface:
``__init__``, ``__getitem__``, ``__len__``.
"""
def __init__(
self,
env_id: str = None,
action_augment_num: int = 16,
numpy_data_path: str = None,
):
"""
Overview:
Initialization method of QGPOCustomizedDataset class
Arguments:
env_id (:obj:`str`): The environment id
numpy_data_path (:obj:`str`): The path to the numpy data
"""
super().__init__()
data = np.load(numpy_data_path)
self.states = torch.from_numpy(data["obs"]).float()
self.actions = torch.from_numpy(data["action"]).float()
self.next_states = torch.from_numpy(data["next_obs"]).float()
reward = torch.from_numpy(data["reward"]).view(-1, 1).float()
self.is_finished = torch.from_numpy(data["done"]).view(-1, 1).float()
self.rewards = reward
self.len = self.states.shape[0]
log.info(f"{self.len} data loaded in QGPOCustomizedDataset")
self.storage = LazyTensorStorage(max_size=self.len)
self.storage.set(
range(self.len),
TensorDict(
{
"s": self.states,
"a": self.actions,
"r": self.rewards,
"s_": self.next_states,
"d": self.is_finished,
"fake_a": torch.zeros_like(self.actions)
.unsqueeze(1)
.repeat_interleave(action_augment_num, dim=1),
"fake_a_": torch.zeros_like(self.actions)
.unsqueeze(1)
.repeat_interleave(action_augment_num, dim=1),
},
batch_size=[self.len],
),
)
class QGPODeepMindControlTensorDictDataset(QGPOTensorDictDataset):
def __init__(
self,
path: str,
action_augment_num: int = 16,
):
state_dicts = {}
next_states_dicts = {}
actions_list = []
rewards_list = []
data = np.load(path, allow_pickle=True)
obs_keys = list(data[0]["s"].keys())
for key in obs_keys:
if key not in state_dicts:
state_dicts[key] = []
next_states_dicts[key] = []
state_values = np.array([item["s"][key] for item in data], dtype=np.float32)
next_state_values = np.array(
[item["s_"][key] for item in data], dtype=np.float32
)
state_dicts[key].append(torch.tensor(state_values))
next_states_dicts[key].append(torch.tensor(next_state_values))
actions_values = np.array([item["a"] for item in data], dtype=np.float32)
rewards_values = np.array(
[item["r"] for item in data], dtype=np.float32
).reshape(-1, 1)
actions_list.append(torch.tensor(actions_values))
rewards_list.append(torch.tensor(rewards_values))
self.actions = torch.cat(actions_list, dim=0)
self.rewards = torch.cat(rewards_list, dim=0)
self.len = self.actions.shape[0]
self.states = TensorDict(
{key: torch.cat(state_dicts[key], dim=0) for key in obs_keys},
batch_size=[self.len],
)
self.next_states = TensorDict(
{key: torch.cat(next_states_dicts[key], dim=0) for key in obs_keys},
batch_size=[self.len],
)
self.is_finished = torch.zeros_like(self.rewards, dtype=torch.bool)
self.storage = LazyTensorStorage(max_size=self.len)
self.storage.set(
range(self.len),
TensorDict(
{
"s": self.states,
"a": self.actions,
"r": self.rewards,
"s_": self.next_states,
"fake_a": torch.zeros_like(self.actions)
.unsqueeze(1)
.repeat_interleave(action_augment_num, dim=1),
"fake_a_": torch.zeros_like(self.actions)
.unsqueeze(1)
.repeat_interleave(action_augment_num, dim=1),
"d": self.is_finished,
},
batch_size=[self.len],
),
)
class QGPODeepMindControlVisualTensorDictDataset(torch.utils.data.Dataset):
def __init__(
self,
env_id: str,
policy_type: str,
pixel_size: int,
path: str,
stack_frames: int,
action_augment_num: int = 16,
):
assert env_id in ["cheetah_run", "humanoid_walk", "walker_walk"]
assert policy_type in [
"expert",
"medium",
"medium_expert",
"medium_replay",
"random",
]
assert pixel_size in [64, 84]
if pixel_size == 64:
npz_folder_path = os.path.join(path, env_id, policy_type, "64px")
else:
npz_folder_path = os.path.join(path, env_id, policy_type, "84px")
# find all npz files in the folder
npz_files = [f for f in os.listdir(npz_folder_path) if f.endswith(".npz")]
transition_counter = 0
obs_list = []
action_list = []
reward_list = []
next_obs_list = []
is_finished_list = []
episode_list = []
step_list = []
# open all npz files in the folder
for index, npz_file in enumerate(npz_files):
npz_path = os.path.join(npz_folder_path, npz_file)
data = np.load(npz_path, allow_pickle=True)
length = data["image"].shape[0]
obs = torch.stack(
[
torch.from_numpy(data["image"][i : length - stack_frames + i])
for i in range(stack_frames)
],
dim=1,
)
next_obs = torch.stack(
[
torch.from_numpy(
data["image"][i + 1 : length - stack_frames + i + 1]
)
for i in range(stack_frames)
],
dim=1,
)
action = torch.from_numpy(data["action"][stack_frames:])
reward = torch.from_numpy(data["reward"][stack_frames:])
is_finished = torch.from_numpy(
data["is_last"][stack_frames:] + data["is_terminal"][stack_frames:]
)
episode = torch.tensor([index] * obs.shape[0])
step = torch.arange(obs.shape[0])
transition_counter += obs.shape[0]
obs_list.append(obs)
action_list.append(action)
reward_list.append(reward)
next_obs_list.append(next_obs)
is_finished_list.append(is_finished)
episode_list.append(episode)
step_list.append(step)
if index > 20:
break
self.states = torch.cat(obs_list, dim=0)
self.actions = torch.cat(action_list, dim=0)
self.rewards = torch.cat(reward_list, dim=0)
self.next_states = torch.cat(next_obs_list, dim=0)
self.is_finished = torch.cat(is_finished_list, dim=0)
self.episode = torch.cat(episode_list, dim=0)
self.step = torch.cat(step_list, dim=0)
self.len = self.states.shape[0]
self.storage = LazyMemmapStorage(max_size=self.len)
self.storage.set(
range(self.len),
TensorDict(
{
"s": self.states,
"a": self.actions,
"r": self.rewards,
"s_": self.next_states,
"d": self.is_finished,
"fake_a": torch.zeros_like(self.actions)
.unsqueeze(1)
.repeat_interleave(action_augment_num, dim=1),
"fake_a_": torch.zeros_like(self.actions)
.unsqueeze(1)
.repeat_interleave(action_augment_num, dim=1),
"episode": self.episode,
"step": self.step,
},
batch_size=[self.len],
),
)
def __getitem__(self, index):
"""
Overview:
Get data by index
Arguments:
index (:obj:`int`): Index of data
Returns:
data (:obj:`dict`): Data dict
.. note::
The data dict contains the following keys:
s (:obj:`torch.Tensor`): State
a (:obj:`torch.Tensor`): Action
r (:obj:`torch.Tensor`): Reward
s_ (:obj:`torch.Tensor`): Next state
d (:obj:`torch.Tensor`): Is finished
fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \
(fake action is sampled from the action support generated by the behaviour policy)
fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \
(fake action is sampled from the action support generated by the behaviour policy)
"""
data = self.storage.get(index=index)
return data
def __len__(self):
return self.len
def load_fake_actions(self, fake_actions, fake_next_actions):
self.fake_actions = fake_actions
self.fake_next_actions = fake_next_actions
self.storage.set(
range(self.len),
TensorDict(
{
"s": self.states,
"a": self.actions,
"r": self.rewards,
"s_": self.next_states,
"d": self.is_finished,
"fake_a": self.fake_actions,
"fake_a_": self.fake_next_actions,
"episode": self.episode,
"step": self.step,
},
batch_size=[self.len],
),
)