Skip to main content

第14章 移行ケース集:PyTorch → JAX

PyTorch で書かれた既存コードを JAX(+ Flax / Optax)に移行するための 実践レシピ集 です。PyTorch の各機能ごとに、JAX での書き方を左右比較で示します。

📚 前提:PyTorch で nn.Module を継承してモデルを書いた経験があれば、本章はそのまま読めます。JAX 側の流儀(イミュータブル、関数変換、PyTree、PRNG キー)にまだ慣れていない場合は、先に 第9章 PyTorch との比較第7章 PyTree第8章 乱数 を確認すると、各ケースの「なぜそう書くのか」が腑に落ちます。

本章では Flax は新 API の NNX、最適化は Optax、チェックポイントは Orbax を前提に書いています。インストールはそれぞれ pip install flax optax orbax-checkpoint です。

💡 大方針

  • テンソルjax.Array / jnp
  • nn.Module → Flax NNX の nnx.Module(または自前の純粋関数 + PyTree)
  • optimizer.step() → Optax の update + apply_updates
  • .backward()jax.grad / jax.value_and_grad
  • DataLoader → ふつうの Python イテレータ、または Grain / tf.data

目次


14.1 テンソル基本操作

PyTorch

import torch

x = torch.tensor([1.0, 2.0, 3.0])
y = torch.zeros((3, 3))
z = torch.randn(4, 4)

a = x + 1
b = z @ z.T
c = z.reshape(2, 8)
d = x.unsqueeze(0) # (3,) → (1, 3)
e = x.squeeze()
f = torch.cat([x, x])
g = torch.stack([x, x], dim=0)

JAX

import jax, jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])
y = jnp.zeros((3, 3))
z = jax.random.normal(jax.random.key(0), (4, 4))

a = x + 1
b = z @ z.T
c = z.reshape(2, 8)
d = x[None, :] # = jnp.expand_dims(x, 0)
e = jnp.squeeze(x)
f = jnp.concatenate([x, x])
g = jnp.stack([x, x], axis=0)

主な対応表

PyTorchJAX
torch.tensorjnp.array
torch.zeros / ones / randnjnp.zeros / ones、乱数は jax.random.*
x.unsqueeze(dim)x[None] または jnp.expand_dims(x, dim)
x.squeeze()jnp.squeeze(x)
torch.catjnp.concatenate
torch.stackjnp.stack
x.view(...)x.reshape(...)
x.permute(...)jnp.transpose(x, axes=...)

14.2 デバイス転送

PyTorch

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(3).to(device)
y = x.cpu().numpy()

JAX

import jax, jax.numpy as jnp, numpy as np

# JAX 配列はデフォルトで利用可能なアクセラレータに置かれる
x = jax.random.normal(jax.random.key(0), (3,))
print(x.devices()) # 例: {CudaDevice(id=0)}

# 明示的にデバイス指定
cpu = jax.devices("cpu")[0]
x_on_cpu = jax.device_put(x, cpu)

# NumPy に変換
y = np.asarray(x)

14.3 自動微分

PyTorch

import torch
x = torch.tensor(2.0, requires_grad=True)
y = x ** 3 + 3 * x
y.backward()
print(x.grad) # 3x² + 3 = 15

JAX

import jax
f = lambda x: x ** 3 + 3 * x
print(jax.grad(f)(2.0)) # 15.0

損失 + 勾配を同時に

# PyTorch
loss = loss_fn(model(x), y)
loss.backward()
grad = [p.grad for p in model.parameters()]

# JAX
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)

14.4 線形層(MLP)

PyTorch

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
def __init__(self, din, dhid, dout):
super().__init__()
self.l1 = nn.Linear(din, dhid)
self.l2 = nn.Linear(dhid, dout)
def forward(self, x):
return self.l2(F.relu(self.l1(x)))

model = MLP(10, 32, 1)
y = model(torch.randn(4, 10))

JAX + Flax (NNX)

import jax
from flax import nnx

