Skip to main content

第11章 よくある落とし穴とデバッグ

JAX は強力ですが、独特な部分でつまずきやすいライブラリでもあります。本章では「初学者がよくハマるポイント」と、その対処法をまとめて紹介します。困ったときは、まずこの章を見返してみてください。

11.1 「配列を書き換えたい!」エラー

x = jnp.array([1, 2, 3])
x[0] = 99 # TypeError: ... object does not support item assignment

at[].set() を使いましょう(第3章を参照)。

x = x.at[0].set(99)

11.2 「if が使えない!」エラー

@jax.jit
def f(x):
if x > 0: # TracerBoolConversionError(または類似の tracer 関連エラー)
return x
else:
return -x

jnp.wherejax.lax.cond を使います。

@jax.jit
def f(x):
return jnp.where(x > 0, x, -x)

11.3 jit 内の print が 1 回しか効かない

これは「バグ」ではなく 仕様 です。jit の内側はトレース時に 1 回だけ Python コードが走るので、print も 1 回だけしか呼ばれません。

jax.debug.print を使えば、実行時のテンソル値を確認できます。

@jax.jit
def f(x):
jax.debug.print("x = {}", x)
return x * 2

11.4 毎回コンパイルが走って遅い

jit 化したのに遅い…」というときは、呼び出すたびに引数の shape が変わっていないか を疑いましょう。

@jax.jit
def f(x):
return x.sum()

for n in [10, 11, 12]: # 毎回違う shape → 毎回コンパイル!
f(jnp.zeros((n,)))

→ shape を揃える(パディングする)か、static_argnums を活用しましょう。

11.5 「乱数が毎回同じ!」問題

key = jax.random.key(0)
for _ in range(3):
print(jax.random.normal(key)) # 3 回とも同じ値!

split を忘れずに。

key = jax.random.key(0)
for _ in range(3):
key, subkey = jax.random.split(key)
print(jax.random.normal(subkey))

11.6 「非同期実行」で時間計測が変

JAX は計算指示だけ出してすぐ戻ってくる 非同期実行 が基本です。time.time() で素朴に測ると、「計算が終わる前の時刻」になってしまうことがあります。

.block_until_ready() で結果を待ちましょう。

import time
t0 = time.time()
y = f(x).block_until_ready()
print(time.time() - t0)

11.7 GPU メモリが解放されない

JAX は GPU 使用時、最初の JAX 操作が実行されたタイミングで、デフォルトでは GPU メモリ全体の 75% を事前確保 します。複数のプロセスを動かしたいときや、メモリ使用量を細かく制御したいときは、

export XLA_PYTHON_CLIENT_PREALLOCATE=false
# あるいは
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.50

などの環境変数で制御できます。

11.8 naninf の追跡

学習中に loss が nan になったら、jax.config.update("jax_debug_nans", True) を有効にすると、nan が出た瞬間にエラーで止まってくれます。

import jax
jax.config.update("jax_debug_nans", True)

11.9 jit の中で Python の副作用を期待しない

counter = 0

@jax.jit
def f(x):
global counter
counter += 1 # ← 1 回しかカウントされない!
return x * 2

→ 副作用は jit で扱いましょう。状態は引数として渡して、新しい値を返すスタイルにします。JAX の世界では「状態を関数の引数と返り値でやり取りする」のが基本です。

11.10 デバッグの心得

  1. まず jit を外して、通常の Python として動かしてみる
  2. 配列の値を jax.debug.print で覗く
  3. shape のミスは print(x.shape) で確認
  4. NaN 追跡は jax_debug_nans を ON にする
  5. それでも分からなければ、from jax import config; config.update("jax_disable_jit", True) でグローバルに JIT を OFF にして検証する

「困ったらまず jit を外す」——これを覚えておくと、デバッグ中の心の余裕が変わります。

➡️ 第12章 JAX エコシステム(Flax / Optax / Equinox など)