万恶之源
今天偶然之间看到这样一个公式
\[Attention(q_t, K, V) = \sum^m_{s=1} \frac{1}{Z}exp(\frac{<q_t,k_s>}{\sqrt d_k})\cdot v_s
\]
突然满脑子都是问号,Attention的公式我记得:$$Attention(Q,K,V) = Softmax(\frac{Q\cdot K^\top}{\sqrt d_k})V$$
为啥会有个Z?这形式哪里来的?
只看一个query
一个老生常谈的比喻,Q代表用户提出的问题,K代表一个文档库,V代表和K一一对应的问题答案,Attention的计算过程,是比较用户问题和文档库,找到最相近的文章来回答用户问题。
在实际计算过程中,Q、K、V都是矩阵,我们假设$$Q\in R^{n\times d_k}, K\in R^{m\times d_k}, V\in R^{m\times d_v}$$
我们从Q中取出一条query,从K中取出一篇文章key:
\[q_t \in R^{d_k}, k_s \in R^{d_k}
\]
那么这篇文章和query的“相关性”可以被表达为:
\[score(q_t, k_s) = \frac{<q_t. k_s^\top>}{\sqrt d_k}
\]
有了“分数”还不够,分数可能是正的可能是负的,而我们需要的是能回答query的key,需要从K中找到“最重要的一个”,因此需要计算出每个key的“权重”:
\[a_{t,s} = Softmax_s(score(q_t, k_s))= \frac{exp(score(q_t,k_s))}{\sum^m_{j=1}exp(score(q_t, k_j))}
\]
稍微注意一下K的维度,j要从1到m,然后我们可以计算出一篇文章的“结果”:
\[y_s = a_{t,s} \cdot v_s = \frac{exp(\frac{q_t \cdot k_s^\top}{\sqrt d_k})}{\sum_{j=1}^m exp(\frac{q_t\cdot k_j^\top}{\sqrt d_k})} \cdot v_s
\]
那么全局的计算方式也就出来了:
\[y = \sum_{s=1}^m \frac{exp(\frac{q_t\cdot k_s^\top}{\sqrt d_k})}{\sum_{j=1}^m exp(\frac{q_t \cdot k_j^\top}{\sqrt d_k})} \cdot v_s
\]
这时我们可以观察到分母中有一个大块头,甚至会发现这个大块头像是一个“常数”,如果我们换个元就会变成:
\[令Z= \sum_{j=1}^m exp(\frac{q_t\cdot k_j^\top}{\sqrt d_k})
\]
\[y = \sum_{s=1}^m \frac{1}{Z}\cdot exp(\frac{q_t\cdot k_s^\top}{\sqrt d_k})\cdot v_s
\]
