第8章 乱数の扱い方(PRNG キー)
JAX は乱数も「関数型流」に扱います。最初は少し違和感があるかもしれませんが、その理由を理解すれば「なるほど、こうしたかったのか」と納得できるはずです。
8.1 NumPy / PyTorch の乱数のどこが問題なのか
NumPy や PyTorch では、np.random.normal() のように、グローバルな乱数状態 から値を取り出します。
import numpy as np
np.random.seed(0)
print(np.random.normal()) # -0.13...
print(np.random.normal()) # -1.30...
これは確かに便利ですが、グローバルな状態が裏でこっそり更新される ため、次のような問題が起きやすくなります。
- 関数が「副作用」を持つ(呼ぶたびに結果が変わる)
- 並列実行で結果がバラバラになる(誰がいつ更新するか分からない)
- JIT コンパイルとも相性が悪い
8.2 JAX の流儀:明示的な「鍵」を渡す
JAX では 乱数の状態 = PRNG キー(PRNG Key) を毎回明示的に関数に渡します。現在の公式ドキュメントでは、新しい typed key を作る jax.random.key() が推奨されています。古い資料では jax.random.PRNGKey() もよく出てきますが、これは legacy key と呼ばれるもので、現在でも利用できます(後述)。
import jax
key = jax.random.key(0)
print(jax.random.normal(key, (3,))) # 同じ key なら毎回同じ値
print(jax.random.normal(key, (3,))) # ← 同じ値
「同じ key からは同じ値が出る」ので、乱数生成も 同じ入力なら同じ出力 という、純粋関数に近い形で扱えます。これによって、実験の再現性がぐっと高まります。
8.3 でも毎回違う値が欲しい! → split する
異なる乱数が欲しいときは、split で 鍵を分割 します。
key = jax.random.key(0)
key, subkey1 = jax.random.split(key)
key, subkey2 = jax.random.split(key)
a = jax.random.normal(subkey1, (3,))
b = jax.random.normal(subkey2, (3,))
print(a)
print(b)
慣習として、
keyは「将来また分割するための、元の鍵」subkeyは「今回 1 回だけ使う鍵」
として使い分けます。「鍵 1 本で同じ乱数しか出ないなら、鍵を増やせばいいじゃない」という発想です。
8.4 複数の subkey をまとめて作る
key = jax.random.key(42)
keys = jax.random.split(key, num=5) # 5 個の鍵を一度に作る
samples = [jax.random.normal(k, (10,)) for k in keys]
8.5 よく使う乱数関数
import jax.numpy as jnp
key = jax.random.key(0)
jax.random.normal(key, (3, 3)) # 標準正規分布
jax.random.uniform(key, (3,), minval=-1, maxval=1) # 一様分布
jax.random.bernoulli(key, p=0.5, shape=(10,)) # ベルヌーイ分布(0 or 1)
jax.random.randint(key, (5,), 0, 100) # 整数乱数
jax.random.permutation(key, jnp.arange(10)) # シャッフル
8.6 まとめ
- JAX は乱数を PRNG キー で明示的に管理する(現在は
jax.random.key()が基本) - 同じ key からは同じ乱数が出る(再現性が高い)
- 異なる乱数が欲しいときは
splitで分割 - これにより関数の 純粋性 と 並列実行時の再現性 が保たれる
8.7 PRNGKey と key() の違い
| 書き方 | 位置づけ | 形 |
|---|---|---|
jax.random.key(0) | 現在推奨される typed key | スカラー(shape ())、専用の dtype |
jax.random.PRNGKey(0) | legacy key。古いコードや一部ライブラリとの互換用 | uint32 の配列、shape (2,) |
このチュートリアルでは基本的に jax.random.key() を使います。古い記事を読むと PRNGKey が出てきますが、「乱数の鍵を明示的に持ち回る」という考え方そのものは同じです。