Я изучаю несколько реализаций BiLSTM, основанного на самовнимании, и не понимаю, почему в каждой из них размер ввода и вывода различается. В частности, я имею в виду следующие коды, взятые из разных реализаций:
def attnetwork(self, encoder_out, final_hidden):
# encoder_out shape = (batch_size, seq_len, n_hidden)
# final_hidden shape = (1, batch_size, n_hidden)
hidden = final_hidden.squeeze(0)
attn_weights = torch.bmm(encoder_out, hidden.unsqueeze(2)).squeeze(2)
soft_attn_weights = F.softmax(attn_weights, 1)
new_hidden = torch.bmm(encoder_out.transpose(1,2), soft_attn_weights.unsqueeze(2)).squeeze(2)
return new_hidden # shape = (batch_size, n_hidden)
Как вы можете видеть, эта реализация принимает в качестве входных данных два вектора размерности (batch_size, seq_len, n_hidden)
и (1, batch_size, n_hidden)
, соответственно, и возвращает вектор размерности (batch_size, n_hidden)
. Но где размер относительно seq_len
? Мне нужен выходной вектор, равный входному (т.е. (batch_size, seq_len, n_hidden)
).
Другая реализация, в которой размер ввода не совпадает с размером вывода:
def attention(self,H):
M = torch.tanh(H) # Non-linear transformation size:(batch_size, hidden_dim, seq_len)
a = F.softmax(torch.bmm(self.att_weight,M),dim=2) # a.Size : (batch_size,1,seq_len)
a = torch.transpose(a,1,2) # (batch_size,seq_len,1)
return torch.bmm(H,a) # (batch_size,hidden_dim,1)
Другая реализация с той же проблемой:
def attention(self, rnn_out, state):
merged_state = torch.cat([s for s in state],1)
merged_state = merged_state.squeeze(0).unsqueeze(2)
# (batch, seq_len, cell_size) * (batch, cell_size, 1) = (batch, seq_len, 1)
weights = torch.bmm(rnn_out, merged_state)
weights = torch.nn.functional.softmax(weights.squeeze(2)).unsqueeze(2)
# (batch, cell_size, seq_len) * (batch, seq_len, 1) = (batch, cell_size, 1)
return torch.bmm(torch.transpose(rnn_out, 1, 2), weights).squeeze(2)
Как можно сделать вывод тензора того же размера, что и входной, без нарушения механизма самовнимания?
Спасибо!
РЕДАКТИРОВАТЬ: функция пересылки, которую я должен использовать, такова:
def forward(self, x, x_len):
x = nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=True)
out1, (h_n, c_n) = self.lstm1(x)
# out1 = (seq_len, batch, num_directions * hidden_size)
# h_n = (num_layers * num_directions, batch, hidden_size)
x, lengths = nn.utils.rnn.pad_packed_sequence(out1, batch_first=True)
x, att1 = self.atten1(x, lengths) # skip connect
return x
последний x
в return x
Мне абсолютно необходимо, чтобы он имел форму (batch_size, seq_len, hidden_state)
(obv также в другом порядке, чтобы transpose
было достаточно, чтобы исправить это).