Skip to main content

第17章 ニューロエボリューションと進化的モデルマージ

第16章では「数値ベクトル」を進化させましたが、ベクトル = ニューラルネットの重み と考えれば、そのままニューラルネット全体を進化させられます。これが ニューロエボリューション(Neuroevolution)。さらに最近は、ベクトルが 「複数の学習済みモデルの混合比率」 だったりもします(進化的モデルマージ)。

JAX はこれらすべての中心的なツールになっており、Sakana AI を含む多くの研究機関が JAX で関連研究を進めています。

📚 前提第16章 進化計算入門 と、第7章 PyTree第10章 実践 NN を読んでいると本章はスムーズです。本章ではより本格的に vmapscan を組み合わせます。

目次


17.1 ニューロエボリューションとは

ニューロエボリューション は、ニューラルネットのパラメータ(や構造)を 進化計算で最適化 する手法の総称です。代表的な使い道:

  • 強化学習:報酬しか得られない(勾配が直接取れない)環境でも学習できる
  • ハイパーパラメータ・アーキテクチャ探索:勾配不可能な離散的選択を最適化
  • 微分不可能・スパース報酬問題:勾配ベース手法が苦手な領域
  • モデルマージ:複数の学習済みモデルを混ぜる比率を進化で探す(後述)

JAX で書くと、

重みベクトル θ ──(unflatten)──→ パラメータ PyTree ──(model forward)──→ 報酬/損失
↑ │
└────────── 進化計算で更新(OpenAI-ES, CMA-ES, GA, …) ────────────┘

という流れになります。鍵は 「PyTree ⇄ 1 本のベクトル」の往復 です。


17.2 重みベクトル ⇄ パラメータ PyTree の変換

進化計算は 1 本のベクトル を扱うのが基本(共分散行列など)ですが、ニューラルネットの重みは PyTree(辞書 / dataclass) が自然です。両者をつなぐのは、jax.flatten_util.ravel_pytree1 関数

import jax, jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from functools import partial

# 例:MLP のパラメータ PyTree
params = {
"W1": jnp.zeros((4, 16)),
"b1": jnp.zeros((16,)),
"W2": jnp.zeros((16, 2)),
"b2": jnp.zeros((2,)),
}

# 1 本のベクトルに「平らに」して、戻すための関数も貰う
flat, unravel = ravel_pytree(params)
print(flat.shape) # 例: (114,)
restored = unravel(flat) # 元の PyTree に戻る

これがあれば、進化アルゴリズム側はベクトル、ニューラルネット側は PyTree という役割分担で書けます。


17.3 例:MLP を OpenAI-ES で進化させる

回帰タスクを 勾配なし で学習させてみます。sin 関数のフィッティング。

import jax, jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from functools import partial

# --- データ ---
key = jax.random.key(0)
x_train = jnp.linspace(-3.14, 3.14, 100)[:, None]
y_train = jnp.sin(x_train)

# --- モデル(純関数)---
def init_params(key, hidden=32):
k1, k2 = jax.random.split(key)
return {
"W1": jax.random.normal(k1, (1, hidden)) * 0.3,
"b1": jnp.zeros(hidden),
"W2": jax.random.normal(k2, (hidden, 1)) * 0.3,
"b2": jnp.zeros(1),
}

def mlp(params, x):
h = jnp.tanh(x @ params["W1"] + params["b1"])
return h @ params["W2"] + params["b2"]

def fitness_params(params):
pred = mlp(params, x_train)
return jnp.mean((pred - y_train) ** 2) # 小さいほど良い

# --- ベクトル化 ---
init = init_params(key)
theta0, unravel = ravel_pytree(init)
dim = theta0.shape[0]
print("総パラメータ数:", dim)

def fitness_vec(theta):
return fitness_params(unravel(theta))

