Как извлечь функции из тонкой модели VGG tensorflow при выполнении прямого бега?

Я обучил модель классификации с помощью тонкой модели TensorFlow vgg, используя CASIA (набор данных распознавания лиц) в качестве набора данных для обучения. Я хочу протестировать модель, используя набор данных LFW, это задача сопоставления лиц. поэтому мне нужно извлечь сетевые функции, такие как fc7 / fc8, а не слой softmax, и сравнить расстояние между функциями, чтобы определить, принадлежат ли они одному и тому же человеку. Как я могу извлечь черты тонкой модели?

Вот часть обучающего кода.

import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets import vgg 
slim = tf.contrib.slim
FLAGS = tf.app.flags.FLAGS

def tower_loss(scope):
    images, labels = read_and_decode()
    with slim.arg_scope(vgg.vgg_arg_scope()):
        logits, end_points = vgg.vgg_16(images, num_classes=FLAGS.num_classes)
    _ = cal_loss(logits, labels)
    losses = tf.get_collection('losses', scope)
    total_loss = tf.add_n(losses, name='total_loss')
    return total_loss

person Yimu    schedule 16.05.2017    source источник


Ответы (2)


Вы можете попробовать использовать tf.get_default_graph().get_tensor_by_name("VGG16/fc16:0") или любое другое тензорное имя конкретной функции, которую вы хотите извлечь.

Чтобы проверить имя извлекаемых тензоров, вы можете попробовать

for operation in graph.get_operations():
    print operation.values()

Не забудьте поставить :0 в конце имен, поскольку они указывают на то, что извлекаемый вами элемент является тензором.

person kwotsin    schedule 24.05.2017

Получите end_points тонкой модели и извлеките функцию.

person Yimu    schedule 26.05.2017