Skip to main content

第3章 jax.numpy の基本

JAX に入門するときの第一歩は、jax.numpy(しばしば jnp と略します) を使ってみることです。これは NumPy とほぼ同じ API を持っているので、NumPy が使える方ならすぐに馴染めますし、初めての方でも「数学の行列・ベクトル計算を Python で扱う道具」と思えば大丈夫です。

3.1 まずは触ってみる

import jax.numpy as jnp

# 配列を作る
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])

# 四則演算
print(a + b) # [5. 7. 9.]
print(a * b) # [ 4. 10. 18.]
print(jnp.dot(a, b)) # 32.0(内積)

# 形を変える
m = jnp.arange(12).reshape(3, 4)
print(m)
print("shape:", m.shape)
print("sum:", jnp.sum(m))

NumPy をご存知の方には「ほぼ同じだ」と感じてもらえるはずです。Python 初心者の方は、「数字を箱(配列)にまとめて、まとめて計算できる道具」 くらいの感覚で読み進めてください。

💬 as jnp ってなに? import jax.numpy as jnp は「jax.numpy というモジュールを、これからは jnp という短い名前で呼ばせてね」という意味です。タイピングが楽になるための慣習です。

3.2 NumPy との大きな違い:イミュータブル(変更不可)

JAX 配列の最大の特徴は 書き換えができない(イミュータブル) ことです。

import numpy as np
import jax.numpy as jnp

# NumPy なら OK
x_np = np.array([1, 2, 3])
x_np[0] = 99
print(x_np) # [99 2 3]

# JAX ではエラーになる!
x_jax = jnp.array([1, 2, 3])
# x_jax[0] = 99 # ← TypeError

JAX で値を「書き換えたい」ときは、at[...].set(...) という独特な記法を使います。

x_jax = jnp.array([1, 2, 3])
x_new = x_jax.at[0].set(99)
print(x_jax) # [1 2 3] ← 元は変わらない
print(x_new) # [99 2 3] ← 新しい配列が返る

「元を変えずに、新しい配列を作って返す」というのが JAX の流儀です。最初は手間に感じますが、これが 関数の純粋性(同じ入力なら必ず同じ出力になる性質)や 並列化との相性のよさ につながっています。

やりたいことNumPyJAX
要素の代入x[i] = vx = x.at[i].set(v)
加算x[i] += vx = x.at[i].add(v)
最大値で更新x[i] = max(x[i], v)x = x.at[i].max(v)

3.3 デフォルトは float32

NumPy のデフォルトは float64(64 ビットの浮動小数点数)ですが、JAX は、64 ビットモード(jax_enable_x64)を有効にしない限り、浮動小数点数は基本的に float32 がデフォルト です。これは GPU / TPU が float32(あるいは bfloat16 という GPU/TPU 向けの 16 ビット型)で動くことが多いためで、機械学習では十分な精度です。

print(jnp.array([1.0, 2.0]).dtype) # float32

float64 を使いたい場合は、明示的に有効化する必要があります。

import jax
jax.config.update("jax_enable_x64", True)

3.4 配列の生成

NumPy と同じく、よく使う形の配列を一発で作る関数がそろっています。

jnp.zeros((2, 3)) # 0 で埋めた 2×3 の配列
jnp.ones((3,)) # 1 で埋めた長さ 3 のベクトル
jnp.eye(4) # 4×4 の単位行列
jnp.arange(10) # 0, 1, 2, …, 9
jnp.linspace(0, 1, 5) # 0 から 1 まで等間隔に 5 個

乱数だけは作法が少し違います。JAX では「乱数の鍵(PRNG キー)」を明示的に渡します。これは第8章で詳しく解説します。

import jax
key = jax.random.key(0)
r = jax.random.normal(key, (3,))
print(r) # [-0.20584226 0.46256346 1.0978508 ]

3.5 ブロードキャストとインデックス

「形の違う配列を、自動で揃えて計算してくれる」ブロードキャストや、[i:j] のようなスライスも、NumPy と同じように使えます。

x = jnp.arange(6).reshape(2, 3)
y = jnp.array([10, 20, 30])

print(x + y) # 行方向にブロードキャストして加算

# スライス
print(x[:, 1]) # 1 列目だけ取り出す → [1 4]
print(x[0, 1:]) # 0 行目の 1 番目以降 → [1 2]

3.6 配列がどこのデバイスにあるか

JAX の配列は、CPU・GPU・TPU のいずれかのデバイスに置かれます。「どこにあるか」は次のように確認できます。

x = jnp.array([1, 2, 3])
print(x.devices()) # 例: {CudaDevice(id=0)}

jax.device_put(x, device_or_sharding) を使うと、明示的にデバイスや sharding(複数デバイスへの分散方法)を指定することもできます。

3.7 まとめ

  • import jax.numpy as jnp で、NumPy 同様に書ける
  • ただし配列は イミュータブル。書き換えたいときは at[...].set(...) などを使う
  • デフォルトの浮動小数点型は float32
  • 配列は CPU/GPU/TPU いずれかのデバイス上に存在する

➡️ 第4章 JIT コンパイル(jit)で高速化する