# --- OpenAI-ES(最小化版)---
@partial(jax.jit, static_argnames=("pop",))
def es_step(theta, key, sigma=0.05, lr=0.05, pop=64):
key, sub = jax.random.split(key)
eps = jax.random.normal(sub, (pop, dim))
candidates = theta + sigma * eps
losses = jax.vmap(fitness_vec)(candidates) # ← ここで pop 並列評価
# ランクベースで -1..+1 に正規化(小さいほど高評価 → +へ)
ranks = jnp.argsort(jnp.argsort(-losses)).astype(jnp.float32)
adv = ranks / (pop - 1) - 0.5
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
grad_est = (eps.T @ adv) / (pop * sigma)
return theta + lr * grad_est, losses.min()

theta = theta0
key = jax.random.key(1)
for g in range(500):
key, sub = jax.random.split(key)
theta, best = es_step(theta, sub)
if g % 50 == 0:
print(f"gen {g:3d} best loss={float(best):.4f}")

best_params = unravel(theta)

ポイント

  • 1 個体(= 1 つのニューラルネット)用の関数 fitness_params を書いただけ
  • jax.vmap(fitness_vec)(candidates) で、pop 個のニューラルネットを 1 行で並列評価
  • これは PyTorch で同じことをやるとずっと骨が折れます

17.4 RL タスクへの応用:CartPole を ES で解く

ニューロエボリューションが本領を発揮するのは強化学習です。Gymnax を使えば、環境自体が JAX で書かれており、vmap大量並列ロールアウト ができます。

import jax, jax.numpy as jnp
import gymnax
from jax.flatten_util import ravel_pytree
from functools import partial

env, env_params = gymnax.make("CartPole-v1")

# --- 方策ネットワーク ---
def init_policy(key, obs_dim=4, hidden=16, n_actions=2):
k1, k2 = jax.random.split(key)
return {
"W1": jax.random.normal(k1, (obs_dim, hidden)) * 0.5,
"b1": jnp.zeros(hidden),
"W2": jax.random.normal(k2, (hidden, n_actions)) * 0.5,
"b2": jnp.zeros(n_actions),
}

def policy(params, obs):
h = jnp.tanh(obs @ params["W1"] + params["b1"])
logits = h @ params["W2"] + params["b2"]
return jnp.argmax(logits) # 決定論的方策

# --- 1 エピソードのロールアウト ---
def rollout(params, key, max_steps=500):
obs, state = env.reset(key, env_params)
def step(carry, _):
obs, state, total, done, key = carry
action = policy(params, obs)
key, sub = jax.random.split(key)
next_obs, next_state, reward, new_done, _ = env.step(sub, state, action, env_params)
new_done = new_done.astype(jnp.float32)
# done 以降は報酬を加算しない
total = total + reward * (1.0 - done)
done = jnp.maximum(done, new_done)
return (next_obs, next_state, total, done, key), None
init = (obs, state, jnp.float32(0.0), jnp.float32(0.0), key)
(_, _, total, _, _), _ = jax.lax.scan(step, init, None, length=max_steps)
return total

# --- 適応度(複数シードの平均で安定化)---
def fitness_params(params, key, n_eval=4):
keys = jax.random.split(key, n_eval)
rewards = jax.vmap(rollout, in_axes=(None, 0))(params, keys)
return -jnp.mean(rewards) # 「小さいほど良い」に統一

# --- ベクトル化 ---
init = init_policy(jax.random.key(0))
theta0, unravel = ravel_pytree(init)
dim = theta0.shape[0]

def fitness_vec(theta, key):
return fitness_params(unravel(theta), key)

# --- ES ステップ(rollout 用キーも分配する版)---
@partial(jax.jit, static_argnames=("pop",))
def es_step(theta, key, sigma=0.1, lr=0.05, pop=64):
key, k_noise, k_eval = jax.random.split(key, 3)
eps = jax.random.normal(k_noise, (pop, dim))
candidates = theta + sigma * eps
keys = jax.random.split(k_eval, pop)
losses = jax.vmap(fitness_vec)(candidates, keys)
ranks = jnp.argsort(jnp.argsort(-losses)).astype(jnp.float32)
adv = ranks / (pop - 1) - 0.5
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
grad_est = (eps.T @ adv) / (pop * sigma)
return theta + lr * grad_est, -losses.min() # 表示は報酬で

