Я застрял на одной строчке кода и в результате застрял в проекте на все выходные.
Я работаю над проектом, который использует BERT для классификации предложений. Я успешно обучил модель и могу протестировать результаты, используя пример кода из run_classifier.py.
Я могу экспортировать модель, используя этот пример кода (который неоднократно репостировался, поэтому я считаю, что он подходит для этой модели):
def export(self):
def serving_input_fn():
label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
input_ids = tf.placeholder(tf.int32, [None, self.max_seq_length], name='input_ids')
input_mask = tf.placeholder(tf.int32, [None, self.max_seq_length], name='input_mask')
segment_ids = tf.placeholder(tf.int32, [None, self.max_seq_length], name='segment_ids')
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'label_ids': label_ids, 'input_ids': input_ids,
'input_mask': input_mask, 'segment_ids': segment_ids})()
return input_fn
self.estimator._export_to_tpu = False
self.estimator.export_savedmodel(self.output_dir, serving_input_fn)
Я также могу загрузить экспортированный оценщик (где функция экспорта сохраняет экспортированную модель в подкаталог с меткой времени):
predict_fn = predictor.from_saved_model(self.output_dir + timestamp_number)
Однако, хоть убей, я не могу понять, что нужно предоставить в pred_fn в качестве входных данных для вывода. Вот мой лучший код на данный момент:
def predict(self):
input = 'Test input'
guid = 'predict-0'
text_a = tokenization.convert_to_unicode(input)
label = self.label_list[0]
examples = [InputExample(guid=guid, text_a=text_a, text_b=None, label=label)]
features = convert_examples_to_features(examples, self.label_list,
self.max_seq_length, self.tokenizer)
predict_input_fn = input_fn_builder(features, self.max_seq_length, False)
predict_fn = predictor.from_saved_model(self.output_dir + timestamp_number)
result = predict_fn(predict_input_fn) # this generates an error
print(result)
Кажется, не имеет значения, что я предоставляю для pred_fn: массив примеров, массив функций, функцию pred_input_fn. Ясно, что predic_fn нужен словарь какого-то типа, но каждая попытка, которую я пробовал, генерирует исключение из-за несоответствия тензора или других ошибок, которые обычно означают: неправильный ввод.
Я предположил, что функция from_saved_model требует того же типа ввода, что и функция тестирования модели - по-видимому, это не так.
Кажется, что многие люди задавали именно этот вопрос - "как мне использовать экспортированную модель BERT TensorFlow для вывода?" - и не получили ответов:
Любая помощь? Заранее спасибо.