第4章 JIT コンパイル(jit)で高速化する
jit(Just-In-Time コンパイル)は、JAX を一気に速くしてくれる、いわば「魔法の関数」です。本章では、その仕組み、使い方、そしてつまずきやすい注意点を順に見ていきます。
4.1 そもそも JIT とは
通常の Python コードは「1 行ずつ読みながら実行する」スタイル(インタプリタ実行)です。これは柔軟ですが、その分どうしても遅くなりがちです。JAX は XLA(Accelerated Linear Algebra) という Google 製のコンパイラに関数を渡すことで、
- 関数の計算の流れを 一度だけ追跡(トレース) して計算グラフを作り、
- それを 最適化 し、
- 機械語にコンパイルし、
- 以降は コンパイル済みのコードで超高速に実行
ということを実現しています。これが jit の正体です。
4.2 使い方は超シンプル
import jax
import jax.numpy as jnp
def slow_f(x):
for _ in range(10):
x = jnp.sin(x) + jnp.cos(x)
return x
fast_f = jax.jit(slow_f) # ← これだけ!
x = jnp.arange(1_000_000, dtype=jnp.float32)
%timeit slow_f(x).block_until_ready() # 例: 12 ms
%timeit fast_f(x).block_until_ready() # 例: 1 ms
jax.jit(関数) を呼ぶだけで、高速化された新しい関数が返ってきます。
💡
block_until_ready()を付ける理由 JAX は「非同期実行」が基本で、計算指示だけ出してすぐに Python に制御が戻ってきます。本当の実行時間を測りたいときは、「結果が出るまで待ってね」と明示する必要があるためです。
4.3 デコレータとしても書ける
@jax.jit を関数の上に書いても同じ意味になります。実際にはこちらの書き方の方が一般的です。
@jax.jit
def f(x, y):
return jnp.dot(x, y) + jnp.sin(x).sum()
4.4 「トレース」とは何か
jit は、最初の呼び出し時に関数の中身を 「実際の値ではなく、抽象的な型情報だけを使って」追いかけ、計算グラフを作ります。この「下見」のことを トレーシング と呼びます。
@jax.jit
def f(x):
print("トレース中! x =", x) # ← トレースが起きたときだけ呼ばれる
return x * 2
print(f(jnp.array(1.0))) # トレース中! x = Traced<...>
print(f(jnp.array(2.0))) # ← 同じ shape/dtype なら、2 回目以降は print されない
print(f(jnp.array(3.0))) # ← print されない
ここから分かるのは、「Python の print などの 副作用 は、トレース時にだけ実行される」ということです。逆に言えば、毎回実行したい処理を jit の中に入れてはいけない、ということでもあります。
4.5 引数の形が変わると再コンパイル
トレースは「入力配列の shape と dtype」を見て行われます。これらが変わるとキャッシュにヒットせず、再コンパイル が走ります。
@jax.jit
def f(x):
return x * 2
f(jnp.zeros((10,))) # コンパイル 1 回目
f(jnp.zeros((10,))) # キャッシュにヒット → 速い
f(jnp.zeros((20,))) # shape が違うので再コンパイル
なので、配列の形が頻繁に変わるコード(パディングしていない可変長データなど)は jit と相性が悪い ことを覚えておきましょう。
4.6 jit の中では Python の if / for が(直接は)使えない
jit の中身は「計算グラフ」として静的に表現されるため、配列の 値に依存した if や for は使えません。
@jax.jit
def f(x):
if x > 0: # ← エラー! x は具体的な値ではない
return x
else:
return -x
この場合は jnp.where や jax.lax.cond を使います。
@jax.jit
def f(x):
return jnp.where(x > 0, x, -x) # abs(x) と同じ意味
ループ系は jax.lax.fori_loop や jax.lax.scan を使います。
@jax.jit
def cumsum_like(x):
def body(i, carry):
return carry + x[i]
return jax.lax.fori_loop(0, x.shape[0], body, 0.0)
💡 ただし「配列の shape」に依存した普通の Python
if/forは問題なく使えます。shape はコンパイル時にすでに決まっているからです。
4.7 static_argnums で「Python 値」として扱う
「この引数はコンパイル時の定数として扱いたい」というときは static_argnums を使います。
from functools import partial
@partial(jax.jit, static_argnums=(1,))
def repeat(x, n):
return jnp.tile(x, n)
ただし、static 指定した引数の値が変わるたびに再コンパイルが走るので、頻繁に値が変わるものに使うと逆に遅くなります。
4.8 デバッグの心得
jit の中では print(x) でテンソルの中身は見えません。デバッグには jax.debug.print を使いましょう。
@jax.jit
def f(x):
jax.debug.print("x = {}", x)
return x * 2
これなら、実行時の値をちゃんと表示してくれます。
4.9 まとめ
jax.jitは、関数を XLA でコンパイルして高速化する- 最初の 1 回はコンパイルに時間がかかるが、2 回目以降は爆速
- shape / dtype が変わると再コンパイル
- 値に依存する
if/forはlax.cond/lax.scan/whereで書き換える - 副作用(
print、グローバル変数の書き換えなど)はトレース時にだけ実行される。値を毎回見たいならjax.debug.printを使う