基礎(chǔ)微網(wǎng)站開發(fā)口碑好seo基礎(chǔ)入門
在多頭注意力機制中,通常輸入的數(shù)據(jù)包括查詢(Q)、鍵(K)和值(V)。這些數(shù)據(jù)的維度以及權(quán)重矩陣的維度在多頭注意力機制中扮演關(guān)鍵角色。下面對數(shù)據(jù)及權(quán)重的維度進(jìn)行解釋:
-
輸入數(shù)據(jù)(Queries, Keys, Values):
- Queries (Q): 表示待查詢的信息,通常對應(yīng)輸入序列的每個位置。其維度通常為 (batch_size, seq_length, q_dim),其中
q_dim
是查詢向量的維度。 - Keys (K): 表示用于計算注意力分?jǐn)?shù)的信息,也通常對應(yīng)輸入序列的每個位置。其維度通常為 (batch_size, seq_length, key_dim),其中
key_dim
是鍵向量的維度。 - Values (V): 表示待加權(quán)求和的信息,同樣對應(yīng)輸入序列的每個位置。其維度通常為 (batch_size, seq_length, value_dim),其中
value_dim
是值向量的維度。
- Queries (Q): 表示待查詢的信息,通常對應(yīng)輸入序列的每個位置。其維度通常為 (batch_size, seq_length, q_dim),其中
-
權(quán)重矩陣:
- 查詢權(quán)重矩陣 (Q_weights): 用于對查詢(Q)進(jìn)行線性變換,將其映射到多個注意力頭的維度。其維度通常為 (q_dim, num_heads,?head_dim),其中
num_heads
是注意力頭的數(shù)量,head_dim
是每個注意力頭的維度。 - 鍵權(quán)重矩陣 (K_weights): 用于對鍵(K)進(jìn)行線性變換,同樣映射到多個注意力頭的維度。其維度通常為 (key_dim, num_heads,?head_dim)。
- 值權(quán)重矩陣 (V_weights): 用于對值(V)進(jìn)行線性變換,映射到多個注意力頭的維度。其維度通常為 (value_dim, num_heads,?head_dim)。
- 查詢權(quán)重矩陣 (Q_weights): 用于對查詢(Q)進(jìn)行線性變換,將其映射到多個注意力頭的維度。其維度通常為 (q_dim, num_heads,?head_dim),其中
def glorot_uniform():return hk.initializers.VarianceScaling(scale=1.0,mode='fan_avg',distribution='uniform')def stable_softmax(logits: jax.Array) -> jax.Array:"""Numerically stable softmax for (potential) bfloat 16."""if logits.dtype == jnp.float32:output = jax.nn.softmax(logits)elif logits.dtype == jnp.bfloat16:# Need to explicitly do softmax in float32 to avoid numerical issues# with large negatives. Large negatives can occur if trying to mask# by adding on large negative logits so that things softmax to zero.output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)else:raise ValueError(f'Unexpected input dtype {logits.dtype}')return outputclass Attention(hk.Module):"""Multihead attention."""def __init__(self, config, global_config, output_dim, name='attention'):super().__init__(name=name)self.config = configself.global_config = global_configself.output_dim = output_dimdef __call__(self, q_data, m_data, mask, nonbatched_bias=None):"""Builds Attention module.Arguments:q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].m_data: A tensor of memories from which the keys and values areprojected, shape [batch_size, N_keys, m_channels].mask: A mask for the attention, shape [batch_size, N_queries, N_keys].nonbatched_bias: Shared bias, shape [N_queries, N_keys].Returns:A float32 tensor of shape [batch_size, N_queries, output_dim]."""# Sensible default for when the config keys are missingkey_dim = self.config.get('key_dim', int(q_data.shape[-1]))value_dim = self.config.get('value_dim', int(m_data.shape[-1]))num_head = self.config.num_headassert key_dim % num_head == 0assert value_dim % num_head == 0key_dim = key_dim // num_headvalue_dim = value_dim // num_head# weights維度(數(shù)據(jù)最后一維的維度數(shù),注意力頭數(shù)量,每個注意力頭映射的數(shù)據(jù)維度)q_weights = hk.get_parameter('query_w', shape=(q_data.shape[-1], num_head, key_dim),dtype=q_data.dtype,init=glorot_uniform())k_weights = hk.get_parameter('key_w', shape=(m_data.shape[-1], num_head, key_dim),dtype=q_data.dtype,init=glorot_uniform())v_weights = hk.get_parameter('value_w', shape=(m_data.shape[-1], num_head, value_dim),dtype=q_data.dtype,init=glorot_uniform())# bqa: 輸入張量 q_data 的軸的標(biāo)記。(batch_size, seq_length, q_dim)# 'b' :batch 維度,'q':查詢序列維度,'a' 查詢向量的維度。所以,'bqa' 表示 q_data 的三個軸。# ahc:查詢權(quán)重矩陣的形狀, a:查詢向量的維度,h:注意力頭的數(shù)量,c: 每個注意力頭中查詢的維度。# key_dim**(-0.5) 注意力縮放,避免注意力分?jǐn)?shù)過大或過小# jnp.einsum:Einstein Summation Notation(愛因斯坦求和約定)。# 一種緊湊、靈活的方式來指定和計算張量的乘積、求和和轉(zhuǎn)置等操作。q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)# 注意力分?jǐn)?shù),計算每個查詢(q)和鍵(k)之間的點積,以獲得注意力分?jǐn)?shù)。# 結(jié)果維度為bhqk (batch_size, num_heads, num_q, num_k),?# num_q/num_k為查詢/鍵的數(shù)量,一般為 seq_length。logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)if nonbatched_bias is not None:logits += jnp.expand_dims(nonbatched_bias, axis=0)# 注意力分?jǐn)?shù)中加入masklogits = jnp.where(mask, logits, _SOFTMAX_MASK)# 對注意力分?jǐn)?shù)進(jìn)行softmax操作,我們得到每個位置對輸入序列的權(quán)重分配。weights = stable_softmax(logits)# 注意力分?jǐn)?shù)對值進(jìn)行加權(quán)求和,得到多頭注意力機制的輸出# 兩個向量的點積可以用于度量它們之間的相似性。如果兩個向量越相似,它們的點積就越大weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)if self.global_config.zero_init:init = hk.initializers.Constant(0.0)else:init = glorot_uniform()# 帶有bias的門控注意力if self.config.gating:gating_weights = hk.get_parameter('gating_w',shape=(q_data.shape[-1], num_head, value_dim),dtype=q_data.dtype,init=hk.initializers.Constant(0.0))gating_bias = hk.get_parameter('gating_b',shape=(num_head, value_dim),dtype=q_data.dtype,init=hk.initializers.Constant(1.0))gate_values = jnp.einsum('bqc, chv->bqhv', q_data,gating_weights) + gating_biasgate_values = jax.nn.sigmoid(gate_values)# ⊙ 對應(yīng)元素相乘weighted_avg *= gate_valueso_weights = hk.get_parameter('output_w', shape=(num_head, value_dim, self.output_dim),dtype=q_data.dtype,init=init)o_bias = hk.get_parameter('output_b', shape=(self.output_dim,),dtype=q_data.dtype,init=hk.initializers.Constant(0.0))# 線性變換到輸出維度大小output = jnp.einsum('bqhc,hco->bqo', weighted_avg, o_weights) + o_biasreturn output