Lecture 17 – Collision theory#

from __future__ import annotations

import warnings
from typing import Callable

import numpy as np
import plotly.graph_objects as go
import sympy as sp
from ampform.io import aslatex
from ampform.sympy import UnevaluatedExpression, implement_doit_method
from ampform.sympy.math import create_expression
from IPython.display import Math
from plotly.colors import DEFAULT_PLOTLY_COLORS
from plotly.subplots import make_subplots


This notebook is an attempt to recreate the Mathematica notebook provided by Miguel Albaladejo. Another nice tutorial about the complex plane is this Julia notebook by Mikhail Mikhasenko.

Riemann sheets#

Square root example#

There are multiple solutions for \(x\) to the equation \(y^2 = x\) – the fact that we usually take \(y = \sqrt{x}\) to be the solution to this equation is just a matter of convention. It would be more complete to represent the solution as a set of points in the complex plane. In this case, we have the set \(S = \left\{\left(z, w\right)\in\mathbb{C}^2 | w^2=z\right\}\). This is set forms a Riemann surface in \(\mathbb{C}^2\) space.

def plot_riemann_surfaces(
    funcs: list[Callable],
    func_unicode: str,
    boundaries: tuple[complex, float] | tuple[complex, complex] = (0, 1),
    resolution: int | tuple[int, int] = 50,
    colorize: bool = True,
    mask: Callable[[np.ndarray, np.ndarray], bool] | None = None,
) -> None:
    X, Y = create_meshgrid(boundaries, resolution)
    Z = X + Y * 1j
    T = [f(Z) for f in funcs]
    if mask is not None:
        the_mask = np.full(Z.shape, False)
        for t in T:
            the_mask |= mask(Z, t)
        if np.all(the_mask):
            raise ValueError("All points were masked away")
        X[the_mask] = np.nan
        Y[the_mask] = np.nan
        Z[the_mask] = np.nan
        for t in T:
            t[the_mask] = np.nan

    vmax = max(max(t.imag.max(), t.real.max()) for t in T)
    style = lambda i, t: dict(
        if colorize
        else [[0, "rgb(0, 0, 0)"], [1, DEFAULT_PLOTLY_COLORS[i - 1]]],
        surfacecolor=t.real if colorize else np.ones(shape=t.shape),
    S_im = [
        go.Surface(x=X, y=Y, z=t.imag, **style(i, t), name=f"Sheet {i}")
        for i, t in enumerate(T, 1)
    S_re = [
        go.Surface(x=X, y=Y, z=t.real, **style(i, t), name=f"Sheet {i}")
        for i, t in enumerate(T, 1)
    fig = make_subplots(
        specs=[[{"type": "surface"}, {"type": "surface"}]],
        subplot_titles=(f"Im {func_unicode}", f"Re {func_unicode}"),
    for i in range(len(funcs)):
        fig.add_trace(S_im[i], col=1, row=1)
        fig.add_trace(S_re[i], col=2, row=1)
    fig.update_layout(height=550, width=1_000)

def create_meshgrid(
    boundaries: tuple[complex, float] | tuple[complex, complex] = (0, 1),
    resolution: int | tuple[int, int] = 50,
) -> tuple[np.ndarray, np.ndarray]:
    if isinstance(resolution, tuple):
        x_res, y_res = resolution
        x_res, y_res = resolution, resolution
    box_min, box_max = boundaries
    if isinstance(box_max, (float, int)):
        pos, r_max = box_min, box_max
        R, Θ = np.meshgrid(
            np.linspace(0, r_max, num=x_res),
            np.linspace(-np.pi, +np.pi, num=y_res),
        X = R * np.cos(Θ) + pos
        Y = R * np.sin(Θ) + pos
        return X, Y
    x1 = complex(box_min).real
    x2 = complex(box_max).real
    y1 = complex(box_min).imag
    y2 = complex(box_max).imag
    return np.meshgrid(
        np.linspace(x1, x2, num=x_res),
        np.linspace(y1, y2, num=y_res),

def cut_t(
    cutoff: float | tuple[float, float]
) -> Callable[[np.ndarray, np.ndarray], bool]:
    if isinstance(cutoff, tuple):
        re_cut, im_cut = cutoff
        re_cut, im_cut = cutoff, cutoff
    return lambda z, t: (np.abs(t.real) > re_cut) | (np.abs(t.imag) > im_cut)
    funcs=[lambda z: -np.sqrt(z), lambda z: +np.sqrt(z)],
        lambda z: -1 / np.sqrt(z),
        lambda z: +1 / np.sqrt(z),

Note also that since \(y = e^{x + 2n \pi i}\) for \(\forall n \in \mathbb{Z}\), we have that \(x = \log(y) + 2n\pi i\):

        lambda z: np.log(z) - 2j * np.pi,
        lambda z: np.log(z) + 2j * np.pi,
    func_unicode="log z",
    boundaries=(0, np.e**2),
    mask=cut_t((np.e, np.nan)),
Video explainers

Definition of the G(s) functions#

class SignedSqrt(UnevaluatedExpression):
    is_commutative = True
    is_real = False

    def __new__(cls, z, **hints) -> SignedSqrt:
        return create_expression(cls, z, **hints)

    def evaluate(self) -> sp.Expr:
        z = self.args[0]
        return sp.sqrt(abs(z)) * sp.exp(sp.I * PosArg(z) / 2)

    def _latex(self, printer, *args) -> str:
        z = printer._print(self.args[0])
        return Rf"\sqrt[+]{{{z}}}"

class PosArg(UnevaluatedExpression):
    is_commutative = True

    def __new__(cls, z, **hints) -> SignedSqrt:
        return create_expression(cls, z, **hints)

    def evaluate(self) -> sp.Expr:
        z = self.args[0]
        arg = sp.arg(z)
        return sp.Piecewise(
            (arg + 2 * sp.pi, sp.im(z) < 0),
            (arg, True),

    def _latex(self, printer, *args) -> str:
        z = printer._print(self.args[0])
        return Rf"\arg^+\left({z}\right)"

z = sp.Symbol("z", complex=True)
Math(aslatex({e: e.evaluate() for e in [SignedSqrt(z), PosArg(z)]}))
\[\begin{split}\displaystyle \begin{array}{rcl} \sqrt[+]{z} &=& e^{\frac{i \arg^+\left(z\right)}{2}} \sqrt{\left|{z}\right|} \\ \arg^+\left(z\right) &=& \begin{cases} \arg{\left(z \right)} + 2 \pi & \text{for}\: \operatorname{im}{\left(z\right)} < 0 \\\arg{\left(z \right)} & \text{otherwise} \end{cases} \\ \end{array}\end{split}\]
    funcs=[sp.lambdify(z, SignedSqrt(z).doit())],
    mask=lambda z, t: (np.abs(z.imag) < 1e-5) & (z.real > 0),
    resolution=(30, 301),
class G(UnevaluatedExpression):
    is_commutative = True
    is_real = False

    def __new__(cls, s, m, g0, sign=+1, **hints) -> Sigma:
        return create_expression(cls, s, m, g0, sign, **hints)

    def evaluate(self) -> sp.Expr:
        s, m, g0, sign = self.args
        sigma = Sigma(s, m)
        g = (g0 - sigma * sp.log((sigma - 1) / (sigma + 1))) / (16 * sp.pi**2)
        return sp.Piecewise(
            (g, sign < 0),
            (G(s, m, g0, sign=-1) + 2 * sp.I * sigma / (16 * sp.pi), True),

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        sign = self.args[-1]
        number = "I" if sign < 0 else "II"
        return f"G_{{{number}}}({s})"

class Sigma(UnevaluatedExpression):
    is_commutative = True
    is_real = False

    def __new__(cls, s, m, **hints) -> Sigma:
        return create_expression(cls, s, m, **hints)

    def evaluate(self) -> sp.Expr:
        s, m = self.args
        return SignedSqrt(1 - 4 * m**2 / s)

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"\sigma\left({s}\right)"

s, g0 = sp.symbols("s g0", complex=True)
m = sp.Symbol("m", real=True, nonnegative=True)
sigma = Sigma(s, m)
G1 = G(s, m, g0, sign=-1)
G2 = G(s, m, g0, sign=+1)
definitions = {e: e.doit(deep=False) for e in [G1, G2, sigma]}
\[\begin{split}\displaystyle \begin{array}{rcl} G_{I}(s) &=& \frac{g_{0} - \log{\left(\frac{\sigma\left(s\right) - 1}{\sigma\left(s\right) + 1} \right)} \sigma\left(s\right)}{16 \pi^{2}} \\ G_{II}(s) &=& G_{I}(s) + \frac{i \sigma\left(s\right)}{8 \pi} \\ \sigma\left(s\right) &=& \sqrt[+]{- \frac{4 m^{2}}{s} + 1} \\ \end{array}\end{split}\]
substitutions = {
    m: 139,
    g0: 3.0,
\[\begin{split}\displaystyle \begin{array}{rcl} m &=& 139 \\ g_{0} &=& 3.0 \\ \end{array}\end{split}\]
G1_expr = G1.doit().xreplace(substitutions)
G2_expr = G2.doit().xreplace(substitutions)
assert G1_expr.free_symbols == {s}
assert G2_expr.free_symbols == {s}
G1_func = sp.lambdify(s, G1_expr)
G2_func = sp.lambdify(s, G2_expr)
        lambda z: G1_func(z**2),
        lambda z: G2_func(z**2),
    boundaries=(240 - 40j, 320 + 40j),
    resolution=(50, 401),
    mask=lambda z, t: np.abs(z.imag) == 0,

T-matrix definition#

class S(UnevaluatedExpression):
    is_commutative = True
    is_real = False

    def __new__(cls, s, m, , GV, , g0, sign=+1, **hints) -> Sigma:
        return create_expression(cls, s, m, , GV, , g0, sign, **hints)

    def evaluate(self) -> sp.Expr:
        s, m, , GV, , g0, sign = self.args
        return 1 - 2 * sp.I * Sigma(s, m) / (16 * sp.pi) * T(*self.args)

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        sign = self.args[-1]
        number = "I" if sign < 0 else "II"
        return f"S_{{{number}}}({s})"

class T(UnevaluatedExpression):
    is_commutative = True
    is_real = False

    def __new__(cls, s, m, , GV, , g0, sign=+1, **hints) -> Sigma:
        return create_expression(cls, s, m, , GV, , g0, sign, **hints)

    def evaluate(self) -> sp.Expr:
        s, m, , GV, , g0, sign = self.args
        return 1 / (1 / V1(s, m, , GV, ) - G(s, m, g0, sign))

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        sign = self.args[-1]
        number = "I" if sign < 0 else "II"
        return f"T_{{{number}}}({s})"

class V1(UnevaluatedExpression):
    is_commutative = True
    is_real = False

    def __new__(cls, s, m, , GV, , **hints) -> Sigma:
        return create_expression(cls, s, m, , GV, , **hints)

    def evaluate(self) -> sp.Expr:
        s, m, , GV,  = self.args
        return -(2 * p2(s, m)) / (3 * **2) * (
            1 - GV**2 / **2 * 2 * s / (s - **2)
        ) - GV**2 / **4 * p2(s, m) * h(**2 / (2 * p2(s, m)))

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"V_1\left({s}\right)"

class h(UnevaluatedExpression):
    is_commutative = True

    def __new__(cls, a, **hints) -> Sigma:
        return create_expression(cls, a, **hints)

    def evaluate(self) -> sp.Expr:
        a = self.args[0]
        return -sp.Mul(
            sp.Rational(2, 3),
            (1 + 6 * a + 3 * a**2),
        ) + a * (2 + 3 * a + a**2) * sp.log(1 + 2 / a)

    def _latex(self, printer, *args) -> str:
        a = printer._print(self.args[0])
        return Rf"h\left({a}\right)"

class p2(UnevaluatedExpression):
    is_commutative = True

    def __new__(cls, s, m, **hints) -> Sigma:
        return create_expression(cls, s, m, **hints)

    def evaluate(self) -> sp.Expr:
        s, m = self.args
        return s / 4 - m**2

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"p^2\left({s}\right)"

a, , GV,  = sp.symbols("a m_rho, G_V f_pi")
_exprs = [
    S(s, m, , GV, , g0, sign=-1),
    T(s, m, , GV, , g0, sign=-1),
    T(s, m, , GV, , g0, sign=+1),
    V1(s, m, , GV, ),
    p2(s, m),
Math(aslatex({e: e.doit(deep=False) for e in _exprs}))
\[\begin{split}\displaystyle \begin{array}{rcl} S_{I}(s) &=& - \frac{i \sigma\left(s\right) T_{I}(s)}{8 \pi} + 1 \\ T_{I}(s) &=& \frac{1}{- G_{I}(s) + \frac{1}{V_1\left(s\right)}} \\ T_{II}(s) &=& \frac{1}{- G_{II}(s) + \frac{1}{V_1\left(s\right)}} \\ V_1\left(s\right) &=& - \frac{G_{V}^{2} h\left(\frac{m_{\rho}^{2}}{2 p^2\left(s\right)}\right) p^2\left(s\right)}{f_{\pi}^{4}} - \frac{2 \left(- \frac{2 G_{V}^{2} s}{f_{\pi}^{2} \left(- m_{\rho}^{2} + s\right)} + 1\right) p^2\left(s\right)}{3 f_{\pi}^{2}} \\ h\left(a\right) &=& a \left(a^{2} + 3 a + 2\right) \log{\left(1 + \frac{2}{a} \right)} - \frac{2 \cdot \left(3 a^{2} + 6 a + 1\right)}{3} \\ p^2\left(s\right) &=& - m^{2} + \frac{s}{4} \\ \end{array}\end{split}\]
gv = sp.Symbol("g_v")
substitutions = {
    : 87.3,
    GV: sp.sqrt(gv**2 * **2) / 2,
    gv: 1,
    m: 139,
    : 770,
    g0: -3,
\[\begin{split}\displaystyle \begin{array}{rcl} f_{\pi} &=& 87.3 \\ G_{V} &=& \frac{\sqrt{f_{\pi}^{2} g_{v}^{2}}}{2} \\ g_{v} &=& 1 \\ m &=& 139 \\ m_{\rho} &=& 770 \\ g_{0} &=& -3 \\ \end{array}\end{split}\]
T_exprs = [
    T(s, m, , GV, , g0, sign)
    for sign in [-1, +1]
T_funcs = [sp.lambdify(s, expr) for expr in T_exprs]
x = np.linspace(500, 1_100, num=200)
y = np.linspace(1e-5, 150, num=100)
X, Yn = np.meshgrid(x, -y)
X, Yp = np.meshgrid(x, +y)
Zn = X + Yn * 1j
Zp = X + Yp * 1j
Tn = T_funcs[1](Zn**2)
Tp = T_funcs[0](Zp**2)

vmax = 100
sty = lambda t: dict(
Sn = go.Surface(x=X, y=Yn, z=Tn.real, **sty(Tn), name="Unphysical")
Sp = go.Surface(
    x=X, y=Yp, z=Tp.real, **sty(Tp), name="Physical", colorbar_title="Re T"
y = Yp[0]
z = x + y * 1j
line = go.Scatter3d(
    line=dict(color="darkgreen", width=1),
fig = go.Figure(data=[Sn, Sp, line])
fig.update_layout(height=550, width=600)
    xaxis_title_text="Re s",
    yaxis_title_text="Im s",
    zaxis_title_text="Im T",
    zaxis_range=[-vmax, +vmax],
sty = lambda t: dict(
Sn = go.Surface(x=X, y=Yn, z=Tn.imag, **sty(Tn), name="Unphysical")
Sp = go.Surface(
    x=X, y=Yp, z=Tp.imag, **sty(Tp), name="Physical", colorbar_title="Im T"
y = Yp[0]
z = x + y * 1j
line = go.Scatter3d(
    line=dict(color="darkgreen", width=1),
fig = go.Figure(data=[Sn, Sp, line])
fig.update_layout(height=550, width=600)
    xaxis_title_text="Re s",
    yaxis_title_text="Im s",
    zaxis_title_text="Re T",
    zaxis_range=[-vmax, +vmax],