Вероятности класса Spark MultilayerPerceptronClassifier

Я опытный программист Python, пытающийся перевести некоторый код Python в Spark для задачи классификации. Я впервые работаю в Spark/Scala.

В Python нейронные сети Keras/tensorflow и sci-kit Learn отлично справляются с классификацией нескольких классов, и я могу легко вернуть 3 наиболее вероятных класса вместе с вероятностями, которые являются ключевыми для этого проекта.

В целом мне удалось перенести код в Spark (Scala), и я могу генерировать правильные прогнозы, но мне не удалось найти способ вернуть вероятности для лучших прогнозируемых классов из MultilayerPerceptronClassifier в MLlib.

Самое близкое решение, которое я нашел, было в этом сообщении: Как получить вероятности классификации из MultilayerPerceptronClassifier? Однако я не могу заставить работать решение в посте либо из-за отсутствия ключевого фрагмента кода, либо из-за того, что я слишком новичок в Scala (вероятно, последнее), чтобы внести необходимые корректировки.

Кто-нибудь решил эту проблему?

Это текущие версии в моей среде. Версия Spark: 2.1.1 Версия Scala: 2.11.8

Спасибо за вашу помощь,

РКБ


person RKB    schedule 06.02.2019    source источник


Ответы (1)


Если вы внимательно посмотрите на результаты MultilayerPerceptronClassificationModel.transform (model и test, как определено в примере пайплайна в официальной документации)

val result = model.transform(test)

result.printSchema
root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)

вы увидите, что они содержат столбец probability.

Он хранится как столбец o.a.s.ml.linalg.Vector:

result.select($"probability").show(3, false)
+---------------------------------------------------+
|probability                                        |
+---------------------------------------------------+
|[2.630203838780848E-29,1.7323171642231641E-19,1.0] |
|[1.0,1.448487547623119E-121,4.530084532282489E-44] |
|[1.0,5.157808976162274E-122,2.5702890543589884E-44]|
+---------------------------------------------------+
only showing top 3 rows

и доступ к ним можно получить с помощью стандартных методов.

Эта функция доступна, начиная с Spark 2.3 (SPARK-12664 Вероятность раскрытия, сырое предсказание в MultilayerPerceptronClassificationModel).

person user11020637    schedule 06.02.2019