Skip to main content

第6章 ベクトル化(vmap)と並列化(pmap / shard_map / sharding)

JAX の魅力のひとつが、「1 サンプル用に書いた関数を、いとも簡単にバッチ化したり、複数 GPU に分散したりできる」 という点です。本章では、その主役である vmap と、複数デバイス向けの pmap / shard_map / sharding を順に見ていきます。

6.1 vmap:手書きループとサヨナラする

ニューラルネットでは、1 サンプル用に書いた関数を バッチ全体に適用 することがよくあります。普通なら for ループや「先頭次元を考慮した行列演算」を書く必要がありますが、vmap を使えば 1 サンプル用の関数を、書き直すことなくバッチ対応にできます

例:1 つのベクトル用の関数

import jax
import jax.numpy as jnp

def predict(W, x):
# x: shape (D,) の 1 サンプル
return jnp.dot(W, x) # shape (out,)

W = jnp.ones((3, 4)) # 出力 3、入力 4
x = jnp.ones((4,))
print(predict(W, x)) # shape (3,)

バッチ対応にしたい!

xs = jnp.ones((32, 4)) # 32 サンプル

# 普通ならこう書く
# preds = jnp.stack([predict(W, x) for x in xs])

# vmap ならこう!
batched_predict = jax.vmap(predict, in_axes=(None, 0))
preds = batched_predict(W, xs)
print(preds.shape) # (32, 3)

in_axes は「どの引数の、どの軸を バッチ方向 とみなすか」を指定するものです。

  • None → バッチしない(全サンプル共通の値)
  • 0 → 0 軸目(先頭)をバッチ方向にする
  • 1 → 1 軸目をバッチ方向にする

out_axes

出力側にバッチ軸をどこに付けるかも指定できます(デフォルトは 0、つまり先頭です)。

jax.vmap(predict, in_axes=(None, 0), out_axes=0)

vmap のうれしさ

  • 手書きループより速い(XLA が裏で上手にベクトル化してくれる)
  • コードが「1 サンプル用」のまま済むので、ロジックが読みやすい
  • vmap(vmap(f)) のように重ねれば、多次元バッチもラクラク
# 画像(H, W)の各ピクセルに関数を適用したい
def f(pixel):
return pixel * 2 + 1

img = jnp.ones((28, 28))
# 行・列の両方向にベクトル化
result = jax.vmap(jax.vmap(f))(img)
print(result.shape) # (28, 28)

6.2 pmap:複数デバイスで並列実行(古典的 API)

pmap は、vmap と書き方は似ていますが、バッチを複数の物理デバイス(GPU / TPU)に分散して並列実行する 関数です。現在の公式ドキュメントでは「古い並列 map」と位置づけられ、新しいコードでは後述の jax.shard_mapjax.jit + sharding を検討する流れになっています。ただし、既存コードや入門資料では今もよく登場するので、考え方を知っておく価値はあります。

import jax
import jax.numpy as jnp

print(jax.device_count()) # 例: 8(TPU pod など)

def f(x):
return x ** 2

# 8 個のデバイスに、それぞれ x を 1 つずつ渡す
xs = jnp.arange(8)
ys = jax.pmap(f)(xs)
print(ys) # [ 0 1 4 9 16 25 36 49]

pmap で並列実行された関数の中では、jax.lax.psumjax.lax.pmean などを使うことで、デバイス間で合計や平均を取れます。「データ並列の学習を 1 行で書ける」のが魅力です。

6.3 現在の主流:shard_map / jit + sharding

最近の JAX では、pmap だけでなく、

  • jax.sharding で配列の分散方法を宣言し、
  • それを jax.jit に渡すだけで自動的に並列化する、

という「明示的シャーディング」スタイルが推奨されつつあります。さらに「各デバイス上でどんな形の配列をどう処理するか」をプログラム的に明示したい場合は jax.shard_map を使います。

import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
import jax.numpy as jnp
import numpy as np

devices = np.array(jax.devices())
mesh = Mesh(devices, axis_names=("data",))

x = jnp.ones((1024, 128))
sharding = NamedSharding(mesh, P("data", None))
x_sharded = jax.device_put(x, sharding)

@jax.jit
def f(x):
return x * 2

y = f(x_sharded) # 自動で並列実行

この方式は大規模言語モデル(LLM)の学習などで広く使われており、覚えておくと将来きっと役に立ちます。詳細は 公式 sharding ガイド を参照してください。

6.4 ありがちな組み合わせパターン

# 1 サンプル用の loss
def loss(params, x, y):
pred = model(params, x)
return (pred - y) ** 2

# バッチ平均の loss
batched_loss = lambda params, xs, ys: jnp.mean(
jax.vmap(loss, in_axes=(None, 0, 0))(params, xs, ys)
)

# 勾配 + JIT
update = jax.jit(jax.grad(batched_loss))

この vmap で 1 サンプル → grad で微分 → jit で高速化」 という組み合わせは、JAX の典型的な書き方です。3 つの関数変換が美しく重なり合うところが、JAX らしさの真骨頂と言えるでしょう。

6.5 まとめ

  • vmap:1 サンプル用関数をバッチ対応にする(同じデバイス上で)
  • pmap:バッチを複数デバイスに 物理的に 分散して並列実行する古典的 API
  • 最近は jax.shard_mapjit + sharding がモダンな選択肢
  • vmapgradjit の組み合わせが JAX の真骨頂

➡️ 第7章 PyTree という考え方