Переменные с динамической формой TensorFlow

Мне нужно создать матрицу в TensorFlow для хранения некоторых значений. Хитрость в том, что матрица должна поддерживать динамическую форму.

Я пытаюсь сделать то же самое, что и в numpy:

myVar = tf.Variable(tf.zeros((x,y), validate_shape=False)

где x=(?) и y=2. Но это не работает, потому что нули не поддерживают «частично известный TensorShape», так что как мне это сделать в TensorFlow?


person gergf    schedule 06.04.2017    source источник
comment
Зачем нужна динамическая форма? И вы не можете исправить это, используя None в качестве дескриптора формы?   -  person rmeertens    schedule 07.04.2017
comment
Потому что моя матрица зависит от количества образцов в партии, которое может меняться. Насколько я знаю, ни tf.zeros, ни np.zeros не принимают None в форме.   -  person gergf    schedule 07.04.2017
comment
Ах я вижу. Могу я спросить, что вы хотите сделать с этой матрицей??   -  person rmeertens    schedule 07.04.2017
comment
Конечно. Я программирую weighted_softmax для семантической сегментации. Я хочу взвесить каждый класс по его априору, поэтому я вычисляю априор каждого изображения, когда получаю истинные метки в функции потерь: мне нужна матрица для хранения этих априорных значений. Было бы проще, если бы функция потерь могла получать дополнительные параметры, поэтому я могу вычислить априорные значения с помощью Numpy и передать его TensorFlow, но я не знаю, как это сделать.   -  person gergf    schedule 07.04.2017


Ответы (2)


1) Вы можете использовать tf.fill(dims, value=0.0), который работает с динамическими формами.

2) Вы можете использовать заполнитель для переменного измерения, например, например:

m = tf.placeholder(tf.int32, shape=[])
x = tf.zeros(shape=[m])

with tf.Session() as sess:
    print(sess.run(x, feed_dict={m: 5}))
person kafman    schedule 07.04.2017
comment
Каков результирующий dtype операции tf.fill(dims, value=0.0)? - person reubenjohn; 04.11.2017

Если вы знаете форму вне сеанса, это может помочь.

import tensorflow as tf
import numpy as np

v = tf.Variable([], validate_shape=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(v, feed_dict={v: np.zeros((3,4))}))
    print(sess.run(v, feed_dict={v: np.zeros((2,2))}))
person hychiang    schedule 19.04.2018