Tensorflow в отражении Scala

Я пытаюсь заставить tensorflow для java работать на Scala. Я использую java-библиотеку tensorflow без какой-либо оболочки для Scala.

На sbt у меня есть:

Если я запустил HelloWord найденный здесь, он РАБОТАЕТ нормально, с адаптации Scala:

import org.tensorflow.Graph
import org.tensorflow.Session
import org.tensorflow.Tensor
import org.tensorflow.TensorFlow


val g = new Graph()
val value = "Hello from " + TensorFlow.version()
val t = Tensor.create(value.getBytes("UTF-8"))
// The Java API doesn't yet include convenience functions for adding operations.
g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();

val s = new Session(g)
val output = s.runner().fetch("MyConst").run().get(0)

Однако если я попытаюсь использовать отражение Scala для компиляции функции из строки, она НЕ РАБОТАЕТ. Вот фрагмент, который я использовал:

import scala.reflect.runtime.{universe => ru}
import scala.tools.reflect.ToolBox
val fnStr = """
    {() =>
      import org.tensorflow.Graph
      import org.tensorflow.Session
      import org.tensorflow.Tensor
      import org.tensorflow.TensorFlow

      val g = new Graph()
      val value = "Hello from " + TensorFlow.version()
      val t = Tensor.create(value.getBytes("UTF-8"))
      g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();

      val s = new Session(g)

      s.runner().fetch("MyConst").run().get(0)
    }
    """
val mirror = ru.runtimeMirror(getClass.getClassLoader)
val tb = mirror.mkToolBox()
var t = tb.parse(fnStr)
val fn = tb.eval(t).asInstanceOf[() => Any]
// and finally, executing the function
fn()

Здесь упрощено build.sbt, чтобы воспроизвести ошибку выше:

lazy val commonSettings = Seq(
    scalaVersion := "2.12.10",

    libraryDependencies ++= {
      Seq(
                  // To support runtime compilation
        "org.scala-lang" % "scala-reflect" % scalaVersion.value,
        "org.scala-lang" % "scala-compiler" % scalaVersion.value,

        // for tensorflow4java
        "org.tensorflow" % "tensorflow" % "1.15.0",
        "org.tensorflow" % "proto" % "1.15.0",
        "org.tensorflow" % "libtensorflow_jni" % "1.15.0"

      )
    }
)

lazy val `test-proj` = project
  .in(file("."))
  .settings(commonSettings)

При выполнении вышеуказанного, например, с sbt console, я получаю следующую ошибку и трассировку стека:

java.lang.NoSuchMethodError: org.tensorflow.Session.runner()Lorg/tensorflow/Session$$Runner;
  at __wrapper$1$f093d26a3c504d4381a37ef78b6c3d54.__wrapper$1$f093d26a3c504d4381a37ef78b6c3d54$.$anonfun$wrapper$1(<no source file>:15)

Пожалуйста, не обращайте внимания на утечки памяти, указанные в предыдущем коде, о том, что контекст ресурсов (to close ()) не используется


person aitorhh    schedule 21.03.2020    source источник
comment
Зачем вам нужно компилировать код из строки с помощью ToolBox?   -  person Dmytro Mitin    schedule 26.03.2020
comment
Подумайте, например, о сценарии «без обслуживания» или «функция как услуга» (FaaS).   -  person aitorhh    schedule 26.03.2020
comment
Возможно, в качестве обходного пути вы можете программно записать эту строку кода в текстовый файл, скомпилировать файл и запустить скомпилированный файл класса. Это работает для вас?   -  person Dmytro Mitin    schedule 29.03.2020
comment
Спасибо @DmytroMitin, это действительный обходной путь. Я также рассматриваю тензорный поток для scala, но могу найти аналогичную проблему   -  person aitorhh    schedule 30.03.2020


Ответы (2)


Дело в том, что эта ошибка возникает в сочетании рефлексивной компиляции и взаимодействия Scala-Java.

