Source code for ler.lens_galaxy_population.sampler_functions

# -*- coding: utf-8 -*-
"""
Module for lens galaxy parameter sampling functions.

This module provides probability density functions (PDFs) and random variable
samplers (RVS) for lens galaxy parameters including redshift, velocity dispersion,
and axis ratio. It also includes rejection and importance sampling algorithms
for sampling lens parameters weighted by gravitational lensing cross sections.

Key Components: \n
- Lens redshift samplers (SIS model from Haris et al. 2018) \n
- Velocity dispersion samplers (generalized gamma distribution) \n
- Axis ratio samplers (Rayleigh and Padilla-Strauss distributions) \n
- Rejection and importance sampling for cross-section weighted parameters \n

Copyright (C) 2026 Hemantakumar Phurailatpam. Distributed under MIT License.
"""

import numpy as np
from numba import njit, prange
from scipy.interpolate import CubicSpline
from astropy.cosmology import LambdaCDM
from ..utils import is_njitted
from multiprocessing import Pool
from tqdm import tqdm
from ..utils import (
    save_pickle,
    load_pickle,
    inverse_transform_sampler,
    redshift_optimal_spacing,
)
import os


[docs] def available_sampler_list(): """ Return list of available lens parameter samplers. Returns ------- sampler_list : ``list`` List of available sampler function names. Examples -------- >>> samplers = available_sampler_list() >>> print(samplers) ['lens_redshift_strongly_lensed_sis_haris', 'velocity_dispersion_gengamma', 'axis_ratio_rayleigh', 'axis_ratio_padilla_strauss'] """ return [ "lens_redshift_strongly_lensed_sis_haris_pdf", "lens_redshift_strongly_lensed_sis_haris_rvs", "velocity_dispersion_ewoud_denisty_function", "velocity_dispersion_bernardi_denisty_function", "velocity_dispersion_gengamma_density_function", "velocity_dispersion_gengamma_pdf", "velocity_dispersion_gengamma_rvs", "axis_ratio_rayleigh_rvs", "axis_ratio_rayleigh_pdf", "axis_ratio_padilla_strauss_rvs", "axis_ratio_padilla_strauss_pdf", "bounded_normal_sample", "rejection_sampler", "importance_sampler", "importance_sampler_mp", ]
# --------------------------------- # Lens redshift sampler functions # --------------------------------- @njit
[docs] def lens_redshift_strongly_lensed_sis_haris_pdf( zl, zs, cosmo=LambdaCDM(H0=70, Om0=0.3, Ode0=0.7, Tcmb0=0.0, Neff=3.04, m_nu=None, Ob0=0.0), ): """ Compute lens redshift PDF for SIS model (Haris et al. 2018). Computes the probability density function for lens redshift between zl=0 and zl=zs using the analytical form from Haris et al. (2018) equation A7, based on the SIS (Singular Isothermal Sphere) lens model. Parameters ---------- zl : ``float`` Redshift of the lens galaxy. zs : ``float`` Redshift of the source. cosmo : ``astropy.cosmology`` Cosmology object for distance calculations. \n default: LambdaCDM(H0=70, Om0=0.3, Ode0=0.7, Tcmb0=0.0, Neff=3.04, m_nu=None, Ob0=0.0) Returns ------- pdf : ``float`` Probability density at the given lens redshift. Examples -------- >>> from astropy.cosmology import LambdaCDM >>> cosmo = LambdaCDM(H0=70, Om0=0.3, Ode0=0.7, Tcmb0=0.0, Neff=3.04, m_nu=None, Ob0=0.0) >>> pdf = lens_redshift_strongly_lensed_sis_haris_pdf(zl=0.5, zs=1.0, cosmo=cosmo) >>> print(f"PDF at zl=0.5: {pdf:.4f}") PDF at zl=0.5: 1.8750 """ Dc_zl = cosmo.comoving_distance(zl).value Dc_zs = cosmo.comoving_distance(zs).value x = Dc_zl / Dc_zs return 30 * x**2 * (1 - x) ** 2
[docs] def lens_redshift_strongly_lensed_sis_haris_rvs( size, zs, z_min=0.001, z_max=10.0, cosmo=LambdaCDM(H0=70, Om0=0.3, Ode0=0.7, Tcmb0=0.0, Neff=3.04, m_nu=None, Ob0=0.0), ): """ Sample lens redshifts for SIS model (Haris et al. 2018). Uses inverse transform sampling with the analytical CDF of the Haris et al. (2018) lens redshift distribution to efficiently generate samples. Parameters ---------- size : ``int`` Number of samples to draw. zs : ``float`` Redshift of the source. z_min : ``float`` Minimum redshift for interpolation grid. \n default: 0.001 z_max : ``float`` Maximum redshift for interpolation grid. \n default: 10.0 cosmo : ``astropy.cosmology`` Cosmology object for distance calculations. \n default: LambdaCDM(H0=70, Om0=0.3, Ode0=0.7, Tcmb0=0.0, Neff=3.04, m_nu=None, Ob0=0.0) Returns ------- zl : ``numpy.ndarray`` Array of sampled lens redshifts with shape (size,). Examples -------- >>> from astropy.cosmology import LambdaCDM >>> cosmo = LambdaCDM(H0=70, Om0=0.3, Ode0=0.7, Tcmb0=0.0, Neff=3.04, m_nu=None, Ob0=0.0) >>> zl_samples = lens_redshift_strongly_lensed_sis_haris_rvs( ... size=1000, ... zs=1.5, ... z_min=0.001, ... z_max=10.0, ... cosmo=cosmo, ... ) >>> print(f"Mean lens redshift: {zl_samples.mean():.2f}") >>> print(f"Redshift range: [{zl_samples.min():.3f}, {zl_samples.max():.3f}]") """ # Create comoving distance to redshift mapping zs_array = redshift_optimal_spacing(z_min, z_max, 500) Dc_array = cosmo.comoving_distance(zs_array).value inverse_spline_Dc = CubicSpline(Dc_array, zs_array) # Create CDF values using analytical form: CDF = 6x^5 - 15x^4 + 10x^3 x_array = np.linspace(0.0, 1.0, 500) cdf_values = 6 * x_array**5 - 15 * x_array**4 + 10 * x_array**3 # Inverse transform sampling r = inverse_transform_sampler(size, cdf_values, x_array) zs_Dc = cosmo.comoving_distance(zs).value zl_Dc = zs_Dc * r return inverse_spline_Dc(zl_Dc)
# --------------------------------------- # Velocity dispersion sampler functions # --------------------------------------- @njit def _gamma(x): """ Compute the gamma function using the Lanczos approximation. Parameters ---------- x : ``float`` Input value. Returns ------- result : ``float`` Gamma function value at x. """ g = 7 p = np.array( [ 0.99999999999980993, 676.5203681218851, -1259.1392167224028, 771.32342877765313, -176.61502916214059, 12.507343278686905, -0.13857109526572012, 9.9843695780195716e-6, 1.5056327351493116e-7, ] ) if x < 0.5: return np.pi / (np.sin(np.pi * x) * _gamma(1 - x)) else: x -= 1 y = p[0] for i in range(1, g + 2): y += p[i] / (x + i) t = x + g + 0.5 return np.sqrt(2 * np.pi) * t ** (x + 0.5) * np.exp(-t) * y @njit def _cvdf_fit(log_vd, redshift): """ Compute the cumulative velocity dispersion function fit. This uses the fit coefficients from Bernardi et al. (2010) for the velocity dispersion function at the local universe. Parameters ---------- log_vd : ``float`` Log10 of the velocity dispersion (km/s). redshift : ``float`` Redshift of the lens galaxy. Returns ------- result : ``float`` Cumulative velocity dispersion function value. """ this_vars = np.array( [ [7.39149763, 5.72940031, -1.12055245], [-6.86339338, -5.27327109, 1.10411386], [2.85208259, 1.25569600, -0.28663846], [0.06703215, -0.04868317, 0.00764841], ] ) coeffs = [ this_vars[i][0] + this_vars[i][1] * redshift + this_vars[i][2] * redshift**2 for i in range(4) ] mstar = log_vd - coeffs[3] return coeffs[0] + coeffs[1] * mstar + coeffs[2] * mstar**2 - np.exp(mstar) @njit def _cvdf_derivative(log_vd, redshift, dx): """ Compute the numerical derivative of the CVDF fit function. Parameters ---------- log_vd : ``float`` Log10 of the velocity dispersion (km/s). redshift : ``float`` Redshift of the lens galaxy. dx : ``float`` Step size for numerical differentiation. Returns ------- derivative : ``float`` Numerical derivative value. """ return ( 0.5 * (_cvdf_fit(log_vd + dx, redshift) - _cvdf_fit(log_vd - dx, redshift)) / dx ) @njit def _pdf_phi_z_ratio(sigma, z): """ Compute the PDF ratio of velocity dispersion function at redshift z to z=0. This ratio is used in the derivation of the velocity dispersion function at redshift z following Oguri et al. (2018b). This lacks the scaling factor. Parameters ---------- sigma : ``float`` Velocity dispersion (km/s). z : ``float`` Redshift of the lens galaxy. Returns ------- ratio : ``float`` Ratio of phi(sigma, z) / phi(sigma, 0). """ log_vd = np.log10(sigma) phi_sim_z = 10 ** _cvdf_fit(log_vd, z) / sigma * _cvdf_derivative(log_vd, z, 1e-8) phi_sim_0 = 10 ** _cvdf_fit(log_vd, 0) / sigma * _cvdf_derivative(log_vd, 0, 1e-8) return phi_sim_z / phi_sim_0 @njit
[docs] def velocity_dispersion_ewoud_denisty_function( sigma, z, alpha=0.94, beta=1.85, phistar=2.099e-2, sigmastar=113.78 ): """ Calculate the lens galaxy velocity dispersion function at redshift z (Oguri et al. (2018b) + Wempe et al. (2022)). Parameters ---------- sigma : ``numpy.ndarray`` Velocity dispersion of the lens galaxy (km/s). z : ``float`` Redshift of the lens galaxy. alpha : ``float`` Shape parameter of the velocity dispersion function. \n default: 0.94 beta : ``float`` Slope parameter of the velocity dispersion function. \n default: 1.85 phistar : ``float`` Normalization of the velocity dispersion function (Mpc^-3). \n default: 2.099e-2 sigmastar : ``float`` Characteristic velocity dispersion (km/s). \n default: 113.78 Returns ------- result : ``numpy.ndarray`` Velocity dispersion function values. \n Negative values are clipped to 0. """ result = _pdf_phi_z_ratio(sigma, z) * velocity_dispersion_bernardi_denisty_function( sigma=sigma, alpha=alpha, beta=beta, phistar=phistar, sigmastar=sigmastar ) result[result < 0.0] = 0.0 return result
@njit
[docs] def velocity_dispersion_bernardi_denisty_function( sigma, alpha, beta, phistar, sigmastar ): """ Calculate the local universe velocity dispersion function. This implements the velocity dispersion function from Bernardi et al. (2010). Parameters ---------- sigma : ``numpy.ndarray`` Velocity dispersion of the lens galaxy (km/s). alpha : ``float`` Shape parameter (alpha/beta is used in the gamma function). \n For Oguri et al. (2018b): alpha=0.94 \n For Choi et al. (2008): alpha=2.32/2.67 beta : ``float`` Slope parameter of the velocity dispersion function. \n For Oguri et al. (2018b): beta=1.85 \n For Choi et al. (2008): beta=2.67 phistar : ``float`` Normalization of the velocity dispersion function (Mpc^-3). \n For Oguri et al. (2018b): phistar=2.099e-2*(h/0.7)^3 \n For Choi et al. (2008): phistar=8.0e-3*h^3 sigmastar : ``float`` Characteristic velocity dispersion (km/s). \n For Oguri et al. (2018b): sigmastar=113.78 \n For Choi et al. (2008): sigmastar=161.0 Returns ------- philoc : ``numpy.ndarray`` Local velocity dispersion function values. """ philoc = ( phistar * (sigma / sigmastar) ** alpha * np.exp(-((sigma / sigmastar) ** beta)) * beta / _gamma(alpha / beta) / sigma ) return philoc
[docs] def velocity_dispersion_gengamma_density_function( sigma, alpha=0.94, beta=1.85, phistar=2.099e-2, sigmastar=113.78, **kwargs, ): """ Compute unnormalized velocity dispersion function using generalized gamma. Computes the galaxy velocity dispersion function (VDF) using the generalized gamma distribution formulation from Choi et al. (2007). Parameters ---------- sigma : ``float`` or ``numpy.ndarray`` Velocity dispersion in km/s. alpha : ``float`` Power-law index governing the low-velocity slope. \n default: 0.94 beta : ``float`` Exponential parameter for high-velocity cutoff sharpness. \n default: 1.85 phistar : ``float`` Normalization constant (comoving number density, Mpc^-3). \n default: 8.0e-3 sigmastar : ``float`` Characteristic velocity scale in km/s. \n default: 113.78 Returns ------- density : ``float`` or ``numpy.ndarray`` Unnormalized velocity dispersion function value. Examples -------- >>> import numpy as np >>> sigma = np.array([150.0, 200.0, 250.0]) >>> density = velocity_dispersion_gengamma_density_function(sigma) >>> print(f"Density at sigma=200 km/s: {density[1]:.6f}") """ from scipy.stats import gengamma density = phistar * gengamma.pdf( sigma / sigmastar, a=alpha / beta, c=beta, ) return density
[docs] def velocity_dispersion_gengamma_pdf( sigma, sigma_min=100.0, sigma_max=400.0, alpha=0.94, beta=1.85, sigmastar=113.78, ): """ Compute normalized velocity dispersion PDF using generalized gamma. Computes the probability density function for velocity dispersion using the generalized gamma distribution, normalized over the specified range. Parameters ---------- sigma : ``float`` or ``numpy.ndarray`` Velocity dispersion in km/s. sigma_min : ``float`` Minimum velocity dispersion for normalization (km/s). \n default: 100.0 sigma_max : ``float`` Maximum velocity dispersion for normalization (km/s). \n default: 400.0 alpha : ``float`` Power-law index governing the low-velocity slope. \n default: 0.94 beta : ``float`` Exponential parameter for high-velocity cutoff sharpness. \n default: 1.85 sigmastar : ``float`` Characteristic velocity scale in km/s. \n default: 113.78 Returns ------- pdf : ``float`` or ``numpy.ndarray`` Normalized probability density at the given velocity dispersion. Examples -------- >>> pdf = velocity_dispersion_gengamma_pdf( ... sigma=200.0, ... sigma_min=100.0, ... sigma_max=400.0, ... ) >>> print(f"PDF at sigma=200 km/s: {pdf:.6f}") """ # Compute normalization constant sigma_array = np.linspace(sigma_min, sigma_max, 500) density = velocity_dispersion_gengamma_density_function( sigma=sigma_array, alpha=alpha, beta=beta, sigmastar=sigmastar, ) norm_const = np.trapz(density, sigma_array) pdf = ( velocity_dispersion_gengamma_density_function( sigma=sigma, alpha=alpha, beta=beta, sigmastar=sigmastar, ) / norm_const ) return pdf
def _truncated_gengamma_rvs(size, a, c, loc, scale, lower_bound, upper_bound): """ Sample from truncated generalized gamma distribution via rejection. Parameters ---------- size : ``int`` Number of samples to draw. a : ``float`` Shape parameter of gengamma distribution. c : ``float`` Shape parameter of gengamma distribution. loc : ``float`` Location parameter of gengamma distribution. scale : ``float`` Scale parameter of gengamma distribution. lower_bound : ``float`` Lower bound of the truncated distribution. upper_bound : ``float`` Upper bound of the truncated distribution. Returns ------- rvs : ``numpy.ndarray`` Truncated gengamma random samples with shape (size,). """ from scipy.stats import gengamma total_samples = [] while len(total_samples) < size: batch = gengamma.rvs(a, c, loc=loc, scale=scale, size=size * 2) valid = batch[(batch >= lower_bound) & (batch <= upper_bound)] total_samples.extend(valid) return np.array(total_samples[:size])
[docs] def velocity_dispersion_gengamma_rvs( size, sigma_min=100.0, sigma_max=400.0, alpha=0.94, beta=1.85, sigmastar=113.78, ): """ Sample velocity dispersions from generalized gamma distribution. Uses truncated generalized gamma sampling via rejection to generate velocity dispersion samples within the specified bounds. Parameters ---------- size : ``int`` Number of samples to draw. sigma_min : ``float`` Minimum velocity dispersion (km/s). \n default: 100.0 sigma_max : ``float`` Maximum velocity dispersion (km/s). \n default: 400.0 alpha : ``float`` Power-law index governing the low-velocity slope. \n default: 0.94 beta : ``float`` Exponential parameter for high-velocity cutoff sharpness. \n default: 1.85 sigmastar : ``float`` Characteristic velocity scale in km/s. \n default: 113.78 Returns ------- sigma : ``numpy.ndarray`` Sampled velocity dispersions in km/s with shape (size,). Examples -------- >>> sigma_samples = velocity_dispersion_gengamma( ... size=1000, ... sigma_min=100.0, ... sigma_max=400.0, ... ) >>> print(f"Mean sigma: {sigma_samples.mean():.2f} km/s") >>> print(f"Range: [{sigma_samples.min():.1f}, {sigma_samples.max():.1f}] km/s") """ sigma_samples = _truncated_gengamma_rvs( size=size, a=alpha / beta, c=beta, loc=0, scale=sigmastar, lower_bound=sigma_min, upper_bound=sigma_max, ) return sigma_samples
# ------------------------------ # Axis ratio sampler functions # ------------------------------ @njit
[docs] def axis_ratio_rayleigh_rvs(size, sigma, q_min=0.2, q_max=1.0): """ Sample axis ratios from velocity-dependent Rayleigh distribution. Generates axis ratio samples using the Rayleigh distribution with scale parameter dependent on velocity dispersion, as described in Wierda et al. (2021) Appendix C. Parameters ---------- size : ``int`` Number of samples to draw. sigma : ``numpy.ndarray`` Velocity dispersions in km/s with shape (size,). q_min : ``float`` Minimum allowed axis ratio. \n default: 0.2 q_max : ``float`` Maximum allowed axis ratio. \n default: 1.0 Returns ------- q : ``numpy.ndarray`` Sampled axis ratios with shape (size,). Examples -------- >>> import numpy as np >>> sigma = np.random.uniform(100, 300, 1000) >>> q_samples = axis_ratio_rayleigh_rvs(size=1000, sigma=sigma, q_min=0.2, q_max=1.0) >>> print(f"Mean axis ratio: {q_samples.mean():.2f}") >>> print(f"Range: [{q_samples.min():.2f}, {q_samples.max():.2f}]") """ a = sigma / 161.0 q = np.ones(size) idx = np.arange(size) size_ = size while size_ != 0: # Scale parameter for Rayleigh distribution (Wierda et al. 2021) s = 0.38 - 0.09177 * a[idx] s[s <= 0] = 0.0001 u = np.random.uniform(0, 1, size=size_) b = s * np.sqrt(-2 * np.log(u)) # Inverse CDF of Rayleigh q_ = 1.0 - b # Select samples within bounds idx2 = (q_ >= q_min) & (q_ <= q_max) q[idx[idx2]] = q_[idx2] # Track remaining samples outside bounds idx = idx[(q_ <= q_min) | (q_ >= q_max)] size_ = len(idx) return q
@njit(parallel=True)
[docs] def axis_ratio_rayleigh_pdf(q, sigma, q_min=0.2, q_max=1.0): """ Compute truncated Rayleigh PDF for axis ratio. Computes the probability density function for axis ratio using the truncated Rayleigh distribution with velocity-dependent scale parameter (Wierda et al. 2021 equation C16). Parameters ---------- q : ``numpy.ndarray`` Axis ratios at which to evaluate PDF. sigma : ``numpy.ndarray`` Velocity dispersions in km/s (same shape as q). q_min : ``float`` Minimum axis ratio for truncation. \n default: 0.2 q_max : ``float`` Maximum axis ratio for truncation. \n default: 1.0 Returns ------- pdf : ``numpy.ndarray`` Probability density values with same shape as q. Examples -------- >>> import numpy as np >>> q = np.array([0.5, 0.7, 0.9]) >>> sigma = np.array([150.0, 200.0, 250.0]) >>> pdf = axis_ratio_rayleigh_pdf(q, sigma) >>> print(f"PDF values: {pdf}") """ out = np.zeros_like(q) b_lo = 1.0 - q_max b_hi = 1.0 - q_min for i in prange(q.size): s = 0.38 - 0.09177 * (sigma[i] / 161.0) if s <= 0.0: s = 1e-4 qi = q[i] if qi < q_min or qi > q_max: out[i] = 0.0 continue b = 1.0 - qi # Base (untruncated) Rayleigh density base = (b / (s * s)) * np.exp(-0.5 * (b * b) / (s * s)) # Truncation normalization Z = np.exp(-0.5 * (b_lo * b_lo) / (s * s)) - np.exp( -0.5 * (b_hi * b_hi) / (s * s) ) if Z <= 0.0: out[i] = 0.0 else: out[i] = base / Z return out
@njit def _axis_ratio_padilla_strauss_data(): """ Return axis ratio PDF data points from Padilla & Strauss (2008). Returns ------- q_array : ``numpy.ndarray`` Axis ratio data points. pdf_array : ``numpy.ndarray`` PDF values at each axis ratio. """ q_array = np.array( [ 0.04903276402927845, 0.09210526315789469, 0.13596491228070173, 0.20789473684210524, 0.2899703729522482, 0.3230132450331126, 0.35350877192982455, 0.37946148483792264, 0.4219298245614036, 0.4689525967235971, 0.5075026141512723, 0.5226472638550018, 0.5640350877192983, 0.6096491228070177, 0.6500000000000001, 0.6864848379226213, 0.7377192982456142, 0.7787295224817011, 0.8007581038689441, 0.822786685256187, 0.8668438480306729, 0.8973684210526317, 0.9254385964912283, ] ) pdf = np.array( [ 0.04185262687135349, 0.06114520695141845, 0.096997499638376, 0.1932510900336828, 0.39547914337673706, 0.49569751276216234, 0.6154609137685201, 0.7182049959882812, 0.920153741243567, 1.1573982157399754, 1.3353263628106684, 1.413149656448315, 1.5790713532948977, 1.7280185150744938, 1.8132994441344819, 1.8365803753840484, 1.8178662203211204, 1.748929843583365, 1.688182592496342, 1.6274353414093188, 1.4948487090314488, 1.402785526832393, 1.321844068356993, ] ) return q_array, pdf @njit
[docs] def axis_ratio_padilla_strauss_rvs(size): """ Sample axis ratios from Padilla & Strauss (2008) distribution. Uses inverse transform sampling with the empirical PDF from Padilla & Strauss (2008) for early-type galaxy axis ratios. Parameters ---------- size : ``int`` Number of samples to draw. Returns ------- q : ``numpy.ndarray`` Sampled axis ratios with shape (size,). Examples -------- >>> q_samples = axis_ratio_padilla_strauss_rvs(size=1000) >>> print(f"Mean axis ratio: {q_samples.mean():.2f}") >>> print(f"Range: [{q_samples.min():.2f}, {q_samples.max():.2f}]") """ q_array, pdf = _axis_ratio_padilla_strauss_data() # Compute CDF and normalize cdf_values = np.cumsum(pdf) cdf_values = cdf_values / cdf_values[-1] return inverse_transform_sampler(size, cdf_values, q_array)
@njit
[docs] def axis_ratio_padilla_strauss_pdf(q): """ Compute axis ratio PDF from Padilla & Strauss (2008). Evaluates the probability density function for axis ratio using cubic spline interpolation of the Padilla & Strauss (2008) data. Parameters ---------- q : ``numpy.ndarray`` Axis ratios at which to evaluate PDF. Returns ------- pdf : ``numpy.ndarray`` Probability density values with same shape as q. Examples -------- >>> import numpy as np >>> q = np.array([0.3, 0.5, 0.7, 0.9]) >>> pdf = axis_ratio_padilla_strauss_pdf(q) >>> print(f"PDF at q=0.5: {pdf[1]:.4f}") """ q_array, pdf = _axis_ratio_padilla_strauss_data() spline = CubicSpline(q_array, pdf, extrapolate=True) return spline(q)
@njit
[docs] def bounded_normal_sample(size, mean, std, low, high): """ Sample from truncated normal distribution via rejection. Generates samples from a normal distribution with specified mean and standard deviation, rejecting samples outside the specified bounds. Parameters ---------- size : ``int`` Number of samples to draw. mean : ``float`` Mean of the normal distribution. std : ``float`` Standard deviation of the normal distribution. low : ``float`` Lower bound for samples. high : ``float`` Upper bound for samples. Returns ------- samples : ``numpy.ndarray`` Bounded normal samples with shape (size,). Examples -------- >>> samples = bounded_normal_sample(size=1000, mean=2.0, std=0.2, low=1.5, high=2.5) >>> print(f"Mean: {samples.mean():.2f}, Std: {samples.std():.2f}") >>> print(f"Range: [{samples.min():.2f}, {samples.max():.2f}]") """ samples = np.empty(size) for i in range(size): while True: sample = np.random.normal(mean, std) if low <= sample <= high: break samples[i] = sample return samples
# ------------------------------------------------- # Rejection sampling of strongly lensed parameters # -------------------------------------------------
[docs] def rejection_sampler( zs, zl, sigma_max, sigma_rvs, q_rvs, phi_rvs, gamma_rvs, shear_rvs, cross_section, safety_factor=1.2, ): """ Core rejection sampling algorithm for lens parameters. Parameters ---------- zs : ``numpy.ndarray`` Source redshifts. zl : ``numpy.ndarray`` Lens redshifts. sigma_max : ``float`` Maximum velocity dispersion (km/s) for computing upper bound. sigma_rvs : ``callable`` Function to sample velocity dispersion: sigma_rvs(n, zl) -> array. q_rvs : ``callable`` Function to sample axis ratio: q_rvs(n, sigma) -> array. phi_rvs : ``callable`` Function to sample orientation angle: phi_rvs(n) -> array. gamma_rvs : ``callable`` Function to sample power-law index: gamma_rvs(n) -> array. shear_rvs : ``callable`` Function to sample external shear: shear_rvs(n) -> (gamma1, gamma2). cross_section : ``callable`` Function to compute lensing cross section. safety_factor : ``float`` Multiplicative safety factor for the upper bound. \n default: 1.2 Returns ------- sigma_array : ``numpy.ndarray`` Sampled velocity dispersions (km/s). q_array : ``numpy.ndarray`` Sampled axis ratios. phi_array : ``numpy.ndarray`` Sampled orientation angles (rad). gamma_array : ``numpy.ndarray`` Sampled power-law indices. gamma1_array : ``numpy.ndarray`` Sampled external shear component 1. gamma2_array : ``numpy.ndarray`` Sampled external shear component 2. """ n_samples = zl.size sigma_array = np.zeros(n_samples) q_array = np.zeros(n_samples) phi_array = np.zeros(n_samples) gamma_array = np.zeros(n_samples) gamma1_array = np.zeros(n_samples) gamma2_array = np.zeros(n_samples) idx_remaining = np.arange(n_samples) # Compute maximum cross section for rejection bound cs_max = ( cross_section( zs=zs, zl=zl, sigma=sigma_max * np.ones(n_samples), q=0.9 * np.ones(n_samples), phi=np.zeros(n_samples), gamma=2.645 * np.ones(n_samples), gamma1=np.zeros(n_samples), gamma2=np.zeros(n_samples), ) * safety_factor ) while len(idx_remaining) > 0: n_remaining = len(idx_remaining) sigma_samples = sigma_rvs(n_remaining, zl[idx_remaining]) q_samples = q_rvs(n_remaining, sigma_samples) phi_samples = phi_rvs(n_remaining) gamma_samples = gamma_rvs(n_remaining) gamma1_samples, gamma2_samples = shear_rvs(n_remaining) cs = cross_section( zs=zs[idx_remaining], zl=zl[idx_remaining], sigma=sigma_samples, q=q_samples, phi=phi_samples, gamma=gamma_samples, gamma1=gamma1_samples, gamma2=gamma2_samples, ) accept = np.random.random(n_remaining) < (cs / cs_max[idx_remaining]) accepted_indices = idx_remaining[accept] sigma_array[accepted_indices] = sigma_samples[accept] q_array[accepted_indices] = q_samples[accept] phi_array[accepted_indices] = phi_samples[accept] gamma_array[accepted_indices] = gamma_samples[accept] gamma1_array[accepted_indices] = gamma1_samples[accept] gamma2_array[accepted_indices] = gamma2_samples[accept] idx_remaining = idx_remaining[~accept] return sigma_array, q_array, phi_array, gamma_array, gamma1_array, gamma2_array
[docs] def create_rejection_sampler( sigma_max, sigma_rvs, q_rvs, phi_rvs, gamma_rvs, shear_rvs, cross_section, safety_factor=1.2, use_njit_sampler=True, ): """ Create a rejection sampler for cross-section weighted lens parameters. Returns a callable that samples lens parameters using rejection sampling, weighting by the gravitational lensing cross section. Optionally uses Numba JIT compilation for improved performance. Parameters ---------- sigma_max : ``float`` Maximum velocity dispersion (km/s) for computing upper bound. sigma_rvs : ``callable`` Function to sample velocity dispersion: sigma_rvs(n, zl) -> array. q_rvs : ``callable`` Function to sample axis ratio: q_rvs(n, sigma) -> array. phi_rvs : ``callable`` Function to sample orientation angle: phi_rvs(n) -> array. gamma_rvs : ``callable`` Function to sample power-law index: gamma_rvs(n) -> array. shear_rvs : ``callable`` Function to sample external shear: shear_rvs(n) -> (gamma1, gamma2). cross_section : ``callable`` Function to compute lensing cross section. safety_factor : ``float`` Multiplicative safety factor for the upper bound. \n default: 1.2 use_njit_sampler : ``bool`` If True, uses Numba JIT compilation for faster execution. \n default: True Returns ------- rejection_sampler_wrapper : ``callable`` Function with signature (zs, zl) -> (sigma, q, phi, gamma, gamma1, gamma2). Examples -------- >>> import numpy as np >>> from numba import njit >>> @njit ... def sigma_rvs(n, zl): ... return 100 + 200 * np.random.random(n) >>> @njit ... def q_rvs(n, sigma): ... return 0.5 + 0.5 * np.random.random(n) >>> @njit ... def phi_rvs(n): ... return np.pi * np.random.random(n) >>> @njit ... def gamma_rvs(n): ... return 2.0 + 0.2 * np.random.randn(n) >>> @njit ... def shear_rvs(n): ... return 0.05 * np.random.randn(n), 0.05 * np.random.randn(n) >>> @njit ... def cross_section(zs, zl, sigma, q, phi, gamma, gamma1, gamma2): ... return sigma**4 >>> sampler = create_rejection_sampler( ... sigma_max=400.0, ... sigma_rvs=sigma_rvs, ... q_rvs=q_rvs, ... phi_rvs=phi_rvs, ... gamma_rvs=gamma_rvs, ... shear_rvs=shear_rvs, ... cross_section=cross_section, ... ) >>> zs = np.array([1.0, 1.5, 2.0]) >>> zl = np.array([0.3, 0.5, 0.7]) >>> sigma, q, phi, gamma, gamma1, gamma2 = sampler(zs, zl) """ if use_njit_sampler: _base_sampler = njit(rejection_sampler) print( "Faster, njitted and rejection sampling based lens parameter sampler will be used." ) @njit def rejection_sampler_wrapper(zs, zl): return _base_sampler( zs=zs, zl=zl, sigma_max=sigma_max, sigma_rvs=sigma_rvs, q_rvs=q_rvs, phi_rvs=phi_rvs, gamma_rvs=gamma_rvs, shear_rvs=shear_rvs, cross_section=cross_section, safety_factor=safety_factor, ) else: print( "Slower, non-njit and rejection sampling based lens parameter sampler will be used." ) def rejection_sampler_wrapper(zs, zl): return rejection_sampler( zs=zs, zl=zl, sigma_max=sigma_max, sigma_rvs=sigma_rvs, q_rvs=q_rvs, phi_rvs=phi_rvs, gamma_rvs=gamma_rvs, shear_rvs=shear_rvs, cross_section=cross_section, safety_factor=safety_factor, ) return rejection_sampler_wrapper
# -------------------------------------------------- # Importance sampling of strongly lensed parameters # -------------------------------------------------- @njit def _sigma_proposal_uniform(n, sigma_min, sigma_max): """ Draw uniform samples for velocity dispersion proposal. Parameters ---------- n : ``int`` Number of samples to draw. sigma_min : ``float`` Minimum velocity dispersion (km/s). sigma_max : ``float`` Maximum velocity dispersion (km/s). Returns ------- sigma : ``numpy.ndarray`` Uniform samples in [sigma_min, sigma_max]. """ return sigma_min + (sigma_max - sigma_min) * np.random.random(n) @njit def _weighted_choice_1d(weights): """ Draw an index with probability proportional to weights. Numba-safe replacement for np.random.choice(n, p=weights). Parameters ---------- weights : ``numpy.ndarray`` Non-negative weights (need not be normalized). Returns ------- idx : ``int`` Randomly selected index. """ total = 0.0 for i in range(weights.size): total += weights[i] if not (total > 0.0): return np.random.randint(weights.size) u = np.random.random() * total c = 0.0 for i in range(weights.size): c += weights[i] if u <= c: return i return weights.size - 1
[docs] def importance_sampler( zs, zl, sigma_min, sigma_max, q_rvs, phi_rvs, gamma_rvs, shear_rvs, number_density, cross_section, n_prop, ): """ Core importance sampling algorithm for lens parameters. This function samples lens galaxy parameters weighted by their lensing cross sections using importance sampling with a uniform proposal distribution for velocity dispersion. Algorithm --------- For each lens-source pair (zl_i, zs_i): 1. **Draw proposal samples** (n_prop samples per lens): - sigma_k ~ Uniform(sigma_min, sigma_max) with proposal density q(sigma) = 1/(sigma_max - sigma_min) - q_k ~ p(q|sigma_k) (axis ratio conditioned on velocity dispersion) - φ_k ~ p(φ) (orientation angle prior) - gamma_k ~ p(gamma) (power-law index prior) - (gamma1_k, gamma2_k) ~ p(gamma1, gamma2) (external shear prior) 2. **Compute lensing cross sections**: - cs_k = CrossSection(zs_i, zl_i, sigma_k, q_k, φ_k, gamma_k, gamma1_k, gamma2_k) - Normalize: cs_k ← cs_k / Σ_k cs_k 3. **Compute importance weights**: - The target distribution is: p(θ|zl) ∝ p(sigma|zl) x CrossSection(θ) - The proposal distribution is: q(θ) ∝ Uniform(sigma) x p(q|sigma) x p(φ) x p(gamma) x p(shear) - Importance weight: w_k = cs_k x [p(sigma_k|zl) / q(sigma)] - Normalize: w_k ← w_k / Σ_k w_k 4. **Resample**: - Draw one sample index from {1, ..., n_prop} with probabilities {w_1, ..., w_n_prop} - Return the corresponding parameter values as the posterior sample The algorithm produces samples from the posterior distribution of lens parameters weighted by their contribution to the strong lensing cross section. Parameters ---------- zs : ``numpy.ndarray`` Source redshifts. zl : ``numpy.ndarray`` Lens redshifts. sigma_min : ``float`` Minimum velocity dispersion (km/s) for uniform proposal. sigma_max : ``float`` Maximum velocity dispersion (km/s) for uniform proposal. q_rvs : ``callable`` Function to sample axis ratio: q_rvs(n, sigma) -> array. phi_rvs : ``callable`` Function to sample orientation angle: phi_rvs(n) -> array. gamma_rvs : ``callable`` Function to sample power-law index: gamma_rvs(n) -> array. shear_rvs : ``callable`` Function to sample external shear: shear_rvs(n) -> (gamma1, gamma2). number_density : ``callable`` Number density or velocity dispersion function: number_density(sigma, zl) -> array. cross_section : ``callable`` Function to compute lensing cross section. n_prop : ``int`` Number of proposal samples per lens. Returns ------- sigma_post : ``numpy.ndarray`` Sampled velocity dispersions (km/s). q_post : ``numpy.ndarray`` Sampled axis ratios. phi_post : ``numpy.ndarray`` Sampled orientation angles (rad). gamma_post : ``numpy.ndarray`` Sampled power-law indices. gamma1_post : ``numpy.ndarray`` Sampled external shear component 1. gamma2_post : ``numpy.ndarray`` Sampled external shear component 2. """ n_samples = zl.size sigma_post = np.zeros(n_samples) q_post = np.zeros(n_samples) phi_post = np.zeros(n_samples) gamma_post = np.zeros(n_samples) gamma1_post = np.zeros(n_samples) gamma2_post = np.zeros(n_samples) p0 = 1.0 / (sigma_max - sigma_min) for i in prange(n_samples): # for each (zl, zs) pair # Draw proposals from uniform distribution sigma_prop = _sigma_proposal_uniform(n_prop, sigma_min, sigma_max) # Draw other parameters from their priors q_prop = q_rvs(n_prop, sigma_prop) phi_prop = phi_rvs(n_prop) gamma_prop = gamma_rvs(n_prop) gamma1_prop, gamma2_prop = shear_rvs(n_prop) # Compute cross sections zs_arr = zs[i] * np.ones(n_prop) zl_arr = zl[i] * np.ones(n_prop) cs = cross_section( zs_arr, zl_arr, sigma_prop, q_prop, phi_prop, gamma_prop, gamma1_prop, gamma2_prop, ) # Compute importance weights sigma_function = number_density(sigma_prop, zl_arr) w = cs * (sigma_function / p0) w = np.where(w > 0.0, w, 0.0) w_sum = np.sum(w) # Normalize weights if w_sum > 0.0: w = w / w_sum else: w = np.ones(n_prop) / n_prop # Draw posterior sample via weighted choice idx = _weighted_choice_1d(w) sigma_post[i] = sigma_prop[idx] q_post[i] = q_prop[idx] phi_post[i] = phi_prop[idx] gamma_post[i] = gamma_prop[idx] gamma1_post[i] = gamma1_prop[idx] gamma2_post[i] = gamma2_prop[idx] return sigma_post, q_post, phi_post, gamma_post, gamma1_post, gamma2_post
# -------------------------------------- # Importance sampling (multiprocessing) # -------------------------------------- def _importance_sampler_worker(params): """ Worker function for multiprocessing importance sampling. Parameters ---------- params : ``tuple`` Packed parameters: (zs_i, zl_i, worker_idx). Returns ------- worker_idx : ``int`` Worker index for result ordering. result : ``tuple`` (sigma, q, phi, gamma, gamma1, gamma2) for this sample. """ zs_i, zl_i, worker_idx = params # Load shared data from pickle file shared_data = load_pickle("importance_sampler_shared.pkl") sigma_min = shared_data["sigma_min"] sigma_max = shared_data["sigma_max"] q_rvs = shared_data["q_rvs"] phi_rvs = shared_data["phi_rvs"] gamma_rvs = shared_data["gamma_rvs"] shear_rvs = shared_data["shear_rvs"] number_density = shared_data["number_density"] cross_section = shared_data["cross_section"] n_prop = shared_data["n_prop"] p0 = 1.0 / (sigma_max - sigma_min) # Draw proposals sigma_prop = np.random.uniform(sigma_min, sigma_max, n_prop) # Draw other parameters from their priors q_prop = q_rvs(n_prop, sigma_prop) phi_prop = phi_rvs(n_prop) gamma_prop = gamma_rvs(n_prop) gamma1_prop, gamma2_prop = shear_rvs(n_prop) # Compute cross sections zs_arr = zs_i * np.ones(n_prop) zl_arr = zl_i * np.ones(n_prop) cs = cross_section( zs_arr, zl_arr, sigma_prop, q_prop, phi_prop, gamma_prop, gamma1_prop, gamma2_prop, )/ (4.0*np.pi) # Compute importance weights sigma_function = number_density(sigma_prop, zl_arr) w = cs * (sigma_function / p0) w = np.where(w > 0.0, w, 0.0) w_sum = np.sum(w) # Normalize weights if w_sum > 0.0: w = w / w_sum else: w = np.ones(n_prop) / n_prop # Draw posterior sample via weighted choice idx = np.random.choice(n_prop, p=w) result = ( sigma_prop[idx], q_prop[idx], phi_prop[idx], gamma_prop[idx], gamma1_prop[idx], gamma2_prop[idx], ) return worker_idx, result
[docs] def importance_sampler_mp( zs, zl, sigma_min, sigma_max, q_rvs, phi_rvs, gamma_rvs, shear_rvs, number_density, cross_section, n_prop, npool=4, ): """ Multiprocessing version of importance sampling for lens parameters. Parameters ---------- zs : ``numpy.ndarray`` Source redshifts. zl : ``numpy.ndarray`` Lens redshifts. sigma_min : ``float`` Minimum velocity dispersion (km/s) for uniform proposal. sigma_max : ``float`` Maximum velocity dispersion (km/s) for uniform proposal. q_rvs : ``callable`` Function to sample axis ratio: q_rvs(n, sigma) -> array. phi_rvs : ``callable`` Function to sample orientation angle: phi_rvs(n) -> array. gamma_rvs : ``callable`` Function to sample power-law index: gamma_rvs(n) -> array. shear_rvs : ``callable`` Function to sample external shear: shear_rvs(n) -> (gamma1, gamma2). number_density : ``callable`` Number density or velocity dispersion function: number_density(sigma, zl) -> array. cross_section : ``callable`` Function to compute lensing cross section. n_prop : ``int`` Number of proposal samples per lens. npool : ``int`` Number of parallel processes to use. \n default: 4 Returns ------- sigma_post : ``numpy.ndarray`` Sampled velocity dispersions (km/s). q_post : ``numpy.ndarray`` Sampled axis ratios. phi_post : ``numpy.ndarray`` Sampled orientation angles (rad). gamma_post : ``numpy.ndarray`` Sampled power-law indices. gamma1_post : ``numpy.ndarray`` Sampled external shear component 1. gamma2_post : ``numpy.ndarray`` Sampled external shear component 2. """ n_samples = zl.size # Save shared data to pickle file for workers shared_data = { "sigma_min": sigma_min, "sigma_max": sigma_max, "q_rvs": q_rvs, "phi_rvs": phi_rvs, "gamma_rvs": gamma_rvs, "shear_rvs": shear_rvs, "number_density": number_density, "cross_section": cross_section, "n_prop": n_prop, } save_pickle("importance_sampler_shared.pkl", shared_data) # Prepare input parameters for workers input_params = [(zs[i], zl[i], i) for i in range(n_samples)] # Initialize output arrays sigma_post = np.zeros(n_samples) q_post = np.zeros(n_samples) phi_post = np.zeros(n_samples) gamma_post = np.zeros(n_samples) gamma1_post = np.zeros(n_samples) gamma2_post = np.zeros(n_samples) # Run multiprocessing with progress bar with Pool(processes=npool) as pool: for worker_idx, result in tqdm( pool.imap_unordered(_importance_sampler_worker, input_params), total=n_samples, ncols=100, desc="Importance sampling", ): sigma_post[worker_idx] = result[0] q_post[worker_idx] = result[1] phi_post[worker_idx] = result[2] gamma_post[worker_idx] = result[3] gamma1_post[worker_idx] = result[4] gamma2_post[worker_idx] = result[5] # Cleanup pickle file os.remove("importance_sampler_shared.pkl") return sigma_post, q_post, phi_post, gamma_post, gamma1_post, gamma2_post
[docs] def create_importance_sampler( sigma_min, sigma_max, q_rvs, phi_rvs, gamma_rvs, shear_rvs, number_density, cross_section, n_prop, use_njit_sampler=True, npool=4, ): """ Create an importance sampler for cross-section weighted lens parameters. Returns a callable that samples lens parameters using importance sampling with uniform proposal distribution, optionally JIT-compiled for improved performance. Parameters ---------- sigma_min : ``float`` Minimum velocity dispersion (km/s) for uniform proposal. sigma_max : ``float`` Maximum velocity dispersion (km/s) for uniform proposal. q_rvs : ``callable`` Function to sample axis ratio: q_rvs(n, sigma) -> array. phi_rvs : ``callable`` Function to sample orientation angle: phi_rvs(n) -> array. gamma_rvs : ``callable`` Function to sample power-law index: gamma_rvs(n) -> array. shear_rvs : ``callable`` Function to sample external shear: shear_rvs(n) -> (gamma1, gamma2). number_density : ``callable`` Number density or velocity dispersion function: number_density(sigma, zl) -> array. cross_section : ``callable`` Function to compute lensing cross section. n_prop : ``int`` Number of proposal samples per lens. use_njit_sampler : ``bool`` If True, uses Numba JIT compilation for faster execution. \n default: True npool : ``int`` Number of parallel processes (only used when use_njit_sampler=False). \n default: 4 Returns ------- importance_sampler_wrapper : ``callable`` Function with signature (zs, zl) -> (sigma, q, phi, gamma, gamma1, gamma2). Examples -------- >>> import numpy as np >>> from numba import njit >>> @njit ... def q_rvs(n, sigma): ... return 0.5 + 0.5 * np.random.random(n) >>> @njit ... def phi_rvs(n): ... return np.pi * np.random.random(n) >>> @njit ... def gamma_rvs(n): ... return 2.0 + 0.2 * np.random.randn(n) >>> @njit ... def shear_rvs(n): ... return 0.05 * np.random.randn(n), 0.05 * np.random.randn(n) >>> @njit ... def number_density(sigma, zl): ... return np.ones_like(sigma) >>> @njit ... def cross_section(zs, zl, sigma, q, phi, gamma, gamma1, gamma2): ... return sigma**4 >>> sampler = create_importance_sampler( ... sigma_min=100.0, ... sigma_max=400.0, ... q_rvs=q_rvs, ... phi_rvs=phi_rvs, ... gamma_rvs=gamma_rvs, ... shear_rvs=shear_rvs, ... number_density=number_density, ... cross_section=cross_section, ... n_prop=100, ... ) >>> zs = np.array([1.0, 1.5, 2.0]) >>> zl = np.array([0.3, 0.5, 0.7]) >>> sigma, q, phi, gamma, gamma1, gamma2 = sampler(zs, zl) """ if use_njit_sampler: print( "Faster, njitted and importance sampling based lens parameter sampler will be used." ) _base_sampler = njit(parallel=True)(importance_sampler) @njit(parallel=True) def importance_sampler_wrapper(zs, zl): return _base_sampler( zs=zs, zl=zl, sigma_min=sigma_min, sigma_max=sigma_max, q_rvs=q_rvs, phi_rvs=phi_rvs, gamma_rvs=gamma_rvs, shear_rvs=shear_rvs, number_density=number_density, cross_section=cross_section, n_prop=n_prop, ) else: print( "Slower, non-njit and importance sampling based lens parameter sampler will be used." ) def importance_sampler_wrapper(zs, zl): return importance_sampler_mp( zs=zs, zl=zl, sigma_min=sigma_min, sigma_max=sigma_max, q_rvs=q_rvs, phi_rvs=phi_rvs, gamma_rvs=gamma_rvs, shear_rvs=shear_rvs, number_density=number_density, cross_section=cross_section, n_prop=n_prop, npool=npool, ) return importance_sampler_wrapper
def _njit_checks( sigma_rvs_, q_rvs_, phi_rvs, gamma_rvs, shear_rvs, sigma_pdf_, number_density_, cross_section_, ): """ Check and wrap sampler functions for JIT compatibility. Parameters ---------- sigma_rvs_ : ``callable`` Function to sample velocity dispersion. q_rvs_ : ``callable`` Function to sample axis ratio. phi_rvs : ``callable`` Function to sample orientation angle. gamma_rvs : ``callable`` Function to sample power-law index. shear_rvs : ``callable`` Function to sample external shear. sigma_pdf_ : ``callable`` PDF of velocity dispersion. number_density_ : ``callable`` Number density function of lens galaxies. cross_section_ : ``callable`` Function to compute lensing cross section. Returns ------- use_njit_sampler : ``bool`` True if all functions are JIT compiled, False otherwise. dict_ : ``dict`` Dictionary containing wrapped/compiled versions of all functions. """ # Wrap cross_section function based on argument count if cross_section_.__code__.co_argcount == 4: cross_section_function = njit( lambda zs, zl, sigma, q, phi, gamma, gamma1, gamma2: cross_section_( zs, zl, sigma, q ) ) elif cross_section_.__code__.co_argcount == 3: cross_section_function = njit( lambda zs, zl, sigma, q, phi, gamma, gamma1, gamma2: cross_section_( zs, zl, sigma ) ) else: cross_section_function = cross_section_ # Wrap samplers and PDFs based on argument count if sigma_rvs_.__code__.co_argcount == 1: if is_njitted(sigma_rvs_): sigma_rvs = njit(lambda size, zl: sigma_rvs_(size)) sigma_pdf = njit(lambda sigma, zl: sigma_pdf_(sigma)) number_density = njit(lambda sigma, zl: number_density_(sigma)) else: sigma_rvs = lambda size, zl: sigma_rvs_(size) sigma_pdf = lambda sigma, zl: sigma_pdf_(sigma) number_density = lambda sigma, zl: number_density_(sigma) else: sigma_rvs = sigma_rvs_ sigma_pdf = sigma_pdf_ number_density = number_density_ if q_rvs_.__code__.co_argcount == 1: if is_njitted(q_rvs_): q_rvs = njit(lambda size, sigma: q_rvs_(size)) else: q_rvs = lambda size, sigma: q_rvs_(size) else: q_rvs = q_rvs_ # Build dictionary of wrapped functions dict_ = { "sigma_rvs": sigma_rvs, "q_rvs": q_rvs, "phi_rvs": phi_rvs, "gamma_rvs": gamma_rvs, "shear_rvs": shear_rvs, "sigma_pdf": sigma_pdf, "number_density": number_density, "cross_section_function": cross_section_function, } # Check if all functions are JIT compiled use_njit_sampler = True for key, value in dict_.items(): if not is_njitted(value): print(f"Warning: {key} is not njitted.") use_njit_sampler = False return use_njit_sampler, dict_