Shortcuts

Source code for ding.model.template.coma

from typing import Dict, Union
import torch
import torch.nn as nn

from functools import reduce
from ding.torch_utils import one_hot, MLP
from ding.utils import squeeze, list_split, MODEL_REGISTRY, SequenceType
from .q_learning import DRQN


class COMAActorNetwork(nn.Module):
    """
    Overview:
        Decentralized actor network in COMA algorithm.
    Interface:
         ``__init__``, ``forward``
    """

    def __init__(
        self,
        obs_shape: int,
        action_shape: int,
        hidden_size_list: SequenceType = [128, 128, 64],
    ):
        """
        Overview:
            Initialize COMA actor network
        Arguments:
            - obs_shape (:obj:`int`): the dimension of each agent's observation state
            - action_shape (:obj:`int`): the dimension of action shape
            - hidden_size_list (:obj:`list`): the list of hidden size, default to [128, 128, 64]
        """
        super(COMAActorNetwork, self).__init__()
        self.main = DRQN(obs_shape, action_shape, hidden_size_list)

    def forward(self, inputs: Dict) -> Dict:
        """
        Overview:
            The forward computation graph of COMA actor network
        Arguments:
            - inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state']
            - agent_state (:obj:`torch.Tensor`): each agent local state(obs)
            - action_mask (:obj:`torch.Tensor`): the masked action
            - prev_state (:obj:`torch.Tensor`): the previous hidden state
        Returns:
            - output (:obj:`dict`): output data dict with keys ['logit', 'next_state', 'action_mask']
        ArgumentsKeys:
            - necessary: ``obs`` { ``agent_state``, ``action_mask`` }, ``prev_state``
        ReturnsKeys:
            - necessary: ``logit``, ``next_state``, ``action_mask``
        Examples:
            >>> T, B, A, N = 4, 8, 3, 32
            >>> embedding_dim = 64
            >>> action_dim = 6
            >>> data = torch.randn(T, B, A, N)
            >>> model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim])
            >>> prev_state = [[None for _ in range(A)] for _ in range(B)]
            >>> for t in range(T):
            >>>     inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state}
            >>>     outputs = model(inputs)
            >>>     logit, prev_state = outputs['logit'], outputs['next_state']
        """
        agent_state = inputs['obs']['agent_state']
        prev_state = inputs['prev_state']
        if len(agent_state.shape) == 3:  # B, A, N
            agent_state = agent_state.unsqueeze(0)
            unsqueeze_flag = True
        else:
            unsqueeze_flag = False
        T, B, A = agent_state.shape[:3]
        agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
        prev_state = reduce(lambda x, y: x + y, prev_state)
        output = self.main({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
        logit, next_state = output['logit'], output['next_state']
        next_state, _ = list_split(next_state, step=A)
        logit = logit.reshape(T, B, A, -1)
        if unsqueeze_flag:
            logit = logit.squeeze(0)
        return {'logit': logit, 'next_state': next_state, 'action_mask': inputs['obs']['action_mask']}


class COMACriticNetwork(nn.Module):
    """
    Overview:
        Centralized critic network in COMA algorithm.
    Interface:
         ``__init__``, ``forward``
    """

    def __init__(
        self,
        input_size: int,
        action_shape: int,
        hidden_size: int = 128,
    ):
        """
        Overview:
            initialize COMA critic network
        Arguments:
            - input_size (:obj:`int`): the size of input global observation
            - action_shape (:obj:`int`): the dimension of action shape
            - hidden_size_list (:obj:`list`): the list of hidden size, default to 128
        Returns:
            - output (:obj:`dict`): output data dict with keys ['q_value']
        Shapes:
            - obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)`
            - prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]`
            - logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)`
            - next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]`
            - action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)`
        """
        super(COMACriticNetwork, self).__init__()
        self.action_shape = action_shape
        self.act = nn.ReLU()
        self.mlp = nn.Sequential(
            MLP(input_size, hidden_size, hidden_size, 2, activation=self.act), nn.Linear(hidden_size, action_shape)
        )

    def forward(self, data: Dict) -> Dict:
        """
        Overview:
            forward computation graph of qmix network
        Arguments:
            - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action']
            - agent_state (:obj:`torch.Tensor`): each agent local state(obs)
            - global_state (:obj:`torch.Tensor`): global state(obs)
            - action (:obj:`torch.Tensor`): the masked action
        ArgumentsKeys:
            - necessary: ``obs`` { ``agent_state``, ``global_state`` }, ``action``, ``prev_state``
        ReturnsKeys:
            - necessary: ``q_value``
        Examples:
            >>> agent_num, bs, T = 4, 3, 8
            >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
            >>> coma_model = COMACriticNetwork(
            >>>     obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim)
            >>> data = {
            >>>     'obs': {
            >>>         'agent_state': torch.randn(T, bs, agent_num, obs_dim),
            >>>         'global_state': torch.randn(T, bs, global_obs_dim),
            >>>     },
            >>>     'action': torch.randint(0, action_dim, size=(T, bs, agent_num)),
            >>> }
            >>> output = coma_model(data)
        """
        x = self._preprocess_data(data)
        q = self.mlp(x)
        return {'q_value': q}

    def _preprocess_data(self, data: Dict) -> torch.Tensor:
        """
        Overview:
            preprocess data to make it can be used by MLP net
        Arguments:
            - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action']
            - agent_state (:obj:`torch.Tensor`): each agent local state(obs)
            - global_state (:obj:`torch.Tensor`): global state(obs)
            - action (:obj:`torch.Tensor`): the masked action
        ArgumentsKeys:
            - necessary: ``obs`` { ``agent_state``, ``global_state``} , ``action``, ``prev_state``
        Return:
            - x (:obj:`torch.Tensor`): the data can be used by MLP net, including \
                ``global_state``, ``agent_state``, ``last_action``, ``action``, ``agent_id``
        """
        t_size, batch_size, agent_num = data['obs']['agent_state'].shape[:3]
        agent_state_ori, global_state = data['obs']['agent_state'], data['obs']['global_state']

        # splite obs, last_action and agent_id
        agent_state = agent_state_ori[..., :-self.action_shape - agent_num]
        last_action = agent_state_ori[..., -self.action_shape - agent_num:-agent_num]
        last_action = last_action.reshape(t_size, batch_size, 1, -1).repeat(1, 1, agent_num, 1)
        agent_id = agent_state_ori[..., -agent_num:]

        action = one_hot(data['action'], self.action_shape)  # T, B, A,N
        action = action.reshape(t_size, batch_size, -1, agent_num * self.action_shape).repeat(1, 1, agent_num, 1)
        action_mask = (1 - torch.eye(agent_num).to(action.device))
        action_mask = action_mask.view(-1, 1).repeat(1, self.action_shape).view(agent_num, -1)  # A, A*N
        action = (action_mask.unsqueeze(0).unsqueeze(0)) * action  # T, B, A, A*N
        global_state = global_state.unsqueeze(2).repeat(1, 1, agent_num, 1)

        x = torch.cat([global_state, agent_state, last_action, action, agent_id], -1)
        return x


[docs]@MODEL_REGISTRY.register('coma') class COMA(nn.Module): """ Overview: The network of COMA algorithm, which is QAC-type actor-critic. Interface: ``__init__``, ``forward`` Properties: - mode (:obj:`list`): The list of forward mode, including ``compute_actor`` and ``compute_critic`` """ mode = ['compute_actor', 'compute_critic']
[docs] def __init__( self, agent_num: int, obs_shape: Dict, action_shape: Union[int, SequenceType], actor_hidden_size_list: SequenceType ) -> None: """ Overview: initialize COMA network Arguments: - agent_num (:obj:`int`): the number of agent - obs_shape (:obj:`Dict`): the observation information, including agent_state and \ global_state - action_shape (:obj:`Union[int, SequenceType]`): the dimension of action shape - actor_hidden_size_list (:obj:`SequenceType`): the list of hidden size """ super(COMA, self).__init__() action_shape = squeeze(action_shape) actor_input_size = squeeze(obs_shape['agent_state']) critic_input_size = squeeze(obs_shape['agent_state']) + squeeze(obs_shape['global_state']) + \ agent_num * action_shape + (agent_num - 1) * action_shape critic_hidden_size = actor_hidden_size_list[-1] self.actor = COMAActorNetwork(actor_input_size, action_shape, actor_hidden_size_list) self.critic = COMACriticNetwork(critic_input_size, action_shape, critic_hidden_size)
[docs] def forward(self, inputs: Dict, mode: str) -> Dict: """ Overview: forward computation graph of COMA network Arguments: - inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] - agent_state (:obj:`torch.Tensor`): each agent local state(obs) - global_state (:obj:`torch.Tensor`): global state(obs) - action (:obj:`torch.Tensor`): the masked action ArgumentsKeys: - necessary: ``obs`` { ``agent_state``, ``global_state``, ``action_mask`` }, ``action``, ``prev_state`` ReturnsKeys: - necessary: - compute_critic: ``q_value`` - compute_actor: ``logit``, ``next_state``, ``action_mask`` Shapes: - obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` - prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` - logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` - next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` - action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` - q_value (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` Examples: >>> agent_num, bs, T = 4, 3, 8 >>> agent_num, bs, T = 4, 3, 8 >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 >>> coma_model = COMA( >>> agent_num=agent_num, >>> obs_shape=dict(agent_state=(obs_dim, ), global_state=(global_obs_dim, )), >>> action_shape=action_dim, >>> actor_hidden_size_list=[128, 64], >>> ) >>> prev_state = [[None for _ in range(agent_num)] for _ in range(bs)] >>> data = { >>> 'obs': { >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), >>> 'action_mask': None, >>> }, >>> 'prev_state': prev_state, >>> } >>> output = coma_model(data, mode='compute_actor') >>> data= { >>> 'obs': { >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), >>> 'global_state': torch.randn(T, bs, global_obs_dim), >>> }, >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), >>> } >>> output = coma_model(data, mode='compute_critic') """ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) if mode == 'compute_actor': return self.actor(inputs) elif mode == 'compute_critic': return self.critic(inputs)