Source code for numpyro_schechter.distribution

import jax.numpy as jnp
from jax.scipy.special import gammaln
from numpyro.distributions import Distribution, constraints, Poisson
from .math_utils import custom_gammaincc, schechter_mag, SUPPORTED_ALPHA_DOMAIN_DEPTHS, LN10_0p4


[docs] class SchechterMag(Distribution): """ NumPyro-compatible distribution based on the single-component Schechter luminosity function in absolute magnitude space. This distribution models galaxy number densities using the Schechter parameterisation in magnitude space: φ(M) ∝ 10**[0.4(α + 1)(M* − M)] * exp[−10**(0.4(M* − M))] Parameters ---------- alpha : float Faint-end slope of the Schechter function. M_star : float Characteristic magnitude (M*). logphi : float Logarithm (base 10) of the normalisation φ* (number density per unit volume per magnitude). The internal `phi_star` is computed as 10**logphi. mag_obs : array-like Observed magnitudes used for computing the likelihood. include_poisson_term : bool, optional If True, includes a Poisson likelihood term on the total observed counts to constrain φ*. volume : float, optional Survey volume, used to scale expected counts for the Poisson term. Volume is in the same units as `phi_star` (e.g., if `phi_star` is per Mpc³ per mag, volume should be in Mpc³). alpha_domain_depth : int, optional Depth parameter for recurrence-based incomplete gamma calculation (default=3). Constraints ----------- - `alpha` must be real and non-integer. By default, the valid domain for `alpha + 1` is (−3, 3), corresponding to `alpha_domain_depth=3`. To support broader domains, increase `alpha_domain_depth` (see note). - `mag_obs` must contain values such that `10**(0.4(M* − M)) > 0`. Notes ----- - Normalisation is handled via a custom, recurrence-based computation of the upper incomplete gamma function, allowing support for automatic differentiation in JAX/NumPyro. To ensure compatibility with NUTS/HMC (which uses JAX's reverse-mode autodiff), only a fixed set of `alpha_domain_depth` values are supported. See `SchechterMag.supported_depths()` for available options (e.g., 3, 5, 10, 15). - If `include_poisson_term=True`, the log-likelihood includes: log L = Σ log(pdf_shape) + Poisson(N_obs | N_exp) where N_exp = volume * normalisation integral over observed magnitudes. """ support = constraints.real @property def has_rsample(self) -> bool: return False
[docs] @staticmethod def supported_depths(): """ Returns a list of supported values for `alpha_domain_depth`, corresponding to increasing valid alpha ranges. Larger `alpha_domain_depth` will see reduced performance due to the corresponding increase in recursions. """ return SUPPORTED_ALPHA_DOMAIN_DEPTHS
def __init__(self, alpha, M_star, logphi, mag_obs, include_poisson_term=False, volume=1.0, alpha_domain_depth=3, validate_args=None): # Store input parameters self.alpha = alpha self.M_star = M_star self.logphi = logphi self.phi_star = 10.0 ** logphi # convert log10(phi*) to phi* self.mag_obs = mag_obs self.alpha_domain_depth = alpha_domain_depth self.include_poisson_term = include_poisson_term self.N_obs = len(mag_obs) self.volume = volume # Magnitude range of the data M_min, M_max = jnp.min(mag_obs), jnp.max(mag_obs) # Precompute normalisation over observed magnitude range a = alpha + 1.0 x_min = 10 ** (0.4 * (M_star - M_max)) # faint light x_max = 10 ** (0.4 * (M_star - M_min)) # bright light norm_shape = LN10_0p4 * self.phi_star * (custom_gammaincc(a, x_min, recur_depth=self.alpha_domain_depth) - custom_gammaincc(a, x_max, recur_depth=self.alpha_domain_depth)) self.norm = jnp.where(norm_shape > 0, norm_shape, jnp.inf) # avoid /0 super().__init__(batch_shape=(), event_shape=(), validate_args=validate_args) def __str__(self): return ( f"SchechterMag distribution\n" f" alpha = {self.alpha}\n" f" M_star = {self.M_star}\n" f" logphi = {self.logphi}\n" f" alpha_domain_depth = {self.alpha_domain_depth}" ) def __repr__(self): return str(self)
[docs] def log_prob(self, value): # Compute the Schechter PDF (shape only) at given magnitudes pdf_shape = LN10_0p4 * schechter_mag(self.phi_star, self.M_star, self.alpha, value) / self.norm logp_shape = jnp.log(pdf_shape + 1e-30) # avoid log(0) if self.include_poisson_term: # Expected total counts for the survey N_exp = self.norm * self.volume # Add Poisson term for total number counts log_poisson = Poisson(N_exp).log_prob(self.N_obs) return jnp.sum(logp_shape) + log_poisson else: # Shape-only likelihood return jnp.sum(logp_shape)
[docs] def sample(self, key, sample_shape=()): raise NotImplementedError("Sampling not implemented for SchechterMag.")
[docs] class DoubleSchechterMag(Distribution): """ NumPyro-compatible distribution based on the double-component Schechter luminosity function in absolute magnitude space. This distribution models galaxy number densities using the sum of two Schechter functions: φ(M) = φ1(M) + φ2(M) with each component of the form: φ_i(M) ∝ 10**[0.4(α_i + 1)(M*_i − M)] * exp[−10**(0.4(M*_i − M))] Parameters ---------- alpha1, alpha2 : float Faint-end slopes of the two Schechter components. M_star1, M_star2 : float Characteristic magnitudes (M*) of the first and second components. logphi1, logphi2 : float Logarithm (base 10) of the normalisation φ*_i (number density per unit volume per magnitude) for the first and second Schechter components, respectively. Internally converted as phi_star_i = 10**logphi_i. mag_obs : array-like Observed magnitudes used for computing the likelihood. include_poisson_term : bool, optional If True, includes a Poisson likelihood term on the total observed counts to constrain φ*_i. volume : float, optional Survey volume, used to scale expected counts for the Poisson term. Volume is in the same units as `phi_star_i` (e.g., if `phi_star_i` is per Mpc³ per mag, volume should be in Mpc³). alpha_domain_depth : int, optional Depth parameter for recurrence-based incomplete gamma calculation (default=3). Constraints ----------- - `alpha1` and `alpha2` must be real and non-integer. By default, the valid domain for `alpha_i + 1` is (−3, 3), corresponding to `alpha_domain_depth=3`. To support broader domains, increase `alpha_domain_depth` (see note). - `mag_obs` must contain values such that `10**(0.4(M*_i − M)) > 0`. Notes ----- - Normalisation is handled via a custom, recurrence-based computation of the upper incomplete gamma function, allowing support for automatic differentiation in JAX/NumPyro. To ensure compatibility with NUTS/HMC (which uses JAX's reverse-mode autodiff), only a fixed set of `alpha_domain_depth` values are supported. See `SchechterMag.supported_depths()` for available options (e.g., 3, 5, 10, 15). - If `include_poisson_term=True`, the log-likelihood includes: log L = Σ log(pdf_shape) + Poisson(N_obs | N_exp) where N_exp = volume * (normalisation1 + normalisation2) over observed magnitudes. """ support = constraints.real @property def has_rsample(self): return False
[docs] @staticmethod def supported_depths(): """ Returns a list of supported values for `alpha_domain_depth`, corresponding to increasing valid alpha ranges. Larger `alpha_domain_depth` will see reduced performance due to the corresponding increase in recursions. """ return SUPPORTED_ALPHA_DOMAIN_DEPTHS
def __init__(self, alpha1, M_star1, logphi1, alpha2, M_star2, logphi2, mag_obs, include_poisson_term=False, volume=1.0, alpha_domain_depth=3, validate_args=None): # First Schechter component parameters self.alpha1 = alpha1 self.M_star1 = M_star1 self.logphi1 = logphi1 self.phi_star1 = 10.0 ** logphi1 # Second Schechter component parameters self.alpha2 = alpha2 self.M_star2 = M_star2 self.logphi2 = logphi2 self.phi_star2 = 10.0 ** logphi2 # Data and settings self.mag_obs = mag_obs self.alpha_domain_depth = alpha_domain_depth self.include_poisson_term = include_poisson_term self.N_obs = len(mag_obs) self.volume = volume # Magnitude range of the data M_min, M_max = jnp.min(mag_obs), jnp.max(mag_obs) # First component normalisation a1 = alpha1 + 1.0 x1_min = 10 ** (0.4 * (M_star1 - M_max)) x1_max = 10 ** (0.4 * (M_star1 - M_min)) norm1 = self.phi_star1 * ( custom_gammaincc(a1, x1_min, recur_depth=alpha_domain_depth) - custom_gammaincc(a1, x1_max, recur_depth=alpha_domain_depth) ) # Second component normalisation a2 = alpha2 + 1.0 x2_min = 10 ** (0.4 * (M_star2 - M_max)) x2_max = 10 ** (0.4 * (M_star2 - M_min)) norm2 = self.phi_star2 * ( custom_gammaincc(a2, x2_min, recur_depth=alpha_domain_depth) - custom_gammaincc(a2, x2_max, recur_depth=alpha_domain_depth) ) # Combined normalisation for both components self.norm = LN10_0p4 * (norm1 + norm2) self.norm = jnp.where(self.norm > 0, self.norm, jnp.inf) # avoid /0 super().__init__(batch_shape=(), event_shape=(), validate_args=validate_args) def __str__(self): return ( f"DoubleSchechterMag distribution\n" f" alpha1={self.alpha1}, M_star1={self.M_star1}, logphi1={self.logphi1}\n" f" alpha2={self.alpha2}, M_star2={self.M_star2}, logphi2={self.logphi2}\n" f" alpha_domain_depth={self.alpha_domain_depth}" ) def __repr__(self): return str(self)
[docs] def log_prob(self, value): # Compute PDFs for each component and sum phi1 = schechter_mag(self.phi_star1, self.M_star1, self.alpha1, value) phi2 = schechter_mag(self.phi_star2, self.M_star2, self.alpha2, value) pdf_shape = LN10_0p4 * (phi1 + phi2) / self.norm logp_shape = jnp.log(pdf_shape + 1e-30) # avoid log(0) if self.include_poisson_term: # Expected total counts from both components N_exp = self.norm * self.volume # Add Poisson term for number counts log_poisson = Poisson(N_exp).log_prob(self.N_obs) return jnp.sum(logp_shape) + log_poisson else: # Shape-only likelihood return jnp.sum(logp_shape)
[docs] def sample(self, key, sample_shape=()): raise NotImplementedError("Sampling not implemented for DoubleSchechterMag.")