Skip to main content

第13章 移行ケース集:NumPy → JAX

NumPy で書かれた既存のコードを JAX に置き換えるための 実践レシピ集 です。「NumPy ならこう書く / JAX ならこう書く」をひたすら左右に並べて見比べられるようにしてあります。コピペで動くサンプルが中心です。

📚 前提:NumPy の基本(np.arraynp.zeros、ブロードキャストなど)を一度でも書いたことがあれば読めます。JAX 側のインストールや基本概念に不安があれば、先に 第1章第3章第8章 乱数 を眺めるとスムーズです。

💡 多くの場合 import numpy as npimport jax.numpy as jnp に書き換えるだけで動きます。問題になるのは、書き換え(in-place 操作)・乱数・ループ・データ型 あたりです。本章ではそこを重点的に扱います。

目次


13.1 配列の生成・基本演算

NumPy

import numpy as np

a = np.array([1.0, 2.0, 3.0])
b = np.zeros((3, 3))
c = np.eye(4)
d = np.arange(10)
e = np.linspace(0, 1, 100)

y = np.sin(a) + np.cos(a) ** 2
z = a @ a.T

JAX

import jax.numpy as jnp

a = jnp.array([1.0, 2.0, 3.0])
b = jnp.zeros((3, 3))
c = jnp.eye(4)
d = jnp.arange(10)
e = jnp.linspace(0, 1, 100)

y = jnp.sin(a) + jnp.cos(a) ** 2
z = a @ a.T

