Skip to main content

第10章 実践チュートリアル:線形回帰と簡単なニューラルネット

ここまでの知識を総動員して、実際に動かして学んでみましょう。ライブラリは JAX 本体だけ で進めます(ニューラルネットも自前で書きます)。

10.1 線形回帰

データに対して y = w * x + b の直線をフィットさせる、最もシンプルな機械学習タスクです。

import jax
import jax.numpy as jnp

# --- 1. データを作る(ノイズ込みの直線)---
key = jax.random.key(0)
key, k1, k2 = jax.random.split(key, 3)

true_w, true_b = 2.5, -1.0
x_data = jax.random.uniform(k1, (100,), minval=-5, maxval=5)
noise = jax.random.normal(k2, (100,)) * 0.5
y_data = true_w * x_data + true_b + noise

# --- 2. モデル(純粋関数)---
def predict(params, x):
return params["w"] * x + params["b"]

def loss_fn(params, x, y):
pred = predict(params, x)
return jnp.mean((pred - y) ** 2)

# --- 3. 初期パラメータ ---
params = {"w": 0.0, "b": 0.0}

# --- 4. 1 ステップ分の更新(JIT で高速化)---
@jax.jit
def update(params, x, y, lr=0.05):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = jax.tree_util.tree_map(
lambda p, g: p - lr * g, params, grads
)
return new_params, loss

# --- 5. 学習ループ ---
for step in range(200):
params, loss = update(params, x_data, y_data)
if step % 20 == 0:
print(f"step={step:3d} loss={float(loss):.4f} w={float(params['w']):.3f} b={float(params['b']):.3f}")

print("\n推定:", params)
print("真値: w=2.5, b=-1.0")

実行例:

step= 0 loss=22.6890 w=0.485 b=-0.018
step= 20 loss= 0.2891 w=2.481 b=-0.948
...
step=180 loss= 0.2480 w=2.498 b=-1.003
推定: {'w': 2.498..., 'b': -1.003...}
真値: w=2.5, b=-1.0

きれいに真の値に収束しているはずです 🎉

ここで使った JAX のテクニック

  • jnp での数値計算
  • value_and_grad での損失と勾配の同時計算
  • tree_map での「パラメータと勾配を構造ごと」一括更新
  • @jax.jit での高速化

10.2 簡単なニューラルネット(多層パーセプトロン)

入力 → 隠れ層 → 出力、という 2 層のシンプルな MLP を実装してみます。

import jax
import jax.numpy as jnp

# --- 1. パラメータ初期化(PyTree)---
def init_mlp(key, in_dim, hidden_dim, out_dim):
k1, k2 = jax.random.split(key)
params = {
"W1": jax.random.normal(k1, (in_dim, hidden_dim)) * 0.1,
"b1": jnp.zeros((hidden_dim,)),
"W2": jax.random.normal(k2, (hidden_dim, out_dim)) * 0.1,
"b2": jnp.zeros((out_dim,)),
}
return params

# --- 2. 順伝播(純粋関数)---
def mlp(params, x):
h = jnp.tanh(x @ params["W1"] + params["b1"])
y = h @ params["W2"] + params["b2"]
return y

# --- 3. 損失 ---
def loss_fn(params, x, y):
pred = mlp(params, x)
return jnp.mean((pred - y) ** 2)

# --- 4. データ:sin 関数をフィットさせてみる ---
key = jax.random.key(42)
k1, k2 = jax.random.split(key)
x_train = jnp.linspace(-3.14, 3.14, 200).reshape(-1, 1)
y_train = jnp.sin(x_train)

# --- 5. 初期化 ---
params = init_mlp(k1, in_dim=1, hidden_dim=64, out_dim=1)

# --- 6. 学習ステップ ---
@jax.jit
def step(params, x, y, lr=0.01):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = jax.tree_util.tree_map(
lambda p, g: p - lr * g, params, grads
)
return new_params, loss

# --- 7. 学習ループ ---
for s in range(2000):
params, loss = step(params, x_train, y_train)
if s % 200 == 0:
print(f"step={s:4d} loss={float(loss):.5f}")

# --- 8. 評価 ---
preds = mlp(params, x_train)
print("\nfinal loss:", float(jnp.mean((preds - y_train) ** 2)))

loss が下がっていけば学習成功です。matplotlib でプロットすれば、sin カーブをきれいに学習しているのが見えるはずです。

import matplotlib.pyplot as plt
plt.plot(x_train, y_train, label="true")
plt.plot(x_train, preds, label="pred")
plt.legend(); plt.show()

10.3 ミニバッチ学習(vmap の使いどころ)

データが多い場合は「ミニバッチ」で学習します。ここでは簡単のために単純なバッチ処理に留めますが、

  • 1 サンプル用 lossvmap でバッチ化
  • ミニバッチをランダムサンプリングして毎回違う key で取り出す
  • 学習率スケジューラを使う

といった工夫で、より本格的な学習ループになります。

10.4 まとめ

  • JAX 本体だけでも、ちゃんとした学習ループが書ける
  • パラメータは PyTree(辞書)にまとめる
  • value_and_grad で勾配を取り、tree_map で更新
  • 計算重い部分は @jit で高速化
  • 実用では Flax / Equinox + Optax を使うとさらに楽(次章)

➡️ 第11章 よくある落とし穴とデバッグ