Skip to main content

第5章 自動微分(grad

機械学習では「損失関数の微分(勾配)」が必要不可欠です。JAX はその計算を 自動で、しかも正確に やってくれます。本章では JAX の自動微分機能、jax.grad の使い方を見ていきます。

5.1 まずは数学の復習:微分とは

ある関数 f(x) の微分 f'(x) は、ざっくり言うと x をほんの少し動かすと f(x) がどれくらい変化するか」を表す量 です。たとえば:

  • f(x) = x² のとき → f'(x) = 2x
  • f(x) = sin(x) のとき → f'(x) = cos(x)

機械学習では、「損失(=モデルの予測と正解のズレ)」をパラメータで微分して、その方向にパラメータを少しずつ更新することで学習を進めます(これを 勾配降下法 といいます)。だから微分は機械学習の心臓部なのです。

💬 「微分なんて覚えてないよ…」という方へ 大丈夫です。JAX は微分の公式を自分で覚えてくれているので、あなたは関数を書くだけで、勾配は JAX が出してくれます。

5.2 jax.grad で 1 行微分

import jax
import jax.numpy as jnp

def f(x):
return x ** 2 + 3 * x + 1

df_dx = jax.grad(f)

print(f(2.0)) # 4 + 6 + 1 = 11
print(df_dx(2.0)) # 2x + 3 = 7

jax.grad(f) は「f第一引数で微分した、新しい関数」を返します。Python らしくない感じがするかもしれませんが、これが「関数を変換して、新しい関数を作る」という JAX の哲学です。

5.3 ベクトル入力にもそのまま使える

def loss(w):
# スカラー(1 つの数値)を返す関数
return jnp.sum(w ** 2)

grad_loss = jax.grad(loss)

w = jnp.array([1.0, 2.0, 3.0])
print(grad_loss(w)) # [2. 4. 6.]

ベクトルを渡しても、ちゃんと ベクトルの各要素についての偏微分 をまとめて返してくれます。

⚠️ 大事なルール jax.grad の対象となる関数は スカラー(1 つの数値)を返さなければなりません。これは「微分された結果の形は、入力の形と同じになるべき」という数学的な理由からです。複数の出力を一度に微分したいときは、後述の jacobian などを使います。

5.4 複数引数のうち、どれで微分するか

def f(x, y):
return x ** 2 + 3 * x * y

# デフォルトは第 0 引数(x)について微分
df_dx = jax.grad(f)
print(df_dx(2.0, 5.0)) # 2x + 3y = 19

# y について微分したい場合
df_dy = jax.grad(f, argnums=1)
print(df_dy(2.0, 5.0)) # 3x = 6

# 両方について
df = jax.grad(f, argnums=(0, 1))
print(df(2.0, 5.0)) # (19.0, 6.0)

argnums を指定するだけで、どの引数で微分するかを自由に選べます。

5.5 値と勾配を同時に取る value_and_grad

機械学習では「損失そのもの」と「勾配」の両方が欲しい場面が多いです。value_and_grad を使うと、両方を一度に取れます(しかも、内部で重複計算を避けるので効率的です)。

loss_and_grad = jax.value_and_grad(f)
loss_value, grad_value = loss_and_grad(2.0, 5.0)
print(loss_value, grad_value) # 34.0 19.0

5.6 高階微分も簡単

grad は「普通の関数」を返すので、もう一度 grad を被せれば 2 階微分になります。

f = lambda x: x ** 3
print(jax.grad(f)(2.0)) # 3x² = 12
print(jax.grad(jax.grad(f))(2.0)) # 6x = 12
print(jax.grad(jax.grad(jax.grad(f)))(2.0)) # 6

PyTorch では高階微分のために create_graph=True などのオプションが必要ですが、JAX では何の特別な指定もいりません。grad を重ねるだけ、というシンプルさは JAX の魅力のひとつです。

5.7 ヤコビアン・ヘッシアン

  • jax.jacfwd(f) / jax.jacrev(f):ヤコビ行列(多入力・多出力の微分)
  • jax.hessian(f):ヘッセ行列(2 階偏微分の行列)
def f(x):
return jnp.array([x[0] ** 2, x[0] * x[1]])

x = jnp.array([2.0, 3.0])
print(jax.jacrev(f)(x)) # 2×2 のヤコビ行列

forward モード(jacfwd)と reverse モード(jacrev)の使い分けの原則は次のとおりです。

  • 入力次元 ≪ 出力次元jacfwd
  • 入力次元 ≫ 出力次元jacrev

深層学習はパラメータの数が膨大(=入力次元が大きい)なので、ほぼ常に reverse モード(=grad の中身)が使われます。

5.8 jit との組み合わせ

gradjit は自由に組み合わせられます。実用では「勾配計算 + パラメータ更新」を jit で高速化するのが定番のパターンです。

@jax.jit
def update(params, x, y, lr=0.01):
def loss_fn(p):
pred = p[0] * x + p[1]
return jnp.mean((pred - y) ** 2)
grads = jax.grad(loss_fn)(params)
return params - lr * grads

5.9 まとめ

  • jax.grad(f) で、関数 f の勾配関数を作れる
  • 対象関数は スカラー出力 であること
  • value_and_grad を使えば、損失と勾配を同時に取得できて効率的
  • 何階微分でも grad を重ねれば OK
  • 多入力・多出力なら jacfwd / jacrev、二階なら hessian

➡️ 第6章 ベクトル化(vmap)と並列化(pmap