JAX Autodiff Workflow#
This example demonstrates using the JAX backend to obtain parametrized functions suitable for jax.grad, jax.vmap, and jax.jit. It’s aimed at users who want to perform parameter optimization or batched evaluations.
What you’ll learn#
How to compile a model for the JAX backend.
How to use the returned functions in JAX workflows (grad, vmap, jit).
Source#
1# flake8: noqa
2"""
3JAX backend workflow with autodiff and platform configuration.
4
5Goals:
6- Show how to guard imports when JAX may be absent.
7- Configure dtype/platform overrides via JaxBackendOptions.
8- Compile a small Hamiltonian and take gradients with jax.grad.
9- Illustrate parameter tracking (compiled.parameters set).
10
11This file is safe to import even if JAX is not installed; demos check at runtime.
12"""
13
14from __future__ import annotations
15
16import os
17import sys
18from pathlib import Path
19from typing import Any, Dict
20
21ROOT = Path(__file__).resolve().parents[1]
22if str(ROOT) not in sys.path:
23 sys.path.insert(0, str(ROOT))
24
25# Force CPU by default to avoid GPU/driver issues in example contexts.
26os.environ.setdefault("JAX_PLATFORM_NAME", "cpu")
27os.environ.setdefault("JAX_PLATFORMS", "cpu")
28
29from latex_parser.backend_jax import JaxBackend, JaxBackendOptions
30from latex_parser.dsl import HilbertConfig, QubitSpec
31from latex_parser.ir import latex_to_ir
32
33
34def _static_matrix(compiled: Any) -> Any:
35 """
36 Return the static Hamiltonian component from a compiled JAX model.
37
38 CompiledOpenSystemJax exposes ``H`` which is either a single array
39 (static) or a time-dependent list whose first element is the static H0.
40 """
41 H = compiled.H
42 return H[0] if isinstance(H, list) else H
43
44
45def _require_jax():
46 try:
47 import jax
48 import jax.numpy as jnp
49 except Exception as exc: # pragma: no cover - optional dependency
50 print("JAX not installed; install with `pip install jax jaxlib`.")
51 raise RuntimeError("JAX missing") from exc
52 return jax, jnp
53
54
55def compile_with_custom_options() -> Dict[str, Any]:
56 """
57 Compile a driven qubit with explicit dtype/platform overrides.
58 """
59 jax, jnp = _require_jax()
60 opts = JaxBackendOptions(dtype=jnp.complex64, platform="cpu")
61 cfg = HilbertConfig(qubits=[QubitSpec(label="q", index=1)], bosons=[], customs=[])
62 H_latex = r"A \cos(\omega t) \sigma_{x,1}"
63 params = {"A": 0.5, "omega": 1.0}
64 backend = JaxBackend()
65 compiled = backend.compile_open_system_from_latex(
66 H_latex=H_latex,
67 params=params,
68 config=cfg,
69 c_ops_latex=None,
70 t_name="t",
71 options=opts,
72 )
73 print("Compiled type:", type(compiled.H))
74 print("Parameters tracked:", compiled.parameters)
75 return {"compiled": compiled, "opts": opts, "jax": jax, "jnp": jnp}
76
77
78def gradient_of_expectation() -> None:
79 """
80 Compute a simple gradient of an expectation value with respect to A.
81 """
82 jax, jnp = _require_jax()
83 context = compile_with_custom_options()
84 compiled = context["compiled"]
85 H = _static_matrix(compiled)
86
87 def energy(A_val: float) -> Any:
88 H_scaled = (A_val / 0.5) * H
89 # Ground state energy (simplified for demo): min eigenvalue.
90 eigvals = jnp.linalg.eigvalsh(H_scaled)
91 return jnp.min(eigvals)
92
93 grad_fn = jax.grad(energy)
94 print("dE/dA at A=0.5:", grad_fn(0.5))
95
96
97def time_dependent_jax_ir() -> None:
98 """
99 Show how IR time dependence is reflected in JAX compilation.
100 """
101 _require_jax()
102 cfg = HilbertConfig(qubits=[QubitSpec(label="q", index=1)], bosons=[], customs=[])
103 H_latex = r"A \cos(\omega t) \sigma_{x,1}"
104 params = {"A": 0.5, "omega": 1.0}
105 backend = JaxBackend()
106 compiled = backend.compile_open_system_from_latex(
107 H_latex=H_latex,
108 params=params,
109 config=cfg,
110 c_ops_latex=None,
111 t_name="t",
112 )
113 print("Time-dependent flag:", compiled.time_dependent)
114 print("H list structure:", compiled.H)
115
116
117def inspect_ir_for_jax() -> None:
118 """
119 Inspect IR before passing to JAX backend (developer aid).
120 """
121 cfg = HilbertConfig(qubits=[QubitSpec(label="q", index=1)], bosons=[], customs=[])
122 H_latex = r"A \cos(\omega t) \sigma_{x,1} + B \sigma_{z,1}"
123 ir = latex_to_ir(H_latex, cfg, t_name="t")
124 print("IR has_time_dep:", ir.has_time_dep)
125 for idx, term in enumerate(ir.terms):
126 print(f"Term {idx}: scalar={term.scalar_expr}, ops={term.ops}")
127
128
129def explore_platform_settings() -> None:
130 """
131 Show how to set JAX platform via options/env.
132 """
133 try:
134 import os
135 from latex_parser.backend_jax import _apply_platform
136 except Exception: # pragma: no cover - optional
137 print("JAX not installed; skipping platform demo.")
138 return
139 for platform in ("cpu", "gpu"):
140 os.environ["JAX_PLATFORM_NAME"] = platform
141 _apply_platform(platform)
142 print("Requested platform:", platform)
143
144
145def dtype_demonstration() -> None:
146 """
147 Compare float32 vs float64 compilation for the same Hamiltonian.
148 """
149 try:
150 _, jnp = _require_jax()
151 except RuntimeError:
152 return
153 cfg = HilbertConfig(qubits=[QubitSpec(label="q", index=1)], bosons=[], customs=[])
154 H_latex = r"\delta \sigma_{x,1}"
155 params = {"delta": 0.4}
156 backend = JaxBackend()
157 for dtype in (jnp.complex64, jnp.complex128):
158 opts = JaxBackendOptions(dtype=dtype, platform="cpu")
159 compiled = backend.compile_open_system_from_latex(
160 H_latex=H_latex,
161 params=params,
162 config=cfg,
163 c_ops_latex=None,
164 t_name="t",
165 options=opts,
166 )
167 H_static = _static_matrix(compiled)
168 print("dtype:", dtype, "H dtype:", H_static.dtype)
169
170
171def simple_batch_eval() -> None:
172 """
173 Evaluate a compiled JAX Hamiltonian over a batch of parameter values.
174 """
175 try:
176 jax, jnp = _require_jax()
177 except RuntimeError:
178 return
179 cfg = HilbertConfig(qubits=[QubitSpec(label="q", index=1)], bosons=[], customs=[])
180 H_latex = r"A \sigma_{x,1}"
181 params = {"A": 1.0}
182 backend = JaxBackend()
183 compiled = backend.compile_open_system_from_latex(
184 H_latex=H_latex,
185 params=params,
186 config=cfg,
187 c_ops_latex=None,
188 t_name="t",
189 )
190 H0 = _static_matrix(compiled)
191
192 def evaluate_batch(A_vals):
193 return jax.vmap(lambda a: a * H0)(A_vals)
194
195 batch = jnp.linspace(0.1, 1.0, 4)
196 result = evaluate_batch(batch)
197 print("Batch scales:", batch)
198 print("Batch H shape:", result.shape)
199
200
201def guard_missing_jax_install() -> None:
202 """
203 Provide a single place to verify JAX installation status.
204 """
205 try:
206 _require_jax()
207 print("JAX present.")
208 except RuntimeError:
209 print("JAX missing; install to run autodiff demos.")
210
211
212def main() -> None:
213 try:
214 compile_with_custom_options()
215 gradient_of_expectation()
216 time_dependent_jax_ir()
217 explore_platform_settings()
218 dtype_demonstration()
219 simple_batch_eval()
220 except RuntimeError:
221 print("Skipping JAX-specific demos.")
222 inspect_ir_for_jax()
223 guard_missing_jax_install()
224
225
226if __name__ == "__main__":
227 main()
Run#
python examples/example_jax_autodiff_workflow.py
Notes#
JAX must be installed and available in your environment for this example to run. If JAX is not present the example will either skip or raise.