第17章 ニューロエボリューションと進化的モデルマージ
第16章では「数値ベクトル」を進化させましたが、ベクトル = ニューラルネットの重み と考えれば、そのままニューラルネット全体を進化させられます。これが ニューロエボリューション(Neuroevolution)。さらに最近は、ベクトルが 「複数の学習済みモデルの混合比率」 だったりもします(進化的モデルマージ)。
JAX はこれらすべての中心的なツールになっており、Sakana AI を含む多くの研究機関が JAX で関連研究を進めています。
📚 前提:第16章 進化計算入門 と、第7章 PyTree、第10章 実践 NN を読んでいると本章はスムーズです。本章ではより本格的に
vmapとscanを組み合わせます。
目次
- 17.1 ニューロエボリューションとは
- 17.2 重みベクトル ⇄ パラメータ PyTree の変換
- 17.3 例:MLP を OpenAI-ES で進化させる
- 17.4 RL タスクへの応用:CartPole を ES で解く
- 17.5 evosax で同じことをスマートに
- 17.6 進化的モデルマージ(Evolutionary Model Merge)
- 17.7 微分とのハイブリッド:勾配 + 進化
- 17.8 規模を出すための工夫
- 17.9 さらに学ぶには
17.1 ニューロエボリューションとは
ニューロエボリューション は、ニューラルネットのパラメータ(や構造)を 進化計算で最適化 する手法の総称です。代表的な使い道:
- 強化学習:報酬しか得られない(勾配が直接取れない)環境でも学習できる
- ハイパーパラメータ・アーキテクチャ探索:勾配不可能な離散的選択を最適化
- 微分不可能・スパース報酬問題:勾配ベース手法が苦手な領域
- モデルマージ:複数の学習済みモデルを混ぜる比率を進化で探す(後述)
JAX で書くと、
重みベクトル θ ──(unflatten)──→ パラメータ PyTree ──(model forward)──→ 報酬/損失
↑ │
└────────── 進化計算で更新(OpenAI-ES, CMA-ES, GA, …) ────────────┘
という流れになります。鍵は 「PyTree ⇄ 1 本のベクトル」の往復 です。
17.2 重みベクトル ⇄ パラメータ PyTree の変換
進化計算は 1 本のベクトル を扱うのが基本(共分散行列など)ですが、ニューラルネットの重みは PyTree(辞書 / dataclass) が自然です。両者をつなぐのは、jax.flatten_util.ravel_pytree の 1 関数。
import jax, jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from functools import partial
# 例:MLP のパラメータ PyTree
params = {
"W1": jnp.zeros((4, 16)),
"b1": jnp.zeros((16,)),
"W2": jnp.zeros((16, 2)),
"b2": jnp.zeros((2,)),
}
# 1 本のベクトルに「平らに」して、戻すための関数も貰う
flat, unravel = ravel_pytree(params)
print(flat.shape) # 例: (114,)
restored = unravel(flat) # 元の PyTree に戻る
これがあれば、進化アルゴリズム側はベクトル、ニューラルネット側は PyTree という役割分担で書けます。
17.3 例:MLP を OpenAI-ES で進化させる
回帰タスクを 勾配なし で学習させてみます。sin 関数のフィッティング。
import jax, jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from functools import partial
# --- データ ---
key = jax.random.key(0)
x_train = jnp.linspace(-3.14, 3.14, 100)[:, None]
y_train = jnp.sin(x_train)
# --- モデル(純関数)---
def init_params(key, hidden=32):
k1, k2 = jax.random.split(key)
return {
"W1": jax.random.normal(k1, (1, hidden)) * 0.3,
"b1": jnp.zeros(hidden),
"W2": jax.random.normal(k2, (hidden, 1)) * 0.3,
"b2": jnp.zeros(1),
}
def mlp(params, x):
h = jnp.tanh(x @ params["W1"] + params["b1"])
return h @ params["W2"] + params["b2"]
def fitness_params(params):
pred = mlp(params, x_train)
return jnp.mean((pred - y_train) ** 2) # 小さいほど良い
# --- ベクトル化 ---
init = init_params(key)
theta0, unravel = ravel_pytree(init)
dim = theta0.shape[0]
print("総パラメータ数:", dim)
def fitness_vec(theta):
return fitness_params(unravel(theta))
# --- OpenAI-ES(最小化版)---
@partial(jax.jit, static_argnames=("pop",))
def es_step(theta, key, sigma=0.05, lr=0.05, pop=64):
key, sub = jax.random.split(key)
eps = jax.random.normal(sub, (pop, dim))
candidates = theta + sigma * eps
losses = jax.vmap(fitness_vec)(candidates) # ← ここで pop 並列評価
# ランクベースで -1..+1 に正規化(小さいほど高評価 → +へ)
ranks = jnp.argsort(jnp.argsort(-losses)).astype(jnp.float32)
adv = ranks / (pop - 1) - 0.5
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
grad_est = (eps.T @ adv) / (pop * sigma)
return theta + lr * grad_est, losses.min()
theta = theta0
key = jax.random.key(1)
for g in range(500):
key, sub = jax.random.split(key)
theta, best = es_step(theta, sub)
if g % 50 == 0:
print(f"gen {g:3d} best loss={float(best):.4f}")
best_params = unravel(theta)
ポイント
- 1 個体(= 1 つのニューラルネット)用の関数
fitness_paramsを書いただけ jax.vmap(fitness_vec)(candidates)で、pop個のニューラルネットを 1 行で並列評価- これは PyTorch で同じことをやるとずっと骨が折れます
17.4 RL タスクへの応用:CartPole を ES で解く
ニューロエボリューションが本領を発揮するのは強化学習です。Gymnax を使えば、環境自体が JAX で書かれており、vmap で 大量並列ロールアウト ができます。
import jax, jax.numpy as jnp
import gymnax
from jax.flatten_util import ravel_pytree
from functools import partial
env, env_params = gymnax.make("CartPole-v1")
# --- 方策ネットワーク ---
def init_policy(key, obs_dim=4, hidden=16, n_actions=2):
k1, k2 = jax.random.split(key)
return {
"W1": jax.random.normal(k1, (obs_dim, hidden)) * 0.5,
"b1": jnp.zeros(hidden),
"W2": jax.random.normal(k2, (hidden, n_actions)) * 0.5,
"b2": jnp.zeros(n_actions),
}
def policy(params, obs):
h = jnp.tanh(obs @ params["W1"] + params["b1"])
logits = h @ params["W2"] + params["b2"]
return jnp.argmax(logits) # 決定論的方策
# --- 1 エピソードのロールアウト ---
def rollout(params, key, max_steps=500):
obs, state = env.reset(key, env_params)
def step(carry, _):
obs, state, total, done, key = carry
action = policy(params, obs)
key, sub = jax.random.split(key)
next_obs, next_state, reward, new_done, _ = env.step(sub, state, action, env_params)
new_done = new_done.astype(jnp.float32)
# done 以降は報酬を加算しない
total = total + reward * (1.0 - done)
done = jnp.maximum(done, new_done)
return (next_obs, next_state, total, done, key), None
init = (obs, state, jnp.float32(0.0), jnp.float32(0.0), key)
(_, _, total, _, _), _ = jax.lax.scan(step, init, None, length=max_steps)
return total
# --- 適応度(複数シードの平均で安定化)---
def fitness_params(params, key, n_eval=4):
keys = jax.random.split(key, n_eval)
rewards = jax.vmap(rollout, in_axes=(None, 0))(params, keys)
return -jnp.mean(rewards) # 「小さいほど良い」に統一
# --- ベクトル化 ---
init = init_policy(jax.random.key(0))
theta0, unravel = ravel_pytree(init)
dim = theta0.shape[0]
def fitness_vec(theta, key):
return fitness_params(unravel(theta), key)
# --- ES ステップ(rollout 用キーも分配する版)---
@partial(jax.jit, static_argnames=("pop",))
def es_step(theta, key, sigma=0.1, lr=0.05, pop=64):
key, k_noise, k_eval = jax.random.split(key, 3)
eps = jax.random.normal(k_noise, (pop, dim))
candidates = theta + sigma * eps
keys = jax.random.split(k_eval, pop)
losses = jax.vmap(fitness_vec)(candidates, keys)
ranks = jnp.argsort(jnp.argsort(-losses)).astype(jnp.float32)
adv = ranks / (pop - 1) - 0.5
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
grad_est = (eps.T @ adv) / (pop * sigma)
return theta + lr * grad_est, -losses.min() # 表示は報酬で
theta = theta0
key = jax.random.key(1)
for g in range(100):
key, sub = jax.random.split(key)
theta, best_reward = es_step(theta, sub)
if g % 10 == 0:
print(f"gen {g:3d} best reward={float(best_reward):.1f}")
ここで起きていること
pop=64個体 ×n_eval=4シード = 256 並列のロールアウト が GPU 上で同時実行- 全体は
@jax.jitで 1 つの XLA 計算にコンパイルされる - それを Python の
forで 100 世代回すだけ
PyTorch で同じものを書くと、multiprocessing でワーカーを立てるか、torch.vmap を駆使するか、いずれにせよ JAX ほど素直には書けません。
17.5 evosax で同じことをスマートに
第16章でも触れた evosax を使えば、上のコードの大半が 数行に短縮 できます。同じ CartPole 学習を CMA-ES に差し替えるのも 1 行:
import jax, jax.numpy as jnp
from evosax.algorithms import OpenES # API はバージョンに合わせて要確認(公式 examples 優先)
strategy = OpenES(population_size=64, solution=theta0)
params = strategy.default_params
state = strategy.init(jax.random.key(0), params)
@jax.jit
def step(carry, _):
key, state = carry
key, k_ask, k_eval, k_tell = jax.random.split(key, 4)
x, state = strategy.ask(k_ask, state, params)
eval_keys = jax.random.split(k_eval, x.shape[0])
losses = jax.vmap(fitness_vec)(x, eval_keys)
state = strategy.tell(k_tell, x, losses, state, params)
return (key, state), losses.min()
(_, state), traj = jax.lax.scan(step, (jax.random.key(1), state), None, length=100)
print("best:", -float(traj.min()))
OpenES を CMA_ES や SimpleGA に置き換えれば、まったく同じコードで別アルゴリズムが試せます。「数十種類のアルゴリズムを統一インターフェースで」 が evosax の真骨頂です。
17.6 進化的モデルマージ(Evolutionary Model Merge)
進化的モデルマージ は、複数の学習済みモデル(同じアーキテクチャ)を「重みレベルで線形補間して混ぜる」ときの 混合比率を進化計算で探す という発想です。Sakana AI が 2024 年に発表して話題になりました(Evolutionary Optimization of Model Merging Recipes)。
なぜ進化計算なのか
- マージ比率の探索は 離散選択(どの層をどう混ぜるか)+ 連続パラメータ(比率) を含む
- 評価関数は「ベンチマークでの正答率」など 微分不可能
- 1 つのマージ案ごとに「マージしてベンチマーク評価」する必要があり、勾配が取れない
これはまさに進化計算の得意分野です。ただし、評価 1 回あたりのコストが非常に大きくなりやすいので、実用では「小さな検証セットで粗く探索 → 有望な候補だけ本評価」という二段階評価が重要になります。
骨格コード:パラメータごとの線形マージ
import jax, jax.numpy as jnp
# 事前に学習しておいた N 個のモデルのパラメータ(同じ PyTree 構造)
# expert_params: list of pytrees, length N
N = len(expert_params)
def merge(weights, expert_params):
"""weights: shape (N,)。各葉に対して重み付き和を取る。"""
weights = jax.nn.softmax(weights)
# stacked_leaves: 各 leaf について shape (N, *leaf.shape)
stacked = jax.tree_util.tree_map(
lambda *xs: jnp.stack(xs, axis=0), *expert_params
)
return jax.tree_util.tree_map(
lambda x: jnp.tensordot(weights, x, axes=1), stacked
)
def fitness(weights):
merged = merge(weights, expert_params)
return -evaluate_on_benchmark(merged) # 高い精度ほど良い → 小さい損失
あとは weights ∈ R^N を CMA-ES や OpenAI-ES で最適化するだけ。先ほどの es_step がそのまま使えます。
層ごとに別の比率にする(DFS 風)
Sakana AI の論文では「パラメータ空間でのマージ(PS)」と「データフロー空間でのマージ(DFS:層の並べ替え)」を同時に最適化しています。最初の PS の部分だけなら、weights を「層ごと(あるいは葉ごと)に独立」させるのが直接的な拡張です。
def merge_per_leaf(weights_pytree, expert_params):
"""weights_pytree: expert_params[0] と同じ構造。各葉が shape (N,)"""
def merge_leaf(w, *leaves):
w = jax.nn.softmax(w)
return sum(wi * li for wi, li in zip(w, leaves))
return jax.tree_util.tree_map(merge_leaf, weights_pytree, *expert_params)
💡 実用の注意
- 進化に何百〜何千回の評価が必要なので、評価関数(ベンチマーク)を高速化 することが死活問題。小さな代理タスクで pre-screen、有望なものだけフルベンチで再評価、など。
- LLM のような巨大モデルでは、フル評価 1 回が数十秒〜数分 になります。
population_sizeとn_generationsの積をコントロールするのが大事。- 大規模実験では Sakana AI が公開している Evolutionary Model Merge リポジトリが参考になります。
17.7 微分とのハイブリッド:勾配 + 進化
JAX のうれしいところは、勾配ベースの最適化と進化計算を同じコードに混ぜられる こと。たとえば、
- PGA-MAP-Elites:内側のポリシー更新は勾配、外側の多様性探索は進化
- ES-MAML:MAML 風メタ学習だが、内側を ES で
- 進化 + ファインチューニング:進化で初期重みを探し、見つかったらそこから勾配ファインチューン
JAX なら次のように混ぜられます。
@jax.jit
def hybrid_step(theta, key):
# 1. 進化で大域探索
theta, _ = es_step(theta, key)
# 2. 勾配で局所探索(数ステップ)
for _ in range(5):
g = jax.grad(loss_fn)(unravel(theta))
g_flat, _ = ravel_pytree(g)
theta = theta - 1e-3 * g_flat
return theta
「jit の中で grad と進化を両方使う」のは、JAX の真骨頂と言えるパターンです。
17.8 規模を出すための工夫
ニューロエボリューションは規模が出てこその手法です。スケールアップの定石をいくつか。
vmapで個体×評価シードの 2 次元バッチを作る:本章のvmap(rollout, in_axes=(None, 0))を、さらに個体方向にvmapで重ねる- 複数 GPU に sharding:
popをjax.shardingでデバイス分割すれば、評価がそのまま分散 - 環境を
vmap/scanで完全 JAX 化:Python の Gym では並列効率が出ない。Gymnax / Brax / EnvPool(JAX バックエンド)を使う bfloat16の活用:方策評価は精度が低くても OK。bfloat16で計算量とメモリを半減scanのチャンク化:エピソードが長いときはscanのlengthを区切って評価をパイプライン化- 早期終了の代替:個体ごとに
doneで打ち切るのは難しいので、最大長まで回して報酬に(1 - done_cum)を掛ける のが定石(17.4 の rollout 参照)
17.9 さらに学ぶには
- evosax — 進化アルゴリズム集
- EvoJAX — タスク + アルゴリズム + ポリシーのフレームワーク
- QDax — Quality-Diversity(MAP-Elites, PGA-ME 等)
- Gymnax — JAX 完結の RL 環境
- Brax — JAX で書かれた剛体物理シミュレータ
- Sakana AI: Evolutionary Model Merge:論文 / 実装リポジトリ
まとめ
- ニューラルネットの重みも、
ravel_pytreeで 1 本のベクトルに丸めれば、そのまま進化計算の対象になる vmapでモデル並列、scanで世代ループ、jitで全体高速化 という三段構えが鉄則- 強化学習(CartPole レベルなら数十秒で解ける)、RL 風の問題、進化的モデルマージなど、応用範囲は非常に広い
- 自分で書いてから、必要に応じて evosax / EvoJAX / QDax などのフレームワークに乗り換えるのがおすすめ
- 進化と勾配の ハイブリッド は JAX の特に面白い使いどころ