Если бы я суммировал, что такое Google Jax, я бы сказал, что это неоднородная смесь стиля функционального программирования (FP) и дифференцируемых операций NumPy, выполняемых на ускорителях.
Знакомство с NumPy и FP делает его особенным. Его способ действия без побочных эффектов делает его, так сказать, безопасным. Вам не разрешается производить мутации, то есть модификацию места. Некоторые могут подумать, что это может снизить производительность, но обычно это не так; компилятор позаботится об этом. Он предлагает механизм асинхронной отправки, при котором не нужно ждать завершения вычислений, прежде чем управление будет передано обратно пользователю. По сути, мы получаем объект будущего, который не связан с вычислением (обещание). Эта парадигма способствует гибкости и распределенным вычислениям. Еще одним важным моментом является своевременная компиляция (jit), которая позволяет нам скомпилировать несколько операций вместе, используя XLA (оптимизированный компилятор линейной алгебры). Кроме того, мы получаем векторизованную карту, доступную через vmap API.
Давайте поговорим о функции grad, которую предлагает jax. С точки зрения FP,
grad :: Differentiable f => f -> f’
То есть для дифференцируемой функции f мы получаем ее градиент. Grad (f) - это функция, которая вычисляет градиент, а grad (f) (x) - градиент f, вычисляемый в x. Чтобы проиллюстрировать, как можно использовать градиент и vmap вместе, вот простая функция -
jax.grad (f) (x) - градиент f, вычисленный в x
Повсеместная карта определяется как:
map :: (a -> b) -> [a] -> [b]
Расширенный vmap делает еще один шаг вперед, и мы получаем преимущества автоматической векторизации. По умолчанию ось нулевого массива используется для отображения всех аргументов. Вот некоторые варианты использования vmap:
mat = random.normal(key, (150,100)) batched_x = random.normal(key, (10, 100)) def apply_matrix(v): return jnp.vdot(mat, v) vmap(lambda mat,v: jnp.dot(mat, v), (None,0) ) (mat, batched_x) vmap(lambda v: jnp.dot(mat, v), 0) (batched_x) (vmap(lambda v: jnp.dot(mat, v), 1, 0) (random.normal(key, (100, 10)))) (vmap(lambda v: jnp.dot(mat, v), 1, 1) (random.normal(key, (100, 10)))) vv = lambda v1, v2: jnp.vdot(v1,v2) mv = vmap(vv, (0,None), 0) #([b,a], [a]) -> [b] mm = vmap(mv, (None, 1), 0) # Note: (None, 0), normally. Here, we have unusual (10, None) shape mm(mat, batched_x.T) vmap(mv, (None, 0), 0) (mat, batched_x)
Стоит упомянуть особенность, заключающуюся в том, что мы можем зарегистрировать наши собственные типы данных, реализовав интерфейс Pytree. Pytree - это древовидная структура, построенная из контейнерных объектов Python. Таким образом, преобразования функций JAX могут быть применены к функциям, которые принимают как входные и производят как выходные пирамиды массивов.
from jax.tree_util import register_pytree_node @register_pytree_node_class class Point: def __init__(self, x, y, z): self.x = x self.y = y self.z = z def __repr__(self): return f"Point({self.x}, {self.y}, {self.z})" def tree_flatten(self): return ((self.x, self.y, self.z), None) @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children)
Теперь мы можем определить произвольные функции, которые работают с нашим типом данных, и сделать их дифференцируемыми:
@jit def dist_orig(pt: Point): return jnp.sqrt(pt.x**2 + pt.y**2 + pt.z**2) grad(dist_orig)(Point(1., 2., 3.))
Давайте поговорим о двух фундаментальных операциях: векторном произведении Якоби и произведении вектора-Якобейна.
JVP - это проекция данного вектора на матрицу Якоби оператора. Он фиксирует важную информацию о локальной геометрии отображения ввода-вывода глубокой нейронной сети (DN), что является одной из основных причин его популярности. К сожалению, JVP требуют больших вычислительных ресурсов для реальных архитектур DN.
import jax.numpy as jnp from jax import grad, jit, vmap from jax import random key = random.PRNGKey(0) # Linear logistic model def predict(W, b, inputs): return sigmoid(jnp.dot(inputs, W) + b) # inputs is data matrix inputs = jnp.array([[0.52, 1.12, 0.77], [0.88, -1.08, 0.15], [0.52, 0.06, -1.39], [0.74, -2.49, 1.39]]) targets = jnp.array([True, True, False, True]) # loss is a scalar def loss(W, b): preds = predict(W, b, inputs) label_probs = preds * targets + (1 - preds)*(1 - targets) return -jnp.sum(jnp.log(label_probs)) key, W_key, b_key = random.split(key, 3) W = random.normal(W_key, (3,)) b = random.normal(b_key, ()) from jax import jvp # Isolate the function from the weight matrix to the predictions def f(W): return predict(W, b, inputs) key, subkey = random.split(key) v = random.normal(subkey, W.shape) # Push forward the vector `v` along `f` evaluated at `W` y, u = jvp(lambda W: predict(W, b, inputs), primals=(W,), tangents=(v,))
Векторные произведения Якоби образуют основу авто-дифференцирования в обратном режиме.
vjpfun
- это функция от котангенсного вектора той же формы, что и primals_out
, до кортежа котангенсных векторов той же формы, что и primals
, представляющая векторно-якобианское произведение fun
по оценке primals
.
from jax import vjp y, vjp_fun = vjp(lambda W: predict(W, b, inputs), W) key, subkey = random.split(key) u = random.normal(subkey, y.shape) identity = jnp.eye(*y.shape, dtype=jnp.float32) # Pull back the covector `u` along `f` evaluated at `W` vjp_fun(u) print("Recovering Jacobian elements row-wise!") print(vjp_fun(identity[0]), vjp_fun(identity[1]), vjp_fun(identity[2]), vjp_fun(identity[3]), sep="\n")
В приведенных выше примерах я попытался проиллюстрировать ключевые особенности JAX, включая его подход, основанный на FP, который поощряет компоновку и приводит к чистому коду. Ознакомьтесь с официальной документацией, если статья вызвала ваш интерес.