Source code for ding.model.template.language_transformer
from typing import List, Dict, Optional
import torch
from torch import nn
try:
from transformers import AutoTokenizer, AutoModelForTokenClassification
except ImportError:
from ditk import logging
logging.warning("not found transformer, please install it using: pip install transformers")
from ding.utils import MODEL_REGISTRY
[docs]@MODEL_REGISTRY.register('language_transformer')
class LanguageTransformer(nn.Module):
"""
Overview:
The LanguageTransformer network. Download a pre-trained language model and add head on it.
In the default case, we use BERT model as the text encoder, whose bi-directional character is good
for obtaining the embedding of the whole sentence.
Interfaces:
``__init__``, ``forward``
"""
mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
[docs] def __init__(
self,
model_name: str = "bert-base-uncased",
add_linear: bool = False,
embedding_size: int = 128,
freeze_encoder: bool = True,
hidden_dim: int = 768,
norm_embedding: bool = False
) -> None:
"""
Overview:
Init the LanguageTransformer Model according to input arguments.
Arguments:
- model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased".
- add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \
``False``.
- embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
- freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
defaults to be ``True``.
- hidden_dim (:obj:`int`): The embedding dimension of the encoding model (e.g. BERT). This value should \
correspond to the model you use. For bert-base-uncased, this value is 768.
- norm_embedding (:obj:`bool`): Whether to normalize the embedding vectors. Default to be ``False``.
"""
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForTokenClassification.from_pretrained(model_name)
in_channel = hidden_dim if not add_linear else embedding_size
self.value_head = nn.Linear(in_channel, 1)
self.norm = nn.Identity() if not norm_embedding else nn.LayerNorm(
normalized_shape=in_channel, elementwise_affine=False
)
# Freeze transformer encoder and only train the linear layer
if freeze_encoder:
for param in self.model.parameters():
param.requires_grad = False
if add_linear:
# Add a small, adjustable linear layer on top of language model tuned through RL
self.embedding_size = embedding_size
self.linear = nn.Linear(self.model.config.hidden_size, embedding_size)
else:
self.linear = None
def _calc_embedding(self, x: list) -> torch.Tensor:
# ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer,
# the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach
# the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is
# exactly ``max_length``, which can enable batch-wise computing.
input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device)
output = self.model(**input, output_hidden_states=True)
# Get last layer hidden states
last_hidden_states = output.hidden_states[-1]
# Get [CLS] hidden states
sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size
sentence_embedding = self.norm(sentence_embedding)
if self.linear:
sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size
return sentence_embedding
[docs] def forward(
self,
train_samples: List[str],
candidate_samples: Optional[List[str]] = None,
mode: str = 'compute_actor'
) -> Dict:
"""
Overview:
LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores.
Different ``mode`` will forward with different network modules to get different outputs.
Arguments:
- train_samples (:obj:`List[str]`): One list of strings.
- candidate_samples (:obj:`Optional[List[str]]`): The other list of strings to calculate matching scores.
- - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
Returns:
- output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \
corresponding ``torch.distributions.Categorical`` object.
Examples:
>>> test_pids = [1]
>>> cand_pids = [0, 2, 4]
>>> problems = [ \
"This is problem 0", "This is the first question", "Second problem is here", "Another problem", \
"This is the last problem" \
]
>>> ctxt_list = [problems[pid] for pid in test_pids]
>>> cands_list = [problems[pid] for pid in cand_pids]
>>> model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
>>> scores = model(ctxt_list, cands_list)
>>> assert scores.shape == (1, 3)
"""
assert mode in self.mode
prompt_embedding = self._calc_embedding(train_samples)
res_dict = {}
if mode in ['compute_actor', 'compute_actor_critic']:
cands_embedding = self._calc_embedding(candidate_samples)
scores = torch.mm(prompt_embedding, cands_embedding.t())
res_dict.update({'dist': torch.distributions.Categorical(logits=scores), 'logit': scores})
if mode in ['compute_critic', 'compute_actor_critic']:
value = self.value_head(prompt_embedding)
res_dict.update({'value': value})
return res_dict