Skip to main content

第8章 乱数の扱い方(PRNG キー)

JAX は乱数も「関数型流」に扱います。最初は少し違和感があるかもしれませんが、その理由を理解すれば「なるほど、こうしたかったのか」と納得できるはずです。

8.1 NumPy / PyTorch の乱数のどこが問題なのか

NumPy や PyTorch では、np.random.normal() のように、グローバルな乱数状態 から値を取り出します。

import numpy as np
np.random.seed(0)
print(np.random.normal()) # -0.13...
print(np.random.normal()) # -1.30...

これは確かに便利ですが、グローバルな状態が裏でこっそり更新される ため、次のような問題が起きやすくなります。

  • 関数が「副作用」を持つ(呼ぶたびに結果が変わる)
  • 並列実行で結果がバラバラになる(誰がいつ更新するか分からない)
  • JIT コンパイルとも相性が悪い

8.2 JAX の流儀:明示的な「鍵」を渡す

JAX では 乱数の状態 = PRNG キー(PRNG Key) を毎回明示的に関数に渡します。現在の公式ドキュメントでは、新しい typed key を作る jax.random.key() が推奨されています。古い資料では jax.random.PRNGKey() もよく出てきますが、これは legacy key と呼ばれるもので、現在でも利用できます(後述)。

import jax

key = jax.random.key(0)
print(jax.random.normal(key, (3,))) # 同じ key なら毎回同じ値
print(jax.random.normal(key, (3,))) # ← 同じ値

「同じ key からは同じ値が出る」ので、乱数生成も 同じ入力なら同じ出力 という、純粋関数に近い形で扱えます。これによって、実験の再現性がぐっと高まります。

8.3 でも毎回違う値が欲しい! → split する

異なる乱数が欲しいときは、split鍵を分割 します。

key = jax.random.key(0)
key, subkey1 = jax.random.split(key)
key, subkey2 = jax.random.split(key)

a = jax.random.normal(subkey1, (3,))
b = jax.random.normal(subkey2, (3,))
print(a)
print(b)

慣習として、

  • key は「将来また分割するための、元の鍵」
  • subkey は「今回 1 回だけ使う鍵」

として使い分けます。「鍵 1 本で同じ乱数しか出ないなら、鍵を増やせばいいじゃない」という発想です。

8.4 複数の subkey をまとめて作る

key = jax.random.key(42)
keys = jax.random.split(key, num=5) # 5 個の鍵を一度に作る
samples = [jax.random.normal(k, (10,)) for k in keys]

8.5 よく使う乱数関数

import jax.numpy as jnp

key = jax.random.key(0)
jax.random.normal(key, (3, 3)) # 標準正規分布
jax.random.uniform(key, (3,), minval=-1, maxval=1) # 一様分布
jax.random.bernoulli(key, p=0.5, shape=(10,)) # ベルヌーイ分布(0 or 1)
jax.random.randint(key, (5,), 0, 100) # 整数乱数
jax.random.permutation(key, jnp.arange(10)) # シャッフル

8.6 まとめ

  • JAX は乱数を PRNG キー で明示的に管理する(現在は jax.random.key() が基本)
  • 同じ key からは同じ乱数が出る(再現性が高い
  • 異なる乱数が欲しいときは split で分割
  • これにより関数の 純粋性並列実行時の再現性 が保たれる

8.7 PRNGKeykey() の違い

書き方位置づけ
jax.random.key(0)現在推奨される typed keyスカラー(shape ())、専用の dtype
jax.random.PRNGKey(0)legacy key。古いコードや一部ライブラリとの互換用uint32 の配列、shape (2,)

このチュートリアルでは基本的に jax.random.key() を使います。古い記事を読むと PRNGKey が出てきますが、「乱数の鍵を明示的に持ち回る」という考え方そのものは同じです。

➡️ 第9章 PyTorch との比較