Examples#
The examples/ directory contains a collection of astrophysics simulations demonstrating various applications of the Adirondax library.
Gallery#
kelvin_helmholtz#
See on GitHub: examples/kelvin_helmholtz#
README:
# Kelvin-Helmholtz
Simulate the Kelvin-Helmholtz Instability (Euler equations)
Philip Mocz (2025)
Usage:
```console
python kelvin_helmholtz.py
```
Takes around 6 seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="output.png" alt="output" width="45%"/>
</div>
Script:
import jax.numpy as jnp
# TODO: REMOVE THE FOLLOWING LINES
import sys
sys.path.append("../../")
import adirondax as adx
import time
import matplotlib.pyplot as plt
"""
Simulate the Kelvin-Helmholtz instability
Philip Mocz (2025)
"""
def set_up_simulation():
# Define the parameters for the simulation
n = 256
nt = 400 * int(n / 32)
t_stop = 2.0
params = {
"physics": {
"hydro": True,
},
"mesh": {
"type": "cartesian",
"resolution": [n, n],
"box_size": [1.0, 1.0],
},
"time": {
"span": t_stop,
"num_timesteps": nt,
},
"output": {
"num_checkpoints": 10,
"save": True,
"plot_dynamic_range": 2.0,
},
"hydro": {
"eos": {"type": "ideal", "gamma": 5.0 / 3.0},
"slope_limiting": False,
},
}
# Initialize the simulation
sim = adx.Simulation(params)
# Set initial conditions
# (opposite moving streams with perturbation)
sim.state["t"] = 0.0
X, Y = sim.mesh
w0 = 0.1
sigma = 0.05 / jnp.sqrt(2.0)
sim.state["rho"] = 1.0 + (jnp.abs(Y - 0.5) < 0.25)
sim.state["vx"] = -0.5 + (jnp.abs(Y - 0.5) < 0.25)
sim.state["vy"] = (
w0
* jnp.sin(4.0 * jnp.pi * X)
* (
jnp.exp(-((Y - 0.25) ** 2) / (2.0 * sigma**2))
+ jnp.exp(-((Y - 0.75) ** 2) / (2.0 * sigma**2))
)
)
sim.state["P"] = 2.5 * jnp.ones(X.shape)
return sim
def make_plot(sim):
# Plot the solution
plt.figure(figsize=(6, 4), dpi=80)
plt.imshow(jnp.rot90(sim.state["rho"]), cmap="jet", vmin=0.8, vmax=2.2)
plt.colorbar(label="density")
plt.tight_layout()
plt.savefig("output.png", dpi=240)
plt.show()
def main():
sim = set_up_simulation()
# Evolve the system
t0 = time.time()
sim.run()
print("Run time (s): ", time.time() - t0)
make_plot(sim)
if __name__ == "__main__":
main()
logo_inverse_problem#
See on GitHub: examples/logo_inverse_problem#
README:
# Logo Inverse Problem
Inverse problem with the Schrodinger-Poisson system.
Find initial phases that evolve system into the JAX logo.
Philip Mocz (2025)
Usage:
```console
python logo_inverse_problem.py
```
Takes around 2 seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="output.png" alt="output" width="45%"/>
</div>
Script:
import jax
import jax.numpy as jnp
# TODO: REMOVE THE FOLLOWING LINES
import sys
sys.path.append("../../")
import adirondax as adx
from jaxopt import ScipyMinimize
import time
import matplotlib.pyplot as plt
import matplotlib.image as img
"""
Solve an Inverse-Problem that finds the initial wave function phases that
evolve under the Schrodinger-Poisson equations into a target density field
Philip Mocz (2024)
"""
def set_up_simulation(save=False):
# Define the parameters for the simulation
n = 128
nt = 100 * int(n / 128)
t_stop = 0.03
params = {
"physics": {
"quantum": True,
"gravity": True,
},
"mesh": {
"type": "cartesian",
"resolution": [n, n],
"box_size": [1.0, 1.0],
},
"time": {
"span": t_stop,
"num_timesteps": nt,
},
"output": {
"save": save,
"plot_dynamic_range": 2.0,
},
}
# Initialize the simulation
sim = adx.Simulation(params)
return sim
def solve_inverse_problem(sim):
# Load the target density field
target_data = img.imread("target.png")[:, :, 0]
rho_target = jnp.flipud(jnp.array(target_data, dtype=float)).T
rho_target = 1.0 - 0.5 * (rho_target - 0.5)
rho_target /= jnp.mean(rho_target)
assert rho_target.shape[0] == sim.resolution[0]
assert rho_target.shape[1] == sim.resolution[1]
# Define the loss function for the optimization
@jax.jit
def loss_function(theta, rho_target):
sim.state["t"] = 0.0
sim.state["psi"] = jnp.exp(1.0j * theta)
sim.run()
psi = sim.state["psi"]
rho = jnp.abs(psi) ** 2
return jnp.mean((rho - rho_target) ** 2)
# Solve the inverse-problem (takes around 5 seconds on my macbook)
opt = ScipyMinimize(
method="l-bfgs-b", fun=loss_function, tol=1e-5, options={"disp": True}
)
theta = jnp.zeros_like(rho_target)
t0 = time.time()
sol = opt.run(theta, rho_target)
print("Inverse-problem solve time (s): ", time.time() - t0)
theta = jnp.mod(sol.params, 2.0 * jnp.pi) - jnp.pi
print("Mean theta:", jnp.mean(theta))
return theta
def rerun_simulation(sim, theta):
# Re-run the simulation with the optimal initial conditions
sim.state["t"] = 0.0
sim.state["psi"] = jnp.exp(1.0j * theta)
sim.run()
print("Final time:", sim.state["t"])
return sim.state["psi"]
def make_plot(psi, theta):
# Plot the solution
plt.figure(figsize=(6, 4), dpi=80)
grid = plt.GridSpec(1, 2, wspace=0.0, hspace=0.0)
ax1 = plt.subplot(grid[0, 0])
ax2 = plt.subplot(grid[0, 1])
plt.sca(ax1)
plt.cla()
plt.imshow(theta.T, cmap="bwr")
plt.clim(-jnp.pi, jnp.pi)
ax1.get_xaxis().set_visible(False)
ax1.get_yaxis().set_visible(False)
ax1.invert_yaxis()
ax1.set_aspect("equal")
plt.title(r"${\rm initial\,angle}(\psi)$")
plt.sca(ax2)
plt.cla()
plt.imshow(jnp.log10(jnp.abs(psi) ** 2).T, cmap="inferno")
plt.clim(-0.2, 0.2)
ax2.get_xaxis().set_visible(False)
ax2.get_yaxis().set_visible(False)
ax2.invert_yaxis()
ax2.set_aspect("equal")
plt.title(r"${\rm final\,}\log_{10}(|\psi|^2)$")
plt.tight_layout()
plt.savefig("output.png", dpi=240)
plt.show()
def main():
sim = set_up_simulation(save=False)
theta = solve_inverse_problem(sim)
sim = set_up_simulation(save=True)
psi = rerun_simulation(sim, theta)
make_plot(psi, theta)
if __name__ == "__main__":
main()
orszag_tang#
See on GitHub: examples/orszag_tang#
README:
# Orszag-Tang
Simulate the Orszag-Tang vortex (2D MHD equations)
Philip Mocz (2025)
Usage:
```console
python orszag_tang.py
```
Takes around 55 seconds ('llf') / 105 seconds ('hlld') to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="output.png" alt="output" width="45%"/>
</div>
Script:
import jax.numpy as jnp
# TODO: REMOVE THE FOLLOWING LINES
import sys
sys.path.append("../../")
import adirondax as adx
from adirondax.hydro.common2d import get_curl, get_avg
import time
import matplotlib.pyplot as plt
# switch on for double precision
# jax.config.update("jax_enable_x64", True)
"""
Simulate the Orszag-Tang vortex
Philip Mocz (2025)
"""
def set_up_simulation():
# Define the parameters for the simulation
n = 512
nt = -1 # 2400
t_stop = 0.5
gamma = 5.0 / 3.0
box_size = 1.0
dx = box_size / n
params = {
"physics": {
"hydro": True,
"magnetic": True,
},
"mesh": {
"type": "cartesian",
"resolution": [n, n],
"box_size": [box_size, box_size],
},
"time": {
"span": t_stop,
"num_timesteps": nt,
},
"output": {
"num_checkpoints": 100,
"save": False,
"plot_dynamic_range": 2.3,
},
"hydro": {
"eos": {"type": "ideal", "gamma": gamma},
"cfl": 0.6,
"riemann_solver": "hlld",
"slope_limiting": True,
},
}
# Initialize the simulation
sim = adx.Simulation(params)
# Set initial conditions
sim.state["t"] = jnp.array(0.0)
X, Y = sim.mesh
sim.state["rho"] = (gamma**2 / (4.0 * jnp.pi)) * jnp.ones(X.shape)
sim.state["vx"] = -jnp.sin(2.0 * jnp.pi * Y)
sim.state["vy"] = jnp.sin(2.0 * jnp.pi * X)
P_gas = (gamma / (4.0 * jnp.pi)) * jnp.ones(X.shape)
# (Az is at top-right node of each cell)
xlin_node = jnp.linspace(dx, box_size, n)
Xn, Yn = jnp.meshgrid(xlin_node, xlin_node, indexing="ij")
Az = jnp.cos(4.0 * jnp.pi * Xn) / (4.0 * jnp.pi * jnp.sqrt(4.0 * jnp.pi)) + jnp.cos(
2.0 * jnp.pi * Yn
) / (2.0 * jnp.pi * jnp.sqrt(4.0 * jnp.pi))
bx, by = get_curl(Az, dx, dx)
Bx, By = get_avg(bx, by)
P_tot = P_gas + 0.5 * (Bx**2 + By**2)
sim.state["P"] = P_tot
sim.state["bx"] = bx
sim.state["by"] = by
return sim
def make_plot(sim):
# Plot the solution
plt.figure(figsize=(6, 4), dpi=80)
plt.imshow(jnp.rot90(sim.state["rho"]), cmap="jet", vmin=0.06, vmax=0.5)
plt.colorbar(label="density")
plt.tight_layout()
plt.savefig("output.png", dpi=240)
plt.show()
def main():
sim = set_up_simulation()
# Evolve the system
t0 = time.time()
sim.run()
print("Steps taken:", sim.steps_taken)
print("Run time (s):", time.time() - t0)
make_plot(sim)
if __name__ == "__main__":
main()
rayleigh_taylor#
See on GitHub: examples/rayleigh_taylor#
README:
# Rayleigh-Taylor
Simulate the Rayleigh-Taylor Instability (Euler equations)
Philip Mocz (2025)
Usage:
```console
python rayleigh_taylor.py
```
Takes around 9 seconds to run on my macbook (cpu).
## Simulation snapshots
<div style="display:flex;flex-wrap:wrap;gap:8px">
<img src="output.png" alt="output" width="45%"/>
</div>
Script:
import jax.numpy as jnp
# TODO: REMOVE THE FOLLOWING LINES
import sys
sys.path.append("../../")
import adirondax as adx
import time
import matplotlib.pyplot as plt
"""
Simulate the Rayleigh-Taylor Instability
Philip Mocz (2025)
"""
def set_up_simulation():
# Define the parameters for the simulation
nx = 64
ny = 192
nt = 13000 # -1
t_stop = 20.0
params = {
"physics": {
"hydro": True,
"external_potential": True,
},
"mesh": {
"type": "cartesian",
"resolution": [nx, ny],
"box_size": [0.5, 1.5],
"boundary_condition": ["periodic", "reflective"],
},
"time": {
"span": t_stop,
"num_timesteps": nt,
},
"output": {
"num_checkpoints": 100,
"save": True,
"plot_dynamic_range": 2.0,
},
"hydro": {
"eos": {"type": "ideal", "gamma": 1.4},
"slope_limiting": False,
},
}
# Initialize the simulation
sim = adx.Simulation(params)
# Set initial conditions
# (heavy fluid on top of light)
sim.state["t"] = 0.0
X, Y = sim.mesh
w0 = 0.0025
P0 = 2.5
g = 0.1
sim.state["rho"] = 1.0 + (Y > 0.75)
sim.state["vx"] = jnp.zeros(X.shape)
sim.state["vy"] = (
w0 * (1.0 - jnp.cos(4.0 * jnp.pi * X)) * (1.0 - jnp.cos(4.0 * jnp.pi * Y / 3.0))
)
sim.state["P"] = P0 - g * (Y - 0.75) * sim.state["rho"]
# external potential
def external_potential(x, y):
V = g * y
return V
sim.external_potential = external_potential
return sim
def make_plot(sim):
# Plot the solution
plt.figure(figsize=(4, 6), dpi=80)
plt.imshow(sim.state["rho"].T, cmap="jet", vmin=0.8, vmax=2.2)
plt.gca().invert_yaxis()
plt.colorbar(label="density")
plt.tight_layout()
plt.savefig("output.png", dpi=240)
plt.show()
def main():
sim = set_up_simulation()
# Evolve the system
t0 = time.time()
sim.run()
print("Run time (s): ", time.time() - t0)
print("Steps taken:", sim.steps_taken)
make_plot(sim)
if __name__ == "__main__":
main()