Skip to main content

第4章 JIT コンパイル(jit)で高速化する

jit(Just-In-Time コンパイル)は、JAX を一気に速くしてくれる、いわば「魔法の関数」です。本章では、その仕組み、使い方、そしてつまずきやすい注意点を順に見ていきます。

4.1 そもそも JIT とは

通常の Python コードは「1 行ずつ読みながら実行する」スタイル(インタプリタ実行)です。これは柔軟ですが、その分どうしても遅くなりがちです。JAX は XLA(Accelerated Linear Algebra) という Google 製のコンパイラに関数を渡すことで、

  1. 関数の計算の流れを 一度だけ追跡(トレース) して計算グラフを作り、
  2. それを 最適化 し、
  3. 機械語にコンパイルし、
  4. 以降は コンパイル済みのコードで超高速に実行

ということを実現しています。これが jit の正体です。

4.2 使い方は超シンプル

import jax
import jax.numpy as jnp

def slow_f(x):
for _ in range(10):
x = jnp.sin(x) + jnp.cos(x)
return x

fast_f = jax.jit(slow_f) # ← これだけ!

x = jnp.arange(1_000_000, dtype=jnp.float32)
%timeit slow_f(x).block_until_ready() # 例: 12 ms
%timeit fast_f(x).block_until_ready() # 例: 1 ms

jax.jit(関数) を呼ぶだけで、高速化された新しい関数が返ってきます。

💡 block_until_ready() を付ける理由 JAX は「非同期実行」が基本で、計算指示だけ出してすぐに Python に制御が戻ってきます。本当の実行時間を測りたいときは、「結果が出るまで待ってね」と明示する必要があるためです。

4.3 デコレータとしても書ける

@jax.jit を関数の上に書いても同じ意味になります。実際にはこちらの書き方の方が一般的です。

@jax.jit
def f(x, y):
return jnp.dot(x, y) + jnp.sin(x).sum()

4.4 「トレース」とは何か

jit は、最初の呼び出し時に関数の中身を 「実際の値ではなく、抽象的な型情報だけを使って」追いかけ、計算グラフを作ります。この「下見」のことを トレーシング と呼びます。

@jax.jit
def f(x):
print("トレース中! x =", x) # ← トレースが起きたときだけ呼ばれる
return x * 2

print(f(jnp.array(1.0))) # トレース中! x = Traced<...>
print(f(jnp.array(2.0))) # ← 同じ shape/dtype なら、2 回目以降は print されない
print(f(jnp.array(3.0))) # ← print されない

ここから分かるのは、「Python の print などの 副作用 は、トレース時にだけ実行される」ということです。逆に言えば、毎回実行したい処理を jit の中に入れてはいけない、ということでもあります。

4.5 引数の形が変わると再コンパイル

トレースは「入力配列の shape と dtype」を見て行われます。これらが変わるとキャッシュにヒットせず、再コンパイル が走ります。

@jax.jit
def f(x):
return x * 2

f(jnp.zeros((10,))) # コンパイル 1 回目
f(jnp.zeros((10,))) # キャッシュにヒット → 速い
f(jnp.zeros((20,))) # shape が違うので再コンパイル

なので、配列の形が頻繁に変わるコード(パディングしていない可変長データなど)は jit と相性が悪い ことを覚えておきましょう。

4.6 jit の中では Python の if / for が(直接は)使えない

jit の中身は「計算グラフ」として静的に表現されるため、配列の 値に依存した iffor は使えません。

@jax.jit
def f(x):
if x > 0: # ← エラー! x は具体的な値ではない
return x
else:
return -x

この場合は jnp.wherejax.lax.cond を使います。

@jax.jit
def f(x):
return jnp.where(x > 0, x, -x) # abs(x) と同じ意味

ループ系は jax.lax.fori_loopjax.lax.scan を使います。

@jax.jit
def cumsum_like(x):
def body(i, carry):
return carry + x[i]
return jax.lax.fori_loop(0, x.shape[0], body, 0.0)

💡 ただし「配列の shape」に依存した普通の Python if / for は問題なく使えます。shape はコンパイル時にすでに決まっているからです。

4.7 static_argnums で「Python 値」として扱う

「この引数はコンパイル時の定数として扱いたい」というときは static_argnums を使います。

from functools import partial

@partial(jax.jit, static_argnums=(1,))
def repeat(x, n):
return jnp.tile(x, n)

ただし、static 指定した引数の値が変わるたびに再コンパイルが走るので、頻繁に値が変わるものに使うと逆に遅くなります。

4.8 デバッグの心得

jit の中では print(x) でテンソルの中身は見えません。デバッグには jax.debug.print を使いましょう。

@jax.jit
def f(x):
jax.debug.print("x = {}", x)
return x * 2

これなら、実行時の値をちゃんと表示してくれます。

4.9 まとめ

  • jax.jit は、関数を XLA でコンパイルして高速化する
  • 最初の 1 回はコンパイルに時間がかかるが、2 回目以降は爆速
  • shape / dtype が変わると再コンパイル
  • 値に依存する if / forlax.cond / lax.scan / where で書き換える
  • 副作用(print、グローバル変数の書き換えなど)はトレース時にだけ実行される。値を毎回見たいなら jax.debug.print を使う

➡️ 第5章 自動微分(grad