Бесплатно ~› Trampoline: рекурсивная программа вылетает с ошибкой OutOfMemoryError

Предположим, что я пытаюсь реализовать очень простой предметно-ориентированный язык только с одной операцией:

printLine(line)

Затем я хочу написать программу, которая принимает целое число n в качестве входных данных, печатает что-то, если n делится на 10k, а затем вызывает себя с n + 1, пока n не достигнет некоторого максимального значения N.

Опуская весь синтаксический шум, вызванный for-comprehension, я хочу:

@annotation.tailrec def p(n: Int): Unit = {
  if (n % 10000 == 0) printLine("line")
  if (n > N) () else p(n + 1)
}

По сути, это будет своего рода «шипение».

Вот несколько попыток реализовать это с помощью монады Free из Scalaz 7.3.0-M7:

import scalaz._

object Demo1 {

  // define operations of a little domain specific language
  sealed trait Lang[X]
  case class PrintLine(line: String) extends Lang[Unit]

  // define the domain specific language as the free monad of operations
  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}

  // lift operations into the free monad
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  // write a program that is just a loop that prints current index 
  // after every few iteration steps
  val mod =  100000
  val N =   1000000

  // straightforward syntax: deadly slow, exits with OutOfMemoryError
  def p0(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- (if (i > N) ret else p0(i + 1))
  } yield ()

  // Same as above, but written out without `for`
  def p1(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
    }

  // Same as above, with `map` attached to recursive call
  def p2(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p2(i + 1).map{ ignore2 => () })
    }

  // Same as above, but without the `map`; performs ok.
  def p3(i: Int): Prog[Unit] = {
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ 
      ignore1 =>
      if (i > N) ret else p3(i + 1)
    }
  }

  // Variation of the above; Ok.
  def p4(i: Int): Prog[Unit] = (for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
  } yield ()).flatMap{ ignored2 => 
    if (i > N) ret else p4(i + 1) 
  }

  // try to use the variable returned by the last generator after yield,
  // hope that the final `map` is optimized away (it's not optimized away...)
  def p5(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    stopHere <- (if (i > N) ret else p5(i + 1))
  } yield stopHere

  // define an interpreter that translates the programs into Trampoline
  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]  
  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case PrintLine(l) => Trampoline.delay(println(l))
    }
  }

  // try it out
  def main(args: Array[String]): Unit = {
    println("\n p0")
    p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p1")
    p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p2")
    p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p3")
    p3(0).foldMap(interpreter).run // ok 
    println("\n p4")
    p4(0).foldMap(interpreter).run // ok
    println("\n p5")
    p5(0).foldMap(interpreter).run // OutOfMemory
  }
}

К сожалению, прямой перевод (p0), кажется, работает с какими-то накладными расходами O(N^2) и вылетает с ошибкой OutOfMemoryError. Проблема, похоже, в том, что for-понимание добавляет map{x => ()} после рекурсивного вызова p0, что вынуждает монаду Free заполнить всю память напоминаниями "завершить 'p0' и ничего не делать". Если я вручную "разверну" понимание for, а последние flatMap выпишу явно (как в p3 и p4), то проблема исчезнет, ​​и все пойдет гладко. Это, однако, чрезвычайно хрупкий обходной путь: поведение программы резко меняется, если мы просто добавляем к ней map(id), а это map(id) даже не видно в коде, потому что оно генерируется автоматически for-пониманием.

В этом более старом сообщении здесь: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/ неоднократно рекомендовалось заключать рекурсивные вызовы в suspend. Вот попытка с экземпляром Applicative и suspend:

import scalaz._

// Essentially same as in `Demo1`, but this time with 
// an `Applicative` and an explicit `Suspend` in the 
// `for`-comprehension
object Demo2 {

  sealed trait Lang[H]

  case class Const[H](h: H) extends Lang[H]
  case class PrintLine[H](line: String) extends Lang[H]

  implicit object Lang extends Applicative[Lang] {
    def point[A](a: => A): Lang[A] = Const(a)
    def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
      case Const(x) => {
        f match {
          case Const(ab) => Const(ab(x))
          case _ => throw new Error
        }
      }
      case PrintLine(l) => PrintLine(l)
    }
  }

  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  val mod = 100000
  val N = 2000000

  // try to suspend the entire second generator
  def p7(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- Free.suspend(if (i > N) ret else p7(i + 1))
  } yield ()

  // try to suspend the recursive call
  def p8(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- if (i > N) ret else Free.suspend(p8(i + 1))
  } yield ()

  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]

  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case Const(x) => Trampoline.done(x)
      case PrintLine(l) => 
        (Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
    }
  }

  def main(args: Array[String]): Unit = {
    p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    p8(0).foldMap(interpreter).run // same...
  }
}

Вставка suspend особо не помогла: все равно тормозит, а с OutOfMemoryErrors вылетает.

