第12章 JAX エコシステム(Flax / Optax / Equinox など)
JAX 本体は意図的にスコープを絞っており、「効率的な配列演算 + 関数変換」を中心に提供しています。実際のディープラーニング開発では、その上に乗っている エコシステム を組み合わせるのが一般的です。本章では代表的なものを紹介します。
12.1 ニューラルネット定義ライブラリ
Flax(NNX)
- 開発元:Google / オープンソースコミュニティ
- 立ち位置:JAX の代表的なニューラルネットワークライブラリ
- 特徴:
- 「NNX」という新 API は オブジェクト指向風 に書けて、PyTorch ユーザにも親しみやすい
- 旧 API(Linen)は純関数型
- Google 内外で広く使われている
from flax import nnx
class MLP(nnx.Module):
def __init__(self, din, dhid, dout, rngs):
self.l1 = nnx.Linear(din, dhid, rngs=rngs)
self.l2 = nnx.Linear(dhid, dout, rngs=rngs)
def __call__(self, x):
return self.l2(nnx.relu(self.l1(x)))
Equinox
- 開発元:Patrick Kidger(個人発の有名 OSS)
- 立ち位置:「すべてが PyTree」というシンプルさ が魅力
- 特徴:
- 普通の Python クラスがそのまま PyTree
- 「JAX のフィロソフィにいちばん忠実」と評する人も多い
- 拡散モデルや科学計算系の研究者に人気
Haiku
- 開発元:DeepMind
- 立ち位置:DeepMind の研究コードで広く使われてきた、関数型ベースのニューラルネットワークライブラリ
- 特徴:
hk.transformによって「状態を持つように見えるモジュール定義」を、JAX が扱いやすい純粋関数 + パラメータに変換する
12.2 最適化アルゴリズム:Optax
- 開発元:DeepMind
- 立ち位置:JAX の標準オプティマイザライブラリ
- 提供:
- SGD / Momentum / Adam / AdamW / RMSProp など定番アルゴリズム
- 学習率スケジューラ(cosine、warmup、exponential など)
- 勾配クリッピング、重み減衰などの「変換」をパイプラインで組み合わせ可能
import optax
opt = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(learning_rate=1e-3, weight_decay=0.01),
)
opt_state = opt.init(params)
updates, opt_state = opt.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
「複数の処理をパイプラインのように繋いでオプティマイザを組み立てる」という発想は、関数型に親しい JAX らしい設計です。
12.3 データセット・パイプライン
- Grain:Google が公開しているデータローディングライブラリ。JAX/Flax の大規模学習で使われることがある
- tf.data:TensorFlow のデータパイプラインを JAX 学習で使うことも多い
- Hugging Face Datasets:テキスト系を中心に、標準的に使われる
12.4 LLM・基盤モデル系
- MaxText(Google):TPU で LLM を学習するためのリファレンス実装
- EasyLM:LLaMA 系などをトレーニング・推論する JAX ベースのフレームワーク
- Levanter(Stanford CRFM):スケーラブルな LLM 学習向けライブラリ
- Penzai(Google DeepMind):ニューラルネットの可視化・操作
12.5 強化学習・物理シミュレーション
- RLax(DeepMind):強化学習用の数学パーツ集
- Brax(Google):物理シミュレータ(剛体力学)
- Gymnax:Gym 風の環境を JAX で実装、
vmapで大量並列ロールアウト - Mctx:MuZero / AlphaZero などのモンテカルロ木探索
12.6 科学計算・微分方程式
- Diffrax:微分方程式ソルバ(Equinox の作者と同じ)
- JAX MD:分子動力学
- NetKet:量子多体問題
12.7 まずはどれを覚える?
迷ったら、
- JAX 本体(
jnp/jit/grad/vmap) - Optax(最適化)
- Flax NNX(モデル定義、入りやすさ重視)または Equinox(シンプルさ重視)
の組み合わせを覚えれば、ディープラーニング開発の大半はカバーできます。
12.8 さらに学ぶには
12.9 技術的に正確に読み続けるための注意
JAX は進化が速いライブラリです。とくに インストール方法、CUDA 対応、並列化 API(pmap / shard_map / sharding)、乱数キー(PRNGKey / key()) は変化しやすいので、実際に使う前には公式ドキュメントの最新版を確認してください。本チュートリアルも、2026 年時点の情報をもとに書いていますが、半年〜1 年経つと細かい部分が変わっている可能性があります。
これで基礎編は終わりです。お疲れさまでした 🎉 続けて、既存コードからの 移行ケース集(第13〜15章) や、進化計算編(第16〜17章) にも進めます。
JAX は最初の学習コストが少しかかりますが、いったん慣れてしまえば「微分・並列化・JIT を当たり前に組み合わせる、新しい数値計算の世界」が広がります。「関数を変換して、変換して、また変換する」という、ちょっと不思議で気持ちのよい感覚を、ぜひ手元で体験してみてください。
➡️ 既存コードからの移行に進む: