Examples#

The examples/ directory contains a collection of astrophysics simulations demonstrating various applications of the Adirondax library.

kelvin_helmholtz#

kelvin_helmholtz

See on GitHub: examples/kelvin_helmholtz#

README:

# Kelvin-Helmholtz

Simulate the Kelvin-Helmholtz Instability (Euler equations)

Philip Mocz (2025)

Usage:

```console
python kelvin_helmholtz.py
```

Takes around 6 seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="output.png" alt="output" width="45%"/>
</div>


## References

[Springel, V.; E pur si muove: Galilean-invariant cosmological hydrodynamical simulations on a moving mesh. Monthly Notices of the Royal Astronomical Society (2020)](https://ui.adsabs.harvard.edu/abs/2010MNRAS.401..791S)

Script:

import jax.numpy as jnp

# TODO: REMOVE THE FOLLOWING LINES
import sys

sys.path.append("../../")

import adirondax as adx
import time
import matplotlib.pyplot as plt

"""
Simulate the Kelvin-Helmholtz instability

Philip Mocz (2025)
"""


def set_up_simulation():
    # Define the parameters for the simulation
    n = 256
    nt = 400 * int(n / 32)
    t_stop = 2.0

    params = {
        "physics": {
            "hydro": True,
        },
        "mesh": {
            "type": "cartesian",
            "resolution": [n, n],
            "box_size": [1.0, 1.0],
        },
        "time": {
            "span": t_stop,
            "num_timesteps": nt,
        },
        "output": {
            "num_checkpoints": 10,
            "save": True,
            "plot_dynamic_range": 2.0,
        },
        "hydro": {
            "eos": {"type": "ideal", "gamma": 5.0 / 3.0},
            "slope_limiting": False,
        },
    }

    # Initialize the simulation
    sim = adx.Simulation(params)

    # Set initial conditions
    # (opposite moving streams with perturbation)
    sim.state["t"] = 0.0
    X, Y = sim.mesh
    w0 = 0.1
    sigma = 0.05 / jnp.sqrt(2.0)
    sim.state["rho"] = 1.0 + (jnp.abs(Y - 0.5) < 0.25)
    sim.state["vx"] = -0.5 + (jnp.abs(Y - 0.5) < 0.25)
    sim.state["vy"] = (
        w0
        * jnp.sin(4.0 * jnp.pi * X)
        * (
            jnp.exp(-((Y - 0.25) ** 2) / (2.0 * sigma**2))
            + jnp.exp(-((Y - 0.75) ** 2) / (2.0 * sigma**2))
        )
    )
    sim.state["P"] = 2.5 * jnp.ones(X.shape)

    return sim


def make_plot(sim):
    # Plot the solution
    plt.figure(figsize=(6, 4), dpi=80)
    plt.imshow(jnp.rot90(sim.state["rho"]), cmap="jet", vmin=0.8, vmax=2.2)
    plt.colorbar(label="density")
    plt.tight_layout()
    plt.savefig("output.png", dpi=240)
    plt.show()


def main():
    sim = set_up_simulation()

    # Evolve the system
    t0 = time.time()
    sim.run()
    print("Run time (s): ", time.time() - t0)

    make_plot(sim)


if __name__ == "__main__":
    main()

logo_inverse_problem#

logo_inverse_problem

See on GitHub: examples/logo_inverse_problem#

README:

# Logo Inverse Problem

Inverse problem with the Schrodinger-Poisson system.
Find initial phases that evolve system into the JAX logo.

Philip Mocz (2025)

Usage:

```console
python logo_inverse_problem.py
```

Takes around 2 seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="output.png" alt="output" width="45%"/>
</div>

Script:

import jax
import jax.numpy as jnp

# TODO: REMOVE THE FOLLOWING LINES
import sys

sys.path.append("../../")

import adirondax as adx
from jaxopt import ScipyMinimize
import time
import matplotlib.pyplot as plt
import matplotlib.image as img

"""
Solve an Inverse-Problem that finds the initial wave function phases that
evolve under the Schrodinger-Poisson equations into a target density field

Philip Mocz (2024)
"""


def set_up_simulation(save=False):
    # Define the parameters for the simulation
    n = 128
    nt = 100 * int(n / 128)
    t_stop = 0.03

    params = {
        "physics": {
            "quantum": True,
            "gravity": True,
        },
        "mesh": {
            "type": "cartesian",
            "resolution": [n, n],
            "box_size": [1.0, 1.0],
        },
        "time": {
            "span": t_stop,
            "num_timesteps": nt,
        },
        "output": {
            "save": save,
            "plot_dynamic_range": 2.0,
        },
    }

    # Initialize the simulation
    sim = adx.Simulation(params)

    return sim


def solve_inverse_problem(sim):
    # Load the target density field
    target_data = img.imread("target.png")[:, :, 0]
    rho_target = jnp.flipud(jnp.array(target_data, dtype=float)).T
    rho_target = 1.0 - 0.5 * (rho_target - 0.5)
    rho_target /= jnp.mean(rho_target)

    assert rho_target.shape[0] == sim.resolution[0]
    assert rho_target.shape[1] == sim.resolution[1]

    # Define the loss function for the optimization
    @jax.jit
    def loss_function(theta, rho_target):
        sim.state["t"] = 0.0
        sim.state["psi"] = jnp.exp(1.0j * theta)
        sim.run()
        psi = sim.state["psi"]
        rho = jnp.abs(psi) ** 2
        return jnp.mean((rho - rho_target) ** 2)

    # Solve the inverse-problem (takes around 5 seconds on my macbook)
    opt = ScipyMinimize(
        method="l-bfgs-b", fun=loss_function, tol=1e-5, options={"disp": True}
    )
    theta = jnp.zeros_like(rho_target)
    t0 = time.time()
    sol = opt.run(theta, rho_target)
    print("Inverse-problem solve time (s): ", time.time() - t0)
    print("number of iterations: ", sol.state.iter_num)
    theta = jnp.mod(sol.params, 2.0 * jnp.pi) - jnp.pi
    print("Mean theta:", jnp.mean(theta))

    return theta


def rerun_simulation(sim, theta):
    # Re-run the simulation with the optimal initial conditions
    sim.state["t"] = 0.0
    sim.state["psi"] = jnp.exp(1.0j * theta)
    sim.run()
    print("Final time:", sim.state["t"])

    return sim.state["psi"]


def make_plot(psi, theta):
    # Plot the solution
    plt.figure(figsize=(6, 4), dpi=80)
    grid = plt.GridSpec(1, 2, wspace=0.0, hspace=0.0)
    ax1 = plt.subplot(grid[0, 0])
    ax2 = plt.subplot(grid[0, 1])
    plt.sca(ax1)
    plt.cla()
    plt.imshow(theta.T, cmap="bwr")
    plt.clim(-jnp.pi, jnp.pi)
    ax1.get_xaxis().set_visible(False)
    ax1.get_yaxis().set_visible(False)
    ax1.invert_yaxis()
    ax1.set_aspect("equal")
    plt.title(r"${\rm initial\,angle}(\psi)$")
    plt.sca(ax2)
    plt.cla()
    plt.imshow(jnp.log10(jnp.abs(psi) ** 2).T, cmap="inferno")
    plt.clim(-0.2, 0.2)
    ax2.get_xaxis().set_visible(False)
    ax2.get_yaxis().set_visible(False)
    ax2.invert_yaxis()
    ax2.set_aspect("equal")
    plt.title(r"${\rm final\,}\log_{10}(|\psi|^2)$")
    plt.tight_layout()
    plt.savefig("output.png", dpi=240)
    plt.show()


def main():
    sim = set_up_simulation(save=False)
    theta = solve_inverse_problem(sim)
    sim = set_up_simulation(save=True)
    psi = rerun_simulation(sim, theta)
    make_plot(psi, theta)


if __name__ == "__main__":
    main()

orszag_tang#

orszag_tang

See on GitHub: examples/orszag_tang#

README:

# Orszag-Tang

Simulate the Orszag-Tang vortex (2D MHD equations)

Philip Mocz (2025)

Usage:

```console
python orszag_tang.py
```

Takes around 55 seconds ('llf') / 105 seconds ('hlld') to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="output.png" alt="output" width="45%"/>
</div>


## References

[Orszag, S.A.; Tang, C.M.; Small-scale structure of two-dimensional magnetohydrodynamic turbulence. Journal of Fluid Mechanics (1979)](https://ui.adsabs.harvard.edu/abs/1979JFM....90..129O)

Script:

import jax.numpy as jnp

# TODO: REMOVE THE FOLLOWING LINES
import sys

sys.path.append("../../")

import adirondax as adx
from adirondax.hydro.common2d import get_curl, get_avg
import time
import matplotlib.pyplot as plt

# switch on for double precision
# jax.config.update("jax_enable_x64", True)

"""
Simulate the Orszag-Tang vortex

Philip Mocz (2025)
"""


def set_up_simulation():
    # Define the parameters for the simulation
    n = 512
    nt = -1  # 2400
    t_stop = 0.5
    gamma = 5.0 / 3.0
    box_size = 1.0
    dx = box_size / n

    params = {
        "physics": {
            "hydro": True,
            "magnetic": True,
        },
        "mesh": {
            "type": "cartesian",
            "resolution": [n, n],
            "box_size": [box_size, box_size],
        },
        "time": {
            "span": t_stop,
            "num_timesteps": nt,
        },
        "output": {
            "num_checkpoints": 100,
            "save": False,
            "plot_dynamic_range": 2.3,
        },
        "hydro": {
            "eos": {"type": "ideal", "gamma": gamma},
            "cfl": 0.6,
            "riemann_solver": "hlld",
            "slope_limiting": True,
        },
    }

    # Initialize the simulation
    sim = adx.Simulation(params)

    # Set initial conditions
    sim.state["t"] = jnp.array(0.0)
    X, Y = sim.mesh
    sim.state["rho"] = (gamma**2 / (4.0 * jnp.pi)) * jnp.ones(X.shape)
    sim.state["vx"] = -jnp.sin(2.0 * jnp.pi * Y)
    sim.state["vy"] = jnp.sin(2.0 * jnp.pi * X)
    P_gas = (gamma / (4.0 * jnp.pi)) * jnp.ones(X.shape)
    # (Az is at top-right node of each cell)
    xlin_node = jnp.linspace(dx, box_size, n)
    Xn, Yn = jnp.meshgrid(xlin_node, xlin_node, indexing="ij")
    Az = jnp.cos(4.0 * jnp.pi * Xn) / (4.0 * jnp.pi * jnp.sqrt(4.0 * jnp.pi)) + jnp.cos(
        2.0 * jnp.pi * Yn
    ) / (2.0 * jnp.pi * jnp.sqrt(4.0 * jnp.pi))
    bx, by = get_curl(Az, dx, dx)
    Bx, By = get_avg(bx, by)
    P_tot = P_gas + 0.5 * (Bx**2 + By**2)
    sim.state["P"] = P_tot
    sim.state["bx"] = bx
    sim.state["by"] = by

    return sim


def make_plot(sim):
    # Plot the solution
    plt.figure(figsize=(6, 4), dpi=80)
    plt.imshow(jnp.rot90(sim.state["rho"]), cmap="jet", vmin=0.06, vmax=0.5)
    plt.colorbar(label="density")
    plt.tight_layout()
    plt.savefig("output.png", dpi=240)
    plt.show()


def main():
    sim = set_up_simulation()

    # Evolve the system
    t0 = time.time()
    sim.run()
    print("Steps taken:", sim.steps_taken)
    print("Run time (s):", time.time() - t0)

    make_plot(sim)


if __name__ == "__main__":
    main()

rayleigh_taylor#

rayleigh_taylor

See on GitHub: examples/rayleigh_taylor#

README:

# Rayleigh-Taylor

Simulate the Rayleigh-Taylor Instability (Euler equations)

Philip Mocz (2025)

Usage:

```console
python rayleigh_taylor.py
```

Takes around 9 seconds to run on my macbook (cpu).


## Simulation snapshots

<div style="display:flex;flex-wrap:wrap;gap:8px">
  <img src="output.png" alt="output" width="45%"/>
</div>


## References

[Springel, V.; E pur si muove: Galilean-invariant cosmological hydrodynamical simulations on a moving mesh. Monthly Notices of the Royal Astronomical Society (2020)](https://ui.adsabs.harvard.edu/abs/2010MNRAS.401..791S)

Script:

import jax.numpy as jnp

# TODO: REMOVE THE FOLLOWING LINES
import sys

sys.path.append("../../")

import adirondax as adx
import time
import matplotlib.pyplot as plt

"""
Simulate the Rayleigh-Taylor Instability

Philip Mocz (2025)
"""


def set_up_simulation():
    # Define the parameters for the simulation
    nx = 64
    ny = 192
    nt = 13000  #  -1
    t_stop = 20.0

    params = {
        "physics": {
            "hydro": True,
            "external_potential": True,
        },
        "mesh": {
            "type": "cartesian",
            "resolution": [nx, ny],
            "box_size": [0.5, 1.5],
            "boundary_condition": ["periodic", "reflective"],
        },
        "time": {
            "span": t_stop,
            "num_timesteps": nt,
        },
        "output": {
            "num_checkpoints": 100,
            "save": True,
            "plot_dynamic_range": 2.0,
        },
        "hydro": {
            "eos": {"type": "ideal", "gamma": 1.4},
            "slope_limiting": False,
        },
    }

    # Initialize the simulation
    sim = adx.Simulation(params)

    # Set initial conditions
    # (heavy fluid on top of light)
    sim.state["t"] = 0.0
    X, Y = sim.mesh
    w0 = 0.0025
    P0 = 2.5
    g = 0.1
    sim.state["rho"] = 1.0 + (Y > 0.75)
    sim.state["vx"] = jnp.zeros(X.shape)
    sim.state["vy"] = (
        w0 * (1.0 - jnp.cos(4.0 * jnp.pi * X)) * (1.0 - jnp.cos(4.0 * jnp.pi * Y / 3.0))
    )
    sim.state["P"] = P0 - g * (Y - 0.75) * sim.state["rho"]

    # external potential
    def external_potential(x, y):
        V = g * y
        return V

    sim.external_potential = external_potential

    return sim


def make_plot(sim):
    # Plot the solution
    plt.figure(figsize=(4, 6), dpi=80)
    plt.imshow(sim.state["rho"].T, cmap="jet", vmin=0.8, vmax=2.2)
    plt.gca().invert_yaxis()
    plt.colorbar(label="density")
    plt.tight_layout()
    plt.savefig("output.png", dpi=240)
    plt.show()


def main():
    sim = set_up_simulation()

    # Evolve the system
    t0 = time.time()
    sim.run()
    print("Run time (s): ", time.time() - t0)
    print("Steps taken:", sim.steps_taken)

    make_plot(sim)


if __name__ == "__main__":
    main()