JAX 日本語チュートリアル
このドキュメントは、高性能な数値計算と機械学習のための Python ライブラリ JAX を、Python にあまり詳しくない方でも理解できるように、日本語で丁寧に解説するチュートリアルです。
JAX は「NumPy ライクな API で書ける」「自動微分ができる」「JIT コンパイルで高速化できる」「CPU/GPU/TPU で同じコードが動く」「ベクトル化・分散並列化を関数変換として扱える」という特徴を持つ、ちょっと風変わりで強力なライブラリです。Google の研究部門や DeepMind、その他多くの研究機関が、大規模な機械学習や科学計算の現場で JAX を使っています。
本チュートリアルは、JAX の 公式ドキュメント、JAX の GitHub リポジトリ、Google Cloud Blog の PyTorch 開発者向け JAX ガイド(日本語)、Zenn の Turing Motors による JAX/Flax 入門記事 を参考に、2026 年時点の最新情報に合わせて整理しました。
このチュートリアルの想定読者
- Python の基本文法(変数・関数・リスト・
for文)はなんとなく分かる - NumPy を少し触ったことがある、または「なんか聞いたことがある」レベル
- 機械学習・ディープラーニングに興味がある
- PyTorch や TensorFlow の経験は なくても OK
「Python は学校で少しやった」「コードはまだ書き慣れていない」という方でも、できるだけ専門用語に頼らず、具体例を交えて読み進められるように書いています。読み始める前に、まず 用語集 に軽く目を通すと、本編がぐっと分かりやすくなります。
全体構成
本チュートリアルは 基礎編(第1〜12章) と 移行ケース集(第13〜15章) の 2 部構成です。
- 基礎編:JAX という言語そのものに入門するチュートリアル。順番に読み進めると JAX で書ける状態になります。
- 移行ケース集:すでに NumPy や PyTorch で書かれたコードを JAX に書き換えるためのレシピ集。「左に元のコード、右に JAX 版」の対比形式で、目的別にすぐ引けます。基礎編をひととおり読んだ後に辞書のように使うのもおすすめです。
目次
📘 基礎編:JAX の入門
- 用語集(最初に読むと便利)
- JAX とは何か
- インストールと環境構築
jax.numpyの基本- JIT コンパイル(
jit)で高速化する - 自動微分(
grad) - ベクトル化(
vmap)と並列化(pmap/shard_map) - PyTree という考え方
- 乱数の扱い方(PRNG キー)
- PyTorch との比較
- 実践チュートリアル:線形回帰と簡単なニューラルネット
- よくある落とし穴とデバッグ
- JAX エコシステム(Flax / Optax / Equinox など)
🔁 移行ケース集:既存コードを JAX に書き換える
- NumPy → JAX:配列・乱数・ループ・線形代数・FFT・SciPy など、NumPy の典型パターンを JAX で書く 15 ケース
- PyTorch → JAX:テンソル操作・モデル定義・学習ループ・損失/最適化・CNN/RNN/Attention・分散学習など、PyTorch の典型パターンを JAX(+ Flax / Optax)で書く 20 ケース
- 応用と複合パターン:アンサンブル学習・メタ学習・ベイズ推論・強化学習・微分可能シミュレーション・K-Means・ニュートン法など、「JAX らしい書き方」が活きる 12 ケース
🧬 進化計算(Evolutionary Computation)
- 進化計算入門と基本アルゴリズム:ランダム探索 / (1+1)-ES / GA / DE / PSO / OpenAI-ES / CMA-ES / MAP-Elites / NSGA-II
- ニューロエボリューションと進化的モデルマージ:MLP/方策ネットの ES 学習、CartPole の RL、evosax の使い方、Sakana AI の Evolutionary Model Merge、勾配 × 進化のハイブリッド
読み進め方
- はじめての方:用語集 → 第1章 → 第2章 …と 順番に 読むのがおすすめです。第10章で実際に手を動かし、第11章で「ハマりやすいポイント」を確認するところまで到達すると、JAX の世界がぐっと身近になります。その後、必要に応じて第13〜15章のケース集を辞書代わりに参照してください。
- PyTorch ユーザの方:先に第9章「PyTorch との比較」をざっと眺め、第3〜8章で JAX 独自の流儀を抑えてから、第14章「PyTorch → JAX 移行ケース集」 に進むのが最短ルートです。
- NumPy / SciPy ユーザの方:第1章 → 第3章 → 第4章 → 第5章 →(必要なら第8章)→ 第13章「NumPy → JAX 移行ケース集」 の順がスムーズです。
- すでに JAX を触ったことがある方:気になる章だけ拾い読みしても大丈夫です。各章は独立して読めるように書いてあります。
目的別クイックリンク(移行ケース集の早見表)
| やりたいこと | 行き先 |
|---|---|
np.xxx の JAX 版を知りたい | 13.1〜13.14 と 13.15 小技対応表 |
x[i] = v を JAX でどう書くか | 13.2 in-place 更新 |
for ループを高速化したい | 13.5 累積計算 → scan |
np.random.seed から JAX 乱数へ | 13.6 乱数生成、第8章 |
数値微分を jax.grad に置き換えたい | 13.10 自動微分 |
| PyTorch の学習ループを移植したい | 14.5 学習ループ全体 |
nn.Module を JAX でどう書くか | 14.4 線形層(MLP)、14.11 CNN |
loss.backward() の置き換え | 14.3 自動微分 |
| Adam / AdamW / Scheduler の対応 | 14.7・14.8 |
| Dropout / BatchNorm の扱い | 14.10 |
| RNN / Transformer の移植 | 14.12 RNN、14.13 Attention |
| DDP / FSDP 相当の分散学習 | 14.17 マルチ GPU / 分散学習 |
detach / gather / scatter などの対応 | 14.20 小技対応表 |
アンサンブルを vmap で並列化 | 15.2 アンサンブル学習 |
| ハイパラ探索を並列で回す | 15.3 ハイパラ並列探索 |
| MAML 風メタ学習 | 15.4 メタ学習 |
| 強化学習の並列ロールアウト | 15.6 強化学習 |
| 微分可能シミュレーション | 15.7 物理シミュレーション |
| 進化計算を JAX で書きたい | 第16章 進化計算入門 |
| GA / ES / CMA-ES の実装例 | 16.5 GA / 16.8 OpenAI-ES / 16.9 CMA-ES |
| ニューラルネットを進化させたい | 第17章 ニューロエボリューション |
| 強化学習を進化計算で(CartPole) | 17.4 RL タスクへの応用 |
| 進化的モデルマージ(Sakana AI) | 17.6 進化的モデルマージ |
参考にした主な資料
- JAX 公式ドキュメント — インストール、Quickstart、Random numbers、sharding など最新の一次情報
- JAX GitHub リポジトリ — README と issue から最新の状況を確認
- Google Cloud Blog: PyTorch デベロッパー向け JAX 基礎ガイド(日本語) — PyTorch ユーザ視点での違いと類似点
- Zenn: Turing Motors「今こそはじめる JAX/Flax 入門 Part 1」 — 国内事例を含む日本語の入門解説
- evosax / EvoJAX / QDax — 進化計算系の JAX ライブラリ
- Sakana AI: Evolutionary Optimization of Model Merging Recipes — 進化的モデルマージの論文と実装リポジトリ
💡 各章の末尾には「次の章へのリンク」が付いています。気軽にコードをコピーして、手元で試しながら読み進めてみてください。動かして初めて分かることがたくさんあります。