Skip to main content

第15章 移行ケース集:応用と複合パターン

第13章・14章では「1 対 1」の置き換えを中心に見ました。本章では、もう少し大きな単位の処理JAX らしさが活きる応用パターン をケーススタディとして紹介します。「同じ目的を、JAX ならどう書くか?」の引き出しを広げる章です。

📚 前提:第13章・第14章で扱った置き換えに加えて、jax.vmap / jax.lax.scan / jax.grad / jit の組み合わせに親しんでいると、本章のケースが「なぜ短く書けるのか」がよく分かります。基礎が不安なら 第4章第5章第6章 を先に読んでください。

一部のサンプルでは外部ライブラリ(NumPyro、Gymnax など)を使います。必要なときだけ pip install numpyro gymnax などをしてください。

目次


15.1 ミニバッチ学習ループ全体(純 JAX)

PyTorch の標準的な学習ループを JAX 本体だけ で書くと、こんな形になります。

import jax, jax.numpy as jnp
import numpy as np

# --- データ ---
key = jax.random.key(0)
N, D = 1024, 20
X = jax.random.normal(jax.random.key(1), (N, D))
y = (X @ jnp.ones(D) > 0).astype(jnp.float32)

# --- モデル(純粋関数) ---
def init(key):
return {"W": 0.01 * jax.random.normal(key, (D, 1)), "b": jnp.zeros(1)}

def predict(params, x):
return jax.nn.sigmoid(x @ params["W"] + params["b"]).squeeze(-1)

def loss_fn(params, x, y):
p = predict(params, x)
eps = 1e-7
return -jnp.mean(y * jnp.log(p + eps) + (1 - y) * jnp.log(1 - p + eps))

@jax.jit
def step(params, x, y, lr=0.1):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)
return params, loss

# --- ミニバッチを返すジェネレータ ---
def batches(key, X, y, bs):
n = X.shape[0]
idx = jax.random.permutation(key, n)
for i in range(0, n, bs):
b = idx[i:i + bs]
yield X[b], y[b]

# --- 学習ループ ---
params = init(jax.random.key(2))
key_iter = jax.random.key(3)
for epoch in range(20):
key_iter, sub = jax.random.split(key_iter)
for xb, yb in batches(sub, X, y, bs=64):
params, loss = step(params, xb, yb)
print(f"epoch {epoch:02d} loss={float(loss):.4f}")

ポイント

  • 毎エポック、jax.random.split で新しいシャッフル用 key を作る
  • jit した step を呼ぶだけのシンプル構成
  • PyTorch でいう optimizer.zero_grad()loss.backward() も不要

15.2 アンサンブル学習:vmap でモデルを並列化

「同じネットワークを、初期値だけ変えて N 個並列に学習する」のが、JAX なら vmap で本当に 1 行で書けます。

import jax, jax.numpy as jnp

D, H, K = 20, 32, 3 # 入力次元、隠れ、アンサンブル数

# 1 モデル分の初期化
def init_one(key):
k1, k2 = jax.random.split(key)
return {
"W1": 0.1 * jax.random.normal(k1, (D, H)),
"W2": 0.1 * jax.random.normal(k2, (H, 1)),
}

# K 個のモデルを並列に初期化
keys = jax.random.split(jax.random.key(0), K)
ensemble_params = jax.vmap(init_one)(keys) # 各テンソルが先頭次元 K で揃う

# 1 モデル分の predict
def predict_one(params, x):
return jnp.tanh(x @ params["W1"]) @ params["W2"]

# 同じ入力 x に対して K モデルの予測を一気に取得
def predict_ensemble(ensemble_params, x):
return jax.vmap(predict_one, in_axes=(0, None))(ensemble_params, x)

x = jax.random.normal(jax.random.key(1), (8, D))
preds = predict_ensemble(ensemble_params, x) # shape: (K, 8, 1)
print(preds.shape)

ポイント

  • 「モデル = パラメータ(PyTree)」だからこそ、vmap で並べられる
  • PyTorch では nn.ModuleList を作って for ループで回す必要があり、ここまで自然に書けない

15.3 ハイパーパラメータ並列探索

学習率を 16 通り並列に試すような実験も、vmap で簡潔に書けます。

import jax, jax.numpy as jnp

def train_with(lr, key):
params0 = init(key)
# Python for で 100 回 unroll せず、scan で「学習を 100 回繰り返す」計算として渡す
def body(params, _):
params, _ = step(params, X, y, lr=lr)
return params, None
params, _ = jax.lax.scan(body, params0, None, length=100)
return loss_fn(params, X, y)

lrs = jnp.logspace(-4, 0, 16)
keys = jax.random.split(jax.random.key(0), 16)

# 16 個の学習を並列に
final_losses = jax.vmap(train_with)(lrs, keys)
print(final_losses)

pmap を使えば「16 通りを 8 GPU に分散」も同じノリで書けます。


15.4 メタ学習:勾配の勾配

