Numerical integration#

%matplotlib widget

Hide code cell source

import os
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import cache, partial
from typing import Literal, cast

import ipywidgets as w
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import quadax
import sympy as sp
from ipympl.backend_nbagg import Canvas
from IPython.display import SVG, Math, display
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection, QuadMesh
from matplotlib.lines import Line2D
from scipy.integrate import quad_vec

from ampform.dynamics.form_factor import BlattWeisskopfSquared, FormFactor
from ampform.dynamics.phasespace import (
    ChewMandelstamIntegral,
    PhaseSpaceFactorSplitSqrt,
)
from ampform.io import aslatex
from ampform.kinematics.phasespace import BreakupMomentumSplitSqrt

# cspell:disable-next-line
Algorithm = Literal["quadcc", "quadgk", "quadts", "romberg", "rombergts", "quad_vec"]
jax.config.update("jax_enable_x64", True)


def hide_toolbars(canvas: Canvas) -> None:
    canvas.header_visible = False
    canvas.footer_visible = False
    canvas.toolbar_visible = False

The dispersion integral of the phasespace factor cannot be solved analytically for higher angular momenta (\(\ell > 0\)). We therefore need to perform a numerical integration. However, since the integrand of the dispersion integral contains singularities, it is important to choose the numerical integration algorithm.

This notebook contains an interactive widget that shows the numerical stability and performance of different numerical integration algorithms when computing the dispersion integral along (close to) the real axis of the complex energy plane.

Tip

The conclusion is that Romberg’s method (quadax.romberg()) is the best choice for computing the dispersion integral close to the physical (real) axis. For computing the dispersion integral further away from the axis, the Gauss–Kronrod method (e.g. quadax.quadgk()) is accurate enough and faster than Romberg’s method.

  • Use quadax.romberg() when computing the dispersion integral for an amplitude model fit (physical axis).

  • Use quadax.quadgk() when determining the pole positions of the amplitude model (complex plane).

Chew–Mandelstam dispersion integral#

The general form of the Chew–Mandelstam dispersion integral \(\Sigma_\ell(s)\) is an integral over the product \(\rho(s) \, n_\ell^2(s)\) of the phase space factor \(\rho(s)\) (PhaseSpaceFactor) and the square of the form factor \(n_\ell^2(s)\) (FormFactor). The required formulas are:

Hide code cell source

s, m1, m2, z = sp.symbols("s m1 m2 z", nonnegative=True)
ell = sp.Symbol("ell", integer=True, nonnegative=True)
cm = ChewMandelstamIntegral(s, m1, m2, ell, dummify=False)
ff = FormFactor(s, m1, m2, ell)
rho = PhaseSpaceFactorSplitSqrt(s, m1, m2)
q = BreakupMomentumSplitSqrt(s, m1, m2)
bl = BlattWeisskopfSquared(z, ell)
max_ell = 5
src = aslatex({
    **{e: e.doit(deep=False) for e in (cm, rho, ff, q)},
    **{bl.subs(ell, i): bl.subs(ell, i).doit() for i in range(max_ell + 1)},
})
Math(src)
\[\begin{split}\displaystyle \begin{aligned} \Sigma_{\ell}\left(s\right) \;&=\; \frac{s - \left(m_{1} + m_{2}\right)^{2}}{\pi} \int\limits_{\left(m_{1} + m_{2}\right)^{2}}^{\infty} \frac{\mathcal{F}_{\ell}\left(x, m_{1}, m_{2}\right)^{2} \rho\left(x\right)}{\left(x - \left(m_{1} + m_{2}\right)^{2}\right) \left(- i \epsilon - s + x\right)}\, dx \\ \rho\left(s\right) \;&=\; \frac{\sqrt{s - \left(m_{1} - m_{2}\right)^{2}} \sqrt{s - \left(m_{1} + m_{2}\right)^{2}}}{s} \\ \mathcal{F}_{\ell}\left(s, m_{1}, m_{2}\right) \;&=\; \sqrt{B_{\ell}^2\left(q^2\left(s\right)\right)} \\ q\left(s\right) \;&=\; \frac{\sqrt{s - \left(m_{1} - m_{2}\right)^{2}} \sqrt{s - \left(m_{1} + m_{2}\right)^{2}}}{2 \sqrt{s}} \\ B_{0}^2\left(z\right) \;&=\; 1 \\ B_{1}^2\left(z\right) \;&=\; \frac{2 z}{z + 1} \\ B_{2}^2\left(z\right) \;&=\; \frac{13 z^{2}}{z^{2} + 3 z + 9} \\ B_{3}^2\left(z\right) \;&=\; \frac{277 z^{3}}{z^{3} + 6 z^{2} + 45 z + 225} \\ B_{4}^2\left(z\right) \;&=\; \frac{12746 z^{4}}{z^{4} + 10 z^{3} + 135 z^{2} + 1575 z + 11025} \\ B_{5}^2\left(z\right) \;&=\; \frac{998881 z^{5}}{z^{5} + 15 z^{4} + 315 z^{3} + 6300 z^{2} + 99225 z + 893025} \\ \end{aligned}\end{split}\]