https://github.com/scala/bug/issues/8956

Toolbox не может проверить тип значения (s.runner()) зависимого от пути типа (s.Runner), если этот тип происходит из нестатического внутреннего класса Java. И Runner - это именно такой класс внутри org.tensorflow.Session.

Вы можете запустить компилятор вручную (аналогично как Toolbox это работает)

import org.tensorflow.Tensor
import scala.reflect.internal.util.{AbstractFileClassLoader, BatchSourceFile}
import scala.reflect.io.{AbstractFile, VirtualDirectory}
import scala.reflect.runtime
import scala.reflect.runtime.universe
import scala.reflect.runtime.universe._
import scala.tools.nsc.{Global, Settings}

val code: String =
  """
    |import org.tensorflow.Graph
    |import org.tensorflow.Session
    |import org.tensorflow.Tensor
    |import org.tensorflow.TensorFlow
    |
    |object Main {
    |  def foo() = () => {
    |      val g = new Graph()
    |      val value = "Hello from " + TensorFlow.version()
    |      val t = Tensor.create(value.getBytes("UTF-8"))
    |      g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
    |
    |      val s = new Session(g)
    |
    |      s.runner().fetch("MyConst").run().get(0)
    |  }
    |}
""".stripMargin

val directory = new VirtualDirectory("(memory)", None)
val runtimeMirror = createRuntimeMirror(directory, runtime.currentMirror)
compileCode(code, List(), directory)
val tensor = runObjectMethod("Main", runtimeMirror, "foo").asInstanceOf[() => Tensor[_]]
tensor() // STRING tensor with shape []

def compileCode(code: String, classpathDirectories: List[AbstractFile], outputDirectory: AbstractFile): Unit = {
  val settings = new Settings
  classpathDirectories.foreach(dir => settings.classpath.prepend(dir.toString))
  settings.outputDirs.setSingleOutput(outputDirectory)
  settings.usejavacp.value = true
  val global = new Global(settings)
  (new global.Run).compileSources(List(new BatchSourceFile("(inline)", code)))
}

def runObjectMethod(objectName: String, runtimeMirror: Mirror, methodName: String, arguments: Any*): Any = {
  val objectSymbol = runtimeMirror.staticModule(objectName)
  val objectModuleMirror = runtimeMirror.reflectModule(objectSymbol)
  val objectInstance = objectModuleMirror.instance
  val objectType = objectSymbol.typeSignature
  val methodSymbol = objectType.decl(TermName(methodName)).asMethod
  val objectInstanceMirror = runtimeMirror.reflect(objectInstance)
  val methodMirror = objectInstanceMirror.reflectMethod(methodSymbol)
  methodMirror(arguments: _*)
}

def createRuntimeMirror(directory: AbstractFile, parentMirror: Mirror): Mirror = {
  val classLoader = new AbstractFileClassLoader(directory, parentMirror.classLoader)
  universe.runtimeMirror(classLoader)
}

динамически анализировать json на карте flink

Динамическая компиляция нескольких классов Scala во время выполнения

Как оценить код, использующий аннотацию InterfaceStability (что не удается из-за недопустимой циклической ссылки с участием класса InterfaceStability)?

person Dmytro Mitin    schedule 26.11.2020
comment
Спасибо за ответ! Хотя ваш ответ действительно работает, здесь необходимо обратиться к функции Main.foo (), а в вопросе у нас есть непосредственно лямбда-функция. Точно ничего критичного! Это определенно путь - person aitorhh; 01.12.2020
comment
@aitorhh Ну, в Scala 2 лямбда не может быть верхнего уровня, она должна быть внутри какого-то метода или конструктора класса или объекта. Просто Toolbox прячет эту упаковку от ваших глаз. При запуске компилятора вручную я сделал это более подробно. Если вы предпочитаете скрывать это, вы можете сделать это, как в stackoverflow.com/questions/53976254/ - person Dmytro Mitin; 01.12.2020

