JAX Autodiff Workflow

JAX Autodiff Workflow#

This example demonstrates using the JAX backend to obtain parametrized functions suitable for jax.grad, jax.vmap, and jax.jit. It’s aimed at users who want to perform parameter optimization or batched evaluations.

What you’ll learn#

  • How to compile a model for the JAX backend.

  • How to use the returned functions in JAX workflows (grad, vmap, jit).

Source#

  1# flake8: noqa
  2"""
  3JAX backend workflow with autodiff and platform configuration.
  4
  5Goals:
  6- Show how to guard imports when JAX may be absent.
  7- Configure dtype/platform overrides via JaxBackendOptions.
  8- Compile a small Hamiltonian and take gradients with jax.grad.
  9- Illustrate parameter tracking (compiled.parameters set).
 10
 11This file is safe to import even if JAX is not installed; demos check at runtime.
 12"""
 13
 14from __future__ import annotations
 15
 16import os
 17import sys
 18from pathlib import Path
 19from typing import Any, Dict
 20
 21ROOT = Path(__file__).resolve().parents[1]
 22if str(ROOT) not in sys.path:
 23    sys.path.insert(0, str(ROOT))
 24
 25# Force CPU by default to avoid GPU/driver issues in example contexts.
 26os.environ.setdefault("JAX_PLATFORM_NAME", "cpu")
 27os.environ.setdefault("JAX_PLATFORMS", "cpu")
 28
 29from latex_parser.backend_jax import JaxBackend, JaxBackendOptions
 30from latex_parser.dsl import HilbertConfig, QubitSpec
 31from latex_parser.ir import latex_to_ir
 32
 33
 34def _static_matrix(compiled: Any) -> Any:
 35    """
 36    Return the static Hamiltonian component from a compiled JAX model.
 37
 38    CompiledOpenSystemJax exposes ``H`` which is either a single array
 39    (static) or a time-dependent list whose first element is the static H0.
 40    """
 41    H = compiled.H
 42    return H[0] if isinstance(H, list) else H
 43
 44
 45def _require_jax():
 46    try:
 47        import jax
 48        import jax.numpy as jnp
 49    except Exception as exc:  # pragma: no cover - optional dependency
 50        print("JAX not installed; install with `pip install jax jaxlib`.")
 51        raise RuntimeError("JAX missing") from exc
 52    return jax, jnp
 53
 54
 55def compile_with_custom_options() -> Dict[str, Any]:
 56    """
 57    Compile a driven qubit with explicit dtype/platform overrides.
 58    """
 59    jax, jnp = _require_jax()
 60    opts = JaxBackendOptions(dtype=jnp.complex64, platform="cpu")
 61    cfg = HilbertConfig(qubits=[QubitSpec(label="q", index=1)], bosons=[], customs=[])
 62    H_latex = r"A \cos(\omega t) \sigma_{x,1}"
 63    params = {"A": 0.5, "omega": 1.0}
 64    backend = JaxBackend()
 65    compiled = backend.compile_open_system_from_latex(
 66        H_latex=H_latex,
 67        params=params,
 68        config=cfg,
 69        c_ops_latex=None,
 70        t_name="t",
 71        options=opts,
 72    )
 73    print("Compiled type:", type(compiled.H))
 74    print("Parameters tracked:", compiled.parameters)
 75    return {"compiled": compiled, "opts": opts, "jax": jax, "jnp": jnp}
 76
 77
 78def gradient_of_expectation() -> None:
 79    """
 80    Compute a simple gradient of an expectation value with respect to A.
 81    """
 82    jax, jnp = _require_jax()
 83    context = compile_with_custom_options()
 84    compiled = context["compiled"]
 85    H = _static_matrix(compiled)
 86
 87    def energy(A_val: float) -> Any:
 88        H_scaled = (A_val / 0.5) * H
 89        # Ground state energy (simplified for demo): min eigenvalue.
 90        eigvals = jnp.linalg.eigvalsh(H_scaled)
 91        return jnp.min(eigvals)
 92
 93    grad_fn = jax.grad(energy)
 94    print("dE/dA at A=0.5:", grad_fn(0.5))
 95
 96
 97def time_dependent_jax_ir() -> None:
 98    """
 99    Show how IR time dependence is reflected in JAX compilation.
100    """
101    _require_jax()
102    cfg = HilbertConfig(qubits=[QubitSpec(label="q", index=1)], bosons=[], customs=[])
103    H_latex = r"A \cos(\omega t) \sigma_{x,1}"
104    params = {"A": 0.5, "omega": 1.0}
105    backend = JaxBackend()
106    compiled = backend.compile_open_system_from_latex(
107        H_latex=H_latex,
108        params=params,
109        config=cfg,
110        c_ops_latex=None,
111        t_name="t",
112    )
113    print("Time-dependent flag:", compiled.time_dependent)
114    print("H list structure:", compiled.H)
115
116
117def inspect_ir_for_jax() -> None:
118    """
119    Inspect IR before passing to JAX backend (developer aid).
120    """
121    cfg = HilbertConfig(qubits=[QubitSpec(label="q", index=1)], bosons=[], customs=[])
122    H_latex = r"A \cos(\omega t) \sigma_{x,1} + B \sigma_{z,1}"
123    ir = latex_to_ir(H_latex, cfg, t_name="t")
124    print("IR has_time_dep:", ir.has_time_dep)
125    for idx, term in enumerate(ir.terms):
126        print(f"Term {idx}: scalar={term.scalar_expr}, ops={term.ops}")
127
128
129def explore_platform_settings() -> None:
130    """
131    Show how to set JAX platform via options/env.
132    """
133    try:
134        import os
135        from latex_parser.backend_jax import _apply_platform
136    except Exception:  # pragma: no cover - optional
137        print("JAX not installed; skipping platform demo.")
138        return
139    for platform in ("cpu", "gpu"):
140        os.environ["JAX_PLATFORM_NAME"] = platform
141        _apply_platform(platform)
142        print("Requested platform:", platform)
143
144
145def dtype_demonstration() -> None:
146    """
147    Compare float32 vs float64 compilation for the same Hamiltonian.
148    """
149    try:
150        _, jnp = _require_jax()
151    except RuntimeError:
152        return
153    cfg = HilbertConfig(qubits=[QubitSpec(label="q", index=1)], bosons=[], customs=[])
154    H_latex = r"\delta \sigma_{x,1}"
155    params = {"delta": 0.4}
156    backend = JaxBackend()
157    for dtype in (jnp.complex64, jnp.complex128):
158        opts = JaxBackendOptions(dtype=dtype, platform="cpu")
159        compiled = backend.compile_open_system_from_latex(
160            H_latex=H_latex,
161            params=params,
162            config=cfg,
163            c_ops_latex=None,
164            t_name="t",
165            options=opts,
166        )
167        H_static = _static_matrix(compiled)
168        print("dtype:", dtype, "H dtype:", H_static.dtype)
169
170
171def simple_batch_eval() -> None:
172    """
173    Evaluate a compiled JAX Hamiltonian over a batch of parameter values.
174    """
175    try:
176        jax, jnp = _require_jax()
177    except RuntimeError:
178        return
179    cfg = HilbertConfig(qubits=[QubitSpec(label="q", index=1)], bosons=[], customs=[])
180    H_latex = r"A \sigma_{x,1}"
181    params = {"A": 1.0}
182    backend = JaxBackend()
183    compiled = backend.compile_open_system_from_latex(
184        H_latex=H_latex,
185        params=params,
186        config=cfg,
187        c_ops_latex=None,
188        t_name="t",
189    )
190    H0 = _static_matrix(compiled)
191
192    def evaluate_batch(A_vals):
193        return jax.vmap(lambda a: a * H0)(A_vals)
194
195    batch = jnp.linspace(0.1, 1.0, 4)
196    result = evaluate_batch(batch)
197    print("Batch scales:", batch)
198    print("Batch H shape:", result.shape)
199
200
201def guard_missing_jax_install() -> None:
202    """
203    Provide a single place to verify JAX installation status.
204    """
205    try:
206        _require_jax()
207        print("JAX present.")
208    except RuntimeError:
209        print("JAX missing; install to run autodiff demos.")
210
211
212def main() -> None:
213    try:
214        compile_with_custom_options()
215        gradient_of_expectation()
216        time_dependent_jax_ir()
217        explore_platform_settings()
218        dtype_demonstration()
219        simple_batch_eval()
220    except RuntimeError:
221        print("Skipping JAX-specific demos.")
222    inspect_ir_for_jax()
223    guard_missing_jax_install()
224
225
226if __name__ == "__main__":
227    main()

Run#

python examples/example_jax_autodiff_workflow.py

Notes#

  • JAX must be installed and available in your environment for this example to run. If JAX is not present the example will either skip or raise.