第7章 PyTree という考え方
JAX を使いこなすうえで避けて通れないのが PyTree(パイツリー) という概念です。最初は耳慣れない言葉かもしれませんが、実体はとてもシンプルです。
7.1 PyTree とは
PyTree とは「配列を葉っぱとして持つ、入れ子になったコンテナ」のことです。コンテナとは、Python の辞書(dict)・タプル(tuple)・リスト(list)、あるいはユーザー定義クラスなど、「中にいろいろ入れられるもの」を指します。
たとえば、次のような構造を考えてみましょう。
params = {
"layer1": {"W": W1, "b": b1},
"layer2": {"W": W2, "b": b2},
}
これは「2 階層の辞書の中に jnp.array が並んでいる」構造で、立派な PyTree です。木のように枝分かれしていて、いちばん端っこ(葉っぱ)に配列がぶら下がっている イメージです。
JAX の主要な関数(grad、jit、vmap、value_and_grad、…)は、引数として PyTree をそのまま受け取れる ように設計されています。これが本当に便利なのです。
7.2 ニューラルネットのパラメータは PyTree
この性質のおかげで、モデルのパラメータをまとめて 1 つの PyTree にしておけば、
loss_value, grads = jax.value_and_grad(loss_fn)(params, x, y)
と書くだけで「params と同じ構造の 勾配 PyTree」が手に入ります。params が辞書なら grads も辞書、リストならリスト、というように 構造ごと対応する のです。
# 勾配降下も簡単
new_params = jax.tree_util.tree_map(
lambda p, g: p - 0.01 * g,
params, grads
)
tree_map(f, x, y) は「x と y という同じ構造の PyTree を、葉っぱ単位で f に渡して、新しい PyTree を作る」関数です。「同じ形の入れ物ふたつを葉っぱ同士組み合わせて、新しい入れ物を作る」と覚えるとイメージしやすいでしょう。
7.3 PyTree でよく使う関数
import jax
# 葉っぱを全部リストとして取り出す
leaves = jax.tree_util.tree_leaves(params)
# 構造を取り出す(あとから組み立て直すための「型」)
leaves, treedef = jax.tree_util.tree_flatten(params)
params2 = jax.tree_util.tree_unflatten(treedef, leaves)
# 葉っぱごとに関数を適用
doubled = jax.tree_util.tree_map(lambda x: x * 2, params)
「平らにする」「組み立て直す」「葉っぱごとに何かする」というシンプルな道具立てで、複雑な構造のパラメータを楽に扱えます。
7.4 ユーザー定義クラスを PyTree にする
@jax.tree_util.register_pytree_node_class などを使うと、自作クラスも PyTree として扱えるようになります。実務でよく使うのは dataclasses.dataclass + flax.struct.dataclass または equinox.Module の組み合わせです(第12章で紹介します)。
7.5 PyTree が便利な理由
- モデルの構造をそのまま自然に表現 できる(辞書のキーで層に名前を付けたり)
- 勾配・最適化・並列化を 構造ごと一気に 処理できる
- 配列が「どこにあるか」をデータ構造で表現できる(=バグが追いやすい)
PyTree という考え方は、慣れると「これなしで深層学習をどう書いてたんだっけ?」と思えるくらい便利です。「複雑なものは木の構造で表現して、葉っぱを一斉に処理する」という発想に親しんでいきましょう。