Я пытаюсь реализовать внимание, описанное в Luong et al. 2015 в PyTorch, но я не мог заставить его работать. Ниже мой код, меня пока интересует только «общий» случай внимания. Интересно, не упускаю ли я какой-нибудь очевидной ошибки. Он работает, но, похоже, не учится.
class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.embedding = nn.Embedding(
num_embeddings=self.output_size,
embedding_dim=self.hidden_size
)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.attn = nn.Linear(self.hidden_size, self.hidden_size)
# hc: [hidden, context]
self.Whc = nn.Linear(self.hidden_size * 2, self.hidden_size)
# s: softmax
self.Ws = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
embedded = self.embedding(input).view(1, 1, -1)
embedded = self.dropout(embedded)
gru_out, hidden = self.gru(embedded, hidden)
# [0] remove the dimension of directions x layers for now
attn_prod = torch.mm(self.attn(hidden)[0], encoder_outputs.t())
attn_weights = F.softmax(attn_prod, dim=1) # eq. 7/8
context = torch.mm(attn_weights, encoder_outputs)
# hc: [hidden: context]
out_hc = F.tanh(self.Whc(torch.cat([hidden[0], context], dim=1)) # eq.5
output = F.log_softmax(self.Ws(out_hc), dim=1) eq. 6
return output, hidden, attn_weights
Я изучил внимание, реализованное в
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
а также
https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb
- Первый не тот механизм внимания, который я ищу. Основным недостатком является то, что его внимание зависит от длины последовательности (
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
), что может быть дорогостоящим для длинных последовательностей. - Второй больше похож на то, что описано в статье, но все же не такой, как нет
tanh
. Кроме того, после обновления до последней версии pytorch он работает очень медленно (ref) . Также я не знаю, почему он принимает последний контекст (ref).