第5章 自動微分(grad)
機械学習では「損失関数の微分(勾配)」が必要不可欠です。JAX はその計算を 自動で、しかも正確に やってくれます。本章では JAX の自動微分機能、jax.grad の使い方を見ていきます。
5.1 まずは数学の復習:微分とは
ある関数 f(x) の微分 f'(x) は、ざっくり言うと 「x をほんの少し動かすと f(x) がどれくらい変化するか」を表す量 です。たとえば:
f(x) = x²のとき →f'(x) = 2xf(x) = sin(x)のとき →f'(x) = cos(x)
機械学習では、「損失(=モデルの予測と正解のズレ)」をパラメータで微分して、その方向にパラメータを少しずつ更新することで学習を進めます(これを 勾配降下法 といいます)。だから微分は機械学習の心臓部なのです。
💬 「微分なんて覚えてないよ…」という方へ 大丈夫です。JAX は微分の公式を自分で覚えてくれているので、あなたは関数を書くだけで、勾配は JAX が出してくれます。
5.2 jax.grad で 1 行微分
import jax
import jax.numpy as jnp
def f(x):
return x ** 2 + 3 * x + 1
df_dx = jax.grad(f)
print(f(2.0)) # 4 + 6 + 1 = 11
print(df_dx(2.0)) # 2x + 3 = 7
jax.grad(f) は「f を 第一引数で微分した、新しい関数」を返します。Python らしくない感じがするかもしれませんが、これが「関数を変換して、新しい関数を作る」という JAX の哲学です。
5.3 ベクトル入力にもそのまま使える
def loss(w):
# スカラー(1 つの数値)を返す関数
return jnp.sum(w ** 2)
grad_loss = jax.grad(loss)
w = jnp.array([1.0, 2.0, 3.0])
print(grad_loss(w)) # [2. 4. 6.]
ベクトルを渡しても、ちゃんと ベクトルの各要素についての偏微分 をまとめて返してくれます。
⚠️ 大事なルール
jax.gradの対象となる関数は スカラー(1 つの数値)を返さなければなりません。これは「微分された結果の形は、入力の形と同じになるべき」という数学的な理由からです。複数の出力を一度に微分したいときは、後述のjacobianなどを使います。
5.4 複数引数のうち、どれで微分するか
def f(x, y):
return x ** 2 + 3 * x * y
# デフォルトは第 0 引数(x)について微分
df_dx = jax.grad(f)
print(df_dx(2.0, 5.0)) # 2x + 3y = 19
# y について微分したい場合
df_dy = jax.grad(f, argnums=1)
print(df_dy(2.0, 5.0)) # 3x = 6
# 両方について
df = jax.grad(f, argnums=(0, 1))
print(df(2.0, 5.0)) # (19.0, 6.0)
argnums を指定するだけで、どの引数で微分するかを自由に選べます。
5.5 値と勾配を同時に取る value_and_grad
機械学習では「損失そのもの」と「勾配」の両方が欲しい場面が多いです。value_and_grad を使うと、両方を一度に取れます(しかも、内部で重複計算を避けるので効率的です)。
loss_and_grad = jax.value_and_grad(f)
loss_value, grad_value = loss_and_grad(2.0, 5.0)
print(loss_value, grad_value) # 34.0 19.0
5.6 高階微分も簡単
grad は「普通の関数」を返すので、もう一度 grad を被せれば 2 階微分になります。
f = lambda x: x ** 3
print(jax.grad(f)(2.0)) # 3x² = 12
print(jax.grad(jax.grad(f))(2.0)) # 6x = 12
print(jax.grad(jax.grad(jax.grad(f)))(2.0)) # 6
PyTorch では高階微分のために create_graph=True などのオプションが必要ですが、JAX では何の特別な指定もいりません。grad を重ねるだけ、というシンプルさは JAX の魅力のひとつです。
5.7 ヤコビアン・ヘッシアン
jax.jacfwd(f)/jax.jacrev(f):ヤコビ行列(多入力・多出力の微分)jax.hessian(f):ヘッセ行列(2 階偏微分の行列)
def f(x):
return jnp.array([x[0] ** 2, x[0] * x[1]])
x = jnp.array([2.0, 3.0])
print(jax.jacrev(f)(x)) # 2×2 のヤコビ行列
forward モード(jacfwd)と reverse モード(jacrev)の使い分けの原則は次のとおりです。
- 入力次元 ≪ 出力次元 →
jacfwd - 入力次元 ≫ 出力次元 →
jacrev
深層学習はパラメータの数が膨大(=入力次元が大きい)なので、ほぼ常に reverse モード(=grad の中身)が使われます。
5.8 jit との組み合わせ
grad と jit は自由に組み合わせられます。実用では「勾配計算 + パラメータ更新」を jit で高速化するのが定番のパターンです。
@jax.jit
def update(params, x, y, lr=0.01):
def loss_fn(p):
pred = p[0] * x + p[1]
return jnp.mean((pred - y) ** 2)
grads = jax.grad(loss_fn)(params)
return params - lr * grads
5.9 まとめ
jax.grad(f)で、関数fの勾配関数を作れる- 対象関数は スカラー出力 であること
value_and_gradを使えば、損失と勾配を同時に取得できて効率的- 何階微分でも
gradを重ねれば OK - 多入力・多出力なら
jacfwd/jacrev、二階ならhessian