第13章 移行ケース集:NumPy → JAX
NumPy で書かれた既存のコードを JAX に置き換えるための 実践レシピ集 です。「NumPy ならこう書く / JAX ならこう書く」をひたすら左右に並べて見比べられるようにしてあります。コピペで動くサンプルが中心です。
📚 前提:NumPy の基本(
np.array、np.zeros、ブロードキャストなど)を一度でも書いたことがあれば読めます。JAX 側のインストールや基本概念に不安があれば、先に 第1章〜第3章、第8章 乱数 を眺めるとスムーズです。
💡 多くの場合
import numpy as npをimport jax.numpy as jnpに書き換えるだけで動きます。問題になるのは、書き換え(in-place 操作)・乱数・ループ・データ型 あたりです。本章ではそこを重点的に扱います。
目次
- 13.1 配列の生成・基本演算
- 13.2 配列の書き換え(in-place)
- 13.3 ファンシーインデックスとマスク
- 13.4 条件分岐
if/where - 13.5 累積計算(for ループ)→
scan - 13.6 乱数生成
- 13.7 線形代数
- 13.8 統計・集約
- 13.9 FFT・信号処理
- 13.10 数値微分 → 自動微分
- 13.11 ベクトル化(Python ループ →
vmap) - 13.12 ソート・unique・set 演算
- 13.13 dtype と数値精度
- 13.14 SciPy 的な処理
- 13.15 よくある小技・周辺操作
13.1 配列の生成・基本演算
NumPy
import numpy as np
a = np.array([1.0, 2.0, 3.0])
b = np.zeros((3, 3))
c = np.eye(4)
d = np.arange(10)
e = np.linspace(0, 1, 100)
y = np.sin(a) + np.cos(a) ** 2
z = a @ a.T
JAX
import jax.numpy as jnp
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.zeros((3, 3))
c = jnp.eye(4)
d = jnp.arange(10)
e = jnp.linspace(0, 1, 100)
y = jnp.sin(a) + jnp.cos(a) ** 2
z = a @ a.T
ポイント
- ほぼ
np→jnpの置き換えだけで OK - ただし デフォルトの浮動小数点数は
float32(NumPy はfloat64)
13.2 配列の書き換え(in-place)
NumPy で当たり前に書いていた「要素代入」は JAX では使えません。at[...].set/add/... に書き換えます。
NumPy
import numpy as np
x = np.zeros(5)
x[0] = 1.0
x[2:4] += 5.0
x[x < 0] = 0.0 # 負の値を 0 に
JAX
import jax.numpy as jnp
x = jnp.zeros(5)
x = x.at[0].set(1.0)
x = x.at[2:4].add(5.0)
x = jnp.where(x < 0, 0.0, x) # 条件付き代入は where が便利
主な対応表
| NumPy | JAX |
|---|---|
x[i] = v | x = x.at[i].set(v) |
x[i] += v | x = x.at[i].add(v) |
x[i] *= v | x = x.at[i].mul(v) |
x[i] = min/max(x[i], v) | x = x.at[i].min(v) / .max(v) |
x[mask] = v | x = jnp.where(mask, v, x) |
💡
at[...].set(...)は「新しい配列を返す」操作です。元の配列は変更されません。戻り値を必ず受け取る ことを忘れずに。
13.3 ファンシーインデックスとマスク
NumPy
import numpy as np
x = np.arange(10)
print(x[[0, 2, 5]]) # [0 2 5]
print(x[x > 3]) # [4 5 6 7 8 9]
# 散布代入
y = np.zeros(10)
np.add.at(y, [0, 0, 1, 1], 1.0) # 重複インデックスも安全に加算
print(y) # [2. 2. 0. ...]
JAX
import jax.numpy as jnp
x = jnp.arange(10)
print(x[jnp.array([0, 2, 5])]) # [0 2 5]
print(x[x > 3]) # ※ jit 内では使えない
# 散布代入(重複インデックス OK)
y = jnp.zeros(10).at[jnp.array([0, 0, 1, 1])].add(1.0)
print(y) # [2. 2. 0. ...]
ポイント
- ブーリアンマスクで
x[x > 3]のように 「出力サイズが実行時に決まる」インデックス はjit内では使えません。 - 代わりに「出力サイズ固定」の
jnp.where、もしくはパディングを使うのが定石。
# jit 互換:マスクされた要素を 0 にする(サイズは固定)
masked = jnp.where(x > 3, x, 0)
13.4 条件分岐 if / where
NumPy
import numpy as np
def relu(x):
if x > 0:
return x
else:
return 0.0
JAX
import jax
import jax.numpy as jnp
# 値に依存する if は jit 内で使えない → where か lax.cond
@jax.jit
def relu(x):
return jnp.where(x > 0, x, 0.0)
# 「枝ごとに重い処理がある」なら lax.cond の方が効率的
@jax.jit
def f(x):
return jax.lax.cond(
x > 0,
lambda x: jnp.sin(x),
lambda x: jnp.cos(x),
x,
)
使い分けの目安
| ケース | おすすめ |
|---|---|
| 単純な要素ごとの条件選択 | jnp.where |
| 枝ごとに重い計算がある | jax.lax.cond |
多分岐(switch) | jax.lax.switch |
13.5 累積計算(for ループ)→ scan
NumPy で for を回して累積を計算するパターンは、JAX では jax.lax.scan が定石です(速度・コンパイル時間ともに有利)。
NumPy(累積和)
import numpy as np
x = np.arange(10).astype(np.float32)
s = np.zeros_like(x)
acc = 0.0
for i in range(len(x)):
acc += x[i]
s[i] = acc
JAX(scan で書く)
import jax
import jax.numpy as jnp
x = jnp.arange(10, dtype=jnp.float32)
def step(carry, xi):
new_carry = carry + xi
return new_carry, new_carry # (次の状態, 出力)
_, s = jax.lax.scan(step, 0.0, x)
print(s) # [0 1 3 6 10 15 21 28 36 45]
回数だけ決まっているループ:fori_loop
@jax.jit
def sum_to(n):
def body(i, acc):
return acc + i
return jax.lax.fori_loop(0, n, body, 0)
「途中で止める」ループ:while_loop
@jax.jit
def newton(x0, tol=1e-6):
def cond(state):
x, prev = state
return jnp.abs(x - prev) > tol
def body(state):
x, _ = state
# f(x) = x^2 - 2 のニュートン法
return (x - (x**2 - 2) / (2 * x), x)
x, _ = jax.lax.while_loop(cond, body, (x0, x0 + 1.0))
return x
print(newton(1.0)) # ≒ √2
13.6 乱数生成
NumPy(グローバル状態)
import numpy as np
np.random.seed(0)
a = np.random.normal(size=(3, 3))
b = np.random.uniform(0, 1, size=(5,))
c = np.random.randint(0, 10, size=(4,))
JAX(明示的な鍵)
import jax
import jax.numpy as jnp
key = jax.random.key(0)
k1, k2, k3 = jax.random.split(key, 3)
a = jax.random.normal(k1, (3, 3))
b = jax.random.uniform(k2, (5,), minval=0.0, maxval=1.0)
c = jax.random.randint(k3, (4,), 0, 10)
ループ内で乱数が必要なとき
# 5 回イテレーションのループで毎回違う乱数が欲しい
key = jax.random.key(0)
keys = jax.random.split(key, 5)
for k in keys:
print(jax.random.normal(k, (3,)))
詳しくは 第8章 を参照。
13.7 線形代数
NumPy
import numpy as np
A = np.random.randn(4, 4)
b = np.random.randn(4)
x = np.linalg.solve(A, b)
U, S, Vt = np.linalg.svd(A)
w, v = np.linalg.eig(A)
inv = np.linalg.inv(A)
det = np.linalg.det(A)
norm = np.linalg.norm(b)
JAX
import jax
import jax.numpy as jnp
key = jax.random.key(0)
k1, k2 = jax.random.split(key)
A = jax.random.normal(k1, (4, 4))
b = jax.random.normal(k2, (4,))
x = jnp.linalg.solve(A, b)
U, S, Vt = jnp.linalg.svd(A)
w, v = jnp.linalg.eig(A) # 環境・dtype によって制約があるため、対称行列なら eigh を優先
inv = jnp.linalg.inv(A)
det = jnp.linalg.det(A)
norm = jnp.linalg.norm(b)
ポイント
jnp.linalg.*の API はnp.linalg.*とほぼ同じ- 対称行列なら
jnp.linalg.eighを使うと、一般のeigより安定しやすく、アクセラレータでも扱いやすい - 大きな疎行列・反復解法には
jax.experimental.sparseやjax.scipy.sparse.linalg.cgも選べる
13.8 統計・集約
NumPy
import numpy as np
x = np.random.randn(100, 5)
x.mean(axis=0)
x.std(axis=0, ddof=1)
x.sum(axis=0)
np.median(x, axis=0)
np.percentile(x, 95, axis=0)
np.cov(x, rowvar=False)
np.corrcoef(x, rowvar=False)
JAX
import jax
import jax.numpy as jnp
key = jax.random.key(0)
x = jax.random.normal(key, (100, 5))
x.mean(axis=0)
x.std(axis=0, ddof=1)
x.sum(axis=0)
jnp.median(x, axis=0)
jnp.percentile(x, 95, axis=0)
jnp.cov(x, rowvar=False)
jnp.corrcoef(x, rowvar=False)
ほぼ完全互換です。
13.9 FFT・信号処理
NumPy
import numpy as np
x = np.random.randn(1024)
X = np.fft.fft(x)
x2 = np.fft.ifft(X).real
P = np.abs(np.fft.rfft(x)) ** 2 # パワースペクトル
JAX
import jax
import jax.numpy as jnp
key = jax.random.key(0)
x = jax.random.normal(key, (1024,))
X = jnp.fft.fft(x)
x2 = jnp.fft.ifft(X).real
P = jnp.abs(jnp.fft.rfft(x)) ** 2
畳み込み(np.convolve)も jnp.convolve がそのまま使えます。2D 畳み込み(CNN 用)は jax.lax.conv_general_dilated を使うのが定番です(第14章で扱います)。
13.10 数値微分 → 自動微分
NumPy で「微分が必要だから差分近似で…」と書いていた箇所は、JAX なら jax.grad で 正確かつ高速 に置き換えられます。
NumPy(差分近似)
import numpy as np
def f(x):
return np.sin(x) * x ** 2
def numerical_grad(f, x, h=1e-5):
return (f(x + h) - f(x - h)) / (2 * h)
print(numerical_grad(f, 2.0))
JAX(自動微分)
import jax
import jax.numpy as jnp
def f(x):
return jnp.sin(x) * x ** 2
print(jax.grad(f)(2.0)) # 差分近似より安定した自動微分の値
ヤコビアンも 1 行
def g(x): # R^3 → R^2
return jnp.array([x[0]**2 + x[1], x[1] * x[2]])
x = jnp.array([1.0, 2.0, 3.0])
print(jax.jacrev(g)(x)) # 2×3 のヤコビ行列
13.11 ベクトル化(Python ループ → vmap)
NumPy(暗黙にベクトル化 or for で書く)
import numpy as np
def kernel(x, y):
# 1 ペア用:スカラー入力 → スカラー出力
return np.exp(-((x - y) ** 2))
xs = np.linspace(0, 1, 100)
ys = np.linspace(0, 1, 100)
# 全ペアのカーネル行列を作りたい
K = np.empty((100, 100))
for i in range(100):
for j in range(100):
K[i, j] = kernel(xs[i], ys[j])
JAX(vmap を 2 重に)
import jax
import jax.numpy as jnp
def kernel(x, y):
return jnp.exp(-((x - y) ** 2))
xs = jnp.linspace(0, 1, 100)
ys = jnp.linspace(0, 1, 100)
# kernel(xs[i], ys[j]) を全 (i, j) で計算
K = jax.vmap(jax.vmap(kernel, in_axes=(None, 0)), in_axes=(0, None))(xs, ys)
print(K.shape) # (100, 100)
for で 2 重ループを回すよりずっと高速で、jit を被せればさらに加速します。
13.12 ソート・unique・set 演算
NumPy
import numpy as np
x = np.array([3, 1, 4, 1, 5, 9, 2, 6])
np.sort(x)
np.argsort(x)
np.unique(x)
np.in1d(x, [1, 2, 3])
JAX
import jax.numpy as jnp
x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6])
jnp.sort(x)
jnp.argsort(x)
jnp.unique(x, size=8, fill_value=-1) # ★ jit 互換のため size 指定が必要
jnp.isin(x, jnp.array([1, 2, 3]))
ポイント
jit化したい場合、uniqueのように 出力サイズが動的に決まる関数 にはsize=を渡して上限を明示する必要があります。
13.13 dtype と数値精度
NumPy はデフォルト float64、JAX はデフォルト float32。これが原因で「NumPy と微妙に値がずれる」ことがあります。
JAX で float64 を使いたい
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
print(jnp.array([1.0]).dtype) # float64
精度ヒエラルキー
| 用途 | 推奨 dtype |
|---|---|
| ふつうの DL 学習 | float32 |
| TPU / 大規模 LLM 学習 | bfloat16 + 一部 float32 |
| 厳密な科学計算 | float64(x64 を ON にする) |
13.14 SciPy 的な処理
JAX には jax.scipy があり、SciPy の代表的な関数が JIT・自動微分対応で再実装されています。
| やりたいこと | SciPy | JAX |
|---|---|---|
| 正規分布の PDF | scipy.stats.norm.pdf(x) | jax.scipy.stats.norm.pdf(x) |
特殊関数 erf | scipy.special.erf(x) | jax.scipy.special.erf(x) |
| LU 分解 | scipy.linalg.lu(A) | jax.scipy.linalg.lu(A) |
| FFT(2D) | scipy.fft.fft2(x) | jnp.fft.fft2(x) |
| 共役勾配法 | scipy.sparse.linalg.cg(A, b) | jax.scipy.sparse.linalg.cg(A, b) |
これにより、SciPy で書いていた科学計算コードに 微分・JIT・vmap を後付けできるようになります。
13.15 よくある小技・周辺操作
NumPy の既存コードには、細かい便利関数がたくさん出てきます。JAX でも多くはそのまま対応できますが、「本当に高速化したいなら vmap や jit と相性のよい書き方にする」 のがポイントです。
| NumPy | JAX | 注意点 |
|---|---|---|
np.einsum("ij,jk->ik", A, B) | jnp.einsum("ij,jk->ik", A, B) | ほぼそのまま。XLA が最適化しやすい |
np.clip(x, a, b) | jnp.clip(x, a, b) | そのまま |
np.maximum(x, 0) | jnp.maximum(x, 0) | ReLU なら jax.nn.relu(x) も可 |
np.nan_to_num(x) | jnp.nan_to_num(x) | そのまま |
np.meshgrid(x, y) | jnp.meshgrid(x, y) | そのまま |
np.concatenate(xs, axis=0) | jnp.concatenate(xs, axis=0) | jit 内では xs の長さを固定する |
np.stack(xs, axis=0) | jnp.stack(xs, axis=0) | 同上 |
np.expand_dims(x, axis) | jnp.expand_dims(x, axis) | x[None, ...] でも可 |
np.squeeze(x) | jnp.squeeze(x) | そのまま |
np.asarray(x) | jnp.asarray(x) | JAX デバイス配列になる。必要なら np.asarray(x) でホストへ戻す |
np.vectorize / np.apply_along_axis は vmap に置き換える
np.vectorize は名前に反して「高速なベクトル化」ではなく、Python ループを薄く包んだものです。JAX では vmap を使いましょう。
NumPy
import numpy as np
def f(x):
return np.sin(x) + x ** 2
xs = np.linspace(0, 1, 1000)
ys = np.vectorize(f)(xs)
JAX
import jax
import jax.numpy as jnp
def f(x):
return jnp.sin(x) + x ** 2
xs = jnp.linspace(0, 1, 1000)
ys = jax.vmap(f)(xs)
apply_along_axis 的な処理も、「1 本のベクトル用の関数」を作って vmap で軸方向に適用するのが JAX らしい書き方です。
# x: (batch, features)
def normalize_one(row):
return (row - row.mean()) / (row.std() + 1e-6)
x_norm = jax.vmap(normalize_one)(x)
まとめ:NumPy 移行のチェックリスト
import numpy as np→import jax.numpy as jnpで動作確認- 配列の 書き換え を
at[...].set(...)などに置き換え - 乱数 を
jax.random.key()+splitに置き換え - 値依存の
ifをjnp.where/lax.condに置き換え - Python の
forループをlax.scan/lax.fori_loop/vmapに置き換え uniqueなどの 動的サイズ API はsize=を渡す- 数値の 再現性が崩れる ときは
jax_enable_x64を疑う - 数値微分を書いていたら
jax.gradに置き換える絶好のチャンス np.vectorize/apply_along_axisは、JAX ではvmapに置き換える