TensorFlow возвращает ValueError при создании простого GAN со слоем Conv1D

Я пытаюсь настроить простой GAN с TF, включая слой Conv1D в модели дискриминатора. Чтобы добиться правильной выходной формы, я включил слой Flatten.

К сожалению, при добавлении слоя генератора и дискриминатора вместе TF возвращает ошибку «ValueError: Входной тензор должен иметь ранг 3, 4 или 5, но был 2». Я попытался сделать то же самое с простейшей фиктивной сетью, и компиляция GAN сработала. Я предполагаю, что проблема во входной форме слоя дискриминатора, но описание ошибки не дает слишком много опережения.

Как я могу справиться с этим типом ошибки? Спасибо заранее за вашу помощь.

def define_discriminator(n_inputs=2):
    model = Sequential()
    model.add(Conv1D(filters = 128, kernel_size = 2, strides=1, input_shape = (n_inputs,1) ))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Flatten())
    model.add(Dense(25, kernel_initializer='he_uniform'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    model.summary()
    return model

# simple dummy net
"""
model = Sequential()
model.add(Dense(25, activation='relu', kernel_initializer='he_uniform', input_dim=n_inputs))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
return model
"""

def define_generator(latent_dim, n_outputs=2):
    model = Sequential()
    model.add(Dense(15, activation='relu', kernel_initializer='he_uniform', input_dim=latent_dim))
    model.add(Dense(n_outputs, activation='linear'))
    model.summary()
    return model

def define_gan(generator, discriminator):
    discriminator.trainable = False
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    model.compile(loss='binary_crossentropy', optimizer='adam')
    model.summary()
    return model

Полное сообщение об ошибке здесь:

Traceback (most recent call last):
  File "C:/pracovni_addr/python_projects/GAN_1D.py", line 181, in <module>
    gan_model = define_gan(m_gen, m_disc)
  File "C:/pracovni_addr/python_projects/GAN_1D.py", line 115, in define_gan
    model.add(discriminator)
  File "C:\Users\CLIENT\AppData\Local\Programs\Python\Python37\lib\site-packages\keras\engine\sequential.py", line 182, in add
    output_tensor = layer(self.outputs[0])
  File "C:\Users\CLIENT\AppData\Local\Programs\Python\Python37\lib\site-packages\keras\backend\tensorflow_backend.py", line 75, in symbolic_fn_wrapper
    return func(*args, **kwargs)
  File "C:\Users\CLIENT\AppData\Local\Programs\Python\Python37\lib\site-packages\keras\engine\base_layer.py", line 489, in __call__
    output = self.call(inputs, **kwargs)
  File "C:\Users\CLIENT\AppData\Local\Programs\Python\Python37\lib\site-packages\keras\engine\network.py", line 583, in call
    output_tensors, _, _ = self.run_internal_graph(inputs, masks)
  File "C:\Users\CLIENT\AppData\Local\Programs\Python\Python37\lib\site-packages\keras\engine\network.py", line 740, in run_internal_graph
    layer.call(computed_tensor, **kwargs))
  File "C:\Users\CLIENT\AppData\Local\Programs\Python\Python37\lib\site-packages\keras\layers\convolutional.py", line 163, in call
    dilation_rate=self.dilation_rate[0])
  File "C:\Users\CLIENT\AppData\Local\Programs\Python\Python37\lib\site-packages\keras\backend\tensorflow_backend.py", line 3671, in conv1d
    **kwargs)
  File "C:\Users\CLIENT\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow_core\python\ops\nn_ops.py", line 917, in convolution_v2
    name=name)
  File "C:\Users\CLIENT\AppData\Local\Programs\Python\Python37\lib\site-packages\tensorflow_core\python\ops\nn_ops.py", line 969, in convolution_internal
    "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
ValueError: Input tensor must be of rank 3, 4 or 5 but was 2.

person Jan Kaňka    schedule 29.12.2019    source источник
comment
Не могли бы вы опубликовать точное сообщение об ошибке или показать нам трассировку стека? Спасибо :)   -  person Exr0n    schedule 29.12.2019
comment
Привет, конечно, добавил!   -  person Jan Kaňka    schedule 29.12.2019


Ответы (1)


Таким образом, ошибка заключалась в неправильной форме вывода генератора (дискриминатор ожидает входную форму как (Нет, 2, 1), но было задано только (Нет, 2).

Проблема решена с помощью:

    model.add(Reshape((n_outputs,1)))

до

    model.sumary() 

в блоке define_generator

person Jan Kaňka    schedule 29.12.2019