Source code for numpyro_schechter.math_utils

import jax
import jax.numpy as jnp
from jax import lax, jit
from jax.scipy.special import gamma, gammaincc
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike

jax.config.update("jax_enable_x64", True)

SUPPORTED_ALPHA_DOMAIN_DEPTHS = [3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 30]
LN10_0p4 = 0.4 * jnp.log(10.0)

[docs] def schechter_mag(phi_star, M_star, alpha, M): """ Schechter density in magnitude space (unnormalised). """ # M: absolute magnitude x = 10 ** (0.4 * (M_star - M)) return ( LN10_0p4 * phi_star * x ** (alpha + 1) * jnp.exp(-x) )
[docs] def custom_gammaincc(s: ArrayLike, x: ArrayLike, recur_depth: int = 3) -> Array: """ Computes Γ(s, x) using a recurrence-based method. Valid for real, non-integer s. The default supported domain for s is approximately (-3, 3), corresponding to `recur_depth=3`. Increase `recur_depth` to extend this range. Input `x` must be positive. Raises: ValueError: If `recur_depth` is not in the supported set. """ s, x = promote_args_inexact("custom_gammaincc", s, x) if recur_depth not in _dispatch_table: raise ValueError(f"Unsupported recur_depth={recur_depth}. " f"Must be one of {SUPPORTED_ALPHA_DOMAIN_DEPTHS}.") return _dispatch_table[recur_depth](s, x)
[docs] def s_positive(s: ArrayLike, x: ArrayLike) -> Array: """ Regularised upper incomplete gamma function * Gamma(s) """ return gamma(s) * gammaincc(s, x)
[docs] def compute_gamma(s: ArrayLike, x: ArrayLike, recur_depth: int) -> Array: def recur(gamma_val, s, x): return (gamma_val - x ** s * jnp.exp(-x)) / s def compute_recurrence(carry, _): gamma_val, s = carry new_s = s - 1 gamma_val_new = lax.cond( jnp.isinf(gamma_val), lambda _: jnp.inf, lambda _: recur(gamma_val, new_s, x), operand=None ) return (gamma_val_new, new_s), gamma_val_new s_start = s + recur_depth gamma_start = s_positive(s_start, x) initial_carry = (gamma_start, s_start) result, _ = lax.scan(compute_recurrence, initial_carry, None, length=recur_depth) return result[0]
_dispatch_table = {} for depth in SUPPORTED_ALPHA_DOMAIN_DEPTHS: def _make_fn(recur_depth=depth): @jit def fn(s, x): return compute_gamma(s, x, recur_depth) return fn _dispatch_table[depth] = _make_fn()