第6章 ベクトル化(vmap)と並列化(pmap / shard_map / sharding)
JAX の魅力のひとつが、「1 サンプル用に書いた関数を、いとも簡単にバッチ化したり、複数 GPU に分散したりできる」 という点です。本章では、その主役である vmap と、複数デバイス向けの pmap / shard_map / sharding を順に見ていきます。
6.1 vmap:手書きループとサヨナラする
ニューラルネットでは、1 サンプル用に書いた関数を バッチ全体に適用 することがよくあります。普通なら for ループや「先頭次元を考慮した行列演算」を書く必要がありますが、vmap を使えば 1 サンプル用の関数を、書き直すことなくバッチ対応にできます。
例:1 つのベクトル用の関数
import jax
import jax.numpy as jnp
def predict(W, x):
# x: shape (D,) の 1 サンプル
return jnp.dot(W, x) # shape (out,)
W = jnp.ones((3, 4)) # 出力 3、入力 4
x = jnp.ones((4,))
print(predict(W, x)) # shape (3,)
バッチ対応にしたい!
xs = jnp.ones((32, 4)) # 32 サンプル
# 普通ならこう書く
# preds = jnp.stack([predict(W, x) for x in xs])
# vmap ならこう!
batched_predict = jax.vmap(predict, in_axes=(None, 0))
preds = batched_predict(W, xs)
print(preds.shape) # (32, 3)
in_axes は「どの引数の、どの軸を バッチ方向 とみなすか」を指定するものです。
None→ バッチしない(全サンプル共通の値)0→ 0 軸目(先頭)をバッチ方向にする1→ 1 軸目をバッチ方向にする
out_axes
出力側にバッチ軸をどこに付けるかも指定できます(デフォルトは 0、つまり先頭です)。
jax.vmap(predict, in_axes=(None, 0), out_axes=0)
vmap のうれしさ
- 手書きループより速い(XLA が裏で上手にベクトル化してくれる)
- コードが「1 サンプル用」のまま済むので、ロジックが読みやすい
vmap(vmap(f))のように重ねれば、多次元バッチもラクラク
# 画像(H, W)の各ピクセルに関数を適用したい
def f(pixel):
return pixel * 2 + 1
img = jnp.ones((28, 28))
# 行・列の両方向にベクトル化
result = jax.vmap(jax.vmap(f))(img)
print(result.shape) # (28, 28)
6.2 pmap:複数デバイスで並列実行(古典的 API)
pmap は、vmap と書き方は似ていますが、バッチを複数の物理デバイス(GPU / TPU)に分散して並列実行する 関数です。現在の公式ドキュメントでは「古い並列 map」と位置づけられ、新しいコードでは後述の jax.shard_map や jax.jit + sharding を検討する流れになっています。ただし、既存コードや入門資料では今もよく登場するので、考え方を知っておく価値はあります。
import jax
import jax.numpy as jnp
print(jax.device_count()) # 例: 8(TPU pod など)
def f(x):
return x ** 2
# 8 個のデバイスに、それぞれ x を 1 つずつ渡す
xs = jnp.arange(8)
ys = jax.pmap(f)(xs)
print(ys) # [ 0 1 4 9 16 25 36 49]
pmap で並列実行された関数の中では、jax.lax.psum や jax.lax.pmean などを使うことで、デバイス間で合計や平均を取れます。「データ並列の学習を 1 行で書ける」のが魅力です。
6.3 現在の主流:shard_map / jit + sharding
最近の JAX では、pmap だけでなく、
jax.shardingで配列の分散方法を宣言し、- それを
jax.jitに渡すだけで自動的に並列化する、
という「明示的シャーディング」スタイルが推奨されつつあります。さらに「各デバイス上でどんな形の配列をどう処理するか」をプログラム的に明示したい場合は jax.shard_map を使います。
import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jax.numpy as jnp
import numpy as np
devices = np.array(jax.devices())
mesh = Mesh(devices, axis_names=("data",))
x = jnp.ones((1024, 128))
sharding = NamedSharding(mesh, P("data", None))
x_sharded = jax.device_put(x, sharding)
@jax.jit
def f(x):
return x * 2
y = f(x_sharded) # 自動で並列実行
この方式は大規模言語モデル(LLM)の学習などで広く使われており、覚えておくと将来きっと役に立ちます。詳細は 公式 sharding ガイド を参照してください。
6.4 ありがちな組み合わせパターン
# 1 サンプル用の loss
def loss(params, x, y):
pred = model(params, x)
return (pred - y) ** 2
# バッチ平均の loss
batched_loss = lambda params, xs, ys: jnp.mean(
jax.vmap(loss, in_axes=(None, 0, 0))(params, xs, ys)
)
# 勾配 + JIT
update = jax.jit(jax.grad(batched_loss))
この 「vmap で 1 サンプル → grad で微分 → jit で高速化」 という組み合わせは、JAX の典型的な書き方です。3 つの関数変換が美しく重なり合うところが、JAX らしさの真骨頂と言えるでしょう。
6.5 まとめ
vmap:1 サンプル用関数をバッチ対応にする(同じデバイス上で)pmap:バッチを複数デバイスに 物理的に 分散して並列実行する古典的 API- 最近は
jax.shard_mapやjit+ sharding がモダンな選択肢 vmap・grad・jitの組み合わせが JAX の真骨頂