MAML のようなメタ学習では「内側の学習ステップで得たパラメータについて、外側の損失を勾配計算する」必要があります。PyTorch では create_graph=True を意識する必要がありますが、JAX なら grad を 2 重に被せるだけ です。

import jax, jax.numpy as jnp

def inner_loss(params, x, y):
return jnp.mean((mlp(params, x) - y) ** 2)

# 1 ステップ内側更新
def inner_step(params, x, y, lr=0.01):
grads = jax.grad(inner_loss)(params, x, y)
return jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)

# 外側損失:内側更新後の損失を、初期パラメータの関数として見る
def outer_loss(init_params, task):
x_tr, y_tr, x_te, y_te = task
updated = inner_step(init_params, x_tr, y_tr)
return inner_loss(updated, x_te, y_te)

meta_grad = jax.grad(outer_loss) # ← 内側の grad を通した勾配

grad が自然に合成されるところが JAX の強さです。


15.5 ベイズ推論:MCMC / NUTS

NumPyro(JAX ベースの確率的プログラミングライブラリ)を使うと、ベイズ推論も簡潔に書けます。

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax, jax.numpy as jnp

def model(X, y=None):
w = numpyro.sample("w", dist.Normal(jnp.zeros(X.shape[1]), 1.0))
b = numpyro.sample("b", dist.Normal(0.0, 1.0))
logits = X @ w + b
numpyro.sample("y", dist.Bernoulli(logits=logits), obs=y)

mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=1000)
mcmc.run(jax.random.key(0), X=X, y=y)
mcmc.print_summary()

JAX の JIT + 自動微分 + ベクトル化 によって、NUTS のような勾配ベース MCMC も高速化しやすくなります。


15.6 強化学習:環境のベクトル化ロールアウト

Gym 風の環境を vmap で並列ロールアウトするのが、JAX 強化学習の真骨頂です。Gymnax を使うと簡単に書けます。

import gymnax, jax, jax.numpy as jnp

env, env_params = gymnax.make("CartPole-v1")

def rollout(key, policy_params, steps=200):
obs, state = env.reset(key, env_params)
def body(carry, _):
obs, state, key = carry
key, k_a, k_s = jax.random.split(key, 3)
action = policy(policy_params, obs, k_a)
next_obs, next_state, reward, done, info = env.step(k_s, state, action, env_params)
return (next_obs, next_state, key), reward
(_, _, _), rewards = jax.lax.scan(body, (obs, state, key), None, length=steps)
return rewards.sum()

# 1024 並列でロールアウト → 平均報酬
keys = jax.random.split(jax.random.key(0), 1024)
returns = jax.vmap(rollout, in_axes=(0, None))(keys, policy_params)
print(returns.mean())

PyTorch で同じことをやろうとすると multiprocessing などが必要ですが、JAX なら 1 つのプロセスで GPU 上に並列ロールアウト できます。


15.7 物理シミュレーション

scan を使えば、時間発展のシミュレーションを高速に書けます。

import jax, jax.numpy as jnp

# 単振り子:状態 (θ, ω) の時間発展
def step(state, _, g=9.8, L=1.0, dt=0.01):
theta, omega = state
domega = -(g / L) * jnp.sin(theta)
new_state = (theta + omega * dt, omega + domega * dt)
return new_state, new_state

(theta0, omega0) = (jnp.pi / 4, 0.0)
_, traj = jax.lax.scan(step, (theta0, omega0), None, length=1000)

jax.grad を組み合わせれば「シミュレーション結果に対するパラメータの勾配」が取れるので、微分可能シミュレーション が自然に書けます。これは PyTorch でもできますが、JAX の方がコードがすっきりします。


15.8 K-Means クラスタリング

scikit-learn

from sklearn.cluster import KMeans
km = KMeans(n_clusters=4).fit(X_np)
labels = km.labels_
centers = km.cluster_centers_

JAX で自前実装

import jax, jax.numpy as jnp

def kmeans(X, K, n_iter=20, key=jax.random.key(0)):
# 初期重心:データからランダムに選ぶ
idx = jax.random.choice(key, X.shape[0], (K,), replace=False)
centers = X[idx]

def step(centers, _):
# 距離行列:(N, K)
d = jnp.sum((X[:, None, :] - centers[None, :, :]) ** 2, axis=-1)
labels = jnp.argmin(d, axis=-1)
# 各クラスタの平均で更新
new_centers = []
for k in range(K):
count = (labels == k).sum()
mean = jnp.where((labels == k)[:, None], X, 0).sum(axis=0) / jnp.maximum(count, 1)
# 空クラスタは 0 に飛ばさず、前回の中心を維持
new_centers.append(jnp.where(count > 0, mean, centers[k]))
new_centers = jnp.stack(new_centers)
return new_centers, None

centers, _ = jax.lax.scan(step, centers, None, length=n_iter)
labels = jnp.argmin(jnp.sum((X[:, None, :] - centers[None, :, :]) ** 2, axis=-1), axis=-1)
return centers, labels

jit + scan のおかげで、scikit-learn より速いことも珍しくありません。


