Пользовательская активация с параметром

Я пытаюсь создать функцию активации в Keras, которая может принимать параметр beta следующим образом:

from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation

class Swish(Activation):

    def __init__(self, activation, beta, **kwargs):
        super(Swish, self).__init__(activation, **kwargs)
        self.__name__ = 'swish'
        self.beta = beta


def swish(x):
    return (K.sigmoid(beta*x) * x)

get_custom_objects().update({'swish': Swish(swish, beta=1.)})

Он отлично работает без параметра beta, но как я могу включить этот параметр в определение активации? Я также хочу, чтобы это значение сохранялось, когда я делаю model.to_json() как для активации ELU.


Обновление: я написал следующий код на основе ответа @today:

from keras.layers import Layer
from keras import backend as K

class Swish(Layer):
    def __init__(self, beta, **kwargs):
        super(Swish, self).__init__(**kwargs)
        self.beta = K.cast_to_floatx(beta)
        self.__name__ = 'swish'

    def call(self, inputs):
        return K.sigmoid(self.beta * inputs) * inputs

    def get_config(self):
        config = {'beta': float(self.beta)}
        base_config = super(Swish, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

from keras.utils.generic_utils import get_custom_objects
get_custom_objects().update({'swish': Swish(beta=1.)})
gnn = keras.models.load_model("Model.h5")
arch = gnn.to_json()
with open(directory + 'architecture.json', 'w') as arch_file:
    arch_file.write(arch)

Однако в настоящее время он не сохраняет значение beta в файле .json. Как я могу заставить его сохранить значение?


person user7867665    schedule 29.10.2018    source источник


Ответы (1)


Поскольку вы хотите сохранить параметры функции активации при сериализации модели, я думаю, что лучше определить функцию активации как слой, подобный расширенные активации, определенные в Keras. Вы можете сделать это следующим образом:

from keras.layers import Layer
from keras import backend as K

class Swish(Layer):
    def __init__(self, beta, **kwargs):
        super(Swish, self).__init__(**kwargs)
        self.beta = K.cast_to_floatx(beta)

    def call(self, inputs):
        return K.sigmoid(self.beta * inputs) * inputs

    def get_config(self):
        config = {'beta': float(self.beta)}
        base_config = super(Swish, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

Затем вы можете использовать его так же, как вы используете слой Keras:

# ...
model.add(Swish(beta=0.3))

Поскольку метод get_config() был реализован в его определении, параметр beta будет сохранен при использовании таких методов, как to_json() или save().

person today    schedule 29.10.2018
comment
Это то, что я делаю, но значение параметра не сохраняется в файле json - person user7867665; 30.10.2018
comment
@ user7867665 Вы уверены, что реализовали метод get_config() и включили в него параметр beta? - person today; 30.10.2018
comment
Я сделал по другому, сейчас тестирую вашу реализацию - person user7867665; 30.10.2018
comment
он не сохраняет бета-значение в файле .json, я использовал именно ваш код - person user7867665; 30.10.2018
comment
@user7867665 user7867665 Действительно странно! Меня устраивает. Не могли бы вы поместить свой код в github gist (или на любой другой веб-сайт для обмена заметками, который дает ссылку для общего доступа) а мне ссылку дайте? - person today; 30.10.2018
comment
Как мне дать вам ссылку, есть ли прямые сообщения? - person user7867665; 30.10.2018
comment
@user7867665 user7867665 Просто вставьте это в комментарий здесь. - person today; 30.10.2018
comment
Давайте продолжим это обсуждение в чате. - person user7867665; 30.10.2018
comment
@user7867665 user7867665 Вы используете generator.to_json(), но ваша модель хранится в gnn?! Кроме того, не рекомендуется редактировать свой вопрос, удаляя исходный вопрос. Вместо этого добавьте дополнительную информацию в конце. Следовательно, я откатил ваше редактирование и изменил его как таковой. - person today; 30.10.2018
comment
Это gnn.to_json(), я вставил что-то не то. Исправлено. Проблема все та же. Спасибо за редактирование - person user7867665; 30.10.2018
comment
@user7867665 Запустите этот код на своем компьютере и убедитесь, что вы видите параметр beta в распечатанной конфигурации. Далее посмотрите, как пользовательский объект передается в функцию load_model. - person today; 30.10.2018
comment
Оно работает! Так почему это не работает для моей модели? Может быть, потому, что он не был обучен этому определению взмаха? - person user7867665; 31.10.2018
comment
Так что это не сработало на моей модели, потому что она не была обучена с использованием этого определения взмаха. Большое спасибо! - person user7867665; 31.10.2018