第1章 JAX とは何か
1.1 ひとことで言うと
JAX(ジャックス)は、オープンソースの 数値計算・機械学習向け Python ライブラリ です。書き心地は NumPy にとてもよく似ていますが、その裏側で次のような強力な機能を提供してくれます。
- 自動微分(Autograd):関数の微分(勾配)を、人間が手で計算しなくても、JAX がきれいに求めてくれる
- JIT コンパイル:Python で書いた計算を XLA という Google 製のコンパイラで最適化し、CPU/GPU/TPU 上で高速に実行する
- CPU / GPU / TPU 対応:同じ JAX コードを、開発時は CPU、本番は GPU や TPU、というように複数のバックエンドで動かせる
- ベクトル化・並列化:
forループを書かずに、データ方向の繰り返しや複数デバイスへの分散実行ができる
ざっくり一言でまとめると、
「NumPy のように書けて、自動で微分でき、コンパイルや並列化も“関数を変換するだけ”で組み合わせられる」
そんなライブラリです。
💬 NumPy って何? Python で行列やベクトルを扱うときの定番ライブラリです。「数字をたくさん並べたデータ(配列)を、まとめて計算できる便利な道具」と思っておけば OK です。
1.2 なぜ今 JAX なのか
最近、JAX は以下のような場面でとくによく使われています。
- 大規模な機械学習・研究開発:JIT、シャーディング(後述)、TPU/GPU 実行を組み合わせて、大きなモデルや実験を効率よく回しやすい
- 強化学習:DeepMind 系の RLax / Optax / Haiku など、JAX ベースの研究ライブラリが豊富
- 科学計算・シミュレーション:物理シミュレーション、微分方程式、確率的プログラミングなど、「微分できる数値計算」が必要な分野
PyTorch が広いエコシステムを持っている一方で、「TPU や大量の GPU を活かして思い切り計算したい」「研究コードを純粋な関数として整理したい」「jit・grad・vmap・シャーディングをスマートに組み合わせて書きたい」というニーズに対して、JAX は有力な選択肢になっています。
1.3 JAX の設計思想:関数型プログラミング
PyTorch では「モデル(オブジェクト)の中に重みや状態が入っていて、loss.backward() を呼ぶと内部にこっそり勾配が溜まっていく」というような、いわゆる オブジェクト指向的 な書き方をします。
それに対して JAX は、関数型(functional)プログラミング の考え方を強く取り入れています。
- 関数は 副作用を持たない(同じ入力なら必ず同じ出力)
- データはなるべく 書き換えない(イミュータブル=変更不可)
- 「モデル」も、単なる パラメータの辞書 + 純粋関数 として表現する
最初は少しとっつきにくく感じるかもしれませんが、慣れてしまうと、コードがとてもクリーンになり、自動微分・並列化・JIT との相性も抜群によくなります。「最初の戸惑い → 慣れたら手放せない」というのが JAX のよく言われる特徴です。
1.4 JAX の中核:4 つの関数変換
JAX のすごさは、たった 4 つの関数変換 に集約されていると言っても過言ではありません。
| 関数 | 役割 | ざっくり言うと |
|---|---|---|
jit | JIT コンパイル | 関数を「速くする」 |
grad | 自動微分 | 関数を「微分する」 |
vmap | 自動ベクトル化 | 関数を「バッチ対応にする」 |
pmap | 並列実行 | 関数を「複数デバイスに分散する」(現在は shard_map や jit + sharding も重要) |
これらはすべて、「関数を受け取って、新しい関数を返す」 という、シンプルかつ強力な仕組みです(数学やプログラミングの用語で「高階関数」と呼びます)。
import jax
import jax.numpy as jnp
def f(x):
return jnp.sin(x) ** 2
# 微分版の関数を作る
df = jax.grad(f)
# 高速化版の関数を作る
fast_f = jax.jit(f)
# バッチ対応版の関数を作る
batched_f = jax.vmap(f)
# 組み合わせも自由自在
fast_batched_df = jax.jit(jax.vmap(jax.grad(f)))
この「組み合わせ可能(composable)」という性質こそ、JAX のいちばんの魅力です。「微分してから、バッチ化して、高速化する」を 3 行で書ける わけですから、研究のプロトタイピングが本当に楽になります。
1.5 NumPy / PyTorch / TensorFlow との位置づけ
| ライブラリ | 主な用途 | 特徴 |
|---|---|---|
| NumPy | 数値計算全般 | CPU 中心、自動微分なし、Python の数値計算の業界標準 |
| PyTorch | 深層学習 | 動的グラフ、書きやすい、オブジェクト指向 |
| TensorFlow | 深層学習 | tf.function などのグラフ実行、本番デプロイ周辺の仕組みが豊富 |
| JAX | 数値計算 + 深層学習 | 関数変換、JIT、シャーディング、研究・大規模計算向け |
JAX 単体には「ニューラルネットの層(Linear や Conv2d のようなもの)」や「最適化器(Adam など)」は含まれていません。それらは Flax / Equinox / Optax などの周辺ライブラリが提供しています(第12章で詳しく紹介します)。
1.6 次の章へ
それでは、まずは JAX を実際にインストールして、自分のパソコンで動かしてみましょう。次の章ではインストール方法と動作確認を行います。