Как передать вычисленные значения в сортировку списка с помощью numba.jit в Python?

Я пытаюсь отсортировать список с помощью настраиваемого ключа в функции numba-jit в Python. Простые настраиваемые ключи работают, например, я знаю, что могу просто отсортировать по абсолютному значению, используя что-то вроде этого:

import numba

@numba.jit(nopython=True)
def myfunc():
    mylist = [-4, 6, 2, 0, -1]
    mylist.sort(key=lambda x: abs(x))
    return mylist  # [0, -1, 2, -4, 6]

Однако в следующем более сложном примере я получаю непонятную ошибку.

import numba
import numpy as np


@numba.jit(nopython=True)
def dist_from_mean(val, mu):
    return abs(val - mu)

@numba.jit(nopython=True)
def func():
    l = [1,7,3,9,10,-4,-2,0]
    avg_val = np.array(l).mean()
    l.sort(key=lambda x: dist_from_mean(x, mu=avg_val))
    return l

Ошибка, о которой он сообщает, следующая:

Traceback (most recent call last):
  File "testitout.py", line 18, in <module>
    ret = func()
  File "/.../python3.6/site-packages/numba/core/dispatcher.py", line 415, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/.../python3.6/site-packages/numba/core/dispatcher.py", line 358, in error_rewrite
    reraise(type(e), e, None)
  File "/.../python3.6/site-packages/numba/core/utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'avg_val' in a function that will escape.

File "testitout.py", line 14:
def func():
    <source elided>
    l.sort(key=lambda x: dist_from_mean(x, mu=avg_val))
                                                ^

Вы знаете, что здесь происходит?


person tiberius    schedule 08.09.2020    source источник


Ответы (1)


Вы знаете, что здесь происходит?

Используя параметр nopython = True, вы деактивируете объектный режим, и, следовательно, Numba не может обрабатывать все значения как объекты Python (см .: https://numba.pydata.org/numba-doc/latest/glossary.html#term-object-mode). (Ссылка на самом деле - это еще одно сообщение, которое я случайно написал сегодня: Как вызвать `@ guvectorize` внутри` @ guvectorize` в numba?)

@numba.jit(nopython=True)
def func():
    l = [1,7,3,9,10,-4,-2,0]
    avg_val = np.array(l).mean()
    l.sort(key=lambda x: dist_from_mean(x, mu=avg_val))
    return l

В любом случае lambda слишком сложен для функции numba jit - по крайней мере, когда он передается в качестве аргумента (сравните https://github.com/numba/numba/issues/4481). При активированном режиме nopython вы можете использовать только ограниченное количество библиотек - полный список можно найти здесь: https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

Вот почему возникает следующая ошибка:

numba.core.errors.TypingError: сбой в конвейере режима nopython (шаг: преобразование make_function в функции JIT). Невозможно захватить непостоянное значение, связанное с переменной avg_val, в функции, которая будет экранирована.

Более того, ваша ссылка на jit-ускоренную функцию внутри другой - при наличии nopython = True. Это также могло быть источником проблемы.

Я настоятельно рекомендую взглянуть на следующий учебник: http://numba.pydata.org/numba-doc/latest/user/5minguide.html#will-numba-work-for-my-код; он должен вам помочь с похожими проблемами!


Дополнительная литература и источники:

person J. M. Arnold    schedule 23.12.2020
comment
Итак, как бы вы рекомендовали настраиваемую сортировку списка с такой функцией в jit-функции? - person tiberius; 30.12.2020