Конвейер Spark, объединяющий преобразователи VectorAssembler и HashingTF

Давайте определим конвейер Spark, который собирает вместе несколько столбцов, а затем применяет хэширование функций:

val df = sqlContext.createDataFrame(Seq((0.0, 1.0, 2.0), (3.0, 4.0, 5.0))).toDF("colx", "coly", "colz")
val va = new VectorAssembler().setInputCols(Array("colx", "coly", "colz")).setOutputCol("ft")
val hashIt = new HashingTF().setInputCol("ft").setOutputCol("ft2")
val pipeline = new Pipeline().setStages(Array(va, hashIt))

Установка конвейера с pipeline.fit(df) бросками:

java.lang.IllegalArgumentException: требование не выполнено: входной столбец должен быть ArrayType, но получен org.apache.spark.mllib.linalg.VectorUDT@f71b0bce

Есть ли трансформатор, который позволит VectorAssembler и HashingTF работать вместе?


person ranlot    schedule 01.03.2016    source источник


Ответы (1)


Лично я даже не буду использовать Pipeline API для этой цели, достаточно функции array

val df = sqlContext.createDataFrame(Seq((0.0, 1.0, 2.0), (3.0, 4.0, 5.0)))
               .toDF("colx", "coly", "colz")
               .withColumn("ft", array('colx, 'coly, 'colz))

val hashIt = new HashingTF().setInputCol("ft").setOutputCol("ft2")
val res = hashIt.transform(df)

res.show(false)
# +----+----+----+---------------+------------------------------+
# |colx|coly|colz|ft             |ft2                           |
# +----+----+----+---------------+------------------------------+
# |0.0 |1.0 |2.0 |[0.0, 1.0, 2.0]|(262144,[0,1,2],[1.0,1.0,1.0])|
# |3.0 |4.0 |5.0 |[3.0, 4.0, 5.0]|(262144,[3,4,5],[1.0,1.0,1.0])|
# +----+----+----+---------------+------------------------------+

В качестве продолжения вопроса, чтобы обобщить применение функции массива в случае количества столбцов> 3, на следующем шаге все столбцы объединяются в один столбец с массивом всех необходимых столбцов:

val df2 = sqlContext.createDataFrame(Seq((0.0, 1.0, 2.0), (3.0, 4.0, 5.0)))
                .toDF("colx", "coly", "colz")
val cols = (for (i <- df2.columns) yield df2(i)).toList
df2.withColumn("ft",array(cols :_*)).show

# +----+----+----+---------------+
# |colx|coly|colz|             ft|
# +----+----+----+---------------+
# | 0.0| 1.0| 2.0|[0.0, 1.0, 2.0]|
# | 3.0| 4.0| 5.0|[3.0, 4.0, 5.0]|
# +----+----+----+---------------+
person eliasah    schedule 01.03.2016