class MLP(nnx.Module):
def __init__(self, din, dhid, dout, *, rngs: nnx.Rngs):
self.l1 = nnx.Linear(din, dhid, rngs=rngs)
self.l2 = nnx.Linear(dhid, dout, rngs=rngs)
def __call__(self, x):
return self.l2(nnx.relu(self.l1(x)))

model = MLP(10, 32, 1, rngs=nnx.Rngs(0))
y = model(jax.random.normal(jax.random.key(1), (4, 10)))

JAX 本体だけで書く(純粋関数版)

import jax, jax.numpy as jnp

def init_mlp(key, din, dhid, dout):
k1, k2 = jax.random.split(key)
return {
"W1": jax.random.normal(k1, (din, dhid)) * 0.1,
"b1": jnp.zeros(dhid),
"W2": jax.random.normal(k2, (dhid, dout)) * 0.1,
"b2": jnp.zeros(dout),
}

def mlp(params, x):
h = jnp.maximum(0, x @ params["W1"] + params["b1"]) # ReLU
return h @ params["W2"] + params["b2"]

14.5 学習ループ全体

PyTorch

import torch, torch.nn as nn

model = MLP(10, 32, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

for x, y in dataloader:
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

JAX + Flax NNX + Optax

import jax, jax.numpy as jnp
import optax
from flax import nnx

model = MLP(10, 32, 1, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))