ポイント

  • ほぼ npjnp の置き換えだけで OK
  • ただし デフォルトの浮動小数点数は float32(NumPy は float64

13.2 配列の書き換え(in-place)

NumPy で当たり前に書いていた「要素代入」は JAX では使えません。at[...].set/add/... に書き換えます。

NumPy

import numpy as np
x = np.zeros(5)
x[0] = 1.0
x[2:4] += 5.0
x[x < 0] = 0.0 # 負の値を 0 に

JAX

import jax.numpy as jnp
x = jnp.zeros(5)
x = x.at[0].set(1.0)
x = x.at[2:4].add(5.0)
x = jnp.where(x < 0, 0.0, x) # 条件付き代入は where が便利

主な対応表

NumPyJAX
x[i] = vx = x.at[i].set(v)
x[i] += vx = x.at[i].add(v)
x[i] *= vx = x.at[i].mul(v)
x[i] = min/max(x[i], v)x = x.at[i].min(v) / .max(v)
x[mask] = vx = jnp.where(mask, v, x)

💡 at[...].set(...) は「新しい配列を返す」操作です。元の配列は変更されません。戻り値を必ず受け取る ことを忘れずに。


13.3 ファンシーインデックスとマスク

NumPy

import numpy as np
x = np.arange(10)
print(x[[0, 2, 5]]) # [0 2 5]
print(x[x > 3]) # [4 5 6 7 8 9]

# 散布代入
y = np.zeros(10)
np.add.at(y, [0, 0, 1, 1], 1.0) # 重複インデックスも安全に加算
print(y) # [2. 2. 0. ...]

JAX

import jax.numpy as jnp
x = jnp.arange(10)
print(x[jnp.array([0, 2, 5])]) # [0 2 5]
print(x[x > 3]) # ※ jit 内では使えない

# 散布代入(重複インデックス OK)
y = jnp.zeros(10).at[jnp.array([0, 0, 1, 1])].add(1.0)
print(y) # [2. 2. 0. ...]

ポイント

  • ブーリアンマスクで x[x > 3] のように 「出力サイズが実行時に決まる」インデックスjit 内では使えません。
  • 代わりに「出力サイズ固定」の jnp.where、もしくはパディングを使うのが定石。
# jit 互換:マスクされた要素を 0 にする(サイズは固定)
masked = jnp.where(x > 3, x, 0)

13.4 条件分岐 if / where

NumPy

import numpy as np
def relu(x):
if x > 0:
return x
else:
return 0.0

JAX

import jax
import jax.numpy as jnp

# 値に依存する if は jit 内で使えない → where か lax.cond
@jax.jit
def relu(x):
return jnp.where(x > 0, x, 0.0)

# 「枝ごとに重い処理がある」なら lax.cond の方が効率的
@jax.jit
def f(x):
return jax.lax.cond(
x > 0,
lambda x: jnp.sin(x),
lambda x: jnp.cos(x),
x,
)

使い分けの目安

ケースおすすめ
単純な要素ごとの条件選択jnp.where
枝ごとに重い計算があるjax.lax.cond
多分岐(switchjax.lax.switch

13.5 累積計算(for ループ)→ scan

NumPy で for を回して累積を計算するパターンは、JAX では jax.lax.scan が定石です(速度・コンパイル時間ともに有利)。

NumPy(累積和)

import numpy as np
x = np.arange(10).astype(np.float32)
s = np.zeros_like(x)
acc = 0.0
for i in range(len(x)):
acc += x[i]
s[i] = acc

JAX(scan で書く)

import jax
import jax.numpy as jnp

x = jnp.arange(10, dtype=jnp.float32)

def step(carry, xi):
new_carry = carry + xi
return new_carry, new_carry # (次の状態, 出力)

_, s = jax.lax.scan(step, 0.0, x)
print(s) # [0 1 3 6 10 15 21 28 36 45]

回数だけ決まっているループ:fori_loop

@jax.jit
def sum_to(n):
def body(i, acc):
return acc + i
return jax.lax.fori_loop(0, n, body, 0)

「途中で止める」ループ:while_loop

@jax.jit
def newton(x0, tol=1e-6):
def cond(state):
x, prev = state
return jnp.abs(x - prev) > tol
def body(state):
x, _ = state
# f(x) = x^2 - 2 のニュートン法
return (x - (x**2 - 2) / (2 * x), x)
x, _ = jax.lax.while_loop(cond, body, (x0, x0 + 1.0))
return x

print(newton(1.0)) # ≒ √2

13.6 乱数生成

NumPy(グローバル状態)

import numpy as np
np.random.seed(0)
a = np.random.normal(size=(3, 3))
b = np.random.uniform(0, 1, size=(5,))
c = np.random.randint(0, 10, size=(4,))

JAX(明示的な鍵)

import jax
import jax.numpy as jnp

key = jax.random.key(0)
k1, k2, k3 = jax.random.split(key, 3)

a = jax.random.normal(k1, (3, 3))
b = jax.random.uniform(k2, (5,), minval=0.0, maxval=1.0)
c = jax.random.randint(k3, (4,), 0, 10)

ループ内で乱数が必要なとき

# 5 回イテレーションのループで毎回違う乱数が欲しい
key = jax.random.key(0)
keys = jax.random.split(key, 5)
for k in keys:
print(jax.random.normal(k, (3,)))

詳しくは 第8章 を参照。


13.7 線形代数

NumPy

import numpy as np
A = np.random.randn(4, 4)
b = np.random.randn(4)

x = np.linalg.solve(A, b)
U, S, Vt = np.linalg.svd(A)
w, v = np.linalg.eig(A)
inv = np.linalg.inv(A)
det = np.linalg.det(A)
norm = np.linalg.norm(b)

JAX

import jax
import jax.numpy as jnp

key = jax.random.key(0)
k1, k2 = jax.random.split(key)
A = jax.random.normal(k1, (4, 4))
b = jax.random.normal(k2, (4,))

x = jnp.linalg.solve(A, b)
U, S, Vt = jnp.linalg.svd(A)
w, v = jnp.linalg.eig(A) # 環境・dtype によって制約があるため、対称行列なら eigh を優先
inv = jnp.linalg.inv(A)
det = jnp.linalg.det(A)
norm = jnp.linalg.norm(b)

ポイント

  • jnp.linalg.* の API は np.linalg.* とほぼ同じ
  • 対称行列なら jnp.linalg.eigh を使うと、一般の eig より安定しやすく、アクセラレータでも扱いやすい
  • 大きな疎行列・反復解法には jax.experimental.sparsejax.scipy.sparse.linalg.cg も選べる

13.8 統計・集約

NumPy

import numpy as np
x = np.random.randn(100, 5)

x.mean(axis=0)
x.std(axis=0, ddof=1)
x.sum(axis=0)
np.median(x, axis=0)
np.percentile(x, 95, axis=0)
np.cov(x, rowvar=False)
np.corrcoef(x, rowvar=False)

JAX

import jax
import jax.numpy as jnp

key = jax.random.key(0)
x = jax.random.normal(key, (100, 5))

x.mean(axis=0)
x.std(axis=0, ddof=1)
x.sum(axis=0)
jnp.median(x, axis=0)
jnp.percentile(x, 95, axis=0)
jnp.cov(x, rowvar=False)
jnp.corrcoef(x, rowvar=False)

ほぼ完全互換です。


13.9 FFT・信号処理

NumPy

import numpy as np
x = np.random.randn(1024)

X = np.fft.fft(x)
x2 = np.fft.ifft(X).real
P = np.abs(np.fft.rfft(x)) ** 2 # パワースペクトル

JAX

import jax
import jax.numpy as jnp

key = jax.random.key(0)
x = jax.random.normal(key, (1024,))

X = jnp.fft.fft(x)
x2 = jnp.fft.ifft(X).real
P = jnp.abs(jnp.fft.rfft(x)) ** 2

畳み込み(np.convolve)も jnp.convolve がそのまま使えます。2D 畳み込み(CNN 用)は jax.lax.conv_general_dilated を使うのが定番です(第14章で扱います)。


13.10 数値微分 → 自動微分

NumPy で「微分が必要だから差分近似で…」と書いていた箇所は、JAX なら jax.grad正確かつ高速 に置き換えられます。

NumPy(差分近似)

import numpy as np
def f(x):
return np.sin(x) * x ** 2

def numerical_grad(f, x, h=1e-5):
return (f(x + h) - f(x - h)) / (2 * h)

print(numerical_grad(f, 2.0))

JAX(自動微分)

import jax
import jax.numpy as jnp

def f(x):
return jnp.sin(x) * x ** 2

print(jax.grad(f)(2.0)) # 差分近似より安定した自動微分の値

ヤコビアンも 1 行

def g(x): # R^3 → R^2
return jnp.array([x[0]**2 + x[1], x[1] * x[2]])

x = jnp.array([1.0, 2.0, 3.0])
print(jax.jacrev(g)(x)) # 2×3 のヤコビ行列

13.11 ベクトル化(Python ループ → vmap

NumPy(暗黙にベクトル化 or for で書く)

import numpy as np
def kernel(x, y):
# 1 ペア用:スカラー入力 → スカラー出力
return np.exp(-((x - y) ** 2))

xs = np.linspace(0, 1, 100)
ys = np.linspace(0, 1, 100)

# 全ペアのカーネル行列を作りたい
K = np.empty((100, 100))
for i in range(100):
for j in range(100):
K[i, j] = kernel(xs[i], ys[j])

JAX(vmap を 2 重に)

import jax
import jax.numpy as jnp

def kernel(x, y):
return jnp.exp(-((x - y) ** 2))

xs = jnp.linspace(0, 1, 100)
ys = jnp.linspace(0, 1, 100)

# kernel(xs[i], ys[j]) を全 (i, j) で計算
K = jax.vmap(jax.vmap(kernel, in_axes=(None, 0)), in_axes=(0, None))(xs, ys)
print(K.shape) # (100, 100)

for で 2 重ループを回すよりずっと高速で、jit を被せればさらに加速します。


13.12 ソート・unique・set 演算

NumPy

import numpy as np
x = np.array([3, 1, 4, 1, 5, 9, 2, 6])
np.sort(x)
np.argsort(x)
np.unique(x)
np.in1d(x, [1, 2, 3])

JAX

import jax.numpy as jnp

x = jnp.array([3, 1, 4, 1, 5, 9, 2, 6])
jnp.sort(x)
jnp.argsort(x)
jnp.unique(x, size=8, fill_value=-1) # ★ jit 互換のため size 指定が必要
jnp.isin(x, jnp.array([1, 2, 3]))

ポイント

  • jit 化したい場合、unique のように 出力サイズが動的に決まる関数 には size= を渡して上限を明示する必要があります。

13.13 dtype と数値精度

NumPy はデフォルト float64、JAX はデフォルト float32。これが原因で「NumPy と微妙に値がずれる」ことがあります。

JAX で float64 を使いたい

import jax
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
print(jnp.array([1.0]).dtype) # float64

精度ヒエラルキー

用途推奨 dtype
ふつうの DL 学習float32
TPU / 大規模 LLM 学習bfloat16 + 一部 float32
厳密な科学計算float64x64 を ON にする)

13.14 SciPy 的な処理

JAX には jax.scipy があり、SciPy の代表的な関数が JIT・自動微分対応で再実装されています。

やりたいことSciPyJAX
正規分布の PDFscipy.stats.norm.pdf(x)jax.scipy.stats.norm.pdf(x)
特殊関数 erfscipy.special.erf(x)jax.scipy.special.erf(x)
LU 分解scipy.linalg.lu(A)jax.scipy.linalg.lu(A)
FFT(2D)scipy.fft.fft2(x)jnp.fft.fft2(x)
共役勾配法scipy.sparse.linalg.cg(A, b)jax.scipy.sparse.linalg.cg(A, b)

これにより、SciPy で書いていた科学計算コードに 微分・JIT・vmap を後付けできるようになります。


13.15 よくある小技・周辺操作

NumPy の既存コードには、細かい便利関数がたくさん出てきます。JAX でも多くはそのまま対応できますが、「本当に高速化したいなら vmapjit と相性のよい書き方にする」 のがポイントです。

NumPyJAX注意点
np.einsum("ij,jk->ik", A, B)jnp.einsum("ij,jk->ik", A, B)ほぼそのまま。XLA が最適化しやすい
np.clip(x, a, b)jnp.clip(x, a, b)そのまま
np.maximum(x, 0)jnp.maximum(x, 0)ReLU なら jax.nn.relu(x) も可
np.nan_to_num(x)jnp.nan_to_num(x)そのまま
np.meshgrid(x, y)jnp.meshgrid(x, y)そのまま
np.concatenate(xs, axis=0)jnp.concatenate(xs, axis=0)jit 内では xs の長さを固定する
np.stack(xs, axis=0)jnp.stack(xs, axis=0)同上
np.expand_dims(x, axis)jnp.expand_dims(x, axis)x[None, ...] でも可
np.squeeze(x)jnp.squeeze(x)そのまま
np.asarray(x)jnp.asarray(x)JAX デバイス配列になる。必要なら np.asarray(x) でホストへ戻す

np.vectorize / np.apply_along_axisvmap に置き換える

np.vectorize は名前に反して「高速なベクトル化」ではなく、Python ループを薄く包んだものです。JAX では vmap を使いましょう。

NumPy

import numpy as np

def f(x):
return np.sin(x) + x ** 2

xs = np.linspace(0, 1, 1000)
ys = np.vectorize(f)(xs)

JAX

import jax
import jax.numpy as jnp

def f(x):
return jnp.sin(x) + x ** 2

xs = jnp.linspace(0, 1, 1000)
ys = jax.vmap(f)(xs)

apply_along_axis 的な処理も、「1 本のベクトル用の関数」を作って vmap で軸方向に適用するのが JAX らしい書き方です。

# x: (batch, features)
def normalize_one(row):
return (row - row.mean()) / (row.std() + 1e-6)

x_norm = jax.vmap(normalize_one)(x)

まとめ:NumPy 移行のチェックリスト

  1. import numpy as npimport jax.numpy as jnp で動作確認
  2. 配列の 書き換えat[...].set(...) などに置き換え
  3. 乱数jax.random.key() + split に置き換え
  4. 値依存の ifjnp.where / lax.cond に置き換え
  5. Python の for ループを lax.scan / lax.fori_loop / vmap に置き換え
  6. unique などの 動的サイズ API は size= を渡す
  7. 数値の 再現性が崩れる ときは jax_enable_x64 を疑う
  8. 数値微分を書いていたら jax.grad に置き換える絶好のチャンス
  9. np.vectorize / apply_along_axis は、JAX では vmap に置き換える

➡️ 第14章 移行ケース集:PyTorch → JAX