第14章 移行ケース集:PyTorch → JAX
PyTorch で書かれた既存コードを JAX(+ Flax / Optax)に移行するための 実践レシピ集 です。PyTorch の各機能ごとに、JAX での書き方を左右比較で示します。
📚 前提:PyTorch で
nn.Moduleを継承してモデルを書いた経験があれば、本章はそのまま読めます。JAX 側の流儀(イミュータブル、関数変換、PyTree、PRNG キー)にまだ慣れていない場合は、先に 第9章 PyTorch との比較 と 第7章 PyTree、第8章 乱数 を確認すると、各ケースの「なぜそう書くのか」が腑に落ちます。本章では Flax は新 API の NNX、最適化は Optax、チェックポイントは Orbax を前提に書いています。インストールはそれぞれ
pip install flax optax orbax-checkpointです。
💡 大方針
- テンソル →
jax.Array/jnpnn.Module→ Flax NNX のnnx.Module(または自前の純粋関数 + PyTree)optimizer.step()→ Optax のupdate+apply_updates.backward()→jax.grad/jax.value_and_gradDataLoader→ ふつうの Python イテレータ、または Grain / tf.data
目次
- 14.1 テンソル基本操作
- 14.2 デバイス転送
- 14.3 自動微分
- 14.4 線形層(MLP)
- 14.5 学習ループ全体
- 14.6 損失関数
- 14.7 オプティマイザ
- 14.8 学習率スケジューラ
- 14.9 重みの保存・読み込み
- 14.10 Dropout / BatchNorm の扱い
- 14.11 CNN(畳み込みネット)
- 14.12 RNN / シーケンスの逐次処理
- 14.13 アテンション(Transformer の核)
- 14.14 マスクとパディング
- 14.15 勾配クリッピング・重み減衰
- 14.16 DataLoader
- 14.17 マルチ GPU / 分散学習
- 14.18 学習・評価モードの切り替え
- 14.19 一般的なつまずきポイント
- 14.20 よくある PyTorch 小技の対応表
14.1 テンソル基本操作
PyTorch
import torch
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.zeros((3, 3))
z = torch.randn(4, 4)
a = x + 1
b = z @ z.T
c = z.reshape(2, 8)
d = x.unsqueeze(0) # (3,) → (1, 3)
e = x.squeeze()
f = torch.cat([x, x])
g = torch.stack([x, x], dim=0)
JAX
import jax, jax.numpy as jnp
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.zeros((3, 3))
z = jax.random.normal(jax.random.key(0), (4, 4))
a = x + 1
b = z @ z.T
c = z.reshape(2, 8)
d = x[None, :] # = jnp.expand_dims(x, 0)
e = jnp.squeeze(x)
f = jnp.concatenate([x, x])
g = jnp.stack([x, x], axis=0)
主な対応表
| PyTorch | JAX |
|---|---|
torch.tensor | jnp.array |
torch.zeros / ones / randn | jnp.zeros / ones、乱数は jax.random.* |
x.unsqueeze(dim) | x[None] または jnp.expand_dims(x, dim) |
x.squeeze() | jnp.squeeze(x) |
torch.cat | jnp.concatenate |
torch.stack | jnp.stack |
x.view(...) | x.reshape(...) |
x.permute(...) | jnp.transpose(x, axes=...) |
14.2 デバイス転送
PyTorch
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(3).to(device)
y = x.cpu().numpy()
JAX
import jax, jax.numpy as jnp, numpy as np
# JAX 配列はデフォルトで利用可能なアクセラレータに置かれる
x = jax.random.normal(jax.random.key(0), (3,))
print(x.devices()) # 例: {CudaDevice(id=0)}
# 明示的にデバイス指定
cpu = jax.devices("cpu")[0]
x_on_cpu = jax.device_put(x, cpu)
# NumPy に変換
y = np.asarray(x)
14.3 自動微分
PyTorch
import torch
x = torch.tensor(2.0, requires_grad=True)
y = x ** 3 + 3 * x
y.backward()
print(x.grad) # 3x² + 3 = 15
JAX
import jax
f = lambda x: x ** 3 + 3 * x
print(jax.grad(f)(2.0)) # 15.0
損失 + 勾配を同時に
# PyTorch
loss = loss_fn(model(x), y)
loss.backward()
grad = [p.grad for p in model.parameters()]
# JAX
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
14.4 線形層(MLP)
PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, din, dhid, dout):
super().__init__()
self.l1 = nn.Linear(din, dhid)
self.l2 = nn.Linear(dhid, dout)
def forward(self, x):
return self.l2(F.relu(self.l1(x)))
model = MLP(10, 32, 1)
y = model(torch.randn(4, 10))
JAX + Flax (NNX)
import jax
from flax import nnx
class MLP(nnx.Module):
def __init__(self, din, dhid, dout, *, rngs: nnx.Rngs):
self.l1 = nnx.Linear(din, dhid, rngs=rngs)
self.l2 = nnx.Linear(dhid, dout, rngs=rngs)
def __call__(self, x):
return self.l2(nnx.relu(self.l1(x)))
model = MLP(10, 32, 1, rngs=nnx.Rngs(0))
y = model(jax.random.normal(jax.random.key(1), (4, 10)))
JAX 本体だけで書く(純粋関数版)
import jax, jax.numpy as jnp
def init_mlp(key, din, dhid, dout):
k1, k2 = jax.random.split(key)
return {
"W1": jax.random.normal(k1, (din, dhid)) * 0.1,
"b1": jnp.zeros(dhid),
"W2": jax.random.normal(k2, (dhid, dout)) * 0.1,
"b2": jnp.zeros(dout),
}
def mlp(params, x):
h = jnp.maximum(0, x @ params["W1"] + params["b1"]) # ReLU
return h @ params["W2"] + params["b2"]
14.5 学習ループ全体
PyTorch
import torch, torch.nn as nn
model = MLP(10, 32, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()
for x, y in dataloader:
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
JAX + Flax NNX + Optax
import jax, jax.numpy as jnp
import optax
from flax import nnx
model = MLP(10, 32, 1, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
return jnp.mean((model(x) - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
for x, y in dataloader:
loss = train_step(model, optimizer, x, y)
純粋関数スタイル(Flax を使わない場合)
import jax, jax.numpy as jnp, optax
def loss_fn(params, x, y):
return jnp.mean((mlp(params, x) - y) ** 2)
opt = optax.adam(1e-3)
params = init_mlp(jax.random.key(0), 10, 32, 1)
opt_state = opt.init(params)
@jax.jit
def step(params, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = opt.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
for x, y in dataloader:
params, opt_state, loss = step(params, opt_state, x, y)
14.6 損失関数
| PyTorch | JAX / Optax |
|---|---|
nn.MSELoss() | jnp.mean((pred - y) ** 2) または optax.l2_loss(pred, y).mean() |
nn.L1Loss() | jnp.mean(jnp.abs(pred - y)) |
nn.CrossEntropyLoss() | optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean() |
nn.BCEWithLogitsLoss() | optax.sigmoid_binary_cross_entropy(logits, labels).mean() |
nn.NLLLoss() | -jnp.take_along_axis(log_probs, labels[..., None], axis=-1).mean() |
例:
import optax
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
14.7 オプティマイザ
| PyTorch | Optax |
|---|---|
optim.SGD(lr=0.1, momentum=0.9) | optax.sgd(0.1, momentum=0.9) |
optim.Adam(lr=1e-3) | optax.adam(1e-3) |
optim.AdamW(lr=1e-3, weight_decay=0.01) | optax.adamw(1e-3, weight_decay=0.01) |
optim.RMSprop(lr=1e-3) | optax.rmsprop(1e-3) |
勾配クリッピング + AdamW のチェイン
import optax
opt = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(1e-3, weight_decay=0.01),
)
PyTorch の「optimizer と clip_grad_norm_ を別々に呼ぶ」のと違って、Optax では 複数の変換をパイプとして繋ぐ スタイルです。
14.8 学習率スケジューラ
PyTorch
import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
for step in range(1000):
train_step(...)
scheduler.step()
Optax(学習率自体がスケジュール関数)
import optax
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=1e-3,
warmup_steps=100,
decay_steps=1000,
)
opt = optax.adam(learning_rate=schedule)
PyTorch は「optimizer の lr を後から書き換える」感じですが、Optax では 「step 数 → 学習率」という関数 を learning_rate 引数に渡すだけ、と発想がスッキリしています。
14.9 重みの保存・読み込み
PyTorch
torch.save(model.state_dict(), "model.pt")
model.load_state_dict(torch.load("model.pt"))
JAX(Orbax を推奨)
import orbax.checkpoint as ocp
ckptr = ocp.StandardCheckpointer()
# 保存
ckptr.save("/tmp/ckpt", args=ocp.args.StandardSave(params))
# 復元
restored = ckptr.restore("/tmp/ckpt", args=ocp.args.StandardRestore(params))
簡易的には pickle でも済みますが、本番では Orbax(公式の Checkpoint ライブラリ)が標準です。
14.10 Dropout / BatchNorm の扱い
PyTorch では model.train() / model.eval() で内部状態を切り替えますが、JAX では 「乱数キーを渡す」「学習時かどうかをフラグで渡す」のように、状態を明示的に扱う のが流儀です。
PyTorch
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
self.drop = nn.Dropout(0.5)
def forward(self, x):
return self.drop(self.fc(x))
model.train() # ドロップアウト ON
y = model(x)
model.eval() # ドロップアウト OFF
y = model(x)
JAX + Flax NNX
from flax import nnx
class Net(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.fc = nnx.Linear(10, 10, rngs=rngs)
self.drop = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
return self.drop(self.fc(x))
model = Net(rngs=nnx.Rngs(0))
model.train() # 学習モード
y = model(x)
model.eval() # 推論モード
y = model(x)
BatchNorm(移動平均統計)
- PyTorch:
running_mean/running_varがモジュールに自動で蓄積される - Flax NNX:
nnx.BatchNormを使うと同様に内部状態を持つ - 純関数型(Flax Linen など)の場合は 「学習用統計」を別の PyTree として返り値で受け取る スタイル
14.11 CNN(畳み込みネット)
PyTorch
import torch.nn as nn
import torch.nn.functional as F
class Conv(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.c2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2)
self.fc = nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.pool(F.relu(self.c1(x)))
x = self.pool(F.relu(self.c2(x)))
return self.fc(x.flatten(1))
JAX + Flax NNX
import jax.numpy as jnp
from flax import nnx
class Conv(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
# Flax は NHWC(channels last)が標準
self.c1 = nnx.Conv(3, 32, kernel_size=(3, 3), padding="SAME", rngs=rngs)
self.c2 = nnx.Conv(32, 64, kernel_size=(3, 3), padding="SAME", rngs=rngs)
self.fc = nnx.Linear(64 * 8 * 8, 10, rngs=rngs)
def __call__(self, x):
x = nnx.max_pool(nnx.relu(self.c1(x)), window_shape=(2, 2), strides=(2, 2))
x = nnx.max_pool(nnx.relu(self.c2(x)), window_shape=(2, 2), strides=(2, 2))
return self.fc(x.reshape(x.shape[0], -1))
重要な違い:チャンネル次元の位置
| 画像入力の形 | |
|---|---|
| PyTorch | (N, C, H, W)(NCHW) |
| JAX / Flax | (N, H, W, C)(NHWC) |
PyTorch から移植するときは x = x.transpose(0, 2, 3, 1) で入れ替える のを忘れずに。
14.12 RNN / シーケンスの逐次処理
PyTorch の nn.LSTM のように「内部で時刻方向にループを回す」処理は、JAX では jax.lax.scan がそのまま当てはまります。
PyTorch(自分で 1 ステップずつ回す例)
def step(h, x):
return torch.tanh(W_h @ h + W_x @ x)
h = torch.zeros(hidden_dim)
for t in range(T):
h = step(h, xs[t])
JAX(scan でループを XLA に渡す)
def step(h, x_t):
new_h = jnp.tanh(W_h @ h + W_x @ x_t)
return new_h, new_h # (次状態, 出力)
h0 = jnp.zeros(hidden_dim)
h_final, hs = jax.lax.scan(step, h0, xs) # xs: (T, input_dim)
scan を使うとコンパイル時間も実行時間も劇的に縮みます。
14.13 アテンション(Transformer の核)
PyTorch
import torch.nn.functional as F
out = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)
JAX
import jax.nn as jnn
# JAX には scaled dot-product attention が用意されている
# Q/K/V はおおむね (batch, seq, heads, dim) 形式を想定する
# mask は bool で、True が「参加する」位置。加算マスクは bias= に渡す。
out = jnn.dot_product_attention(Q, K, V, mask=mask)
自前で書きたい場合は次のように:
import jax.numpy as jnp, jax.nn as jnn
def attention(Q, K, V, mask=None):
d = Q.shape[-1]
logits = (Q @ K.swapaxes(-2, -1)) / jnp.sqrt(d)
if mask is not None:
logits = jnp.where(mask, logits, -jnp.inf)
weights = jnn.softmax(logits, axis=-1)
return weights @ V
14.14 マスクとパディング
PyTorch
mask = (tokens != pad_id) # (B, T)
loss = F.cross_entropy(logits, labels, reduction="none")
loss = (loss * mask).sum() / mask.sum()
JAX
import jax.numpy as jnp, optax
mask = (tokens != pad_id).astype(jnp.float32)
per_token = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
loss = (per_token * mask).sum() / mask.sum()
14.15 勾配クリッピング・重み減衰
PyTorch
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
JAX + Optax
opt = optax.chain(
optax.clip_by_global_norm(1.0), # 勾配クリッピング
optax.adamw(1e-3, weight_decay=0.01),
)
オプティマイザの中に「クリップしてから AdamW」というパイプを書けるのが Optax の良さです。
14.16 DataLoader
JAX には公式の DataLoader 相当はありません。「ミニバッチを返す Python ジェネレータ」を自分で書く か、PyTorch / Grain / tf.data を使います。
PyTorch の DataLoader を JAX 学習で流用する
import torch
from torch.utils.data import DataLoader
import numpy as np
import jax.numpy as jnp
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4,
collate_fn=lambda batch: tuple(np.stack(x) for x in zip(*batch)))
for x_np, y_np in loader:
x, y = jnp.asarray(x_np), jnp.asarray(y_np)
params, opt_state, loss = step(params, opt_state, x, y)
JAX 自体は NumPy 配列を自然に受け取れます。PyTorch の CPU Tensor を使う場合は、DataLoader 側で NumPy 配列にしてから jnp.asarray するのが分かりやすいです。GPU Tensor からコピーを避けたい場合は DLPack 連携(jax.dlpack / torch.utils.dlpack)を検討します。
14.17 マルチ GPU / 分散学習
PyTorch DDP
import torch.distributed as dist
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
JAX:データ並列を sharding で表現
import jax, jax.numpy as jnp, numpy as np
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
devices = np.array(jax.devices())
mesh = Mesh(devices, ("data",))
# バッチ次元をデバイス分割。例: x_global は (global_batch, features)
batch_x = NamedSharding(mesh, P("data", None))
batch_y = NamedSharding(mesh, P("data"))
replicated = NamedSharding(mesh, P()) # パラメータなどを全デバイスに複製
x = jax.device_put(x_global, batch_x)
y = jax.device_put(y_global, batch_y)
params = jax.device_put(params, replicated)
@jax.jit
def step(params, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
# params / grads は PyTree なので、葉っぱごとに更新する
params = jax.tree_util.tree_map(lambda p, g: p - 1e-3 * g, params, grads)
return params, loss
jit + sharding を使えば、PyTorch DDP のような 明示的な通信コード を手で書かなくても、JAX/XLA が必要な集約や通信を含む実行計画を作ります。ただし、巨大モデルの FSDP/ZeRO 相当まで踏み込む場合は、パラメータ sharding、optimizer state sharding、shard_map などを設計する必要があります。
14.18 学習・評価モードの切り替え
PyTorch
model.train()
... # 学習
model.eval()
with torch.no_grad():
... # 評価
JAX
- Dropout / BatchNorm がない場合は、そもそも切り替える必要なし
- ある場合は、Flax NNX なら
model.train()/model.eval()、Linen ならdeterministic=True/Falseを関数引数で渡す - 「勾配を計算しない」は
with torch.no_grad():のようなコンテキストは不要。そもそもjax.gradを呼ばなければ勾配は計算されません
# 評価時は jit した forward だけを呼ぶ
@jax.jit
def predict(params, x):
return mlp(params, x)
14.19 一般的なつまずきポイント
- チャンネル次元の順序:PyTorch (NCHW) ↔ Flax (NHWC) を入れ替え忘れる
requires_gradがない:JAX では「微分したい関数」をjax.gradに渡すだけ。フラグ管理は不要optimizer.zero_grad()がない:JAX には「勾配を持つテンソル」が無いので、ゼロ化も不要model.parameters()がない:パラメータは PyTree(dict / dataclass) として明示的に持つ- 状態の暗黙更新がない:BatchNorm の
running_meanなどは「引数で渡して、返り値で受け取る」スタイル with torch.no_grad():がない:「gradを呼ばない限り勾配は計算されない」- データセットは自前のイテレータでも OK:PyTorch の DataLoader を流用するのも実用的
14.20 よくある PyTorch 小技の対応表
PyTorch の日常的な書き方を、JAX ではどう考えるかをまとめます。
| PyTorch | JAX | コメント |
|---|---|---|
x.detach() | jax.lax.stop_gradient(x) | 勾配をそこで止めたい場合だけ使う |
x.item() | float(x) / int(x) | jit 内では Python 値に戻せないので注意 |
x.clone() | 基本不要 | JAX 配列は immutable。どうしてもコピーなら jnp.array(x, copy=True) |
x.contiguous() | 基本不要 | JAX/XLA がレイアウトを管理する |
x.requires_grad_(False) | なし | grad の対象にしなければ勾配は取られない |
torch.no_grad() | なし | jax.grad を呼ばない限り勾配計算はしない |
torch.where(mask, a, b) | jnp.where(mask, a, b) | ほぼそのまま |
x.masked_fill(mask, v) | jnp.where(mask, v, x) | mask が True のところを置換 |
torch.gather(x, dim, index) | jnp.take_along_axis(x, index, axis=dim) | index の shape に注意 |
scatter_add_ | x = x.at[idx].add(v) | JAX は戻り値を受け取る |
torch.nn.functional.one_hot(y, C) | jax.nn.one_hot(y, C) | dtype は必要に応じて指定 |
torch.einsum(eq, ...) | jnp.einsum(eq, ...) | ほぼそのまま |
torch.vmap(f) | jax.vmap(f) | JAX では最重要 API のひとつ |
torch.compile(f) | jax.jit(f) | JAX では「最初から jit 前提」で設計することが多い |
torch.utils.checkpoint | jax.checkpoint / jax.remat | メモリ節約の再計算 |
| AMP / autocast | 明示的 dtype / Flax の dtype 設計 | bfloat16/float32 を層や配列ごとに設計する |
detach() の置き換え例
PyTorch では、テンソルを計算グラフから切り離すために detach() をよく使います。
# PyTorch
z = encoder(x).detach()
loss = decoder_loss(decoder(z), y)
JAX では、勾配を止めたい場所に jax.lax.stop_gradient を置きます。
# JAX
import jax
z = jax.lax.stop_gradient(encoder(params_enc, x))
loss = decoder_loss(params_dec, z, y)
gather / scatter の置き換え例
import jax.numpy as jnp
# JAX: gather
picked = jnp.take_along_axis(x, indices, axis=1)
# JAX: scatter add
out = jnp.zeros((10,)).at[indices].add(values)
このあたりは「PyTorch では tensor method / in-place 操作、JAX では jnp.* 関数や .at[...] 更新」と覚えると移行しやすいです。
まとめ:PyTorch 移行の指針
| 観点 | PyTorch | JAX |
|---|---|---|
| モデル | nn.Module クラス | Flax NNX の nnx.Module または純粋関数 + PyTree |
| パラメータ | self.weight = nn.Parameter | PyTree(dict / dataclass) |
| 微分 | loss.backward() | jax.value_and_grad(loss_fn)(params, ...) |
| 最適化 | optimizer.step() | optax.update + apply_updates |
| デバイス | .to("cuda") | デフォルトでアクセラレータに置かれる |
| 並列化 | DDP / FSDP | jit + jax.sharding |
| 状態 | モジュール内部に持つ | 引数・返り値で明示的にやり取り |
「状態を関数の外に出し、関数は純粋に保つ」 という発想に慣れれば、PyTorch の知識のほとんどはそのまま活かせます。