Должен ли я использовать suspend как-то по-другому?

Может быть, есть какое-то чисто синтаксическое средство, позволяющее использовать for-comprehension без генерации map в конце?

Я был бы очень признателен, если бы кто-нибудь мог указать, что я делаю неправильно здесь, и как это исправить.


person Andrey Tyukin    schedule 13.12.2016    source источник
comment
Привет, я скопировал и запустил ваш код, и он не был ни медленным, ни OutOfMemory. Когда я увеличил N в десять раз, он становился медленнее (что ожидаемо, потому что вы должны получить O (N * N)), по сравнению с наивным решением tailrec (где у вас есть O (N)), но все еще без ошибки OOM.   -  person I See Voices    schedule 13.12.2016
comment
Вероятно, это будет зависеть от настроек и оборудования JVM. Если вы не видите эффекта сразу, попробуйте что-то вроде n = 1000000, вместо этого N = 10 000 000. На моем ноутбуке некоторые программы работают заметно медленнее, и вылетают с OutOfMemory для N = 5000000. Но вы должны увидеть замедление для меньших значений N.   -  person Andrey Tyukin    schedule 13.12.2016


Ответы (1)


Этот лишний map, добавленный компилятором Scala, перемещает рекурсию из хвостовой позиции в не хвостовую позицию. Свободная монада по-прежнему делает этот стек безопасным, но сложность пространства становится O(N) вместо O(1). (В частности, это все еще не O(N2).)

Можно ли сделать так, чтобы scalac оптимизировалось то, что map далеко, это отдельный вопрос (на который я не знаю ответа).

Я попытаюсь проиллюстрировать, что происходит при интерпретации p1 по сравнению с p3. (Я проигнорирую перевод на Trampoline, который является излишним (см. ниже).)

p3 (т.е. без лишних map)

Позвольте мне использовать следующее сокращение:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => if (i > N) ret else p3(i + 1)

Теперь p3(0) интерпретируется следующим образом

p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)

и так далее... Вы видите, что объем памяти, необходимый в любой момент, не превышает некоторой постоянной верхней границы.

p1 (т.е. с дополнительным map)

Я буду использовать следующие сокращения:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }

def cpu: Unit => Prg[Unit] = // constant pure unit
  ignore => Free.pure(())

Теперь p1(0) интерпретируется следующим образом:

p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu

и так далее... Вы видите, что потребление памяти линейно зависит от N. Мы просто переместили оценку из стека в кучу.

Вывод: чтобы Free не мешало памяти, сохраняйте рекурсию в "хвостовой позиции", то есть справа от flatMap (или map).

Кроме того: перевод на Trampoline не нужен, так как Free уже трамплин. Вы можете напрямую интерпретировать Id и использовать foldMapRec для безопасной интерпретации стека:

val idInterpreter = new (Lang ~> Id) {
  def apply[A](cmd: Lang[A]): Id[A] = cmd match {
    case PrintLine(l) => println(l)
  }
}

p0(0).foldMapRec(idInterpreter)

Это вернет вам часть памяти (но не устранит проблему).

person Tomas Mikula    schedule 13.12.2016
comment
Большое спасибо за подробную иллюстрацию, она подтверждает мою интуицию, что p0 оставляет O(N)-след в памяти во время работы. Я не был уверен в накладных расходах времени: с некоторыми более старыми реализациями Free, которые использовали F:Functor для добавления следующей операции в конец структуры, похожей на связанный список, я мог представить, что это может быть O (N ^ 2) , но мне придется взглянуть на текущую реализацию и подумать об этом еще раз. - person Andrey Tyukin; 13.12.2016
comment
В стороне: я использовал Trampoline только для иллюстрации, я, вероятно, буду использовать что-то еще в качестве «цели интерпретации». Перевод рассуждений на Trampoline не нужен, поскольку Free уже трамплин, кажется, отличается от ответа на этот вопрос: stackoverflow.com/questions/29660067/ - person Andrey Tyukin; 13.12.2016
comment
При оптимизации map прочь: может быть, что-то вроде def noMap[X](x: X) = new { def map(f: Unit => Unit): X = x }, а затем обернуть последний генератор в noMap? Он работает и устраняет последний map, произведенный for, но было бы лучше использовать что-то более традиционное, если оно где-то уже существует (в Scalaz или где-то еще). - person Andrey Tyukin; 13.12.2016
comment
@AndreyTyukin да, была реализация Free, которая зависела от Functor и страдала от квадратичной временной сложности. С тех пор это было улучшено. Ответ, на который вы ссылаетесь, предварительно foldMapRec. Мне не удалось подчеркнуть важность foldMapRec в безопасной для стека интерпретации для Id. Это хитрый трюк с noMap! Я не знаю стандартного решения, но было бы интересно услышать, если вы его найдете. - person Tomas Mikula; 13.12.2016