Skip to main content

第9章 PyTorch との比較

すでに PyTorch を使っている方や、これから「どっちを学ぶべきか?」と迷っている方のために、両者の違いをまとめます。

9.1 哲学の違い

PyTorchJAX
設計思想オブジェクト指向関数型
モデルnn.Module クラスに状態と処理を持たせる純粋関数 + パラメータ(PyTree)
微分テンソルに .grad 属性、loss.backward()jax.grad(f) で勾配関数を作る
計算グラフ動的(毎回構築)動的だが jit で静的化できる
乱数グローバル状態を使う場面が多い明示的な PRNG キー(jax.random.key()
並列化DDP・FSDP などpmap / shard_map / sharding

ひとことで言うと、PyTorch は「オブジェクトに状態を持たせて自然に書く」、JAX は「データと関数を分けて、関数を変換していく」という違いです。

9.2 同じ処理の比較

線形回帰の 1 ステップ

PyTorch

import torch
import torch.nn as nn

model = nn.Linear(3, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()

pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

JAX(+Optax)

import jax
import jax.numpy as jnp
import optax

def model(params, x):
return x @ params["W"] + params["b"]

def loss_fn(params, x, y):
pred = model(params, x)
return jnp.mean((pred - y) ** 2)

opt = optax.sgd(0.01)
opt_state = opt.init(params)

@jax.jit
def step(params, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss

PyTorch は「短く書ける」、JAX は「明示的だが純粋」という違いがよく出ています。JAX のコードは少し長く見えますが、「何が入力で、何が出力なのか」がはっきりしているのも特徴です。

9.3 「テンソル」と「配列」

PyTorch では torch.Tensor、JAX では jax.Array(または jnp.ndarray)と呼ばれます。中身は近いですが、

  • PyTorch: .grad 属性を持ち、.backward() で勾配が溜まる、device を持つ
  • JAX: .grad を持たない純粋な配列データ、変更不可、配置は device / sharding として扱う

という違いがあります。

9.4 「モデル」の作り方

PyTorchJAX
層の定義nn.Module を継承Flax / Equinox / Haiku などのライブラリで dataclass / Module を定義
パラメータの保持self.weight = nn.Parameter(...)PyTree(dict / dataclass)として外部で保持
順伝播def forward(self, x): ...def __call__(self, x): ... または純粋関数

JAX では「モデルの定義(構造)」と「パラメータの値」を分けて扱うのが特徴です。これにより、

  • 同じモデルに異なるパラメータを渡せる
  • パラメータをそのまま vmap でバッチ化できる(=アンサンブル学習がとても簡単)
  • 関数型なので JIT / grad / vmap と相性が抜群

というメリットがあります。

9.5 移行のコツ

PyTorch から JAX へ移行(あるいは併用)を考える場合は、次の順で慣れていくとスムーズです。

  • まず NumPy 部分jnp に置き換える
  • 状態を持たない 純粋関数 にリファクタする
  • パラメータを PyTree にまとめる
  • 学習ループは value_and_grad + Optax で書く
  • 速くしたいところに jit を被せる
  • バッチ化したいところに vmap を被せる

9.6 JAX が向くケース / PyTorch が向くケース

  • JAX が強い:TPU を活用した学習、巨大なモデルの並列化、研究のプロトタイピング、関数型な書き方を好む人
  • PyTorch が強い:エコシステムの広さ、本番デプロイ(TorchScript / ExecuTorch)、初学者の入りやすさ

両者は「思想が違う兄弟」のような関係で、どちらが優れているという話ではありません。目的やチームの好みに応じて選ぶ、あるいは両方を使いこなせると強い、というのが現実的なところです。

➡️ 第10章 実践チュートリアル:線形回帰と簡単なニューラルネット