15.9 ロジスティック回帰(scikit-learn → JAX)

scikit-learn

from sklearn.linear_model import LogisticRegression
clf = LogisticRegression().fit(X, y)
preds = clf.predict(X_test)

JAX

import jax, jax.numpy as jnp, optax

def loss_fn(params, X, y, l2=1e-3):
logits = X @ params["w"] + params["b"]
bce = optax.sigmoid_binary_cross_entropy(logits, y).mean()
return bce + l2 * jnp.sum(params["w"] ** 2)

params = {"w": jnp.zeros(X.shape[1]), "b": jnp.zeros(())}
opt = optax.adam(0.1)
opt_state = opt.init(params)

@jax.jit
def step(params, opt_state, X, y):
loss, grads = jax.value_and_grad(loss_fn)(params, X, y)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss

for _ in range(500):
params, opt_state, loss = step(params, opt_state, X, y)

ご利益

  • 同じパターンでロジスティック回帰、SVM ヒンジ損失、Poisson 回帰など好きな線形モデルが書ける
  • L1 正則化を入れるなら loss_fnl1 * jnp.sum(jnp.abs(params["w"])) を足します。optax.add_decayed_weights は主に L2/weight decay 用です。

15.10 動的計画法:scan で書く

「漸化式を順に評価していく」タイプの計算は、scan の典型的な使いどころです。

NumPy(フィボナッチ数列)

import numpy as np
def fib(n):
a, b = 0, 1
out = np.empty(n)
for i in range(n):
out[i] = a
a, b = b, a + b
return out

JAX(scan で書く + jit 化)

import jax, jax.numpy as jnp
from functools import partial

@partial(jax.jit, static_argnums=(0,))
def fib(n):
def body(carry, _):
a, b = carry
return (b, a + b), a
_, out = jax.lax.scan(body, (jnp.int32(0), jnp.int32(1)), None, length=n)
return out

print(fib(10))

「Python の for で書くと遅い → scan に置き換える」は JAX の鉄板パターンです。


15.11 画像処理:パッチ抽出と畳み込み

NumPy / Pillow で「画像から重ねないパッチを切り出す」

import numpy as np
def patches(img, ph, pw):
H, W, C = img.shape
out = []
for i in range(0, H, ph):
for j in range(0, W, pw):
out.append(img[i:i+ph, j:j+pw])
return np.stack(out)

JAX(reshape + transpose の組み合わせ)

import jax.numpy as jnp

def patches(img, ph, pw):
H, W, C = img.shape
img = img.reshape(H // ph, ph, W // pw, pw, C)
img = img.transpose(0, 2, 1, 3, 4)
return img.reshape(-1, ph, pw, C)

畳み込み(自前で)

import jax, jax.numpy as jnp

# x: (N, H, W, C), w: (kH, kW, in_C, out_C) のとき
def conv2d(x, w, stride=1, padding="SAME"):
return jax.lax.conv_general_dilated(
x, w,
window_strides=(stride, stride),
padding=padding,
dimension_numbers=("NHWC", "HWIO", "NHWC"),
)

PyTorch の F.conv2d は NCHW・OIHW が前提ですが、JAX は dimension_numbers で次元順を明示できる ので柔軟です。


15.12 ニュートン法・最適化ソルバ

NumPy(手書きのニュートン法)

import numpy as np
def newton(f, df, x0, tol=1e-8, max_iter=50):
x = x0
for _ in range(max_iter):
dx = f(x) / df(x)
x -= dx
if abs(dx) < tol:
break
return x

JAX(微分は自動、ループは while_loop

import jax, jax.numpy as jnp

def newton(f, x0, tol=1e-8, max_iter=50):
df = jax.grad(f)
def cond(state):
x, dx, i = state
return (jnp.abs(dx) > tol) & (i < max_iter)
def body(state):
x, _, i = state
dx = f(x) / df(x)
return (x - dx, dx, i + 1)
x, _, _ = jax.lax.while_loop(cond, body, (x0, jnp.array(1.0), 0))
return x

# f(x) = x^2 - 2 → √2 を求める
print(newton(lambda x: x**2 - 2.0, 1.0)) # ≒ 1.41421356

jax.grad で導関数を自動で得られるので、df を手で書く必要がありません。


まとめ:JAX らしい書き方が活きる場面

場面JAX の武器
反復計算・累積lax.scan / fori_loop / while_loop
サンプル方向の並列化vmap
デバイス間の分散pmap / shard_map / jit + sharding
アンサンブル / ハイパラ並列vmap(モデルもバッチ次元として扱う)
メタ学習・微分可能シミュレーションgrad の合成
ベイズ推論 / モンテカルロ法NumPyro + vmap + 鍵分割

「ループ」「並列」「微分」がコードのどこかに 1 つでもあれば、JAX の出番がある と思って書き換えてみると、コードが驚くほどシンプル&高速になることがあります。

➡️ さらに先へ:

➡️ JAX 日本語チュートリアルに戻る