第10章 実践チュートリアル:線形回帰と簡単なニューラルネット
ここまでの知識を総動員して、実際に動かして学んでみましょう。ライブラリは JAX 本体だけ で進めます(ニューラルネットも自前で書きます)。
10.1 線形回帰
データに対して y = w * x + b の直線をフィットさせる、最もシンプルな機械学習タスクです。
import jax
import jax.numpy as jnp
# --- 1. データを作る(ノイズ込みの直線)---
key = jax.random.key(0)
key, k1, k2 = jax.random.split(key, 3)
true_w, true_b = 2.5, -1.0
x_data = jax.random.uniform(k1, (100,), minval=-5, maxval=5)
noise = jax.random.normal(k2, (100,)) * 0.5
y_data = true_w * x_data + true_b + noise
# --- 2. モデル(純粋関数)---
def predict(params, x):
return params["w"] * x + params["b"]
def loss_fn(params, x, y):
pred = predict(params, x)
return jnp.mean((pred - y) ** 2)
# --- 3. 初期パラメータ ---
params = {"w": 0.0, "b": 0.0}
# --- 4. 1 ステップ分の更新(JIT で高速化)---
@jax.jit
def update(params, x, y, lr=0.05):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = jax.tree_util.tree_map(
lambda p, g: p - lr * g, params, grads
)
return new_params, loss
# --- 5. 学習ループ ---
for step in range(200):
params, loss = update(params, x_data, y_data)
if step % 20 == 0:
print(f"step={step:3d} loss={float(loss):.4f} w={float(params['w']):.3f} b={float(params['b']):.3f}")
print("\n推定:", params)
print("真値: w=2.5, b=-1.0")
実行例:
step= 0 loss=22.6890 w=0.485 b=-0.018
step= 20 loss= 0.2891 w=2.481 b=-0.948
...
step=180 loss= 0.2480 w=2.498 b=-1.003
推定: {'w': 2.498..., 'b': -1.003...}
真値: w=2.5, b=-1.0
きれいに真の値に収束しているはずです 🎉
ここで使った JAX のテクニック
jnpでの数値計算value_and_gradでの損失と勾配の同時計算tree_mapでの「パラメータと勾配を構造ごと」一括更新@jax.jitでの高速化
10.2 簡単なニューラルネット(多層パーセプトロン)
入力 → 隠れ層 → 出力、という 2 層のシンプルな MLP を実装してみます。
import jax
import jax.numpy as jnp
# --- 1. パラメータ初期化(PyTree)---
def init_mlp(key, in_dim, hidden_dim, out_dim):
k1, k2 = jax.random.split(key)
params = {
"W1": jax.random.normal(k1, (in_dim, hidden_dim)) * 0.1,
"b1": jnp.zeros((hidden_dim,)),
"W2": jax.random.normal(k2, (hidden_dim, out_dim)) * 0.1,
"b2": jnp.zeros((out_dim,)),
}
return params
# --- 2. 順伝播(純粋関数)---
def mlp(params, x):
h = jnp.tanh(x @ params["W1"] + params["b1"])
y = h @ params["W2"] + params["b2"]
return y
# --- 3. 損失 ---
def loss_fn(params, x, y):
pred = mlp(params, x)
return jnp.mean((pred - y) ** 2)
# --- 4. データ:sin 関数をフィットさせてみる ---
key = jax.random.key(42)
k1, k2 = jax.random.split(key)
x_train = jnp.linspace(-3.14, 3.14, 200).reshape(-1, 1)
y_train = jnp.sin(x_train)
# --- 5. 初期化 ---
params = init_mlp(k1, in_dim=1, hidden_dim=64, out_dim=1)
# --- 6. 学習ステップ ---
@jax.jit
def step(params, x, y, lr=0.01):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = jax.tree_util.tree_map(
lambda p, g: p - lr * g, params, grads
)
return new_params, loss
# --- 7. 学習ループ ---
for s in range(2000):
params, loss = step(params, x_train, y_train)
if s % 200 == 0:
print(f"step={s:4d} loss={float(loss):.5f}")
# --- 8. 評価 ---
preds = mlp(params, x_train)
print("\nfinal loss:", float(jnp.mean((preds - y_train) ** 2)))
loss が下がっていけば学習成功です。matplotlib でプロットすれば、sin カーブをきれいに学習しているのが見えるはずです。
import matplotlib.pyplot as plt
plt.plot(x_train, y_train, label="true")
plt.plot(x_train, preds, label="pred")
plt.legend(); plt.show()
10.3 ミニバッチ学習(vmap の使いどころ)
データが多い場合は「ミニバッチ」で学習します。ここでは簡単のために単純なバッチ処理に留めますが、
- 1 サンプル用 loss を
vmapでバッチ化 - ミニバッチをランダムサンプリングして毎回違う
keyで取り出す - 学習率スケジューラを使う
といった工夫で、より本格的な学習ループになります。
10.4 まとめ
- JAX 本体だけでも、ちゃんとした学習ループが書ける
- パラメータは PyTree(辞書)にまとめる
value_and_gradで勾配を取り、tree_mapで更新- 計算重い部分は
@jitで高速化 - 実用では Flax / Equinox + Optax を使うとさらに楽(次章)