@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
return jnp.mean((model(x) - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss

for x, y in dataloader:
loss = train_step(model, optimizer, x, y)

純粋関数スタイル(Flax を使わない場合)

import jax, jax.numpy as jnp, optax

def loss_fn(params, x, y):
return jnp.mean((mlp(params, x) - y) ** 2)

opt = optax.adam(1e-3)
params = init_mlp(jax.random.key(0), 10, 32, 1)
opt_state = opt.init(params)

@jax.jit
def step(params, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = opt.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss

for x, y in dataloader:
params, opt_state, loss = step(params, opt_state, x, y)

14.6 損失関数

PyTorchJAX / Optax
nn.MSELoss()jnp.mean((pred - y) ** 2) または optax.l2_loss(pred, y).mean()
nn.L1Loss()jnp.mean(jnp.abs(pred - y))
nn.CrossEntropyLoss()optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
nn.BCEWithLogitsLoss()optax.sigmoid_binary_cross_entropy(logits, labels).mean()
nn.NLLLoss()-jnp.take_along_axis(log_probs, labels[..., None], axis=-1).mean()

例:

import optax
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

14.7 オプティマイザ

PyTorchOptax
optim.SGD(lr=0.1, momentum=0.9)optax.sgd(0.1, momentum=0.9)
optim.Adam(lr=1e-3)optax.adam(1e-3)
optim.AdamW(lr=1e-3, weight_decay=0.01)optax.adamw(1e-3, weight_decay=0.01)
optim.RMSprop(lr=1e-3)optax.rmsprop(1e-3)

勾配クリッピング + AdamW のチェイン

import optax
opt = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(1e-3, weight_decay=0.01),
)

PyTorch の「optimizerclip_grad_norm_ を別々に呼ぶ」のと違って、Optax では 複数の変換をパイプとして繋ぐ スタイルです。


14.8 学習率スケジューラ

PyTorch

import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

for step in range(1000):
train_step(...)
scheduler.step()

Optax(学習率自体がスケジュール関数)

import optax
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=1e-3,
warmup_steps=100,
decay_steps=1000,
)
opt = optax.adam(learning_rate=schedule)

PyTorch は「optimizerlr を後から書き換える」感じですが、Optax では 「step 数 → 学習率」という関数learning_rate 引数に渡すだけ、と発想がスッキリしています。


14.9 重みの保存・読み込み

PyTorch

torch.save(model.state_dict(), "model.pt")
model.load_state_dict(torch.load("model.pt"))

JAX(Orbax を推奨)

import orbax.checkpoint as ocp
ckptr = ocp.StandardCheckpointer()

# 保存
ckptr.save("/tmp/ckpt", args=ocp.args.StandardSave(params))

# 復元
restored = ckptr.restore("/tmp/ckpt", args=ocp.args.StandardRestore(params))

簡易的には pickle でも済みますが、本番では Orbax(公式の Checkpoint ライブラリ)が標準です。


14.10 Dropout / BatchNorm の扱い

PyTorch では model.train() / model.eval() で内部状態を切り替えますが、JAX では 「乱数キーを渡す」「学習時かどうかをフラグで渡す」のように、状態を明示的に扱う のが流儀です。

PyTorch

class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
self.drop = nn.Dropout(0.5)
def forward(self, x):
return self.drop(self.fc(x))

model.train() # ドロップアウト ON
y = model(x)
model.eval() # ドロップアウト OFF
y = model(x)

JAX + Flax NNX

from flax import nnx
class Net(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.fc = nnx.Linear(10, 10, rngs=rngs)
self.drop = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
return self.drop(self.fc(x))

model = Net(rngs=nnx.Rngs(0))
model.train() # 学習モード
y = model(x)
model.eval() # 推論モード
y = model(x)

BatchNorm(移動平均統計)

  • PyTorch:running_mean / running_var がモジュールに自動で蓄積される
  • Flax NNX:nnx.BatchNorm を使うと同様に内部状態を持つ
  • 純関数型(Flax Linen など)の場合は 「学習用統計」を別の PyTree として返り値で受け取る スタイル

14.11 CNN(畳み込みネット)

PyTorch

import torch.nn as nn
import torch.nn.functional as F

class Conv(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.c2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2)
self.fc = nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
x = self.pool(F.relu(self.c1(x)))
x = self.pool(F.relu(self.c2(x)))
return self.fc(x.flatten(1))

JAX + Flax NNX

import jax.numpy as jnp
from flax import nnx

class Conv(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
# Flax は NHWC(channels last)が標準
self.c1 = nnx.Conv(3, 32, kernel_size=(3, 3), padding="SAME", rngs=rngs)
self.c2 = nnx.Conv(32, 64, kernel_size=(3, 3), padding="SAME", rngs=rngs)
self.fc = nnx.Linear(64 * 8 * 8, 10, rngs=rngs)
def __call__(self, x):
x = nnx.max_pool(nnx.relu(self.c1(x)), window_shape=(2, 2), strides=(2, 2))
x = nnx.max_pool(nnx.relu(self.c2(x)), window_shape=(2, 2), strides=(2, 2))
return self.fc(x.reshape(x.shape[0], -1))

重要な違い:チャンネル次元の位置

画像入力の形
PyTorch(N, C, H, W)(NCHW)
JAX / Flax(N, H, W, C)(NHWC)

PyTorch から移植するときは x = x.transpose(0, 2, 3, 1) で入れ替える のを忘れずに。


14.12 RNN / シーケンスの逐次処理

PyTorch の nn.LSTM のように「内部で時刻方向にループを回す」処理は、JAX では jax.lax.scan がそのまま当てはまります。

PyTorch(自分で 1 ステップずつ回す例)

def step(h, x):
return torch.tanh(W_h @ h + W_x @ x)

h = torch.zeros(hidden_dim)
for t in range(T):
h = step(h, xs[t])

JAX(scan でループを XLA に渡す)

def step(h, x_t):
new_h = jnp.tanh(W_h @ h + W_x @ x_t)
return new_h, new_h # (次状態, 出力)

h0 = jnp.zeros(hidden_dim)
h_final, hs = jax.lax.scan(step, h0, xs) # xs: (T, input_dim)

scan を使うとコンパイル時間も実行時間も劇的に縮みます。


14.13 アテンション(Transformer の核)

PyTorch

import torch.nn.functional as F
out = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)

JAX

import jax.nn as jnn
# JAX には scaled dot-product attention が用意されている
# Q/K/V はおおむね (batch, seq, heads, dim) 形式を想定する
# mask は bool で、True が「参加する」位置。加算マスクは bias= に渡す。
out = jnn.dot_product_attention(Q, K, V, mask=mask)

自前で書きたい場合は次のように:

import jax.numpy as jnp, jax.nn as jnn

def attention(Q, K, V, mask=None):
d = Q.shape[-1]
logits = (Q @ K.swapaxes(-2, -1)) / jnp.sqrt(d)
if mask is not None:
logits = jnp.where(mask, logits, -jnp.inf)
weights = jnn.softmax(logits, axis=-1)
return weights @ V

14.14 マスクとパディング

PyTorch

mask = (tokens != pad_id) # (B, T)
loss = F.cross_entropy(logits, labels, reduction="none")
loss = (loss * mask).sum() / mask.sum()

JAX

import jax.numpy as jnp, optax

mask = (tokens != pad_id).astype(jnp.float32)
per_token = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
loss = (per_token * mask).sum() / mask.sum()

14.15 勾配クリッピング・重み減衰

PyTorch

loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

JAX + Optax

opt = optax.chain(
optax.clip_by_global_norm(1.0), # 勾配クリッピング
optax.adamw(1e-3, weight_decay=0.01),
)

オプティマイザの中に「クリップしてから AdamW」というパイプを書けるのが Optax の良さです。


14.16 DataLoader

JAX には公式の DataLoader 相当はありません。「ミニバッチを返す Python ジェネレータ」を自分で書く か、PyTorch / Grain / tf.data を使います。

PyTorch の DataLoader を JAX 学習で流用する

import torch
from torch.utils.data import DataLoader
import numpy as np
import jax.numpy as jnp

loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4,
collate_fn=lambda batch: tuple(np.stack(x) for x in zip(*batch)))

for x_np, y_np in loader:
x, y = jnp.asarray(x_np), jnp.asarray(y_np)
params, opt_state, loss = step(params, opt_state, x, y)

JAX 自体は NumPy 配列を自然に受け取れます。PyTorch の CPU Tensor を使う場合は、DataLoader 側で NumPy 配列にしてから jnp.asarray するのが分かりやすいです。GPU Tensor からコピーを避けたい場合は DLPack 連携(jax.dlpack / torch.utils.dlpack)を検討します。


14.17 マルチ GPU / 分散学習

PyTorch DDP

import torch.distributed as dist
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])

