第11章 よくある落とし穴とデバッグ
JAX は強力ですが、独特な部分でつまずきやすいライブラリでもあります。本章では「初学者がよくハマるポイント」と、その対処法をまとめて紹介します。困ったときは、まずこの章を見返してみてください。
11.1 「配列を書き換えたい!」エラー
x = jnp.array([1, 2, 3])
x[0] = 99 # TypeError: ... object does not support item assignment
→ at[].set() を使いましょう(第3章を参照)。
x = x.at[0].set(99)
11.2 「if が使えない!」エラー
@jax.jit
def f(x):
if x > 0: # TracerBoolConversionError(または類似の tracer 関連エラー)
return x
else:
return -x
→ jnp.where か jax.lax.cond を使います。
@jax.jit
def f(x):
return jnp.where(x > 0, x, -x)
11.3 jit 内の print が 1 回しか効かない
これは「バグ」ではなく 仕様 です。jit の内側はトレース時に 1 回だけ Python コードが走るので、print も 1 回だけしか呼ばれません。
→ jax.debug.print を使えば、実行時のテンソル値を確認できます。
@jax.jit
def f(x):
jax.debug.print("x = {}", x)
return x * 2
11.4 毎回コンパイルが走って遅い
「jit 化したのに遅い…」というときは、呼び出すたびに引数の shape が変わっていないか を疑いましょう。
@jax.jit
def f(x):
return x.sum()
for n in [10, 11, 12]: # 毎回違う shape → 毎回コンパイル!
f(jnp.zeros((n,)))
→ shape を揃える(パディングする)か、static_argnums を活用しましょう。
11.5 「乱数が毎回同じ!」問題
key = jax.random.key(0)
for _ in range(3):
print(jax.random.normal(key)) # 3 回とも同じ値!
→ split を忘れずに。
key = jax.random.key(0)
for _ in range(3):
key, subkey = jax.random.split(key)
print(jax.random.normal(subkey))
11.6 「非同期実行」で時間計測が変
JAX は計算指示だけ出してすぐ戻ってくる 非同期実行 が基本です。time.time() で素朴に測ると、「計算が終わる前の時刻」になってしまうことがあります。
→ .block_until_ready() で結果を待ちましょう。
import time
t0 = time.time()
y = f(x).block_until_ready()
print(time.time() - t0)
11.7 GPU メモリが解放されない
JAX は GPU 使用時、最初の JAX 操作が実行されたタイミングで、デフォルトでは GPU メモリ全体の 75% を事前確保 します。複数のプロセスを動かしたいときや、メモリ使用量を細かく制御したいときは、
export XLA_PYTHON_CLIENT_PREALLOCATE=false
# あるいは
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.50
などの環境変数で制御できます。
11.8 nan や inf の追跡
学習中に loss が nan になったら、jax.config.update("jax_debug_nans", True) を有効にすると、nan が出た瞬間にエラーで止まってくれます。
import jax
jax.config.update("jax_debug_nans", True)
11.9 jit の中で Python の副作用を期待しない
counter = 0
@jax.jit
def f(x):
global counter
counter += 1 # ← 1 回しかカウントされない!
return x * 2
→ 副作用は jit の 外 で扱いましょう。状態は引数として渡して、新しい値を返すスタイルにします。JAX の世界では「状態を関数の引数と返り値でやり取りする」のが基本です。
11.10 デバッグの心得
- まず
jitを外して、通常の Python として動かしてみる - 配列の値を
jax.debug.printで覗く - shape のミスは
print(x.shape)で確認 - NaN 追跡は
jax_debug_nansを ON にする - それでも分からなければ、
from jax import config; config.update("jax_disable_jit", True)でグローバルに JIT を OFF にして検証する
「困ったらまず jit を外す」——これを覚えておくと、デバッグ中の心の余裕が変わります。