TFLearn: ошибка при загрузке 2 разных сохраненных моделей одна за другой

У меня есть 2 разные модели нейронных сетей, обученные и сохраненные с помощью TFLearn. Когда я запускаю каждый скрипт, сохраненные модели загружаются правильно. Мне нужна система, в которой вторая модель должна вызываться после вывода первой модели. Но когда я пытаюсь загрузить вторую модель после загрузки первой модели, это дает мне следующую ошибку:

NotFoundError (трассировку см. выше): ключ val_loss_2 не найден в контрольной точке [[Node: save_6/RestoreV2_42 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0" ](_arg_save_6/Const_0_0, save_6/RestoreV2_42/tensor_names, save_6/RestoreV2_42/shape_and_slices)]]

Вторая модель загружается правильно, если я закомментирую загрузку первой модели или если я запущу два сценария отдельно. Любая идея, почему эта ошибка происходит?

Структура кода примерно такая..

from second_model_file import check_second_model

def run_first_model(input):
    features = convert_to_features(input)
    model = tflearn.DNN(get_model())
    model.load("model1_path/model1")   # relative path
    pred = model.predict(features)
    ...
    if pred == certain_value:
       check_second_model()

second_model_file.py что-то похожее:

def check_second_model():
    input_var = get_input_var()
    model2 = tflearn.DNN(regression_model())
    model2.load("model2_path/model2")   # relative path  
    pred = model2.predict(input_var)
    #other stuff  ......     

Модели были сохранены в разных папках, поэтому у каждой есть свой checkpoint файл.


person Anakin    schedule 18.10.2017    source источник


Ответы (1)


Ну хорошо, я нашел решение. Оно было скрыто в обсуждении в этой ветке. Я использовал tf.reset_default_graph() перед построением второй сети и модели, и это сработало. Надеюсь, это поможет кому-то еще.

Новый код:

import tensorflow as tf

def check_second_model():
    input_var = get_input_var()
    tf.reset_default_graph()
    model2 = tflearn.DNN(regression_model())
    model2.load("model2_path/model2")   # relative path  
    pred = model2.predict(input_var)

Хотя я интуитивно понимаю, почему это решение работает, я был бы рад, если бы кто-нибудь объяснил мне лучше, почему оно разработано таким образом.

person Anakin    schedule 19.10.2017
comment
что такое regression_model()? - person rkatkam; 22.06.2019