nn.embedding原理详解

nn.embedding

nn.Embedding 是 PyTorch 中用于处理序列数据中的词嵌入(word embeddings)的核心模块。它本质上是一个查找表,将输入的离散型数据(通常是整数形式的单词索引)映射为连续型的数据表示(即词向量)。这种转换在自然语言处理(NLP)、推荐系统等领域中非常常见。

函数解释

当你调用 nn.Embedding(vocab_size, num_hiddens) 时,你正在初始化一个嵌入层,其中:

  • vocab_size:这是你的词汇表大小,也就是你希望这个嵌入层能够支持的最大单词索引值加一(因为索引是从0开始的)。例如,如果你有一个包含10,000个不同单词的词汇表,那么 vocab_size 应该设置为10,000。
  • num_hiddens:这是每个单词对应的嵌入向量的维度。这个值决定了每个单词被表示为一个多维空间中的点,其坐标数量就是 num_hiddens 的值。通常,这个值可以根据具体任务和模型的需求来选择,比如50、100或300等。

除了这两个主要参数之外,nn.Embedding 还接受其他一些可选参数,如 padding_idxmax_norm 等,这些参数可以用来控制嵌入层的行为,比如指定填充标记的索引,或者限制嵌入向量的最大范数等。

示例代码

下面是一个简单的例子,展示了如何使用 nn.Embedding 来创建一个嵌入层,并将一批单词索引转换为对应的词向量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import torch.nn as nn

# 初始化一个嵌入层,假设词汇表大小为5,每个单词的嵌入维度为3
embedding = nn.Embedding(5, 3)

# 创建一批输入数据,这里我们有两个句子,每个句子有4个单词,单词索引分别为[1, 2, 4, 3]和[4, 3, 2, 1]
x = torch.LongTensor([[1, 2, 4, 3], [4, 3, 2, 1]])

# 使用嵌入层获取对应的词向量
y = embedding(x)

print('权重:\n', embedding.weight)
print('输出:')
print(y)

上述代码输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
权重:
Parameter containing:
tensor([[ 1.2475, 0.2461, -0.1228],
[ 0.5988, -2.0277, -1.4456],
[ 1.2011, 0.2131, -0.9624],
[ 1.2717, 1.7339, 1.2558],
[-0.3740, 2.0479, 0.6131]], requires_grad=True)
输出:
tensor([[[ 0.5988, -2.0277, -1.4456],
[ 1.2011, 0.2131, -0.9624],
[-0.3740, 2.0479, 0.6131],
[ 1.2717, 1.7339, 1.2558]],

[[-0.3740, 2.0479, 0.6131],
[ 1.2717, 1.7339, 1.2558],
[ 1.2011, 0.2131, -0.9624],
[ 0.5988, -2.0277, -1.4456]]], grad_fn=<EmbeddingBackward0>)

在这个例子中,embedding 层被初始化为一个10x3的矩阵,意味着它可以表示最多10个不同的单词,每个单词由一个三维向量表示。输入 x 是一个二维张量,包含了两个句子的单词索引,形状为 [2, 4],即每句话有4个单词。通过 embedding(x) 操作,我们可以得到一个新的张量 y,其形状为 [2, 4, 3],表示每个句子中的每个单词都被替换成了相应的3维词向量。

值得注意的是,在实际应用中,nn.Embedding 层通常作为神经网络的一部分,与其他层(如RNN、LSTM或Transformer等)一起训练,以学习到更有效的词表示。此外,有时我们会使用预训练的词向量(如Word2Vec或GloVe),并通过设置 _weight 参数将其加载到 nn.Embedding 中,同时设置 requires_grad=False 来固定这些预训练的词向量不参与后续的训练过程。

总之,nn.Embedding 是构建深度学习模型特别是涉及文本处理任务的重要组件之一。正确地配置和利用它可以极大地提升模型对文本数据的理解能力。

嵌入表示缩放

由于这里使用的是值范围在-1和1之间的固定位置编码,因此通过学习得到的输入的嵌入表示的值需要先乘以嵌入维度的平方根进行重新缩放,然后再与位置编码相加。

1
2
3
4
# 因为位置编码值在-1和1之间,
# 因此嵌入值乘以嵌入维度的平方根进行缩放,
# 然后再与位置编码相加。
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))

nn.embedding原理详解
https://cosmoliu2002.github.io/posts/embedding-detail/
作者
LiuYu
发布于
2025年3月6日
许可协议