JAX:データ並列を sharding で表現

import jax, jax.numpy as jnp, numpy as np
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding

devices = np.array(jax.devices())
mesh = Mesh(devices, ("data",))

# バッチ次元をデバイス分割。例: x_global は (global_batch, features)
batch_x = NamedSharding(mesh, P("data", None))
batch_y = NamedSharding(mesh, P("data"))
replicated = NamedSharding(mesh, P()) # パラメータなどを全デバイスに複製

x = jax.device_put(x_global, batch_x)
y = jax.device_put(y_global, batch_y)
params = jax.device_put(params, replicated)

@jax.jit
def step(params, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
# params / grads は PyTree なので、葉っぱごとに更新する
params = jax.tree_util.tree_map(lambda p, g: p - 1e-3 * g, params, grads)
return params, loss

jit + sharding を使えば、PyTorch DDP のような 明示的な通信コード を手で書かなくても、JAX/XLA が必要な集約や通信を含む実行計画を作ります。ただし、巨大モデルの FSDP/ZeRO 相当まで踏み込む場合は、パラメータ sharding、optimizer state sharding、shard_map などを設計する必要があります。


14.18 学習・評価モードの切り替え

PyTorch

model.train()
... # 学習
model.eval()
with torch.no_grad():
... # 評価

JAX

  • Dropout / BatchNorm がない場合は、そもそも切り替える必要なし
  • ある場合は、Flax NNX なら model.train() / model.eval()、Linen なら deterministic=True/False を関数引数で渡す
  • 「勾配を計算しない」は with torch.no_grad(): のようなコンテキストは不要。そもそも jax.grad を呼ばなければ勾配は計算されません
# 評価時は jit した forward だけを呼ぶ
@jax.jit
def predict(params, x):
return mlp(params, x)

14.19 一般的なつまずきポイント

  1. チャンネル次元の順序:PyTorch (NCHW) ↔ Flax (NHWC) を入れ替え忘れる
  2. requires_grad がない:JAX では「微分したい関数」を jax.grad に渡すだけ。フラグ管理は不要
  3. optimizer.zero_grad() がない:JAX には「勾配を持つテンソル」が無いので、ゼロ化も不要
  4. model.parameters() がない:パラメータは PyTree(dict / dataclass) として明示的に持つ
  5. 状態の暗黙更新がない:BatchNorm の running_mean などは「引数で渡して、返り値で受け取る」スタイル
  6. with torch.no_grad(): がない:「grad を呼ばない限り勾配は計算されない」
  7. データセットは自前のイテレータでも OK:PyTorch の DataLoader を流用するのも実用的

14.20 よくある PyTorch 小技の対応表

PyTorch の日常的な書き方を、JAX ではどう考えるかをまとめます。

PyTorchJAXコメント
x.detach()jax.lax.stop_gradient(x)勾配をそこで止めたい場合だけ使う
x.item()float(x) / int(x)jit 内では Python 値に戻せないので注意
x.clone()基本不要JAX 配列は immutable。どうしてもコピーなら jnp.array(x, copy=True)
x.contiguous()基本不要JAX/XLA がレイアウトを管理する
x.requires_grad_(False)なしgrad の対象にしなければ勾配は取られない
torch.no_grad()なしjax.grad を呼ばない限り勾配計算はしない
torch.where(mask, a, b)jnp.where(mask, a, b)ほぼそのまま
x.masked_fill(mask, v)jnp.where(mask, v, x)mask が True のところを置換
torch.gather(x, dim, index)jnp.take_along_axis(x, index, axis=dim)index の shape に注意
scatter_add_x = x.at[idx].add(v)JAX は戻り値を受け取る
torch.nn.functional.one_hot(y, C)jax.nn.one_hot(y, C)dtype は必要に応じて指定
torch.einsum(eq, ...)jnp.einsum(eq, ...)ほぼそのまま
torch.vmap(f)jax.vmap(f)JAX では最重要 API のひとつ
torch.compile(f)jax.jit(f)JAX では「最初から jit 前提」で設計することが多い
torch.utils.checkpointjax.checkpoint / jax.rematメモリ節約の再計算
AMP / autocast明示的 dtype / Flax の dtype 設計bfloat16/float32 を層や配列ごとに設計する

detach() の置き換え例

PyTorch では、テンソルを計算グラフから切り離すために detach() をよく使います。

# PyTorch
z = encoder(x).detach()
loss = decoder_loss(decoder(z), y)

JAX では、勾配を止めたい場所に jax.lax.stop_gradient を置きます。

# JAX
import jax

z = jax.lax.stop_gradient(encoder(params_enc, x))
loss = decoder_loss(params_dec, z, y)

gather / scatter の置き換え例

import jax.numpy as jnp

# JAX: gather
picked = jnp.take_along_axis(x, indices, axis=1)

# JAX: scatter add
out = jnp.zeros((10,)).at[indices].add(values)

このあたりは「PyTorch では tensor method / in-place 操作、JAX では jnp.* 関数や .at[...] 更新」と覚えると移行しやすいです。


まとめ:PyTorch 移行の指針

観点PyTorchJAX
モデルnn.Module クラスFlax NNX の nnx.Module または純粋関数 + PyTree
パラメータself.weight = nn.ParameterPyTree(dict / dataclass)
微分loss.backward()jax.value_and_grad(loss_fn)(params, ...)
最適化optimizer.step()optax.update + apply_updates
デバイス.to("cuda")デフォルトでアクセラレータに置かれる
並列化DDP / FSDPjit + jax.sharding
状態モジュール内部に持つ引数・返り値で明示的にやり取り

「状態を関数の外に出し、関数は純粋に保つ」 という発想に慣れれば、PyTorch の知識のほとんどはそのまま活かせます。

➡️ 第15章 移行ケース集:応用と複合パターン