RuntimeError: cudnn RNN backward можно вызвать только в режиме обучения

Я впервые столкнулся с этой проблемой, никогда не сталкивался с такой ошибкой в ​​предыдущих проектах Python. Вот мой обучающий код:

def train(net, opt, criterion,ucf_train, batchsize,i):
    opt.zero_grad()
    total_loss = 0
    net=net.eval()
    net=net.train()
    for vid in range(i*batchsize,i*batchsize+batchsize,1):
    
        output=infer(net,ucf_train[vid])
        m=get_label_no(ucf_train[vid])
        m=m.cuda( )
        loss = criterion(output,m)
        loss.backward(retain_graph=True)
        total_loss += loss 
        opt.step()       #updates wghts and biases

    return total_loss/n_points

код для infer (net, input)

def infer(net, name):
    net.eval()
    hidden_0 = net.init_hidden()
    hidden_1 = net.init_hidden()
    hidden_2 = net.init_hidden()
    video_path = fetch_ucf_video(name)
    cap = cv2.VideoCapture(video_path)
    resize=(224,224)
    T=FrameCapture(video_path)
    print(T)
    lim=T-(T%20)-2
    i=0
    while(1):
      ret, frame2 = cap.read()
      frame2= cv2.resize(frame2, resize)
    #  print(type(frame2))
      if (i%20==0 and i<lim):
          input=normalize(frame2)     
          input=input.cuda()       
          output,hidden_0,hidden_1, hidden_2  = net(input, hidden_0, hidden_1, hidden_2)
      elif (i>=lim):
          break
      i=i+1 
    op=output  
    torch.cuda.empty_cache() 
    op=op.cuda() 
    return op 

Я получаю эту ошибку, я пробовал model.train() после this, где net - моя модель:

 RuntimeError                              Traceback (most recent call last)
<ipython-input-62-42238f3f6877> in <module>()
----> 1 train(net1,opt,criterion,ucf_train,1,0)

2 frames
/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    125     Variable._execution_engine.run_backward(
    126         tensors, grad_tensors, retain_graph, create_graph,
--> 127         allow_unreachable=True)  # allow_unreachable flag
    128 
    129 

RuntimeError: cudnn RNN backward can only be called in training mode

person ashwin    schedule 18.08.2020    source источник


Ответы (1)


Вам следует удалить вызов net.eval(), который идет сразу после def infer(net, name):

Его необходимо удалить, потому что вы вызываете эту функцию вывода внутри своего обучающего кода. Ваша модель должна находиться в режиме обучения на протяжении всего обучения.

И вы никогда не настраиваете свою модель обратно на обучение после вызова eval, так что это корень исключения, которое вы получаете. Если вы хотите использовать этот код вывода в своих тестовых примерах, вы можете покрыть этот случай с помощью if.

Также net.eval(), который идет сразу после присваивания total_loss=0, бесполезен, так как вы вызываете net.train() сразу после этого. Вы также можете удалить его, так как он будет нейтрализован прямо в следующей строке.

Обновленный код

def train(net, opt, criterion,ucf_train, batchsize,i):
    opt.zero_grad()
    total_loss = 0
    net=net.train()
    for vid in range(i*batchsize,i*batchsize+batchsize,1):
        output=infer(net,ucf_train[vid])
        m=get_label_no(ucf_train[vid])
        m=m.cuda( )
        loss = criterion(output,m)
        loss.backward(retain_graph=True)
        total_loss += loss 
        opt.step()       #updates wghts and biases

    return total_loss/n_points

код для infer (net, input)

def infer(net, name, is_train=True):
    if not is_train:
        net.eval()
    hidden_0 = net.init_hidden()
    hidden_1 = net.init_hidden()
    hidden_2 = net.init_hidden()
    video_path = fetch_ucf_video(name)
    cap = cv2.VideoCapture(video_path)
    resize=(224,224)
    T=FrameCapture(video_path)
    print(T)
    lim=T-(T%20)-2
    i=0
    while(1):
      ret, frame2 = cap.read()
      frame2= cv2.resize(frame2, resize)
      #  print(type(frame2))
      if (i%20==0 and i<lim):
          input=normalize(frame2)     
          input=input.cuda()       
          output,hidden_0,hidden_1, hidden_2  = net(input, hidden_0, hidden_1, hidden_2)
      elif (i>=lim):
          break
      i=i+1 
    op=output  
    torch.cuda.empty_cache() 
    op=op.cuda() 
    return op 
person BedirYilmaz    schedule 18.08.2020