Я пытался решить задание 1 для CS 231n, и у меня возникли проблемы с реализацией градиента для softmax. Обратите внимание: я не записан на курс и занимаюсь им только в учебных целях. Сначала я рассчитал градиент вручную, и мне кажется, что это нормально, и я реализовал градиент, как показано ниже, но когда код запускается против числового градиента, результаты не совпадают, я хочу понять, где я ошибаюсь в этой реализации, если кто-нибудь может помочь мне прояснить это ясно.
Спасибо.
Код:
def softmax_loss_naive(W, X, y, reg):
"""
Softmax loss function, naive implementation (with loops)
Inputs have dimension D, there are C classes, and we operate on minibatches
of N examples.
Inputs:
- W: A numpy array of shape (D, C) containing weights.
- X: A numpy array of shape (N, D) containing a minibatch of data.
- y: A numpy array of shape (N,) containing training labels; y[i] = c means
that X[i] has label c, where 0 <= c < C.
- reg: (float) regularization strength
Returns a tuple of:
- loss as single float
- gradient with respect to weights W; an array of same shape as W
"""
# Initialize the loss and gradient to zero.
loss = 0.0
dW = np.zeros_like(W)
#############################################################################
# TODO: Compute the softmax loss and its gradient using explicit loops. #
# Store the loss in loss and the gradient in dW. If you are not careful #
# here, it is easy to run into numeric instability. Don't forget the #
# regularization! #
#############################################################################
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
num_train = X.shape[0]
for i in range(num_train):
score = X[i].dot(W)
correct = y[i]
denom_sum = 0.0
num = 0.0
for j,s in enumerate(score):
denom_sum += np.exp(s)
if j == correct:
num = np.exp(s)
else:
dW[:,j] = X[i].T * np.exp(s)
loss += -np.log(num / denom_sum)
dW[:, correct] += -X[i].T * ( (denom_sum - num) )
dW = dW / (denom_sum)
loss = loss / (num_train)
dW /= num_train
loss += reg * np.sum(W * W)
dW += reg * 2 * W
return loss, dW