第9章 PyTorch との比較
すでに PyTorch を使っている方や、これから「どっちを学ぶべきか?」と迷っている方のために、両者の違いをまとめます。
9.1 哲学の違い
| PyTorch | JAX | |
|---|---|---|
| 設計思想 | オブジェクト指向 | 関数型 |
| モデル | nn.Module クラスに状態と処理を持たせる | 純粋関数 + パラメータ(PyTree) |
| 微分 | テンソルに .grad 属性、loss.backward() | jax.grad(f) で勾配関数を作る |
| 計算グラフ | 動的(毎回構築) | 動的だが jit で静的化できる |
| 乱数 | グローバル状態を使う場面が多い | 明示的な PRNG キー(jax.random.key()) |
| 並列化 | DDP・FSDP など | pmap / shard_map / sharding |
ひとことで言うと、PyTorch は「オブジェクトに状態を持たせて自然に書く」、JAX は「データと関数を分けて、関数を変換していく」という違いです。
9.2 同じ処理の比較
線形回帰の 1 ステップ
PyTorch
import torch
import torch.nn as nn
model = nn.Linear(3, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
JAX(+Optax)
import jax
import jax.numpy as jnp
import optax
def model(params, x):
return x @ params["W"] + params["b"]
def loss_fn(params, x, y):
pred = model(params, x)
return jnp.mean((pred - y) ** 2)
opt = optax.sgd(0.01)
opt_state = opt.init(params)
@jax.jit
def step(params, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
PyTorch は「短く書ける」、JAX は「明示的だが純粋」という違いがよく出ています。JAX のコードは少し長く見えますが、「何が入力で、何が出力なのか」がはっきりしているのも特徴です。
9.3 「テンソル」と「配列」
PyTorch では torch.Tensor、JAX では jax.Array(または jnp.ndarray)と呼ばれます。中身は近いですが、
- PyTorch:
.grad属性を持ち、.backward()で勾配が溜まる、deviceを持つ - JAX:
.gradを持たない純粋な配列データ、変更不可、配置はdevice/shardingとして扱う
という違いがあります。
9.4 「モデル」の作り方
| PyTorch | JAX | |
|---|---|---|
| 層の定義 | nn.Module を継承 | Flax / Equinox / Haiku などのライブラリで dataclass / Module を定義 |
| パラメータの保持 | self.weight = nn.Parameter(...) | PyTree(dict / dataclass)として外部で保持 |
| 順伝播 | def forward(self, x): ... | def __call__(self, x): ... または純粋関数 |
JAX では「モデルの定義(構造)」と「パラメータの値」を分けて扱うのが特徴です。これにより、
- 同じモデルに異なるパラメータを渡せる
- パラメータをそのまま
vmapでバッチ化できる(=アンサンブル学習がとても簡単) - 関数型なので JIT / grad / vmap と相性が抜群
というメリットがあります。
9.5 移行のコツ
PyTorch から JAX へ移行(あるいは併用)を考える場合は、次の順で慣れていくとスムーズです。
- まず NumPy 部分 を
jnpに置き換える - 状態を持たない 純粋関数 にリファクタする
- パラメータを PyTree にまとめる
- 学習ループは
value_and_grad+ Optax で書く - 速くしたいところに
jitを被せる - バッチ化したいところに
vmapを被せる
9.6 JAX が向くケース / PyTorch が向くケース
- JAX が強い:TPU を活用した学習、巨大なモデルの並列化、研究のプロトタイピング、関数型な書き方を好む人
- PyTorch が強い:エコシステムの広さ、本番デプロイ(TorchScript / ExecuTorch)、初学者の入りやすさ
両者は「思想が違う兄弟」のような関係で、どちらが優れているという話ではありません。目的やチームの好みに応じて選ぶ、あるいは両方を使いこなせると強い、というのが現実的なところです。