Обобщенная гипергеометрическая функция для numpy

Я хочу подогнать некоторые данные к гипергеометрической функции. Я использую обобщенную гипергеометрическую функцию, приведенную в mpmath, hyper. Пытаюсь преобразовать для работы с curve_fit с помощью np.frompyfunc. Когда я делаю

np_hyp = np.frompyfunc(hyper,3,1)
np_hyp([-1/3],[-2/3,2/3],x**2/4)

где x - некоторый массив numpy. Я получаю ошибку len(a_s): 'float' object has no length или что-то в этом роде (я буду более точен, когда смогу вернуться к своему компьютеру, чтобы воспроизвести ошибку). Я подозреваю, что это как-то связано с тем, что входные данные являются списками и странным образом конвертируются, когда numpy пытается преобразовать функцию.

Кто-нибудь знает способ исправить эту ошибку? Любая помощь будет принята с благодарностью.


person Sam    schedule 29.03.2021    source источник
comment
быстрое предположение состоит в том, что он также выполняет итерацию по первым двум аргументам, т.е. hyper вызывается с -1/3, а не с [-1/3]. Поскольку функции необходимо определить длину, чтобы получить p и q, она пытается получить длину числа с плавающей запятой -1/3, что вызывает ошибку. Это критично по времени или можно просто написать свою оболочку? Или, может быть, vectorize может помочь с аргументом excluded.   -  person mikuszefski    schedule 30.03.2021


Ответы (1)


Оказывается, мой комментарий сверху верен, т.е. первый и второй список также разлагаются и передаются как один элемент. Этого не должно быть. Решение, следовательно,

from mpmath import hyper
import numpy as np

print( hyper( [ -1 / 3 ],[ -2 / 3, 2 / 3 ], 0.255 ) )

nphyper = np.vectorize( hyper )
nphyper.excluded.add(0)
nphyper.excluded.add(1)

print( 
    nphyper(
        [ -1 / 3 ],
        [ -2 / 3, 2 / 3 ],
        np.array( [ 0.255, 0.257 ] )
    )
)

Из документации неясно, поэтому благодаря этому сообщению я понял, как исключить позиционные аргументы.

person mikuszefski    schedule 30.03.2021
comment
Здорово, спасибо! Хотя теперь у меня есть замечательная ошибка Cannot cast array data from dtype('O') to dtype('float64') according to the rule 'safe', которая, как я подозреваю, связана с тем, что я использую гистограммы для соответствия данным, а не вашу реализацию здесь. - person Sam; 30.03.2021
comment
Исправлено это, необходимо было добавить .astype ('float'), чтобы заменить его mpf float. - person Sam; 30.03.2021