Если бы я суммировал, что такое Google Jax, я бы сказал, что это неоднородная смесь стиля функционального программирования (FP) и дифференцируемых операций NumPy, выполняемых на ускорителях.

Знакомство с NumPy и FP делает его особенным. Его способ действия без побочных эффектов делает его, так сказать, безопасным. Вам не разрешается производить мутации, то есть модификацию места. Некоторые могут подумать, что это может снизить производительность, но обычно это не так; компилятор позаботится об этом. Он предлагает механизм асинхронной отправки, при котором не нужно ждать завершения вычислений, прежде чем управление будет передано обратно пользователю. По сути, мы получаем объект будущего, который не связан с вычислением (обещание). Эта парадигма способствует гибкости и распределенным вычислениям. Еще одним важным моментом является своевременная компиляция (jit), которая позволяет нам скомпилировать несколько операций вместе, используя XLA (оптимизированный компилятор линейной алгебры). Кроме того, мы получаем векторизованную карту, доступную через vmap API.

Давайте поговорим о функции grad, которую предлагает jax. С точки зрения FP,

grad :: Differentiable f => f -> f’

То есть для дифференцируемой функции f мы получаем ее градиент. Grad (f) - это функция, которая вычисляет градиент, а grad (f) (x) - градиент f, вычисляемый в x. Чтобы проиллюстрировать, как можно использовать градиент и vmap вместе, вот простая функция -

jax.grad (f) (x) - градиент f, вычисленный в x

Повсеместная карта определяется как:

map :: (a -> b) -> [a] -> [b]

Расширенный vmap делает еще один шаг вперед, и мы получаем преимущества автоматической векторизации. По умолчанию ось нулевого массива используется для отображения всех аргументов. Вот некоторые варианты использования vmap:

mat = random.normal(key, (150,100))
batched_x = random.normal(key, (10, 100))
def apply_matrix(v):
    return jnp.vdot(mat, v)
vmap(lambda mat,v: jnp.dot(mat, v), (None,0) ) (mat, batched_x)
vmap(lambda v: jnp.dot(mat, v), 0) (batched_x)
(vmap(lambda v: jnp.dot(mat, v), 1, 0) (random.normal(key, (100, 10))))
(vmap(lambda v: jnp.dot(mat, v), 1, 1) (random.normal(key, (100, 10))))
vv = lambda v1, v2: jnp.vdot(v1,v2)
mv = vmap(vv, (0,None), 0) #([b,a], [a]) -> [b]
mm = vmap(mv, (None, 1), 0) # Note: (None, 0), normally. Here, we have unusual (10, None) shape
mm(mat, batched_x.T)
vmap(mv, (None, 0), 0) (mat, batched_x)

Стоит упомянуть особенность, заключающуюся в том, что мы можем зарегистрировать наши собственные типы данных, реализовав интерфейс Pytree. Pytree - это древовидная структура, построенная из контейнерных объектов Python. Таким образом, преобразования функций JAX могут быть применены к функциям, которые принимают как входные и производят как выходные пирамиды массивов.

from jax.tree_util import register_pytree_node
@register_pytree_node_class
class Point:
    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z
def __repr__(self):
        return f"Point({self.x}, {self.y}, {self.z})"
def tree_flatten(self):
        return ((self.x, self.y, self.z), None)
@classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

Теперь мы можем определить произвольные функции, которые работают с нашим типом данных, и сделать их дифференцируемыми:

@jit
def dist_orig(pt: Point):
    return jnp.sqrt(pt.x**2 + pt.y**2 + pt.z**2)
grad(dist_orig)(Point(1., 2., 3.))

Давайте поговорим о двух фундаментальных операциях: векторном произведении Якоби и произведении вектора-Якобейна.

JVP - это проекция данного вектора на матрицу Якоби оператора. Он фиксирует важную информацию о локальной геометрии отображения ввода-вывода глубокой нейронной сети (DN), что является одной из основных причин его популярности. К сожалению, JVP требуют больших вычислительных ресурсов для реальных архитектур DN.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
# Linear logistic model
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)  # inputs is data matrix
inputs = jnp.array([[0.52, 1.12, 0.77],
                    [0.88, -1.08, 0.15],
                    [0.52, 0.06, -1.39],
                    [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# loss is a scalar
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds)*(1 - targets)
    return -jnp.sum(jnp.log(label_probs))
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
from jax import jvp
# Isolate the function from the weight matrix to the predictions
def f(W): return predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(lambda W: predict(W, b, inputs), primals=(W,), tangents=(v,))

Векторные произведения Якоби образуют основу авто-дифференцирования в обратном режиме.

vjpfun - это функция от котангенсного вектора той же формы, что и primals_out, до кортежа котангенсных векторов той же формы, что и primals, представляющая векторно-якобианское произведение fun по оценке primals.

from jax import vjp
y, vjp_fun = vjp(lambda W: predict(W, b, inputs), W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
identity = jnp.eye(*y.shape, dtype=jnp.float32)
# Pull back the covector `u` along `f` evaluated at `W`
vjp_fun(u)
print("Recovering Jacobian elements row-wise!")
print(vjp_fun(identity[0]),
      vjp_fun(identity[1]),
      vjp_fun(identity[2]),
      vjp_fun(identity[3]), sep="\n")

В приведенных выше примерах я попытался проиллюстрировать ключевые особенности JAX, включая его подход, основанный на FP, который поощряет компоновку и приводит к чистому коду. Ознакомьтесь с официальной документацией, если статья вызвала ваш интерес.