Skip to main content

第2章 インストールと環境構築

2.1 Python のバージョン

JAX の Python サポートはリリースごとに更新されます。公式のサポートポリシーでは、各 JAX リリース時点から過去 45 か月以内の Python 機能リリースが少なくともサポート対象になります。2026-06-02 時点では Python 3.11 もサポート対象ですが、3.11 は少なくとも 2026 年 7 月までという扱いなので、新しく環境を作るなら Python 3.12 以上 をおすすめします。python --version で確認しましょう。

python --version
# 例: Python 3.12.3

2.2 仮想環境を作る(推奨)

Python のライブラリは「プロジェクトごとに環境を分ける」のが鉄則です。バージョン違いの衝突を防げます。

venv を使う場合(標準・お手軽)

# プロジェクトのフォルダで
python -m venv .venv

# 仮想環境を有効化(mac / Linux)
source .venv/bin/activate

# Windows の場合
# .venv\Scripts\activate

uv を使う場合(高速・最近の人気者)

# uv 自体のインストール(mac)
brew install uv

# プロジェクトを作る
uv init my_jax_project
cd my_jax_project
uv add jax

2.3 JAX のインストール

JAX は実行環境(CPU・NVIDIA GPU・TPU)によってインストールコマンドが変わります。

CPU のみで動かす(学習・お試しに最適)

pip install -U jax

これだけで OK です。

NVIDIA GPU(CUDA)で動かす

2026-06-02 時点の公式ドキュメントでは、pip で CUDA ライブラリを同梱する CUDA 13 用ホイール が推奨されています(Linux 用)。

pip install --upgrade pip
pip install -U "jax[cuda13]"

CUDA 12 が必要な環境では次も使えます。

pip install -U "jax[cuda12]"

注意点:

  • NVIDIA ドライバは事前に入っている必要があります。
  • Linux では CUDA 13 の場合ドライバ 580 以上、CUDA 12 の場合ドライバ 525 以上が目安です。
  • CUDA ホイールは主に Linux 向けです。Windows のネイティブ GPU 実行や macOS GPU 実行は状況が異なるので、必ず 公式インストールガイド を確認してください。

Google Cloud TPU で動かす

pip install -U "jax[tpu]"

💡 Google Colab では多くの場合 JAX が使える状態になっています。TPU を使う場合は Colab 側のランタイム種別や公式インストールガイドも確認してください。

2.4 動作確認

Python を起動して、次のコードを実行してみましょう。

import jax
import jax.numpy as jnp

# バージョン確認
print("JAX version:", jax.__version__)

# 利用可能なデバイス確認(CPU / GPU / TPU)
print("Devices:", jax.devices())

# 簡単な計算
x = jnp.arange(5)
print("x =", x)
print("x ** 2 =", x ** 2)

出力例(CPU の場合):

JAX version: 0.x.y
Devices: [CpuDevice(id=0)]
x = [0 1 2 3 4]
x ** 2 = [ 0 1 4 9 16]

Devices のところに CudaDeviceTpuDevice が出てくれば、GPU/TPU が認識できています 🎉

2.5 よく使うエディタ・実行環境

  • Jupyter Notebook / JupyterLab: 試行錯誤に便利。pip install jupyterlab で導入。
  • VS Code: Python 拡張機能を入れると快適。
  • Google Colab: ブラウザだけで完結、しかも無料 GPU が使える。

初心者の方は、まず Colab で動かすのが一番おすすめです。

2.6 次の章へ

環境が整ったら、いよいよ JAX のコードを書いていきましょう。次は NumPy そっくりな書き心地の jax.numpy を学びます。

➡️ 第3章 jax.numpy の基本