Как отметил Дмитрий в своем ответе, это невозможно с помощью инструментария. И он указал на другой ответ (Как оценить код, который использует аннотацию InterfaceStability (которая не работает с недопустимой циклической ссылкой, включающей класс InterfaceStability)?). Я думаю, что есть отличное решение, просто заменив класс Compiler, определенный в предыдущем, и заменив Toolbox для этого класса Compiler.

В этом случае финальный фрагмент будет выглядеть так:

import your.package.Compiler
val fnStr = """
    {() =>
      import org.tensorflow.Graph
      import org.tensorflow.Session
      import org.tensorflow.Tensor
      import org.tensorflow.TensorFlow

      val g = new Graph()
      val value = "Hello from " + TensorFlow.version()
      val t = Tensor.create(value.getBytes("UTF-8"))
      g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();

      val s = new Session(g)

      s.runner().fetch("MyConst").run().get(0)
    }
    """
val tb = new Compiler() // this replaces the mirror and toolbox instantiation
var t = tb.parse(fnStr)
val fn = tb.eval(t).asInstanceOf[() => Any]
// and finally, executing the function
println(fn())

И для завершения скопируйте / вставьте из решения по адресу этот ответ:

  class Compiler() {
    import scala.reflect.internal.util.{AbstractFileClassLoader, BatchSourceFile}
    import scala.reflect.io.{AbstractFile, VirtualDirectory}
    import scala.reflect.runtime
    import scala.reflect.runtime.universe
    import scala.reflect.runtime.universe._
    import scala.tools.nsc.{Global, Settings}
    import scala.collection.mutable
    import java.security.MessageDigest
    import java.math.BigInteger
       
    val target  = new VirtualDirectory("(memory)", None)
       
    val classCache = mutable.Map[String, Class[_]]()
       
    private val settings = new Settings()
    settings.deprecation.value = true // enable detailed deprecation warnings
    settings.unchecked.value = true // enable detailed unchecked warnings
    settings.outputDirs.setSingleOutput(target)
    settings.usejavacp.value = true
       
    private val global = new Global(settings)
    private lazy val run = new global.Run
       
    val classLoader = new AbstractFileClassLoader(target, this.getClass.getClassLoader)
       
    /**Compiles the code as a class into the class loader of this compiler.
      * 
      * @param code
      * @return
      */
    def compile(code: String) = {
      val className = classNameForCode(code)
      findClass(className).getOrElse {
        val sourceFiles = List(new BatchSourceFile("(inline)", wrapCodeInClass(className, code)))
        run.compileSources(sourceFiles)
        findClass(className).get
      } 
    }   
       
    /** Compiles the source string into the class loader and
      * evaluates it.
      * 
      * @param code
      * @tparam T
      * @return
      */
    def eval[T](code: String): T = {
      val cls = compile(code)
      cls.getConstructor().newInstance().asInstanceOf[() => Any].apply().asInstanceOf[T]
    }  
        
    def findClass(className: String): Option[Class[_]] = {
      synchronized {
        classCache.get(className).orElse {
          try {
            val cls = classLoader.loadClass(className)
            classCache(className) = cls
            Some(cls)
          } catch {
            case e: ClassNotFoundException => None
          }
        }
      } 
    }   
  
    protected def classNameForCode(code: String): String = {
      val digest = MessageDigest.getInstance("SHA-1").digest(code.getBytes)
      "sha"+new BigInteger(1, digest).toString(16)
    }   
  
    /*  
     * Wrap source code in a new class with an apply method.
     */ 
   private def wrapCodeInClass(className: String, code: String) = {
     "class " + className + " extends (() => Any) {\n" +
     "  def apply() = {\n" +
     code + "\n" +
     "  }\n" +
     "}\n"
   }    
  }  
person aitorhh    schedule 01.12.2020