Способы ускорить работу SciPy?

Я вызываю функцию, которая использует odeint при каждом проходе через цикл for (к сожалению, я не могу ничего сломать из этого цикла). Но дела идут намного медленнее, чем я надеялся. Вот код:

def get_STM(t_i, t_f, X_ref_i, dxdt, Amat):
    """Evaluate the state transition matrix rate of change for a given A matrix.
    """

    STM_i = np.eye(X_ref_i.size).flatten()
    args = (dxdt, Amat)
    X_aug_i = np.hstack((X_ref_i, STM_i))
    t = [t_i, t_f]

    # Propogate reference trajectory & STM together!    
    X_aug_f = odeint(dxdt_interface, X_aug_i, t, args=args)
    X_f = X_aug_f[-1, :X_ref_i.size]
    STM_f = X_aug_f[-1, X_ref_i.size:].reshape(X_ref_i.size, X_ref_i.size)

    return X_f, STM_f

def dxdt_interface(X,t,dxdt,Amat):
    """
    Provides an interface between odeint and dxdt
    Parameters :
    ------------
    X : (42-by-1 np array) augmented state (with Phi)
    t : time
    dxdt : (function handle) time derivative of the (6-by-1) state vector
    Amat : (function handle) state-space matrix
    Returns:
    --------
    (42-by-1 np.array) time derivative of the components of the augmented state 
    """
    # State derivative
    Xdot = np.zeros_like(X)
    X_stacked = np.hstack((X[:6], t))
    Xdot_state = dxdt(*(X_stacked))
    Xdot[:6] = Xdot_state[:6].T

    # STM
    Phi = X[6:].reshape((Xdot_state.size, Xdot_state.size))

    # State-Space matrix
    A = Amat(*(X_stacked))
    Xdot[6:] = (A .dot (Phi)).reshape((A.size))

    return Xdot

Проблема в том, что я вызываю get_STM что-то порядка 8640 раз за прогон, и это приводит к 232217 вызовам dxdt_interface, что составляет около 70% моего общего времени вычислений по 5 мс на вызов get_STM (99,9% из которых связано с odeint ).

Я новичок в методах интеграции SciPy, и я не могу понять, как вообще ускорить это, основываясь на odeint. .html" rel="nofollow">документация. Я изучил джиттинг dxdt_interface с помощью Numba, но не могу заставить его работать, потому что dxdt и Amat являются символическими.

Есть ли какие-то способы ускорить odeint, которые я упустил?

РЕДАКТИРОВАТЬ: включены функции Amat и dxdt ниже. Обратите внимание, что они не вызываются в моем основном цикле for, они создают дескрипторы символических лямбдифицированных функций, которые передаются моей функции get_STM (я вызываю import sympy as sym).

def get_A(use_j3=False):
    """ Returns the jacobian of the state time rate of change
    Parameters
    ----------
    R : Earth's equatorial radius (m)
    theta_dot : Earth's rotation rate (rad/s)
    mu : Earth's standard gravitationnal parameter (m^3/s^2)
    j2 : second zonal harmonic coefficient
    j3 : third zonal harmonic coefficient
    Returns
    ----------    
    A : (function handle) jacobian of the state time rate of change
    """
    theta_dot = EARTH['rotation rate']
    R = EARTH['radius']
    mu = EARTH['mu']
    j2 = EARTH['J2']
    if use_j3:
        j3 = EARTH['J3']
    else:
        j3 = 0

    # Symbolic derivations
    x, y, z, mus, j2s, j3s, Rs, t = sym.symbols('x y z mus j2s j3s Rs t', real=True)
    theta_dots = sym.symbols('theta_dots', real=True)
    xdot,ydot,zdot = sym.symbols('xdot ydot zdot ', real=True)

    X = sym.Matrix([x,y,z,xdot,ydot,zdot])

    A_mat = sym.lambdify( (x,y,z,xdot,ydot,zdot,t), dxdt_s().jacobian(X).subs([
        (theta_dots, theta_dot),(Rs, R),(j2s,j2),(j3s,j3),(mus,mu)]), modules='numpy')

    return A_mat

def Dxdt(use_j3=False):
    """ Returns the time derivative of the state vector
    Parameters
    ----------
    R : Earth's equatorial radius (m)
    theta_dot : Earth's rotation rate (rad/s)
    mu : Earth's standard gravitationnal parameter (m^3/s^2)
    j2 : second zonal harmonic coefficient
    j3 : third zonal harmonic coefficient
    Returns
    ----------    
    dxdt : (function handle) time derivative of the state vector
    """

    theta_dot = EARTH['rotation rate']
    R = EARTH['radius']
    mu = EARTH['mu']
    j2 = EARTH['J2']
    if use_j3:
        j3 = EARTH['J3']
    else:
        j3 = 0

    # Symbolic derivations
    x, y, z, mus, j2s, j3s, Rs, t = sym.symbols('x y z mus j2s j3s Rs t', real=True)
    theta_dots = sym.symbols('theta_dots', real=True)
    xdot,ydot,zdot = sym.symbols('xdot ydot zdot ', real=True)

    dxdt = sym.lambdify( (x,y,z,xdot,ydot,zdot,t), dxdt_s().subs([
        (theta_dots, theta_dot),(Rs, R),(j2s,j2),(j3s,j3),(mus,mu)]), modules='numpy')

    return dxdt

person Nick Sweet    schedule 26.03.2016    source источник
comment
Я думаю, вам нужно ускорить dxdt и Amat. Возможно, используйте codegen в sympy для генерации кода C или Fortran и создания dxdt_interface с помощью cython.   -  person HYRY    schedule 27.03.2016


Ответы (1)


С dxdt и Amat в качестве черных ящиков вы мало что можете сделать, чтобы ускорить это. Одна из возможностей — упростить их вызов. hstack может быть излишним.

In [355]: def dxdt_quiet(*args):
    x=args
    return x
   .....: 
In [356]: t=1.23
In [357]: dxdt_quiet(*xs)
Out[357]: (0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.23)
In [358]: dxdt_quiet(*tuple(x[:6])+(t,))
Out[358]: (0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.23)

кортежный подход немного быстрее:

In [359]: timeit dxdt_quiet(*tuple(x[:6])+(t,))
100000 loops, best of 3: 5.1 µs per loop
In [360]: %%timeit
xs=np.hstack((x[:6],1.234))
dxdt_quiet(*xs)
   .....: 
10000 loops, best of 3: 25.4 µs per loop

Я бы провел больше подобных тестов, чтобы оптимизировать вызовы dxdt_interface.

person hpaulj    schedule 27.03.2016
comment
Только что отредактировал свой пост, чтобы включить функции создания dxdt и Amat. Я также заменил вызов hstack на следующий: X_stacked = tuple(list(X[:6]) + [t]), но не заметил значительного ускорения. - person Nick Sweet; 27.03.2016