theta = theta0
key = jax.random.key(1)
for g in range(100):
key, sub = jax.random.split(key)
theta, best_reward = es_step(theta, sub)
if g % 10 == 0:
print(f"gen {g:3d} best reward={float(best_reward):.1f}")

ここで起きていること

  • pop=64 個体 × n_eval=4 シード = 256 並列のロールアウト が GPU 上で同時実行
  • 全体は @jax.jit で 1 つの XLA 計算にコンパイルされる
  • それを Python の for で 100 世代回すだけ

PyTorch で同じものを書くと、multiprocessing でワーカーを立てるか、torch.vmap を駆使するか、いずれにせよ JAX ほど素直には書けません。


17.5 evosax で同じことをスマートに

第16章でも触れた evosax を使えば、上のコードの大半が 数行に短縮 できます。同じ CartPole 学習を CMA-ES に差し替えるのも 1 行:

import jax, jax.numpy as jnp
from evosax.algorithms import OpenES # API はバージョンに合わせて要確認(公式 examples 優先)

strategy = OpenES(population_size=64, solution=theta0)
params = strategy.default_params
state = strategy.init(jax.random.key(0), params)

@jax.jit
def step(carry, _):
key, state = carry
key, k_ask, k_eval, k_tell = jax.random.split(key, 4)
x, state = strategy.ask(k_ask, state, params)
eval_keys = jax.random.split(k_eval, x.shape[0])
losses = jax.vmap(fitness_vec)(x, eval_keys)
state = strategy.tell(k_tell, x, losses, state, params)
return (key, state), losses.min()

(_, state), traj = jax.lax.scan(step, (jax.random.key(1), state), None, length=100)
print("best:", -float(traj.min()))

OpenESCMA_ESSimpleGA に置き換えれば、まったく同じコードで別アルゴリズムが試せます。「数十種類のアルゴリズムを統一インターフェースで」 が evosax の真骨頂です。


17.6 進化的モデルマージ(Evolutionary Model Merge)

進化的モデルマージ は、複数の学習済みモデル(同じアーキテクチャ)を「重みレベルで線形補間して混ぜる」ときの 混合比率を進化計算で探す という発想です。Sakana AI が 2024 年に発表して話題になりました(Evolutionary Optimization of Model Merging Recipes)。

なぜ進化計算なのか

  • マージ比率の探索は 離散選択(どの層をどう混ぜるか)+ 連続パラメータ(比率) を含む
  • 評価関数は「ベンチマークでの正答率」など 微分不可能
  • 1 つのマージ案ごとに「マージしてベンチマーク評価」する必要があり、勾配が取れない

これはまさに進化計算の得意分野です。ただし、評価 1 回あたりのコストが非常に大きくなりやすいので、実用では「小さな検証セットで粗く探索 → 有望な候補だけ本評価」という二段階評価が重要になります。

骨格コード:パラメータごとの線形マージ

import jax, jax.numpy as jnp

# 事前に学習しておいた N 個のモデルのパラメータ(同じ PyTree 構造)
# expert_params: list of pytrees, length N
N = len(expert_params)

def merge(weights, expert_params):
"""weights: shape (N,)。各葉に対して重み付き和を取る。"""
weights = jax.nn.softmax(weights)
# stacked_leaves: 各 leaf について shape (N, *leaf.shape)
stacked = jax.tree_util.tree_map(
lambda *xs: jnp.stack(xs, axis=0), *expert_params
)
return jax.tree_util.tree_map(
lambda x: jnp.tensordot(weights, x, axes=1), stacked
)

def fitness(weights):
merged = merge(weights, expert_params)
return -evaluate_on_benchmark(merged) # 高い精度ほど良い → 小さい損失

あとは weights ∈ R^NCMA-ES や OpenAI-ES で最適化するだけ。先ほどの es_step がそのまま使えます。

層ごとに別の比率にする(DFS 風)

Sakana AI の論文では「パラメータ空間でのマージ(PS)」と「データフロー空間でのマージ(DFS:層の並べ替え)」を同時に最適化しています。最初の PS の部分だけなら、weights を「層ごと(あるいは葉ごと)に独立」させるのが直接的な拡張です。