where we have used a split square root for a cleaner cut structure in the complex energy plane.

Here, \(i\epsilon\) indicates that \(\Sigma_\ell(s)\) with \(s\in\mathbb{R}\) is formulated just above the real axis in order to avoid \(s'-s=0\). However, as can be seen in the widget below, \(i\epsilon\) is not required when \(s\in\mathbb{C}\), giving us the general form:

\[\begin{split} \begin{aligned} \Sigma_\ell\left(s\right) \;&=\; \frac{s - s_\mathrm{thr}}{\pi} \int\limits_{s_\mathrm{thr}}^{\infty} \frac {\rho\!\left(s'\right) n_\ell^2\!\left(s'\right) ds'} {\left(s' - s_\mathrm{thr}\right) \left(s'- s\right)} \\ n_\ell^2(s') \;&=\; \mathcal{F}_\ell^2\!\left(s', m_1, m_2\right) \\ s_\mathrm{thr} \;&=\; (m_1 + m_2)^2 \end{aligned} \end{split}\]

The reason is that the function \(\rho(s')\,n_\ell^2(s')\) does not have a branch cut along the real axis above the \(s_\mathrm{thr}\) threshold, so there is no need to approach some contour around such a potential discontinuity. The form of the dispersion integral is explained and derived here.

When using this form of the dispersion integral to compute \(\Sigma_\ell(s)\) for \(s\in\mathbb{R}\), the caller should use the fact that \(\lim_{\epsilon\to 0^+} \Sigma_\ell(s + i\epsilon)\). In the numerical implementation, that means giving s + epsilon * 1j as input when s is a array with floats. The challenge is to find a value for \(\epsilon\): the smaller the value for \(\epsilon\), the closer we are to the physical axis, but the less accurate the numerical integration becomes. The widget below allows you to explore this trade-off for different numerical integration algorithms.

Numerical implementation#

In this notebook, we implement the formulas listed above numerical rather than lambdifying the symbolic expressions, so that we have full control over the numerical implementation and to make the implementation recognizable to pure array-oriented workflows.

Hide code cell source

def integrate_numerically(
    s: npt.NDArray[np.float64],
    m1: float,
    m2: float,
    ell: int = 0,
    start_offset: float = 0,
    algorithm: Callable = quadax.quadcc,
    **configuration,
):
    s_thr = (m1 + m2) ** 2
    if algorithm is quad_vec:
        integral, _ = algorithm(
            partial(integrand, s=s, m1=m1, m2=m2, ell=ell),
            s_thr + start_offset,
            np.inf,
            **configuration,
        )
    else:
        integral, _ = algorithm(
            jax.tree_util.Partial(integrand, s=s, m1=m1, m2=m2, ell=ell),
            interval=[s_thr + start_offset, jnp.inf],
            **configuration,
        )
    return (s - s_thr) * integral / jnp.pi


@jax.jit
def integrand(sp, s, m1, m2, ell):
    s_thr = (m1 + m2) ** 2
    return rho_func(sp, m1, m2) * n2(s, m1, m2, ell) / ((sp - s_thr) * (sp - s))


@jax.jit
def rho_func(s, m1, m2):
    return jnp.sqrt(s - (m1 - m2) ** 2) * jnp.sqrt(s - (m1 + m2) ** 2) / s


def n2(s, m1, m2, ell):
    return blatt_weisskopf_squared(q(s, m1, m2), ell)


def blatt_weisskopf_squared(z, ell):
    return jnp.select(
        [ell == 0, ell == 1, ell == 2, ell == 3, ell == 4, ell == 5],
        [
            1,
            2 * z / (z + 1),
            13 * z**2 / (z**2 + 3 * z + 9),
            277 * z**3 / (z**3 + 6 * z**2 + 45 * z + 225),
            12746 * z**4 / (z**4 + 10 * z**3 + 135 * z**2 + 1575 * z + 11025),
            998881
            * z**5
            / (z**5 + 15 * z**4 + 315 * z**3 + 6300 * z**2 + 99225 * z + 893025),
        ],
        default=jnp.nan,
    )


@jax.jit
def q(s, m1, m2):
    return (
        jnp.sqrt(s - (m1 - m2) ** 2) * jnp.sqrt(s - (m1 + m2) ** 2) / (2 * jnp.sqrt(s))
    )

In the case of \(S\)-waves (\(\ell=0\)), we can compare the result of the integration to the analytical solution to the integral:

Hide code cell source

\[\begin{split}\displaystyle \begin{aligned} \hat{\Sigma}_0\left(s\right) \;&=\; \frac{\frac{2 q^\mathrm{c}\left(s\right)}{\sqrt{s}} \log{\left(\frac{m_{1}^{2} + m_{2}^{2} + 2 \sqrt{s} q^\mathrm{c}\left(s\right) - s}{2 m_{1} m_{2}} \right)} - \left(m_{1}^{2} - m_{2}^{2}\right) \left(- \frac{1}{\left(m_{1} + m_{2}\right)^{2}} + \frac{1}{s}\right) \log{\left(\frac{m_{1}}{m_{2}} \right)}}{\pi} \\ \end{aligned}\end{split}\]

Hide code cell source

@jax.jit
def sigma0(s, m1, m2):
    return (1 / jnp.pi) * (
        (2 * q(s, m1, m2) / jnp.sqrt(s))
        * jnp.log((m1**2 + m2**2 - s + 2 * q(s, m1, m2) * jnp.sqrt(s)) / (2 * m1 * m2))
        - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * jnp.log(m1 / m2)
    )

Interactive visualization#

Hide code cell source

cont = dict(continuous_update=False)
physics_sliders = dict(
    projection=w.RadioButtons(
        description="Projection",
        options=["real", "imag", "abs"],
        value="imag",
        layout=w.Layout(width="max-content"),
    ),
    m1=w.FloatSlider(value=0.13, min=0.0, max=2.0, step=0.01, description="m₁", **cont),
    m2=w.FloatSlider(value=0.98, min=0.0, max=2.0, step=0.01, description="m₂", **cont),
    ell=w.IntSlider(value=0, min=0, max=5, description="ℓ", **cont),
    y_lim=w.FloatRangeSlider(
        description="y range",
        min=-5,
        max=10,
        value=(-1, +1),
        readout_format=".1f",
        **cont,
    ),
    z_max=w.FloatLogSlider(
        value=1.0,
        min=-3,
        max=3,
        description="Color scale",
        step=0.25,
        readout_format=".3g",
        **cont,
    ),
    resolution=w.IntSlider(
        value=200,
        min=100,
        max=5000,
        description="Resolution",
        step=100,
        **cont,
    ),
    epsilon=w.FloatLogSlider(
        value=1e-4,
        min=-12,
        max=0.5,
        description="s + iϵ",
        step=0.25,
        readout_format=".0e",
        **cont,
    ),
    start_offset=w.FloatLogSlider(
        value=1e-20,
        min=-64,
        max=2,
        description="thr + ϵ",
        step=0.25,
        readout_format=".0e",
        **cont,
    ),
)
algorithm_sliders = dict(
    algorithm_name=w.RadioButtons(
        description="Integration algorithm",
        layout=w.Layout(width="150px"),
        options=Algorithm.__args__,
    ),
    order=w.RadioButtons(
        description="Integration order",
        layout=w.Layout(width="100px"),
        options=[8, 16, 32, 64, 128, 256],
        value=256,
    ),
    epsabs=w.FloatLogSlider(
        value=1e-5,
        min=-12,
        max=0,
        description="epsabs",
        step=0.5,
        readout_format=".0e",
        **cont,
    ),
    epsrel=w.FloatLogSlider(
        value=1e-5,
        min=-12,
        max=0,
        description="epsrel",
        step=0.5,
        readout_format=".0e",
        **cont,
    ),
    divmax=w.IntSlider(
        value=20,
        min=1,
        max=30,
        description="divmax",
        **cont,
    ),
    limit=w.IntSlider(
        value=50,
        min=1,
        max=100,
        description="limit",
        disabled=True,
        **cont,
    ),
    disable_limit=w.Checkbox(value=True, description="Disable limit", **cont),
)
sliders = dict(
    **physics_sliders,
    **algorithm_sliders,
)


def on_algorithm_change(change):
    algorithm_name = change["new"]
    sliders["limit"].disabled = True
    order = sliders["order"]
    divmax_slider = sliders["divmax"]
    disable_limit = sliders["disable_limit"]
    if algorithm_name == "quad_vec":
        disable_limit.disabled = False
        divmax_slider.disabled = True
        order.disabled = True
    if algorithm_name in {"quadcc", "quadgk", "quadts"}:
        disable_limit.disabled = True
        divmax_slider.disabled = True
        order.disabled = False
        if algorithm_name == "quadcc":
            _update_order(order, [8, 16, 32, 64, 128, 256])
        if algorithm_name == "quadgk":
            _update_order(order, [15, 21, 31, 41, 51, 61])
        if algorithm_name == "quadts":
            _update_order(order, [41, 61, 81, 101])
    if algorithm_name in {"romberg", "rombergts"}:
        disable_limit.disabled = True
        divmax_slider.disabled = False
        order.disabled = True


def _update_order(slider, options: list[int]) -> None:
    current_value = slider.value
    slider.options = options
    slider.value = min(options, key=lambda o: abs(o - current_value))


w.jslink((sliders["disable_limit"], "value"), (sliders["limit"], "disabled"))
sliders["algorithm_name"].observe(on_algorithm_change, names="value")
sliders["algorithm_name"].value = "romberg"  # ty:ignore[invalid-assignment]  # trigger changes
timer_box = cast("w.ValueWidget", w.HTML())
ui = w.VBox([
    tabs := w.Tab([
        w.HBox([
            physics_sliders["projection"],
            w.VBox(list(physics_sliders.values())[1:5]),
            w.VBox(list(physics_sliders.values())[5:]),
        ]),
        w.HBox([
            algorithm_sliders["algorithm_name"],
            algorithm_sliders["order"],
            w.VBox(list(algorithm_sliders.values())[2:]),
        ]),
    ]),
    timer_box,
])
tabs.titles = ["Physics", "Integration"]

Hide code cell source

@dataclass(kw_only=True)
class PlotContent:
    mesh: QuadMesh
    real: tuple[Line2D, Line2D, Line2D]
    imag: tuple[Line2D, Line2D, Line2D]
    pseudothreshold: tuple[Line2D, Line2D, Line2D]
    threshold: tuple[Line2D, Line2D, Line2D]
    s_line: Line2D
    integrated_interval: LineCollection


@dataclass
class DispersionIntegralWidget:
    _c: PlotContent | None = field(init=False, default=None)
    s: npt.NDArray[np.float64] = field(init=False)
    S: npt.NDArray[np.complex128] = field(init=False)
    ax_rho: Axes
    ax_real: Axes
    ax_imag: Axes
    real_lim: tuple[float, float]
    imag_max: float
    grid: tuple[int, int]

    def __post_init__(self):
        X, Y = np.meshgrid(
            np.linspace(*self.real_lim, self.grid[0]),
            np.linspace(-self.imag_max, +self.imag_max, self.grid[1]),
        )
        self.S = X + 1j * Y

    def __call__(
        self,
        *,
        projection: Literal["real", "imag", "abs"],
        m1: float,
        m2: float,
        ell: int,
        epsilon: float,
        start_offset: float,
        z_max: float,
        algorithm_name: Algorithm,
        disable_limit: bool,
        resolution: int,
        y_lim: tuple[float, float],
        **alg_kwargs,
    ) -> None:
        S = self.S
        x = generate_domain(*self.real_lim, resolution)
        s = x + 1j * epsilon
        algorithm = get_algorithm(algorithm_name)
        if disable_limit:
            alg_kwargs.pop("limit", None)
        if ell == 0:
            x_cmp = x
            z_exact = sigma0(x + 1e-12j, m1, m2)
            z_ana = sigma0(s, m1, m2)
        else:
            x_cmp = z_exact = z_ana = np.zeros(shape=(1,))
        alg_kwargs = get_algorithm_options(algorithm_name, **alg_kwargs)
        start_time = time.perf_counter()
        z_num = integrate_numerically(
            s, m1, m2, ell, start_offset, algorithm, **alg_kwargs
        )
        z_num.block_until_ready()
        end_time = time.perf_counter()
        duration = end_time - start_time
        timer_box.value = f"Computation time: <b>{format_time(duration)}</b> for {resolution:,} points"
        if np.all(np.isnan(z_num)):
            timer_box.value += " (<font color='red'>all values are NaN</font>)"
        Z = rho_func(S, m1, m2) * n2(S, m1, m2, ell)
        if projection == "abs":
            Z = jnp.abs(Z)
        else:
            Z = getattr(Z, projection)
        s_neg = (m1 - m2) ** 2
        s_pos = (m1 + m2) ** 2
        if self._c is None:
            self._c = PlotContent(
                mesh=self.ax_rho.pcolormesh(
                    S.real,
                    S.imag,
                    Z,
                    cmap="RdBu_r",
                    rasterized=True,
                    vmin=-z_max,
                    vmax=+z_max,
                ),
                imag=(
                    self.ax_imag.plot(x_cmp, z_exact.imag, color="black", lw=0.2)[0],
                    self.ax_imag.plot(x_cmp, z_ana.imag, alpha=0.5, color="C0")[0],
                    self.ax_imag.plot(x, z_num.imag, color="C2", lw=0.5)[0],
                ),
                real=(
                    self.ax_real.plot(x_cmp, z_exact.real, color="black", lw=0.2)[0],
                    self.ax_real.plot(x_cmp, z_ana.real, alpha=0.5, color="C0")[0],
                    self.ax_real.plot(x, z_num.real, color="C2", lw=0.5)[0],
                ),
                pseudothreshold=tuple(
                    ax.axvline(s_neg, c="C4", label=R"$(m_1-m_2)^2$", ls="dotted")
                    for ax in (self.ax_rho, self.ax_real, self.ax_imag)
                ),
                threshold=tuple(
                    ax.axvline(s_pos, c="C3", label=R"$(m_1+m_2)^2$", ls="dotted")
                    for ax in (self.ax_rho, self.ax_real, self.ax_imag)
                ),
                s_line=self.ax_rho.axhline(
                    y=epsilon,
                    color="C2",
                    label=R"$s+i\epsilon$",
                    linewidth=0.5,
                ),
                integrated_interval=self.ax_rho.hlines(
                    y=0,
                    xmin=s_pos + start_offset,
                    xmax=self.ax_rho.get_xlim()[1],
                    color="black",
                    label="Integration path",
                    linewidth=1,
                ),
            )
        else:
            self._c.mesh.set_array(Z)
            self._c.mesh.set_clim(-z_max, +z_max)
            for i, (x_i, z) in enumerate(
                zip([x_cmp, x_cmp, x], [z_exact, z_ana, z_num], strict=True)
            ):
                self._c.imag[i].set_data(x_i, z.imag)
                self._c.real[i].set_data(x_i, z.real)
            for line in self._c.pseudothreshold:
                line.set_xdata([s_neg])
            for line in self._c.threshold:
                line.set_xdata([s_pos])
            self._c.s_line.set_ydata([epsilon])
            self._c.integrated_interval.set_segments([
                [
                    [s_pos + start_offset, 0],
                    [self.ax_rho.get_xlim()[1], 0],
                ]
            ])
        for ax in (self.ax_real, self.ax_imag):
            ax.set_ylim(*y_lim)
        self.ax_rho.set_title(R"$\rho(s) \, n_\ell^2(s)$")
        self.ax_real.set_ylabel(Rf"Re $\Sigma_{ell}(s)$")
        self.ax_imag.set_ylabel(Rf"Im $\Sigma_{ell}(s)$")


def get_algorithm(algorithm_name: Algorithm) -> Callable:
    if algorithm_name == "quad_vec":
        return quad_vec
    return getattr(quadax, algorithm_name)


def get_algorithm_options(algorithm_name: Algorithm, **kwargs) -> dict:
    def filter_dict(d: dict, keys: set[str]) -> dict:
        return {k: v for k, v in d.items() if k in keys | common_keys}

    common_keys = {"epsabs", "epsrel"}
    if algorithm_name in {"quadcc", "quadgk", "quadts"}:
        return filter_dict(kwargs, {"order"})
    if algorithm_name in {"romberg", "rombergts"}:
        return filter_dict(kwargs, {"divmax"})
    if algorithm_name == "quad_vec":
        return filter_dict(kwargs, {"limit"})
    msg = f"Unknown algorithm: {algorithm_name}"
    raise ValueError(msg)


@cache
def generate_domain(start, stop, resolution: int) -> npt.NDArray[np.float64]:
    return np.linspace(start, stop, resolution, dtype=np.float64)


def format_time(seconds: float) -> str:
    if seconds < 1e-3:
        return f"{1e6 * seconds:,.1f} µs"
    if seconds < 1:
        return f"{1e3 * seconds:,.1f} ms"
    if seconds < 60:
        return f"{seconds:,.1f} s"
    mm = int(seconds // 60)
    ss = seconds % 60
    return f"{mm} min {ss:,.1f} s"

Hide code cell source

plt.rc("font", size=12)
fig, axes = plt.subplots(figsize=(10, 8), nrows=3, sharex=True)
hide_toolbars(fig.canvas)
fig.subplots_adjust(bottom=0.1, hspace=0.1, left=0.08, right=0.95, top=0.95)
ax1, ax2, ax3 = axes
for ax in axes.ravel():
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
ax1.spines["bottom"].set_visible(False)
ax1.set_ylabel("Im $s$")
ax3.set_xlabel("Re $s$")
for ax in (ax2, ax3):
    ax.axhline(0, color="gray", lw=0.5)
plot_widget = DispersionIntegralWidget(
    *axes,
    real_lim=(0, 6),
    imag_max=2,
    grid=(500, 300),
)
out = w.interactive_output(plot_widget, sliders)
ax1.legend(loc="upper right", bbox_to_anchor=(1.0, 1.15))
plt.show()
display(out, ui)
../../_images/30d616faf0d4ade20fb546249654b4c314f55e01c6a587e9a188f69d76685875.svg