第16章 進化計算入門と基本アルゴリズム
進化計算(Evolutionary Computation, EC) は、自然界の進化の仕組み(突然変異・交叉・選択)を模倣して最適化問題を解く手法の総称です。勾配が使えない・関数が非滑らかでも動く・並列化しやすい、という強みがあり、ハイパーパラメータ最適化、強化学習、ニューラルネットの構造探索、最近では 大規模モデルのマージ などで再評価されています。
JAX は進化計算と相性が抜群です。理由は明快で、
vmapで個体群を一気に並列評価 できる(CPU/GPU/TPU すべてで)jitで世代ループを丸ごとコンパイル できて非常に速いjax.random.keyで再現性が高い 実験ができるgradは使わない(むしろ「勾配が要らない最適化」が EC)が、必要なら混ぜられる
📚 前提:本章は 第4章 JIT・第5章 grad・第6章 vmap・第8章 乱数 の基本を理解していると読みやすいです。
gradは本章ではほとんど使いません。
目次
- 16.1 なぜ JAX × 進化計算なのか
- 16.2 ライブラリの選択肢
- 16.3 もっとも単純な例:ランダム探索 / 山登り
- 16.4 (1+1) 進化戦略
- 16.5 単純な遺伝的アルゴリズム(GA)
- 16.6 差分進化(DE)
- 16.7 粒子群最適化(PSO)
- 16.8 OpenAI-ES(自然進化戦略)
- 16.9 CMA-ES(evosax を使う)
- 16.10 Quality-Diversity:MAP-Elites
- 16.11 多目的最適化のヒント(NSGA-II)
- 16.12 デザインの定石とつまずきポイント
16.1 なぜ JAX × 進化計算なのか
進化計算の基本ループは、おおむね次の 3 ステップを世代ごとに繰り返すだけです。
- 生成:個体群(population)からサンプリング/変異/交叉で次世代候補を作る
- 評価:各個体の適応度(fitness)を計算する
- 選択:強い個体を残し、弱い個体を捨てる
このうち 評価 が計算量の支配項になることが多く、しかも 個体ごとに独立 です。ここに JAX の vmap が刺さります。
個体 1 →┐
個体 2 →┤
⋮ ├─→ vmap(fitness_fn)(population)
個体 N →┘ ↑ GPU/TPU で並列実行
さらに jit で世代ループを包めば、「Python の for で世代を回しているように見えて、中身は XLA で最適化された 1 つの計算」になります。
16.2 ライブラリの選択肢
JAX エコシステムには、進化計算用の良質なライブラリがあります。
| ライブラリ | 開発元 | 特徴 |
|---|---|---|
| evosax | RobertTLange ほか | CMA-ES、PEPG、xNES、Open-ES、Sep-CMA-ES、SimpleGA など 30 以上のアルゴリズムを統一 API で提供。ニューロエボリューション標準 |
| EvoJAX | Google(旧 Brain)/ コミュニティ | 進化計算 + 並列環境のハードウェア加速フレームワーク。タスク・ポリシー・アルゴリズムを組み合わせるスタイル |
| QDax | Adaptive & Intelligent Robotics Lab ほか | Quality-Diversity(MAP-Elites など)を JAX で。ロボティクスでよく使われる |
| Brax | 物理シミュレータ。EC の評価環境としてよく組み合わせられる |
このうち本章では、自分で書く部分 を中心に紹介し、最後に evosax で CMA-ES を回す例を載せます。
💡 インストール
pip install evosax# ニューロエボリューションの環境用にpip install gymnax brax
16.3 もっとも単純な例:ランダム探索 / 山登り
進化計算の祖先は「ランダム探索」と「山登り」です。何より雰囲気をつかむのに最適。最適化対象は、定番の Rastrigin 関数(多峰性で最小値が 0)にしておきます。
import jax, jax.numpy as jnp
def rastrigin(x, A=10.0):
n = x.shape[-1]
return A * n + jnp.sum(x**2 - A * jnp.cos(2 * jnp.pi * x), axis=-1)
# ランダム探索:N 個の候補から最良の 1 個を選ぶ
key = jax.random.key(0)
candidates = jax.random.uniform(key, (10_000, 10), minval=-5.12, maxval=5.12)
fitnesses = jax.vmap(rastrigin)(candidates)
best = candidates[jnp.argmin(fitnesses)]
print("best fitness:", float(fitnesses.min()))
ポイント
jax.vmap(rastrigin)(candidates)で 10,000 個を一気に評価- ループも
ifも要らない。これが JAX × EC の気持ちよさ
山登り(greedy hill climbing)
import jax, jax.numpy as jnp
from functools import partial
@partial(jax.jit, static_argnames=("steps",))
def hill_climb(key, x0, steps=1000, sigma=0.1):
def body(carry, _):
x, f, key = carry
key, sub = jax.random.split(key)
x_new = x + sigma * jax.random.normal(sub, x.shape)
f_new = rastrigin(x_new)
better = f_new < f
x = jnp.where(better, x_new, x)
f = jnp.where(better, f_new, f)
return (x, f, key), f
init = (x0, rastrigin(x0), key)
(x, f, _), traj = jax.lax.scan(body, init, None, length=steps)
return x, f, traj
x_final, f_final, traj = hill_climb(jax.random.key(1), jnp.zeros(10))
print("final:", float(f_final))
for ループを書く代わりに jax.lax.scan を使う:これは JAX × EC の鉄則です。
16.4 (1+1) 進化戦略
「親 1 個体 + 子 1 個体を比べて、勝った方を次の親にする」という最小構成の進化戦略。山登りに ステップサイズの自己適応 を加えると、もう ES と呼べます。
import jax, jax.numpy as jnp
from functools import partial
@partial(jax.jit, static_argnames=("steps",))
def one_plus_one_es(key, x0, steps=2000):
"""1/5-success rule で sigma を自己適応する古典的な (1+1)-ES。"""
def body(carry, _):
x, f, sigma, key, success_buf = carry
key, sub = jax.random.split(key)
x_new = x + sigma * jax.random.normal(sub, x.shape)
f_new = rastrigin(x_new)
better = f_new < f
# 1/5 ルール:直近 10 回中の成功率が 1/5 を超えれば sigma を増やす
success_buf = jnp.roll(success_buf, 1).at[0].set(better.astype(jnp.float32))
rate = success_buf.mean()
sigma = jnp.where(rate > 0.2, sigma * 1.1, sigma * 0.9)
x = jnp.where(better, x_new, x)
f = jnp.where(better, f_new, f)
return (x, f, sigma, key, success_buf), f
init = (x0, rastrigin(x0), jnp.float32(0.5), key, jnp.zeros(10))
(x, f, sigma, _, _), traj = jax.lax.scan(body, init, None, length=steps)
return x, f, traj
x, f, traj = one_plus_one_es(jax.random.key(2), jnp.full((10,), 3.0))
print("final fitness:", float(f))
16.5 単純な遺伝的アルゴリズム(GA)
選択 → 交叉 → 突然変異 を 1 世代として、世代を scan で回します。実数値 GA で書いてみます。
import jax, jax.numpy as jnp
def init_population(key, pop_size, dim, low=-5.12, high=5.12):
return jax.random.uniform(key, (pop_size, dim), minval=low, maxval=high)
def tournament_select(key, fitness, pop, k=3):
"""k-トーナメント選択。同サイズの親プールを返す。"""
pop_size = pop.shape[0]
idx = jax.random.randint(key, (pop_size, k), 0, pop_size)
# 各行で最も適応度がよい個体を選ぶ
f_k = fitness[idx] # (pop_size, k)
winners = idx[jnp.arange(pop_size), jnp.argmin(f_k, axis=1)]
return pop[winners]
def crossover(key, parents):
"""隣り合うペアで一様交叉。pop_size は偶数を仮定。"""
pop_size, dim = parents.shape
mask = jax.random.bernoulli(key, 0.5, (pop_size // 2, dim))
p1 = parents[0::2]
p2 = parents[1::2]
c1 = jnp.where(mask, p1, p2)
c2 = jnp.where(mask, p2, p1)
children = jnp.empty_like(parents).at[0::2].set(c1).at[1::2].set(c2)
return children
def mutate(key, pop, sigma=0.1, p=0.1):
"""各遺伝子に確率 p で正規ノイズを加える。"""
k1, k2 = jax.random.split(key)
mask = jax.random.bernoulli(k1, p, pop.shape)
noise = jax.random.normal(k2, pop.shape) * sigma
return pop + mask * noise
@jax.jit
def ga_step(carry, _):
pop, key = carry
fitness = jax.vmap(rastrigin)(pop)
key, ks, kc, km = jax.random.split(key, 4)
parents = tournament_select(ks, fitness, pop)
children = crossover(kc, parents)
pop_next = mutate(km, children)
# エリート保存:いちばん良い親を 1 個体だけ残す
best_idx = jnp.argmin(fitness)
pop_next = pop_next.at[0].set(pop[best_idx])
return (pop_next, key), fitness.min()
key = jax.random.key(42)
pop0 = init_population(key, pop_size=200, dim=10)
(pop_final, _), best_traj = jax.lax.scan(
ga_step, (pop0, jax.random.key(1)), None, length=300
)
print("最小到達値:", float(best_traj.min()))
ポイント
- 個体群評価は
jax.vmap(rastrigin)(pop)の 1 行 - 全世代を
jax.lax.scanで回し、@jax.jitで内部処理を一発コンパイル - エリート保存(最良個体を必ず残す)も
.at[].set()で簡潔に書ける
16.6 差分進化(DE)
DE は「3 個の他個体を使って差分ベクトルを作り、新しい候補にする」というシンプルかつ強力なアルゴリズム。連続最適化のベンチマークでよく強い。
import jax, jax.numpy as jnp
@jax.jit
def de_step(carry, _, F=0.5, CR=0.9):
pop, fitness, key = carry
N, D = pop.shape
# 各個体について、自分以外の 3 個体 a, b, c を選ぶ
key, k_idx, k_cr, k_j = jax.random.split(key, 4)
def pick3(k, self_idx):
# rejection ではなく「自分の位置をシフト」する簡易版
ids = jax.random.permutation(k, N - 1)[:3]
ids = jnp.where(ids >= self_idx, ids + 1, ids)
return ids
keys = jax.random.split(k_idx, N)
abc = jax.vmap(pick3)(keys, jnp.arange(N)) # (N, 3)
a, b, c = pop[abc[:, 0]], pop[abc[:, 1]], pop[abc[:, 2]]
mutant = a + F * (b - c) # 突然変異ベクトル
cross_mask = jax.random.bernoulli(k_cr, CR, pop.shape)
# 各個体で少なくとも 1 次元は必ず引き継ぐ
forced = jax.nn.one_hot(jax.random.randint(k_j, (N,), 0, D), D).astype(bool)
cross_mask = cross_mask | forced
trial = jnp.where(cross_mask, mutant, pop)
trial_fit = jax.vmap(rastrigin)(trial)
better = trial_fit < fitness
pop_new = jnp.where(better[:, None], trial, pop)
fit_new = jnp.where(better, trial_fit, fitness)
return (pop_new, fit_new, key), fit_new.min()
key = jax.random.key(0)
pop0 = jax.random.uniform(key, (60, 10), minval=-5.12, maxval=5.12)
fit0 = jax.vmap(rastrigin)(pop0)
(pop, fit, _), traj = jax.lax.scan(
de_step, (pop0, fit0, jax.random.key(1)), None, length=500
)
print("best:", float(fit.min()))
DE は 個体間の差 が探索ステップを兼ねるので、ステップサイズを自分で適応しなくても良い解を見つけやすいのが特長です。
16.7 粒子群最適化(PSO)
PSO は「各粒子が、自分のベストと群れ全体のベストに引き寄せられて飛ぶ」というイメージのアルゴリズム。状態が「位置・速度・個人ベスト」と少し増えますが、scan でスッキリ書けます。
import jax, jax.numpy as jnp
from functools import partial
@partial(jax.jit, static_argnames=("n_particles", "dim", "steps"))
def pso(key, n_particles=50, dim=10, steps=300, w=0.7, c1=1.5, c2=1.5):
k1, k2 = jax.random.split(key)
x = jax.random.uniform(k1, (n_particles, dim), minval=-5.12, maxval=5.12)
v = jax.random.normal(k2, (n_particles, dim)) * 0.1
f = jax.vmap(rastrigin)(x)
pbest_x = x
pbest_f = f
gbest_x = x[jnp.argmin(f)]
gbest_f = f.min()
def step(carry, _):
x, v, pbest_x, pbest_f, gbest_x, gbest_f, key = carry
key, k_r1, k_r2 = jax.random.split(key, 3)
r1 = jax.random.uniform(k_r1, x.shape)
r2 = jax.random.uniform(k_r2, x.shape)
v = w * v + c1 * r1 * (pbest_x - x) + c2 * r2 * (gbest_x - x)
x = x + v
f = jax.vmap(rastrigin)(x)
better = f < pbest_f
pbest_x = jnp.where(better[:, None], x, pbest_x)
pbest_f = jnp.where(better, f, pbest_f)
idx = jnp.argmin(pbest_f)
new_gbest_f = pbest_f[idx]
update_g = new_gbest_f < gbest_f
gbest_x = jnp.where(update_g, pbest_x[idx], gbest_x)
gbest_f = jnp.where(update_g, new_gbest_f, gbest_f)
return (x, v, pbest_x, pbest_f, gbest_x, gbest_f, key), gbest_f
init = (x, v, pbest_x, pbest_f, gbest_x, gbest_f, jax.random.key(99))
(*_, gbest_f, _), traj = jax.lax.scan(step, init, None, length=steps)
return gbest_f, traj
best_f, traj = pso(jax.random.key(0))
print("best:", float(best_f))
16.8 OpenAI-ES(自然進化戦略)
ニューロエボリューションの王道。平均 θ の周りに正規ノイズで子個体を作り、報酬で重み付けして平均を更新する という、勾配降下に似た形のアルゴリズム。OpenAI が 2017 年に強化学習の代替として有名にしました。
import jax, jax.numpy as jnp
from functools import partial
@partial(jax.jit, static_argnames=("pop",))
def openai_es_step(theta, key, sigma=0.1, lr=0.05, pop=50):
"""1 世代分の OpenAI-ES。fitness は「大きいほど良い」と仮定する版に直しておく。"""
dim = theta.shape[0]
key, sub = jax.random.split(key)
eps = jax.random.normal(sub, (pop, dim)) # 摂動
candidates = theta + sigma * eps
rewards = -jax.vmap(rastrigin)(candidates) # 最小化を最大化に反転
# ランクベース正規化(外れ値に強い)
ranks = jnp.argsort(jnp.argsort(rewards)).astype(jnp.float32)
adv = ranks / (pop - 1) - 0.5 # [-0.5, 0.5]
adv = (adv - adv.mean()) / (adv.std() + 1e-8)
grad_est = (eps.T @ adv) / (pop * sigma)
theta = theta + lr * grad_est
return theta, -jnp.max(rewards) # 表示用に最小化指標で返す
key = jax.random.key(0)
theta = jnp.zeros(10)
for g in range(200):
key, sub = jax.random.split(key)
theta, best = openai_es_step(theta, sub)
if g % 20 == 0:
print(f"gen {g:3d} best={float(best):.4f}")
ポイント
- 「摂動 → 評価 → 重み付き平均で θ を更新」だけ
vmapでpop個体の評価が並列に走る- 勾配が要らないので、ブラックボックス関数(シミュレータの報酬など)でもそのまま使える
- ニューラルネットの重み をそのまま θ にすれば、これだけで RL が回ります(→ 第17章)
16.9 CMA-ES(evosax を使う)
CMA-ES(Covariance Matrix Adaptation ES)は連続最適化の 最強クラス のアルゴリズム。自分で書くと長いので、ここは evosax に任せます。
import jax, jax.numpy as jnp
from evosax.algorithms import CMA_ES # evosax v0.2 系の場合
# 古いバージョンでは `from evosax import CMA_ES` のこともあります
def fitness(x):
return rastrigin(x)
dim = 10
strategy = CMA_ES(population_size=32, solution=jnp.zeros(dim))
params = strategy.default_params
state = strategy.init(jax.random.key(0), params)
@jax.jit
def step(carry, _):
key, state = carry
key, ka, ke = jax.random.split(key, 3)
x, state = strategy.ask(ka, state, params) # 個体群を生成
fit = jax.vmap(fitness)(x)
state = strategy.tell(ke, x, fit, state, params) # 結果を反映
return (key, state), fit.min()
(_, state), traj = jax.lax.scan(step, (jax.random.key(1), state), None, length=200)
print("best so far:", float(traj.min()))
⚠️ API の注意 evosax は活発に開発されており、バージョンによって import パスや
ask/tellの引数が少し変わることがあります。以下のコードは「ask/tell 型 API の考え方」を示すためのものです。必ず手元のevosax.__version__と公式 README / examples を合わせて確認してください。
「同じインターフェースで PSO, Sep_CMA_ES, OpenES, xNES, PGPE, SimpleGA, ... と差し替えられる」のが evosax の魅力です。
16.10 Quality-Diversity:MAP-Elites
MAP-Elites は「適応度だけでなく、振る舞い空間(behavior space)にもグリッドを切って、各セルでいちばん良かった個体をアーカイブに残す」という QD アルゴリズム。多様な解を一気に集められるので、ロボット制御や Procedural Generation で人気です。
考え方の骨格を JAX で書くと、
import jax, jax.numpy as jnp
GRID = 20 # 2D 振る舞い空間を 20×20 に区切る
@jax.jit
def map_elites_step(carry, _, sigma=0.3):
"""archive: (GRID, GRID, dim) ベスト個体、archive_f: (GRID, GRID) その適応度"""
archive, archive_f, key = carry
# アーカイブからランダムに親を選ぶ(埋まっているセルから)
filled = jnp.isfinite(archive_f)
flat = jnp.where(filled.reshape(-1), 1.0, 0.0)
key, ks, km, ke = jax.random.split(key, 4)
idx = jax.random.choice(ks, GRID * GRID, p=flat / (flat.sum() + 1e-8))
i, j = idx // GRID, idx % GRID
parent = archive[i, j]
child = parent + sigma * jax.random.normal(km, parent.shape)
# 評価:適応度 + 振る舞い記述子 b ∈ [0,1]^2
f = rastrigin(child)
b = jax.nn.sigmoid(child[:2]) # 2 次元の振る舞い記述子(例)
bi = jnp.clip((b[0] * GRID).astype(jnp.int32), 0, GRID - 1)
bj = jnp.clip((b[1] * GRID).astype(jnp.int32), 0, GRID - 1)
current = archive_f[bi, bj]
is_better = f < current # 最小化問題
archive = archive.at[bi, bj].set(jnp.where(is_better, child, archive[bi, bj]))
archive_f = archive_f.at[bi, bj].set(jnp.where(is_better, f, current))
return (archive, archive_f, key), f
dim = 10
# 初期化(最初の数個体をランダムに撒く)
key = jax.random.key(0)
init_pop = jax.random.uniform(key, (50, dim), minval=-5.12, maxval=5.12)
archive = jnp.zeros((GRID, GRID, dim))
archive_f = jnp.full((GRID, GRID), jnp.inf)
for x in init_pop:
f = rastrigin(x)
b = jax.nn.sigmoid(x[:2])
bi = int(jnp.clip(b[0] * GRID, 0, GRID - 1))
bj = int(jnp.clip(b[1] * GRID, 0, GRID - 1))
if f < archive_f[bi, bj]:
archive = archive.at[bi, bj].set(x)
archive_f = archive_f.at[bi, bj].set(f)
実用には QDax ライブラリが整っていて、vmap でバッチ化された MAP-Elites や Policy Gradient Assisted MAP-Elites(PGA-ME)などが使えます。
16.11 多目的最適化のヒント(NSGA-II)
複数の目的を同時に最適化する場合は NSGA-II が定番。JAX で書く場合の勘所だけ:
# 適応度を多次元化:fitness shape (pop, n_objectives)
fitnesses = jax.vmap(multi_objective_fn)(pop)
# Pareto 支配(dominance)の判定は祓い込みのループになりがちなので、
# 「個体 i が個体 j を支配する」を (pop, pop) のマスクとして計算する:
def dominates(a, b):
return jnp.all(a <= b) & jnp.any(a < b)
D = jax.vmap(lambda a: jax.vmap(lambda b: dominates(a, b))(fitnesses))(fitnesses)
# D[i, j] = True ⇔ i が j を支配
n_dominated_by = D.sum(axis=0) # 各個体を支配する個体数 → 0 のものが最前線
front0 = (n_dominated_by == 0)
完全な NSGA-II(多段ランキング + crowding distance)の実装は長くなるので、実用では evosax の NSGA-II(提供されている場合)や、PyMoo を JAX 互換に呼ぶ ほうが手早いです。
16.12 デザインの定石とつまずきポイント
進化計算を JAX で書くときの やってよかった / やらかしやすい ポイントをまとめます。
✅ やってよかった
jitの単位は「1 世代」 にする。世代数を Python の for で回しても、内側がjitされていれば十分速い。世代まで全部scanに入れれば理論上はもっと速いが、コンパイル時間とのトレードオフ。- 個体群評価は必ず
vmap。forで 1 個体ずつ評価すると、JAX の良さが消える。 - PRNG キーを世代ごとに
splitする。これを忘れると毎世代同じ突然変異になる。 - エリート保存(前世代のベスト個体を必ず残す)を入れると安定する。
pop.at[0].set(best)で十分。 fitnessは「大きいほど良い」「小さいほど良い」を最初に決める。混ぜると符号バグの温床になる。- ランクベース正規化 を使うと、外れ値の大きい報酬でも安定する(OpenAI-ES の例参照)。
⚠️ つまずきがちなポイント
- 配列の動的サイズ:「適応度が閾値を超えた個体だけ残す」のような 動的サイズ な処理は
jitと相性が悪い。jnp.whereでマスクして「形は固定、無効な個体は無視」する設計にする。 pop_sizeやdimを頻繁に変える:shape が変わるので再コンパイルが走る。pop_size・dim・stepsのような shape / loop 長に関わる引数はstatic_argnamesで static にするか、実験ごとに固定する。popを Python のリストで持ち回す:JAX が活かせない。(pop_size, dim)の配列で持つ。- 乱数キーの使い回し:毎世代
splitを忘れて同じ key を使うと、まったく進化しない世代が量産される。 while_loopでの早期終了:「適応度が ε 以下になったら止める」はjax.lax.while_loopで書けるが、vmapとの合成には注意が必要。とくに pop ごとに止め時が違う場合は、固定回数 + マスクで書くほうが安全。- ハードウェア利用率:CMA-ES のような「行列計算が重め」のアルゴリズムは GPU を活かせるが、
pop_sizeが小さすぎると GPU が遊ぶ。少なくとも 数百〜数千 個体は乗せたい。
まとめ
- JAX の
vmap×jit×scanの組み合わせは、進化計算と 本当に相性がよい - 1 個体用のシンプルな関数を書いて
vmapするだけで、数千〜数万個体の評価が一発で並列化される - 自分で書く場合の典型形は「
scanで世代を回す +vmapで個体群を評価」 - 本格的にやるなら evosax(アルゴリズム集)、EvoJAX(フレームワーク)、QDax(QD)、Brax / Gymnax(環境)の組み合わせ
次章では、これらを使って ニューラルネット自体を進化させる(ニューロエボリューション) と、近年 Sakana AI が示して話題になった 進化的モデルマージ を扱います。