def merge_per_leaf(weights_pytree, expert_params):
"""weights_pytree: expert_params[0] と同じ構造。各葉が shape (N,)"""
def merge_leaf(w, *leaves):
w = jax.nn.softmax(w)
return sum(wi * li for wi, li in zip(w, leaves))
return jax.tree_util.tree_map(merge_leaf, weights_pytree, *expert_params)

💡 実用の注意

  • 進化に何百〜何千回の評価が必要なので、評価関数(ベンチマーク)を高速化 することが死活問題。小さな代理タスクで pre-screen、有望なものだけフルベンチで再評価、など。
  • LLM のような巨大モデルでは、フル評価 1 回が数十秒〜数分 になります。population_sizen_generations の積をコントロールするのが大事。
  • 大規模実験では Sakana AI が公開している Evolutionary Model Merge リポジトリが参考になります。

17.7 微分とのハイブリッド:勾配 + 進化

JAX のうれしいところは、勾配ベースの最適化と進化計算を同じコードに混ぜられる こと。たとえば、

  • PGA-MAP-Elites:内側のポリシー更新は勾配、外側の多様性探索は進化
  • ES-MAML:MAML 風メタ学習だが、内側を ES で
  • 進化 + ファインチューニング:進化で初期重みを探し、見つかったらそこから勾配ファインチューン

JAX なら次のように混ぜられます。

@jax.jit
def hybrid_step(theta, key):
# 1. 進化で大域探索
theta, _ = es_step(theta, key)
# 2. 勾配で局所探索(数ステップ)
for _ in range(5):
g = jax.grad(loss_fn)(unravel(theta))
g_flat, _ = ravel_pytree(g)
theta = theta - 1e-3 * g_flat
return theta

jit の中で grad と進化を両方使う」のは、JAX の真骨頂と言えるパターンです。


17.8 規模を出すための工夫

ニューロエボリューションは規模が出てこその手法です。スケールアップの定石をいくつか。

  1. vmap で個体×評価シードの 2 次元バッチを作る:本章の vmap(rollout, in_axes=(None, 0)) を、さらに個体方向に vmap で重ねる
  2. 複数 GPU に shardingpopjax.sharding でデバイス分割すれば、評価がそのまま分散
  3. 環境を vmap/scan で完全 JAX 化:Python の Gym では並列効率が出ない。Gymnax / Brax / EnvPool(JAX バックエンド)を使う
  4. bfloat16 の活用:方策評価は精度が低くても OK。bfloat16 で計算量とメモリを半減
  5. scan のチャンク化:エピソードが長いときは scanlength を区切って評価をパイプライン化
  6. 早期終了の代替:個体ごとに done で打ち切るのは難しいので、最大長まで回して報酬に (1 - done_cum) を掛ける のが定石(17.4 の rollout 参照)

17.9 さらに学ぶには

  • evosax — 進化アルゴリズム集
  • EvoJAX — タスク + アルゴリズム + ポリシーのフレームワーク
  • QDax — Quality-Diversity(MAP-Elites, PGA-ME 等)
  • Gymnax — JAX 完結の RL 環境
  • Brax — JAX で書かれた剛体物理シミュレータ
  • Sakana AI: Evolutionary Model Merge論文 / 実装リポジトリ

まとめ

  • ニューラルネットの重みも、ravel_pytree で 1 本のベクトルに丸めれば、そのまま進化計算の対象になる
  • vmap でモデル並列、scan で世代ループ、jit で全体高速化 という三段構えが鉄則
  • 強化学習(CartPole レベルなら数十秒で解ける)、RL 風の問題、進化的モデルマージなど、応用範囲は非常に広い
  • 自分で書いてから、必要に応じて evosax / EvoJAX / QDax などのフレームワークに乗り換えるのがおすすめ
  • 進化と勾配の ハイブリッド は JAX の特に面白い使いどころ

➡️ JAX 日本語チュートリアルに戻る