Source code for ler.gw_source_population.prior_functions

# -*- coding: utf-8 -*-
"""
Module for JIT-compiled functions used in GW source population simulations.

This module provides Numba JIT-compiled functions for efficient sampling and
computation of merger rate densities, star formation rates, and mass distributions for various gravitational wave source populations including BBH, BNS, and NSBH.

Key Features: \n
- Merger rate density functions for PopI/II, PopIII, and Primordial BBH \n
- Star formation rate models (Madau & Dickinson 2014, Madau & Fragos 2017) \n
- Mass distribution samplers (power-law Gaussian, broken power-law, bimodal) \n
- JIT-compiled for high performance with Numba \n

Copyright (C) 2026 Hemanta Ph. Distributed under MIT License.
"""

import numpy as np
import math
from numba import njit
from scipy.interpolate import CubicSpline
from astropy.cosmology import LambdaCDM

from ..utils import inverse_transform_sampler, cumulative_trapezoid, is_njitted

# ------------------------------
# Merger rate density functions
# ------------------------------
@njit(fastmath=True)
[docs] def merger_rate_density_bbh_oguri2018_function(zs, R0=19 * 1e-9, b2=1.6, b3=2.1, b4=30): """ Compute the merger rate density for PopI/II BBH. Reference: Oguri et al. (2018). The source-frame rate density is .. math:: R(z) = R_0 \\frac{(b_4 + 1) e^{b_2 z}}{b_4 + e^{b_3 z}}. Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. R0 : ``float`` Local merger rate density at low redshift (Mpc^-3 yr^-1). \n default: 19e-9 (GWTC-4) b2 : ``float`` Fitting parameter. \n default: 1.6 b3 : ``float`` Fitting parameter. \n default: 2.1 b4 : ``float`` Fitting parameter. \n default: 30 Returns ------- rate_density : ``float`` or ``numpy.ndarray`` Merger rate density. Examples -------- >>> import numpy as np >>> from ler.gw_source_population import merger_rate_density_bbh_oguri2018_function >>> rate_density = merger_rate_density_bbh_oguri2018_function(zs=np.array([0.1])) """ return R0 * (b4 + 1) * np.exp(b2 * zs) / (b4 + np.exp(b3 * zs))
@njit(fastmath=True)
[docs] def merger_rate_density_bbh_popIII_ken2022_function(zs, R0=19.2 * 1e-9, aIII=0.66, bIII=0.3, zIII=11.6): """ Compute the unnormalized merger rate density for PopIII BBH. Reference: Ng et al. (2022). The model is .. math:: R(z) = R_0 \\frac{e^{a_{\\rm III}(z-z_{\\rm III})}} {b_{\\rm III} + a_{\\rm III} e^{(a_{\\rm III}+b_{\\rm III})(z-z_{\\rm III})}}. Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. R0 : ``float`` Normalization constant. \n default: 19.2e-9 aIII : ``float`` Fitting parameter. \n default: 0.66 bIII : ``float`` Fitting parameter. \n default: 0.3 zIII : ``float`` Characteristic redshift. \n default: 11.6 Returns ------- rate_density : ``float`` or ``numpy.ndarray`` Merger rate density. Examples -------- >>> import numpy as np >>> from ler.gw_source_population import merger_rate_density_bbh_popIII_ken2022_function >>> rate_density = merger_rate_density_bbh_popIII_ken2022_function(zs=np.array([0.1])) """ return ( R0 * np.exp(aIII * (zs - zIII)) / (bIII + aIII * np.exp((aIII + bIII) * (zs - zIII))) )
@njit(fastmath=True)
[docs] def merger_rate_density_madau_dickinson2014_function(zs, R0=19 * 1e-9, a=0.015, b=2.7, c=2.9, d=5.6): """ Compute the merger rate density for BBH using Madau & Dickinson (2014) model. The shape follows .. math:: \\psi(z) = a \\frac{(1+z)^b}{1 + \\left((1+z)/c\\right)^d}, \\qquad R(z) = R_0 \\frac{\\psi(z)}{\\psi(0)}. Reference: Eq. 15 of https://arxiv.org/pdf/1403.0007 Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. R0 : ``float`` Local merger rate density (Mpc^-3 yr^-1). \n default: 19e-9 a : ``float`` Normalization parameter. \n default: 0.015 b : ``float`` Low-redshift power-law slope. \n default: 2.7 c : ``float`` Turnover redshift parameter. \n default: 2.9 d : ``float`` High-redshift power-law slope. \n default: 5.6 Returns ------- rate_density : ``float`` or ``numpy.ndarray`` Merger rate density (Mpc^-3 yr^-1). Examples -------- >>> import numpy as np >>> from ler.gw_source_population import merger_rate_density_madau_dickinson2014_function >>> rate_density = merger_rate_density_madau_dickinson2014_function(zs=np.array([0.1])) """ def density_helper(zs): return sfr_madau_dickinson2014( zs=zs, a=a, b=b, c=c, d=d, ) density_zs = R0 * density_helper(zs)/ density_helper(np.array([0.]))[0] return density_zs
@njit(fastmath=True)
[docs] def merger_rate_density_madau_dickinson_belczynski_ng_function(zs, R0=19 * 1e-9, alpha_F=2.57, beta_F=5.83, c_F=3.36): """ Compute BBH merger rate density following Ng et al. (2021). This model uses a Madau-Dickinson-like functional form to fit the merger rate density of field BHs, accounting for time delays and metallicity effects. Coefficients from Madau & Dickinson (2014) are translated as: B-> alpha_F, D-> beta_F, C-> c_F. .. math:: \\psi_F(z) = \\frac{(1+z)^{\\alpha_F}} {1 + \\left((1+z)/c_F\\right)^{\\beta_F}}, \\qquad R(z) = R_0 \\frac{\\psi_F(z)}{\\psi_F(0)}. Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. R0 : ``float`` Local merger rate density (Mpc^-3 yr^-1). \n default: 19e-9 alpha_F : ``float`` Low-redshift power-law slope. \n default: 2.57 beta_F : ``float`` High-redshift power-law slope. \n default: 5.83 c_F : ``float`` Turnover redshift parameter. \n default: 3.36 Returns ------- rate_density : ``float`` or ``numpy.ndarray`` Merger rate density (Mpc^-3 yr^-1). Examples -------- >>> import numpy as np >>> from ler.gw_source_population import merger_rate_density_madau_dickinson_belczynski_ng_function >>> rate_density = merger_rate_density_madau_dickinson_belczynski_ng_function(zs=np.array([0.1])) """ def density_helper(zs): return sfr_madau_dickinson2014( zs=zs, a=1.0, b=alpha_F, c=c_F, d=beta_F, ) density_zs = R0 * density_helper(zs)/ density_helper(np.array([0.]))[0] return density_zs
[docs] def merger_rate_density_bbh_primordial_ken2022_function(zs, cosmology=None, R0=0.044 * 1e-9, t0=13.786885302009708): """ Compute the merger rate density for Primordial BBH. Reference: Ng et al. (2022). The rate is parameterized by cosmic age: .. math:: R(z) = R_0 \\left(\\frac{t(z)}{t_0}\\right)^{-34/37}. Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. cosmology : ``astropy.cosmology`` or ``None`` Cosmology object for age calculations. \n default: LambdaCDM(H0=70, Om0=0.3, Ode0=0.7, Tcmb0=0.0, Neff=3.04, m_nu=None, Ob0=0.0) R0 : ``float`` Normalization constant. \n default: 0.044e-9 t0 : ``float`` Present age of the Universe (Gyr). \n default: 13.786885302009708 Returns ------- rate_density : ``float`` or ``numpy.ndarray`` Merger rate density. Examples -------- >>> import numpy as np >>> from ler.gw_source_population import merger_rate_density_bbh_primordial_ken2022_function >>> rate_density = merger_rate_density_bbh_primordial_ken2022_function(zs=np.array([0.1])) """ if cosmology is None: cosmology = LambdaCDM(H0=70, Om0=0.3, Ode0=0.7, Tcmb0=0.0, Neff=3.04, m_nu=None, Ob0=0.0) rate_density = R0 * (cosmology.age(z=zs).value / t0) ** (-34 / 37) return rate_density
[docs] def sfr_madau_fragos2017_with_bbh_td(zs, R0=19 * 1e-9): """ Compute the merger rate density for BBH. This is computed from star formation rate, Madau & Fragos (2017), with an additional time delay. This function relies on pre-generated data points. Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. R0 : ``float`` Local merger rate density (Mpc^-3 yr^-1). \n default: 19e-9 Returns ------- SFR : ``float`` or ``numpy.ndarray`` Star formation rate (Mpc^-3 yr^-1). """ rm = np.array([1.00304765, 1.00370075, 1.00449545, 1.00546251, 1.00663937, 1.00807168, 1.00981505, 1.01193727, 1.01046483, 1.01359803, 1.01741386, 1.02206193, 1.02772495, 1.03462601, 1.04303746, 1.05329142, 1.07093106, 1.08624215, 1.10489848, 1.12760683,1.15519183, 1.18858451, 1.22878158, 1.27676494, 1.33727882, 1.40335222, 1.47956936, 1.56759515, 1.6711375 , 1.79690371, 1.95410462, 2.15201042, 2.36151109, 2.66742932, 3.04354598, 3.49048755, 3.98122536, 4.42347511, 4.61710896, 4.30190679, 3.50890876, 2.37699066, 1.41830834, 0.77944771,0.40667706, 0.20463758, 0.09975143, 0.04745116]) zs_ = np.geomspace(0.001, 10, 48) spline = CubicSpline(zs_, rm, extrapolate=True) SFR = spline(zs)*R0 return SFR
[docs] def sfr_madau_dickinson2014_with_bbh_td(zs, R0=19 * 1e-9): """ Compute the merger rate density for BBH. This is computed from star formation rate, Madau & Dickinson (2014), with an additional time delay. This function relies on pre-generated data points. Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. R0 : ``float`` Local merger rate density (Mpc^-3 yr^-1). \n default: 19e-9 Returns ------- SFR : ``float`` or ``numpy.ndarray`` Star formation rate (Mpc^-3 yr^-1). """ rm = np.array([1.00292325, 1.0035494 , 1.0043112 , 1.00523807, 1.0063658 , 1.00773798, 1.00940767, 1.01143948, 1.00997839, 1.01297699, 1.01662662, 1.02106895, 1.02647649, 1.03305927, 1.04107277, 1.05082743, 1.06802831, 1.08255749, 1.10022606, 1.12169013, 1.14772134, 1.1792097 , 1.21715406, 1.26263694, 1.32051095, 1.38462461, 1.45997648, 1.54851567, 1.65349288, 1.78046645, 1.93811129, 2.1354612 , 2.34086287, 2.63664802, 2.98892341, 3.38353439, 3.76990612, 4.03489696, 4.00806904, 3.56766897, 2.86966689, 2.01282062, 1.29696347, 0.78913584, 0.46166281, 0.26226345, 0.14509118, 0.07854392]) zs_ = np.geomspace(0.001, 10, 48) spline = CubicSpline(zs_, rm, extrapolate=True) SFR = spline(zs)*R0 return SFR
[docs] def sfr_madau_fragos2017_with_bns_td(zs, R0=89 * 1e-9): """ Compute the merger rate density for BNS. This is computed from star formation rate, Madau & Fragos (2017), with an additional time delay. This function relies on pre-generated data points. Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. R0 : ``float`` Local merger rate density (Mpc^-3 yr^-1). \n default: 89e-9 Returns ------- SFR : ``float`` or ``numpy.ndarray`` Star formation rate (Mpc^-3 yr^-1). """ rm = np.array([1.00309364, 1.00375139, 1.00455175, 1.00552568, 1.00671091, 1.00815339, 1.00990912, 1.01204635, 1.00757017, 1.01071962, 1.01455507, 1.01922677, 1.02491815, 1.03185311, 1.04030479, 1.05060602, 1.06970166, 1.08508957, 1.10382838, 1.12661829, 1.15427005, 1.18768774, 1.22781836, 1.27555711, 1.31791484, 1.38209039, 1.4555543 , 1.5397332 , 1.63806934, 1.75685668, 1.90448546, 2.08862044, 2.34440211, 2.63899295, 2.99729389, 3.41567274, 3.86324106, 4.24545603, 4.37018218, 4.00555831, 3.10525751, 2.06354992, 1.20906304, 0.65233811, 0.33356891, 0.16397688, 0.08024945, 0.036953]) zs_ = np.geomspace(0.001, 10, 48) spline = CubicSpline(zs_, rm, extrapolate=True) SFR = spline(zs)*R0 return SFR
[docs] def sfr_madau_dickinson2014_with_bns_td(zs, R0=89 * 1e-9): """ Compute the merger rate density for BNS. This is computed from star formation rate, Madau & Dickinson (2014), with an additional time delay. This function relies on pre-generated data points. Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. R0 : ``float`` Local merger rate density (Mpc^-3 yr^-1). \n default: 89e-9 Returns ------- SFR : ``float`` or ``numpy.ndarray`` Star formation rate (Mpc^-3 yr^-1). """ rm = np.array([1.0029945 , 1.00362259, 1.00438674, 1.00531645, 1.00644763, 1.00782396, 1.00949865, 1.01153645, 1.00240992, 1.00539605, 1.00903013, 1.01345293, 1.01883579, 1.02538714, 1.03336026, 1.04306247, 1.05841698, 1.07283625, 1.09035966, 1.11162909, 1.1373949 , 1.16851509, 1.20594024, 1.25068092, 1.3085267 , 1.37111306, 1.44421094, 1.52948237, 1.62985636, 1.75058453, 1.90010572, 2.0870216 , 2.33573104, 2.6218286 , 2.96031682, 3.3343522 , 3.69149889, 3.92099769, 3.86227814, 3.40811745, 2.59314381, 1.79588097, 1.14260538, 0.686002 , 0.3954134 , 0.22083291, 0.11548455, 0.06064368]) zs_ = np.geomspace(0.001, 10, 48) spline = CubicSpline(zs_, rm, extrapolate=True) SFR = spline(zs)*R0 return SFR
# ------------------------------ # Star formation rate functions # ------------------------------ @njit(fastmath=True)
[docs] def sfr_madau_fragos2017(zs, a=0.01, b=2.6, c=3.2, d=6.2): """ Compute star formation rate using Madau & Fragos (2017) model. Reference: https://arxiv.org/pdf/1606.07887.pdf Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. a : ``float`` Normalization parameter. \n default: 0.01 b : ``float`` Low-redshift power-law slope. \n default: 2.6 c : ``float`` Turnover redshift parameter. \n default: 3.2 d : ``float`` High-redshift power-law slope. \n default: 6.2 Returns ------- SFR : ``float`` or ``numpy.ndarray`` Star formation rate (Msun yr^-1 Mpc^-3). """ return a * (1+zs)**b / (1 + ((1+zs)/c)**d)
@njit(fastmath=True)
[docs] def sfr_madau_dickinson2014(zs, a=0.015, b=2.7, c=2.9, d=5.6): """ Compute star formation rate using Madau & Dickinson (2014) model. Reference: Eq. 15 of https://arxiv.org/pdf/1403.0007 Parameters ---------- zs : ``float`` or ``numpy.ndarray`` Source redshifts. a : ``float`` Normalization parameter. \n default: 0.015 b : ``float`` Low-redshift power-law slope. \n default: 2.7 c : ``float`` Turnover redshift parameter. \n default: 2.9 d : ``float`` High-redshift power-law slope. \n default: 5.6 Returns ------- SFR : ``float`` or ``numpy.ndarray`` Star formation rate (Msun yr^-1 Mpc^-3). Examples -------- >>> from ler.gw_source_population import sfr_madau_dickinson2014 >>> sfr = sfr_madau_dickinson2014(zs=np.array([0.1])) """ return a * (1 + zs) ** b / (1 + ((1 + zs) / c) ** d)
# ------------------------------ # Binary mass functions # ------------------------------ @njit(fastmath=True) def _lognormal_psi(m, Mc, sigma): """ Compute lognormal mass function (equation 1, Ng et al. 2022). Parameters ---------- m : ``numpy.ndarray`` Mass values (solar masses). Mc : ``float`` Characteristic mass scale (solar masses). sigma : ``float`` Width parameter of the lognormal distribution. Returns ------- psi : ``numpy.ndarray`` Mass function values. """ return np.exp(-np.log(m / Mc) ** 2 / (2 * sigma**2)) / ( np.sqrt(2 * np.pi) * sigma * m ) @njit(fastmath=True)
[docs] def ng2022_lognormal_joint_pdf(m1, m2, Mc, sigma): """ Compute joint probability density for lognormal 2D mass distribution. Parameters ---------- m1 : ``numpy.ndarray`` Primary masses (solar masses). m2 : ``numpy.ndarray`` Secondary masses (solar masses). Mc : ``float`` Characteristic mass scale (solar masses). sigma : ``float`` Width parameter of the lognormal distribution. Returns ------- pdf : ``numpy.ndarray`` Joint probability density values. """ return ( (m1 + m2) ** (36 / 37) * (m1 * m2) ** (32 / 37) * _lognormal_psi(m1, Mc, sigma) * _lognormal_psi(m2, Mc, sigma) )
@njit(fastmath=True)
[docs] def binary_masses_BBH_popIII_lognormal_rvs(size, m_min=1.0, m_max=100.0, Mc=20.0, sigma=0.3, chunk_size=10000): """ Generate random samples of binary masses for PopIII BBH (lognormal model). Draws samples from a 2D lognormal distribution in mass space using rejection sampling. Reference: Ng et al. (2022). Parameters ---------- size : ``int`` Number of binary systems to sample. m_min : ``float`` Minimum mass (solar masses). default: 1.0 m_max : ``float`` Maximum mass (solar masses). default: 100.0 Mc : ``float`` Characteristic mass scale (solar masses). default: 20.0 sigma : ``float`` Width parameter of the lognormal distribution. default: 0.3 chunk_size : ``int`` Number of samples per rejection sampling chunk. default: 10000 Returns ------- m1 : ``numpy.ndarray`` Primary mass samples (solar masses). m2 : ``numpy.ndarray`` Secondary mass samples (solar masses). Examples -------- >>> from ler.gw_source_population.prior_functions import binary_masses_BBH_popIII_lognormal_rvs >>> m1, m2 = binary_masses_BBH_popIII_lognormal_rvs(size=1000) >>> print(m1.shape, m2.shape) (1000,) (1000,) """ # rejection sampling initialization m1 = np.random.uniform(m_min, m_max, chunk_size) m2 = np.random.uniform(m_min, m_max, chunk_size) z = ng2022_lognormal_joint_pdf(m1, m2, Mc, sigma) zmax = np.max(z) # rejection sampling in chunks m1_sample = np.zeros(size) m2_sample = np.zeros(size) old_num = 0 while True: m1_try = np.random.uniform(m_min, m_max, size=chunk_size) m2_try = np.random.uniform(m_min, m_max, size=chunk_size) z_try = np.random.uniform(0, zmax, size=chunk_size) zmax = max(zmax, np.max(z_try)) idx = z_try < ng2022_lognormal_joint_pdf(m1_try, m2_try, Mc, sigma) new_num = old_num + np.sum(idx) if new_num >= size: m1_sample[old_num:size] = m1_try[idx][: size - old_num] m2_sample[old_num:size] = m2_try[idx][: size - old_num] break else: m1_sample[old_num:new_num] = m1_try[idx] m2_sample[old_num:new_num] = m2_try[idx] old_num = new_num # swap masses to ensure m1 >= m2 idx = m1_sample < m2_sample m1_sample[idx], m2_sample[idx] = m2_sample[idx], m1_sample[idx] return m1_sample, m2_sample
@njit(fastmath=True)
[docs] def binary_masses_BBH_primordial_lognormal_rvs(size, m_min=1.0, m_max=100.0, Mc=20.0, sigma=0.3, chunk_size=10000): """ Generate random samples of binary masses for Primordial BBH (lognormal model). Draws samples from a 2D lognormal distribution in mass space using rejection sampling. Reference: Ng et al. (2022). Parameters ---------- size : ``int`` Number of binary systems to sample. m_min : ``float`` Minimum mass (solar masses). default: 1.0 m_max : ``float`` Maximum mass (solar masses). default: 100.0 Mc : ``float`` Characteristic mass scale (solar masses). default: 20.0 sigma : ``float`` Width parameter of the lognormal distribution. default: 0.3 chunk_size : ``int`` Number of samples per rejection sampling chunk. default: 10000 Returns ------- m1 : ``numpy.ndarray`` Primary mass samples (solar masses). m2 : ``numpy.ndarray`` Secondary mass samples (solar masses). Examples -------- >>> from ler.gw_source_population.prior_functions import binary_masses_BBH_primordial_lognormal_rvs >>> m1, m2 = binary_masses_BBH_primordial_lognormal_rvs(size=1000) >>> print(m1.shape, m2.shape) (1000,) (1000,) """ # rejection sampling initialization m1 = np.random.uniform(m_min, m_max, chunk_size) m2 = np.random.uniform(m_min, m_max, chunk_size) z = ng2022_lognormal_joint_pdf(m1, m2, Mc, sigma) zmax = np.max(z) # rejection sampling in chunks m1_sample = np.zeros(size) m2_sample = np.zeros(size) old_num = 0 while True: m1_try = np.random.uniform(m_min, m_max, size=chunk_size) m2_try = np.random.uniform(m_min, m_max, size=chunk_size) z_try = np.random.uniform(0, zmax, size=chunk_size) zmax = max(zmax, np.max(z_try)) idx = z_try < ng2022_lognormal_joint_pdf(m1_try, m2_try, Mc, sigma) new_num = old_num + np.sum(idx) if new_num >= size: m1_sample[old_num:size] = m1_try[idx][: size - old_num] m2_sample[old_num:size] = m2_try[idx][: size - old_num] break else: m1_sample[old_num:new_num] = m1_try[idx] m2_sample[old_num:new_num] = m2_try[idx] old_num = new_num # swap masses to ensure m1 >= m2 idx = m1_sample < m2_sample m1_sample[idx], m2_sample[idx] = m2_sample[idx], m1_sample[idx] return m1_sample, m2_sample
@njit(fastmath=True) def _erf(x): """ Compute the error function using Abramowitz & Stegun approximation. Parameters ---------- x : ``float`` or ``numpy.ndarray`` Input value(s). Returns ------- result : ``float`` or ``numpy.ndarray`` Error function value(s). """ # constants for A&S formula 7.1.26 p = 0.3275911 a1 = 0.254829592 a2 = -0.284496736 a3 = 1.421413741 a4 = -1.453152027 a5 = 1.061405429 sign = np.sign(x) x = np.abs(x) t = 1.0 / (1.0 + p * x) y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * np.exp(-x * x) return sign * y @njit(fastmath=True) def _compute_normalization_factor(mu, sigma, mmin, mmax): """ Compute normalization factor for truncated Gaussian. Parameters ---------- mu : ``float`` Mean of the Gaussian. sigma : ``float`` Standard deviation of the Gaussian. mmin : ``float`` Minimum mass bound. mmax : ``float`` Maximum mass bound. Returns ------- N : ``float`` Normalization factor. """ part1 = (mmax - mu) / (np.sqrt(2) * sigma) part2 = (mmin - mu) / (np.sqrt(2) * sigma) N = np.sqrt(2 * np.pi) * sigma * (0.5 * (_erf(part1) - _erf(part2))) return N @njit(fastmath=True) def _bimodal_unnormalized(m, w=0.643, muL=1.352, sigmaL=0.08, muR=1.88, sigmaR=0.3, mmin=1.0, mmax=2.3): """ Compute unnormalized bimodal Gaussian density for BNS mass distribution. Parameters ---------- m : ``numpy.ndarray`` Mass values (solar masses). w : ``float`` Weight of the left (low-mass) peak. default: 0.643 muL : ``float`` Mean of the left peak (solar masses). default: 1.352 sigmaL : ``float`` Standard deviation of the left peak (solar masses). default: 0.08 muR : ``float`` Mean of the right peak (solar masses). default: 1.88 sigmaR : ``float`` Standard deviation of the right peak (solar masses). default: 0.3 mmin : ``float`` Minimum mass (solar masses). default: 1.0 mmax : ``float`` Maximum mass (solar masses). default: 2.3 Returns ------- density : ``numpy.ndarray`` Unnormalized probability density values. """ # left peak (before truncation) pdf_unnormL = np.exp(-((m - muL) ** 2) / (2 * sigmaL**2)) # right peak (before truncation) pdf_unnormR = np.exp(-((m - muR) ** 2) / (2 * sigmaR**2)) # mixture before normalization density = w * pdf_unnormL + (1 - w) * pdf_unnormR return density @njit(fastmath=True)
[docs] def bimodal_pdf(m, w=0.643, muL=1.352, sigmaL=0.08, muR=1.88, sigmaR=0.3, mmin=1.0, mmax=2.3): """ Evaluate fully normalized bimodal Gaussian PDF for BNS mass distribution. Computes the normalized probability density function combining two truncated Gaussian peaks. Reference: Will M. Farr et al. 2020. Parameters ---------- m : ``numpy.ndarray`` Mass values (solar masses). w : ``float`` Weight of the left (low-mass) peak. default: 0.643 muL : ``float`` Mean of the left peak (solar masses). default: 1.352 sigmaL : ``float`` Standard deviation of the left peak (solar masses). default: 0.08 muR : ``float`` Mean of the right peak (solar masses). default: 1.88 sigmaR : ``float`` Standard deviation of the right peak (solar masses). default: 0.3 mmin : ``float`` Minimum mass (solar masses). default: 1.0 mmax : ``float`` Maximum mass (solar masses). default: 2.3 Returns ------- pdf : ``numpy.ndarray`` Normalized probability density values. Examples -------- >>> from ler.gw_source_population.prior_functions import bimodal_pdf >>> import numpy as np >>> m = np.array([1.2, 1.4, 1.8]) >>> pdf_values = bimodal_pdf(m) >>> print(pdf_values) # doctest: +SKIP """ # left peak pdf_unnormL = np.exp(-((m - muL) ** 2) / (2 * sigmaL**2)) normL = _compute_normalization_factor(muL, sigmaL, mmin, mmax) # right peak pdf_unnormR = np.exp(-((m - muR) ** 2) / (2 * sigmaR**2)) normR = _compute_normalization_factor(muR, sigmaR, mmin, mmax) # total pdf pdf = w * pdf_unnormL / normL + (1 - w) * pdf_unnormR / normR return pdf
@njit(fastmath=True)
[docs] def bimodal_cdf(size, w=0.643, muL=1.352, sigmaL=0.08, muR=1.88, sigmaR=0.3, mmin=1.0, mmax=2.3): """ Compute cumulative distribution function for BNS bimodal mass distribution. Parameters ---------- size : ``int`` Resolution of the mass grid. w : ``float`` Weight of the left (low-mass) peak. default: 0.643 muL : ``float`` Mean of the left peak (solar masses). default: 1.352 sigmaL : ``float`` Standard deviation of the left peak (solar masses). default: 0.08 muR : ``float`` Mean of the right peak (solar masses). default: 1.88 sigmaR : ``float`` Standard deviation of the right peak (solar masses). default: 0.3 mmin : ``float`` Minimum mass (solar masses). default: 1.0 mmax : ``float`` Maximum mass (solar masses). default: 2.3 Returns ------- cdf_values : ``numpy.ndarray`` Cumulative distribution function values. mass_grid : ``numpy.ndarray`` Mass grid corresponding to CDF values. """ mass_grid = np.linspace(mmin, mmax, size) pdf_values = bimodal_pdf(mass_grid, w, muL, sigmaL, muR, sigmaR, mmin, mmax) cdf_values, mass_grid, _ = cumulative_trapezoid( x=mass_grid, y=pdf_values, initial=0.0 ) return cdf_values, mass_grid
@njit(fastmath=True)
[docs] def binary_masses_BNS_bimodal_rvs(size, w=0.643, muL=1.352, sigmaL=0.08, muR=1.88, sigmaR=0.3, mmin=1.0, mmax=2.3, resolution=500): """ Generate random samples of binary masses for BNS (bimodal Gaussian model). Uses inverse transform sampling to draw samples from a bimodal Gaussian distribution of neutron star masses. Parameters ---------- size : ``int`` Number of binary systems to sample. w : ``float`` Weight of the left (low-mass) peak. default: 0.643 muL : ``float`` Mean of the left peak (solar masses). default: 1.352 sigmaL : ``float`` Standard deviation of the left peak (solar masses). default: 0.08 muR : ``float`` Mean of the right peak (solar masses). default: 1.88 sigmaR : ``float`` Standard deviation of the right peak (solar masses). default: 0.3 mmin : ``float`` Minimum mass (solar masses). default: 1.0 mmax : ``float`` Maximum mass (solar masses). default: 2.3 resolution : ``int`` Resolution of the mass grid for CDF computation. default: 500 Returns ------- mass_samples : ``numpy.ndarray`` Array of mass samples (solar masses). """ cdf_values, mass_grid = bimodal_cdf(resolution, w, muL, sigmaL, muR, sigmaR, mmin, mmax) return inverse_transform_sampler(size=size, cdf=cdf_values, x=mass_grid)
# @njit # def _inverse_transform_sampler_m1m2(size, cdf_values, x): # """ # Sample m1 and m2 using inverse transform sampling for BNS. # This is a helper function for the BNS Alsing mass distribution function. # Parameters # ---------- # size : ``int`` # Number of samples to draw. # cdf_values : ``numpy.ndarray`` # Cumulative distribution function values. # x : ``numpy.ndarray`` # Mass values corresponding to the CDF. # Returns # ------- # m1 : ``numpy.ndarray`` # Primary mass samples (Msun). # m2 : ``numpy.ndarray`` # Secondary mass samples (Msun). # Examples # -------- # >>> from ler.gw_source_population import _inverse_transform_sampler_m1m2 # >>> m1, m2 = _inverse_transform_sampler_m1m2(size=1000, cdf_values=cdf_values, x=mass_arr) # """ # m1 = inverse_transform_sampler(size=size, cdf=cdf_values, x=x) # m2 = inverse_transform_sampler(size=size, cdf=cdf_values, x=x) # # swap m1 and m2 if m1 < m2 # idx = m1 < m2 # m1[idx], m2[idx] = m2[idx], m1[idx] # return m1, m2 # ------------------------------ # broken_powerlaw # ------------------------------ @njit(fastmath=True) def _broken_powerlaw_unnormalized(m, mminbh=26., mmaxbh=125., alpha_1=6.75, alpha_2=0., b=0.5, delta_m=5.): """ Compute unnormalized PDF for broken power-law distribution. Parameters ---------- m : ``numpy.ndarray`` Mass values. mminbh : ``float`` Minimum BH mass. mmaxbh : ``float`` Maximum BH mass. alpha_1 : ``float`` Power-law index below break. alpha_2 : ``float`` Power-law index above break. b : ``float`` Break location parameter. delta_m : ``float`` Smoothing width. Returns ------- pdf_unnormalized : ``numpy.ndarray`` Unnormalized PDF values. """ mbreak = mminbh + b * (mmaxbh - mminbh) idx_1 = (m > mminbh) & (m < mbreak) idx_2 = (m >= mbreak) & (m < mmaxbh) pdf_unnormalized = np.zeros_like(m) pdf_unnormalized[idx_1] = powerlaw_with_smoothing(m[idx_1], m[idx_1], mminbh, -alpha_1, delta_m) norm_1 = pdf_unnormalized[idx_1][np.sum(idx_1)-1] pdf_unnormalized[idx_2] = powerlaw_with_smoothing(m[idx_2], m[idx_2], mminbh, -alpha_2, delta_m) norm_2 = pdf_unnormalized[idx_2][0] pdf_unnormalized[idx_2] = pdf_unnormalized[idx_2] * (norm_1 / norm_2) return pdf_unnormalized
[docs] def broken_powerlaw_pdf(m, mminbh=26., mmaxbh=125., alpha_1=6.75, alpha_2=0., b=0.5, delta_m=5.): """ Compute normalized PDF for broken power-law distribution. Parameters ---------- m : ``numpy.ndarray`` Mass values. mminbh : ``float`` Minimum BH mass. mmaxbh : ``float`` Maximum BH mass. alpha_1 : ``float`` Power-law index below break. alpha_2 : ``float`` Power-law index above break. b : ``float`` Break location parameter. delta_m : ``float`` Smoothing width. Returns ------- pdf : ``numpy.ndarray`` Normalized PDF values. """ # compute unnormalized PDF on a fine grid for numerical integration m_try = np.geomspace(mminbh, mmaxbh, 1000) pdf_unnormalized = _broken_powerlaw_unnormalized(m_try, mminbh=mminbh, mmaxbh=mmaxbh, alpha_1=alpha_1, alpha_2=alpha_2, b=b, delta_m=delta_m) _, _, norm = cumulative_trapezoid(y=pdf_unnormalized, x=m_try, initial=0.0) pdf = _broken_powerlaw_unnormalized(m, mminbh=mminbh, mmaxbh=mmaxbh, alpha_1=alpha_1, alpha_2=alpha_2, b=b, delta_m=delta_m) / norm return pdf
@njit(fastmath=True) def _gaussian_pdf(m, mu_g=32.27, sigma_g=3.88): """ Compute Gaussian distribution. Parameters ---------- m : ``numpy.ndarray`` Mass values. mu_g : ``float`` Mean of the Gaussian. sigma_g : ``float`` Standard deviation. Returns ------- pdf : ``numpy.ndarray`` Gaussian PDF values. """ normalization = 1.0 / (sigma_g * np.sqrt(2 * np.pi)) exponent = -0.5 * ((m - mu_g) / sigma_g) ** 2 pdf = normalization * np.exp(exponent) return pdf # ----------------------- # Spin Tilt Angles # -----------------------
[docs] def gaussian_plus_isotropic_pdf(x, mu_t=0.426, sigma_t=1.222, zeta=0.652): """ 1D marginal PDF for a single cosine tilt x = cos(theta). p(x) = zeta * TruncNorm[-1,1](x | mu_t, sigma_t) + (1-zeta)/2 """ return ( zeta * truncated_normal_pdf(x, mu_t, sigma_t, -1.0, 1.0) + (1.0 - zeta) * 0.5 )
[docs] def gaussian_plus_isotropic_joint_pdf(x1, x2, mu_t=0.426, sigma_t=1.222, zeta=0.652): """ 2D joint PDF for (x1, x2) = (cos(theta1), cos(theta2)). p(x1, x2) = zeta * TN(x1) * TN(x2) + (1-zeta) / 4 on [-1,1]^2. """ g1 = truncated_normal_pdf(x1, mu_t, sigma_t, -1.0, 1.0) g2 = truncated_normal_pdf(x2, mu_t, sigma_t, -1.0, 1.0) pdf = zeta * g1 * g2 + (1.0 - zeta) * 0.25 return pdf
# ----------------------- # common helper functions # ----------------------- @njit(fastmath=True)
[docs] def powerlaw_pdf(x, alpha=-7.7, x_min=1.0, x_max=2.5): """ Compute normalized power-law distribution. p(x) is proportional to ``x**(-alpha)`` for ``x`` in ``[x_min, x_max]``. Parameters ---------- x : ``numpy.ndarray`` Input values. alpha : ``float`` Power-law spectral index. x_min : ``float`` Minimum value. x_max : ``float`` Maximum value. Returns ------- pdf : ``numpy.ndarray`` Normalized power-law PDF. """ normalization = (x_max ** (-alpha + 1)) / (-alpha + 1) - (x_min ** (-alpha + 1)) / (-alpha + 1) pdf = x ** (-alpha) / normalization return pdf
@njit(fastmath=True)
[docs] def powerlaw_rvs(size, alpha, x_min, x_max): """ Inverse transform sampling for a power-law distribution. p(x) ∝ x^{-alpha}, x in [x_min, x_max] Parameters ---------- size : ``int`` Number of samples to generate. alpha : ``float`` Power-law index (alpha). x_min : ``float`` Minimum value (lower bound). x_max : ``float`` Maximum value (upper bound). Returns ------- x : ``numpy.ndarray`` Array of sampled values. """ u = np.random.uniform(0, 1, size) if alpha == 1.0: # Special case α=1 x = x_min * (x_max / x_min) ** u elif alpha == 0.0: # Special case α=0 (uniform distribution) x = x_min + (x_max - x_min) * u else: pow1 = 1.0 - alpha x_min_pow = x_min**pow1 x_max_pow = x_max**pow1 x = (u * (x_max_pow - x_min_pow) + x_min_pow) ** (1.0 / pow1) return x
@njit()
[docs] def truncated_normal_pdf(x, mu, sigma, x_min, x_max=np.inf): """ Compute left-truncated or left-and-right-truncated normal probability density function. Evaluates the truncated normal distribution $N_{[x_min, x_max]}(x | μ, σ)$, which is a Gaussian distribution with support only between a minimum and maximum value. If x_max is not provided (or set to np.inf), it defaults to left-only truncation. Parameters ---------- x : ``numpy.ndarray`` Input values. mu : ``float`` Mean of the Gaussian distribution. sigma : ``float`` Standard deviation of the Gaussian distribution. x_min : ``float`` Minimum value (left truncation point). x_max : ``float``, optional Maximum value (right truncation point). Default is np.inf (no right truncation). Returns ------- pdf : ``numpy.ndarray`` Probability density values, 0 for $x < x_min$ or $x > x_max$. """ # 1. Compute CDF at the lower bound a = (x_min - mu) / sigma cdf_a = 0.5 * (1.0 + math.erf(a / math.sqrt(2.0))) # 2. Compute CDF at the upper bound if x_max == np.inf: cdf_b = 1.0 else: b = (x_max - mu) / sigma cdf_b = 0.5 * (1.0 + math.erf(b / math.sqrt(2.0))) # 3. Calculate normalization factor norm = cdf_b - cdf_a # Safety guard against numerical precision issues yielding 0 normalization if norm <= 0.0: return np.zeros_like(x) # 4. Evaluate Gaussian z = (x - mu) / sigma pdf = np.exp(-0.5 * z * z) / (sigma * math.sqrt(2.0 * math.pi) * norm) # 5. Apply bounds mask = (x >= x_min) & (x <= x_max) return np.where(mask, pdf, 0.0)
@njit() def _standard_normal_ppf(p): """ Approximate inverse CDF (quantile) of the standard normal distribution. Uses the Acklam rational approximation, accurate enough for sampling. """ if p <= 0.0: return -np.inf if p >= 1.0: return np.inf plow = 0.02425 phigh = 1.0 - plow a1 = -39.69683028665376 a2 = 220.9460984245205 a3 = -275.9285104469687 a4 = 138.3577518672690 a5 = -30.66479806614716 a6 = 2.506628277459239 b1 = -54.47609879822406 b2 = 161.5858368580409 b3 = -155.6989798598866 b4 = 66.80131188771972 b5 = -13.28068155288572 c1 = -0.007784894002430293 c2 = -0.3223964580411365 c3 = -2.400758277161838 c4 = -2.549732539343734 c5 = 4.374664141464968 c6 = 2.938163982698783 d1 = 0.007784695709041462 d2 = 0.3224671290700398 d3 = 2.445134137142996 d4 = 3.754408661907416 if p < plow: q = math.sqrt(-2.0 * math.log(p)) return ( (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1.0) ) if p > phigh: q = math.sqrt(-2.0 * math.log(1.0 - p)) return -( (((((c1 * q + c2) * q + c3) * q + c4) * q + c5) * q + c6) / ((((d1 * q + d2) * q + d3) * q + d4) * q + 1.0) ) q = p - 0.5 r = q * q return ( (((((a1 * r + a2) * r + a3) * r + a4) * r + a5) * r + a6) * q / (((((b1 * r + b2) * r + b3) * r + b4) * r + b5) * r + 1.0) ) @njit() def _truncated_normal_ppf(u, mu, sigma, x_min, x_max=np.inf): """ Inverse CDF (quantile) for a truncated normal distribution. """ a = (x_min - mu) / sigma cdf_a = 0.5 * (1.0 + math.erf(a / math.sqrt(2.0))) if x_max == np.inf: cdf_b = 1.0 else: b = (x_max - mu) / sigma cdf_b = 0.5 * (1.0 + math.erf(b / math.sqrt(2.0))) norm = cdf_b - cdf_a if norm <= 0.0: return x_min if u <= 0.0: return x_min if u >= 1.0: return x_max p = cdf_a + u * norm z = _standard_normal_ppf(p) x = mu + sigma * z if x < x_min: return x_min if x > x_max: return x_max return x @njit()
[docs] def truncated_normal_rvs(size, mu, sigma, x_min, x_max=np.inf): """ Draw samples from a truncated normal distribution using analytical inverse CDF. """ u = np.random.uniform(0.0, 1.0, size) samples = np.zeros(size) for i in range(size): samples[i] = _truncated_normal_ppf(u[i], mu, sigma, x_min, x_max) return samples
# ----------------------- @njit() def _smoothing_S(m, mmin, delta_m, threshold=709.0): """ Compute low-mass smoothing function to avoid sharp cutoffs. Parameters ---------- m : ``numpy.ndarray`` Mass values. mmin : ``float`` Minimum mass. delta_m : ``float`` Smoothing width. threshold : ``float`` Maximum exponent to avoid overflow. \n default: 709.0 Returns ------- s : ``numpy.ndarray`` Smoothing function values. """ s = np.zeros_like(m) # region where smoothing is not needed: m >= mmin + delta_m idx_2 = m >= mmin + delta_m s[idx_2] = 1.0 # region where smoothing is applied: mmin <= m < mmin + delta_m idx_1 = (m >= mmin) & (m < mmin + delta_m) mprime = m[idx_1] - mmin exponent = delta_m / mprime + delta_m / (mprime - delta_m) # safe exponentiation only where exponent is below threshold safe_idx = exponent <= threshold s_vals = np.zeros_like(mprime) s_vals[safe_idx] = 1.0 / (np.exp(exponent[safe_idx]) + 1.0) s[idx_1] = s_vals return s @njit(fastmath=True)
[docs] def powerlaw_with_smoothing(q, m, mmin, beta, delta_m): """ Compute power-law distribution with low-mass smoothing. """ return q ** (beta) * _smoothing_S(m, mmin, delta_m)
def _njit_checks( zs_rvs_, m1_rvs_, q_rvs_, m2_rvs_, tc_rvs_, ra_rvs_, dec_rvs_, phase_rvs_, psi_rvs_, theta_jn_rvs_, a_1_rvs_=None, a_2_rvs_=None, tilt_1_rvs_=None, tilt_2_rvs_=None, phi_12_rvs_=None, phi_jl_rvs_=None, spin_zero=False, spin_precession=False, ): """ Check and wrap sampler functions for JIT compatibility. """ if m1_rvs_.__code__.co_argcount == 1: if is_njitted(m1_rvs_): @njit def m1_rvs(size, zs=None): return m1_rvs_(size) else: def m1_rvs(size, zs=None): return m1_rvs_(size) else: m1_rvs = m1_rvs_ if m2_rvs_ is None: if q_rvs_.__code__.co_argcount == 1: if is_njitted(q_rvs_): @njit def q_rvs(size, m1=None): return q_rvs_(size) else: def q_rvs(size, m1=None): return q_rvs_(size) else: q_rvs = q_rvs_ else: if m2_rvs_.__code__.co_argcount == 1: if is_njitted(m2_rvs_): @njit def m2_rvs(size, zs=None): return m2_rvs_(size) else: def m2_rvs(size, zs=None): return m2_rvs_(size) else: m2_rvs = m2_rvs_ if not spin_zero: if a_2_rvs_.__code__.co_argcount == 1: if is_njitted(a_2_rvs_): @njit def a_2_rvs(size, a_1=None): return a_2_rvs_(size) else: def a_2_rvs(size, a_1=None): return a_2_rvs_(size) else: a_2_rvs = a_2_rvs_ if spin_precession: if tilt_2_rvs_.__code__.co_argcount == 1: if is_njitted(tilt_2_rvs_): @njit def tilt_2_rvs(size, tilt_1=None): return tilt_2_rvs_(size) else: def tilt_2_rvs(size, tilt_1=None): return tilt_2_rvs_(size) else: tilt_2_rvs = tilt_2_rvs_ # Build dictionary of wrapped functions dict_ = { "zs_rvs": zs_rvs_, "m1_rvs": m1_rvs, "q_rvs": q_rvs if m2_rvs_ is None else None, "m2_rvs": m2_rvs if m2_rvs_ is not None else None, "tc_rvs": tc_rvs_, "ra_rvs": ra_rvs_, "dec_rvs": dec_rvs_, "phase_rvs": phase_rvs_, "psi_rvs": psi_rvs_, "theta_jn_rvs": theta_jn_rvs_, "a_1_rvs": a_1_rvs_ if not spin_zero else None, "a_2_rvs": a_2_rvs if not spin_zero else None, "tilt_1_rvs": tilt_1_rvs_ if (not spin_zero and spin_precession) else None, "tilt_2_rvs": tilt_2_rvs if (not spin_zero and spin_precession) else None, "phi_12_rvs": phi_12_rvs_ if (not spin_zero and spin_precession) else None, "phi_jl_rvs": phi_jl_rvs_ if (not spin_zero and spin_precession) else None, } # Check if all functions are JIT compiled use_njit_sampler = True for key, value in dict_.items(): if not is_njitted(value) and value is not None: print(f"Warning: {key} is not njitted.") use_njit_sampler = False return use_njit_sampler, dict_
[docs] def create_gw_parameters_sampler( zs_rvs, m1_rvs, q_rvs, m2_rvs, tc_rvs, ra_rvs, dec_rvs, phase_rvs, psi_rvs, theta_jn_rvs, a_1_rvs, a_2_rvs, tilt_1_rvs, tilt_2_rvs, phi_12_rvs, phi_jl_rvs, use_njit_sampler=True, spin_zero=False, spin_precession=False, ): if m2_rvs is None: if spin_zero: def sampler(size): zs = zs_rvs(size) m1 = m1_rvs(size, zs) q = q_rvs(size, m1) m2 = q * m1 tc = tc_rvs(size) ra = ra_rvs(size) dec = dec_rvs(size) phase = phase_rvs(size) psi = psi_rvs(size) theta_jn = theta_jn_rvs(size) # swap m1 and m2 to ensure m1 >= m2 m1, m2 = np.where(m1 > m2, m1, m2), np.where(m1 > m2, m2, m1) return zs, m1, m2, tc, ra, dec, phase, psi, theta_jn else: if spin_precession: def sampler(size): zs = zs_rvs(size) m1 = m1_rvs(size, zs) q = q_rvs(size, m1) m2 = q * m1 tc = tc_rvs(size) ra = ra_rvs(size) dec = dec_rvs(size) phase = phase_rvs(size) psi = psi_rvs(size) theta_jn = theta_jn_rvs(size) a_1 = a_1_rvs(size) a_2 = a_2_rvs(size) tilt_1 = tilt_1_rvs(size) tilt_2 = tilt_2_rvs(size, tilt_1) phi_12 = phi_12_rvs(size) phi_jl = phi_jl_rvs(size) # swap m1 and m2 to ensure m1 >= m2 m1, m2 = np.where(m1 > m2, m1, m2), np.where(m1 > m2, m2, m1) return zs, m1, m2, tc, ra, dec, phase, psi, theta_jn, a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl else: def sampler(size): zs = zs_rvs(size) m1 = m1_rvs(size, zs) q = q_rvs(size, m1) m2 = q * m1 tc = tc_rvs(size) ra = ra_rvs(size) dec = dec_rvs(size) phase = phase_rvs(size) psi = psi_rvs(size) theta_jn = theta_jn_rvs(size) a_1 = a_1_rvs(size) a_2 = a_2_rvs(size) # swap m1 and m2 to ensure m1 >= m2 m1, m2 = np.where(m1 > m2, m1, m2), np.where(m1 > m2, m2, m1) return zs, m1, m2, tc, ra, dec, phase, psi, theta_jn, a_1, a_2 else: # if m2_rvs is not None if spin_zero: def sampler(size): zs = zs_rvs(size) m1 = m1_rvs(size, zs) m2 = m2_rvs(size, m1) tc = tc_rvs(size) ra = ra_rvs(size) dec = dec_rvs(size) phase = phase_rvs(size) psi = psi_rvs(size) theta_jn = theta_jn_rvs(size) # swap m1 and m2 to ensure m1 >= m2 m1, m2 = np.where(m1 > m2, m1, m2), np.where(m1 > m2, m2, m1) return zs, m1, m2, tc, ra, dec, phase, psi, theta_jn else: if spin_precession: def sampler(size): zs = zs_rvs(size) m1 = m1_rvs(size, zs) m2 = m2_rvs(size, m1) tc = tc_rvs(size) ra = ra_rvs(size) dec = dec_rvs(size) phase = phase_rvs(size) psi = psi_rvs(size) theta_jn = theta_jn_rvs(size) a_1 = a_1_rvs(size, zs) a_2 = a_2_rvs(size, a_1) tilt_1 = tilt_1_rvs(size, zs) tilt_2 = tilt_2_rvs(size, tilt_1) phi_12 = phi_12_rvs(size, zs) phi_jl = phi_jl_rvs(size, zs) # swap m1 and m2 to ensure m1 >= m2 m1, m2 = np.where(m1 > m2, m1, m2), np.where(m1 > m2, m2, m1) return zs, m1, m2, tc, ra, dec, phase, psi, theta_jn, a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl else: def sampler(size): zs = zs_rvs(size) m1 = m1_rvs(size, zs) m2 = m2_rvs(size, m1) tc = tc_rvs(size) ra = ra_rvs(size) dec = dec_rvs(size) phase = phase_rvs(size) psi = psi_rvs(size) theta_jn = theta_jn_rvs(size) a_1 = a_1_rvs(size, zs) a_2 = a_2_rvs(size, a_1) # swap m1 and m2 to ensure m1 >= m2 m1, m2 = np.where(m1 > m2, m1, m2), np.where(m1 > m2, m2, m1) return zs, m1, m2, tc, ra, dec, phase, psi, theta_jn, a_1, a_2 if use_njit_sampler: return njit(sampler) else: return sampler