one_hot 函数功能概述
将类型为 torch.LongTensor 的张量转化为其 one-hot 编码的形式。
此实现的执行效率略高于 torch.nn.functional.one_hot 。
import torch
import torch.nn as nn
def one_hot(val: torch.LongTensor, num: int) -> torch.FloatTensor:
保存原始 val 的形状。
old_shape = val.shape
将 val 改变形状至二维张量。
val_reshape = val.reshape(-1, 1)
初始化结果张量,确定其形状,并设置和 val 在相同的 device 上。
ret = torch.zeros(val_reshape.shape[0], num, device=val.device)
根据 val_reshape 中的值,将若干 1 填入结果张量中。注意,这一步是 in-place 操作(即直接原地改变结果张量的值)。
ret.scatter_(1, val_reshape, 1)
恢复原始形状,并将结果张量返回。
return ret.reshape(*old_shape, num)
get_one_hot_encoding 函数功能概述
使用 torch.nn.Embedding 实现 one-hot 编码。
def get_one_hot_encoding(num: int):
权重矩阵应当设置为大小为 num x num 的单位矩阵。这样对于第 i 行,其内容是只有第 i 维是 1,其它维度都是 0 的向量,恰好就是 one-hot 编码。同时冻结参数,确保权重矩阵不可改变。
return nn.Embedding.from_pretrained(torch.eye(num), freeze=True, padding_idx=None)
get_binary_encoding 函数功能概述
使用 torch.nn.Embedding 实现二进制编码。
def get_binary_encoding(bit_num: int):
生成形状为 $$2^{B} \times B $$ 的矩阵,其中 B 是比特数。
矩阵的第 i 行代表了数字 i 的二进制表达,是一个维度为 B 的向量。
location_embedding = []
for n in range(2 ** bit_num):
s = '0' * (bit_num - len(bin(n)[2:])) + bin(n)[2:]
location_embedding.append(list(int(i) for i in s))
mat = torch.FloatTensor(location_embedding)
使用生成的矩阵作为 embedding 的权重,同时冻结参数确保权重矩阵不可改变。
return torch.nn.Embedding.from_pretrained(mat, freeze=True, padding_idx=None)
test_encoding 函数功能概述
编码函数的主函数。对上述的若干种编码函数进行测试,检查输出的正确性。
def test_encoding():
测试上述两种 one-hot 编码方法,判断它们的输出结果是否一致。
x = torch.LongTensor([9, 0, 1, 2, 1, 3, 5])
one_hot_enc = get_one_hot_encoding(10)
y = one_hot_enc(x)
y_ = one_hot(x, num=10)
assert torch.sum(torch.abs(y - y_)) < 1e-6
测试二进制编码,判断其输出是否等于期望的结果。
bin_enc = get_binary_encoding(2)
x = torch.arange(4)
y = bin_enc(x)
ground_truth = torch.LongTensor([[0, 0], [0, 1], [1, 0], [1, 1]])
assert torch.eq(y, ground_truth).all()
如果读者关于本文档有任何问题和建议,可以在 GitHub 提 issue 或是直接发邮件给我们 (opendilab@pjlab.org.cn) 。