"""Native DSPS model wrapper."""
# ruff: noqa: I001, E402
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from typing import Any
from .jax_runtime import configure_jax_runtime
configure_jax_runtime()
import jax
import jax.numpy as jnp
import numpy as np
from .filters import FilterCurve
from .io import GalaxyObservation
from .photometry import abmag_to_fnu_cgs
[docs]
@dataclass
class DspsContext:
ssp: Any
filters: dict[str, FilterCurve]
n_sfh_bins: int = 96
cosmos_dust_k_by_code: np.ndarray | None = None
cosmos_dust_curve_names: tuple[str, ...] = ()
ssp_wave_jax: Any | None = None
ssp_lgmet_jax: Any | None = None
ssp_lg_age_gyr_jax: Any | None = None
ssp_flux_jax: Any | None = None
ssp_emline_luminosity: np.ndarray | None = None
ssp_emline_wave: np.ndarray | None = None
ssp_emline_name: tuple[str, ...] = ()
nebular_emission_mode: str = "ssp_flux"
jax_filters: tuple[tuple[Any, Any], ...] = ()
cosmos_dust_k_by_code_jax: Any | None = None
[docs]
@dataclass(frozen=True)
class ModelResult:
parameters: dict[str, float]
derived: dict[str, float]
wave: np.ndarray
rest_sed: np.ndarray
dusted_rest_sed: np.ndarray
photometry: dict[str, dict[str, float]]
[docs]
@dataclass(frozen=True)
class JaxModelResult:
wave: jnp.ndarray
rest_sed: jnp.ndarray
dusted_rest_sed: jnp.ndarray
model_mags: jnp.ndarray
t_obs_gyr: jnp.ndarray
formed_mass_msun: jnp.ndarray
sfr_at_obs_msun_per_yr: jnp.ndarray
[docs]
@dataclass(frozen=True)
class BatchSedResult:
"""Batch DSPS SEDs and photometry from one JAX-vmapped call."""
parameter_names: list[str]
parameter_matrix: np.ndarray
wave: np.ndarray
rest_sed: np.ndarray
dusted_rest_sed: np.ndarray
model_mags: np.ndarray
derived: dict[str, np.ndarray]
DERIVED_QUANTITY_NAMES = [
"t_obs_gyr",
"formed_mass_msun",
"log10_formed_mass_msun",
"sfr_at_obs_msun_per_yr",
"log10_sfr_at_obs",
]
[docs]
def load_context(
ssp_path: str,
filters: dict[str, FilterCurve],
n_sfh_bins: int = 96,
cosmos_config: dict[str, Any] | None = None,
nebular_emission: str = "ssp_flux",
) -> DspsContext:
from dsps import load_ssp_templates
ssp = load_ssp_templates(fn=ssp_path)
dust_k_by_code, dust_curve_names = _load_cosmos_dust_grid(ssp, cosmos_config)
emline_luminosity, emline_wave, emline_name = _load_ssp_emline_data(ssp_path, ssp)
return DspsContext(
ssp=ssp,
filters=filters,
n_sfh_bins=n_sfh_bins,
cosmos_dust_k_by_code=dust_k_by_code,
cosmos_dust_curve_names=dust_curve_names,
ssp_wave_jax=jnp.asarray(ssp.ssp_wave, dtype=jnp.float32),
ssp_lgmet_jax=jnp.asarray(ssp.ssp_lgmet, dtype=jnp.float32),
ssp_lg_age_gyr_jax=jnp.asarray(ssp.ssp_lg_age_gyr, dtype=jnp.float32),
ssp_flux_jax=jnp.asarray(ssp.ssp_flux, dtype=jnp.float32),
ssp_emline_luminosity=emline_luminosity,
ssp_emline_wave=emline_wave,
ssp_emline_name=emline_name,
nebular_emission_mode=str(nebular_emission),
jax_filters=tuple(
(
jnp.asarray(curve.wave, dtype=jnp.float32),
jnp.asarray(curve.transmission, dtype=jnp.float32),
)
for curve in filters.values()
),
cosmos_dust_k_by_code_jax=(
None
if dust_k_by_code is None
else jnp.asarray(dust_k_by_code, dtype=jnp.float32)
),
)
def _load_ssp_emline_data(
ssp_path: str, ssp: Any
) -> tuple[np.ndarray | None, np.ndarray | None, tuple[str, ...]]:
luminosity = getattr(ssp, "ssp_emline_luminosity", None)
wave = getattr(ssp, "ssp_emline_wave", None)
names: tuple[str, ...] = ()
if luminosity is not None:
luminosity = np.asarray(luminosity, dtype=float)
if wave is not None:
wave = np.asarray(wave, dtype=float)
try:
import h5py
with h5py.File(ssp_path, "r") as handle:
if luminosity is None and "ssp_emline_luminosity" in handle:
luminosity = np.asarray(handle["ssp_emline_luminosity"], dtype=float)
if wave is None and "ssp_emline_wave" in handle:
wave = np.asarray(handle["ssp_emline_wave"], dtype=float)
if "ssp_emline_name" in handle:
raw = np.asarray(handle["ssp_emline_name"])
decoded = []
for item in raw:
if isinstance(item, (bytes, np.bytes_)):
decoded.append(item.decode("utf-8", errors="replace"))
else:
decoded.append(str(item))
names = tuple(decoded)
except (OSError, ImportError, KeyError, TypeError):
pass
if luminosity is None:
return None, wave, names
n_lines = int(luminosity.shape[-1])
if wave is not None and len(wave) != n_lines:
wave = None
if not names or len(names) != n_lines:
names = tuple(f"line_{i:03d}" for i in range(n_lines))
return luminosity, wave, names
def _load_cosmos_dust_grid(
ssp: Any, cosmos_config: dict[str, Any] | None
) -> tuple[np.ndarray | None, tuple[str, ...]]:
if not cosmos_config or not bool(cosmos_config.get("use_cosmos_dust_in_dsps")):
return None, ()
from .cosmos import load_extinction_curves
mapping, curves, _ = load_extinction_curves(cosmos_config)
if not mapping:
return None, ()
max_code = max(int(code) for code in mapping)
wave = np.asarray(ssp.ssp_wave, dtype=float)
k_by_code = np.zeros((max_code + 1, len(wave)), dtype=float)
names: list[str] = []
for code in range(max_code + 1):
curve_name = mapping.get(code, "none")
names.append(curve_name)
if curve_name == "none":
continue
curve = curves.get(curve_name)
if curve is None:
raise ValueError(
f"COSMOS extinction curve {curve_name!r} is configured but not loaded."
)
k_by_code[code] = np.interp(
wave,
curve.wave_angstrom,
curve.k_lambda,
left=curve.k_lambda[0],
right=curve.k_lambda[-1],
)
return k_by_code, tuple(names)
[docs]
def parameters_for_row(
base: dict[str, Any],
parameter_columns: dict[str, str],
row: dict[str, Any],
redshift_config: dict[str, Any] | None = None,
) -> dict[str, float]:
"""Merge fixed config parameters with optional per-row catalog overrides."""
params = {key: float(value) for key, value in base.items()}
for param_name, column in (parameter_columns or {}).items():
if column in row and np.isfinite(row[column]):
params[param_name] = float(row[column])
params["z_obs"] = resolve_redshift(params, row, redshift_config or {})
params.update(redshift_prior_parameters(params["z_obs"], row, redshift_config or {}))
return params
[docs]
def resolve_redshift(
params: dict[str, float], row: dict[str, Any], redshift_config: dict[str, Any]
) -> float:
"""Resolve DSPS redshift from configured initializer."""
value = params.get("z_obs", redshift_config.get("fixed_value", 0.5))
initial = str(redshift_config.get("initial", "catalog_column"))
column = redshift_config.get("column")
z_min = float(redshift_config.get("min", 1.0e-4))
z_max = float(redshift_config.get("max", 6.0))
if initial == "random_uniform":
value = _random_uniform_redshift(row, redshift_config, z_min, z_max)
elif (
initial == "catalog_column"
and column
and column in row
and np.isfinite(row[column])
):
value = float(row[column])
elif initial in {"catalog_column", "fixed"} and np.isfinite(
redshift_config.get("fixed_value", np.nan)
):
value = float(redshift_config["fixed_value"])
if not np.isfinite(value):
value = z_min
return float(np.clip(value, z_min, z_max))
[docs]
def redshift_prior_parameters(
z_value: float, row: dict[str, Any], redshift_config: dict[str, Any]
) -> dict[str, float]:
"""Return row-level redshift prior metadata consumed by fit priors."""
prior = redshift_config.get("prior_z") or {}
if not isinstance(prior, dict) or str(prior.get("mode", "none")) != "gaussian":
return {}
sigma = float(prior.get("sigma", 0.35))
if bool(prior.get("scale_with_1pz", True)):
sigma *= 1.0 + max(float(z_value), 0.0)
sigma = max(sigma, float(prior.get("sigma_min", 0.02)))
return {
"z_obs_prior_mu": float(z_value),
"z_obs_prior_sigma": float(sigma),
}
def _random_uniform_redshift(
row: dict[str, Any], redshift_config: dict[str, Any], z_min: float, z_max: float
) -> float:
seed = int(float(redshift_config.get("seed", 42)))
payload = "|".join(
f"{key}={row[key]}" for key in sorted(row) if np.isscalar(row[key])
)
digest = hashlib.blake2b(f"{seed}|{payload}".encode(), digest_size=8).digest()
unit = int.from_bytes(digest, "big") / float(2**64 - 1)
return z_min + unit * (z_max - z_min)
[docs]
def run_dsps_model(context: DspsContext, params: dict[str, float]) -> ModelResult:
"""Run DSPS from simple SFH/metallicity parameters to SED and photometry."""
jax_result = run_dsps_model_jax(context, params)
rest_sed = np.asarray(jax_result.rest_sed, dtype=float)
wave = np.asarray(jax_result.wave, dtype=float)
dusted_sed = np.asarray(jax_result.dusted_rest_sed, dtype=float)
model_mags = np.asarray(jax_result.model_mags, dtype=float)
photometry: dict[str, dict[str, float]] = {}
for (name, curve), mag in zip(context.filters.items(), model_mags, strict=True):
photometry[name] = {
"model_mag_ab": float(mag),
"model_flux_fnu_cgs": float(abmag_to_fnu_cgs(float(mag))),
"filter_source": curve.source,
"effective_wavelength_angstrom": curve.effective_wavelength,
"filter_wave_angstrom": curve.wave,
"filter_transmission": curve.transmission,
}
return ModelResult(
parameters={key: float(value) for key, value in params.items()},
derived={
"t_obs_gyr": float(jax_result.t_obs_gyr),
"formed_mass_msun": float(jax_result.formed_mass_msun),
"log10_formed_mass_msun": _safe_log10(float(jax_result.formed_mass_msun)),
"sfr_at_obs_msun_per_yr": float(jax_result.sfr_at_obs_msun_per_yr),
"log10_sfr_at_obs": _safe_log10(float(jax_result.sfr_at_obs_msun_per_yr)),
},
wave=wave,
rest_sed=rest_sed,
dusted_rest_sed=dusted_sed,
photometry=photometry,
)
[docs]
def run_dsps_model_jax(context: DspsContext, params: dict[str, Any]) -> JaxModelResult:
"""Pure-JAX DSPS forward model used by gradient-based fits."""
from dsps import calc_rest_sed_sfh_table_lognormal_mdf
from dsps.cosmology import DEFAULT_COSMOLOGY, age_at_z
z_obs = jnp.asarray(params["z_obs"], dtype=jnp.float32)
t_obs = jnp.ravel(age_at_z(z_obs, *DEFAULT_COSMOLOGY))[0]
gal_t_table = jnp.linspace(0.05, jnp.maximum(t_obs, 0.06), context.n_sfh_bins)
gal_sfr_table = build_sfh_table_jax(gal_t_table, params)
gal_sfr_table, formed_mass = normalize_sfh_mass_jax(
gal_t_table, gal_sfr_table, params
)
sed_info = calc_rest_sed_sfh_table_lognormal_mdf(
gal_t_table,
gal_sfr_table,
jnp.asarray(params["log10_metallicity"], dtype=jnp.float32),
jnp.asarray(params["metallicity_scatter"], dtype=jnp.float32),
_context_ssp_lgmet(context),
_context_ssp_lg_age_gyr(context),
_context_ssp_flux(context),
t_obs,
)
wave = _context_ssp_wave(context)
dusted_sed = apply_dust_jax(
wave, sed_info.rest_sed, params, context.cosmos_dust_k_by_code_jax
)
model_mags = predict_mags_jax(context, wave, dusted_sed, z_obs)
return JaxModelResult(
wave=wave,
rest_sed=sed_info.rest_sed,
dusted_rest_sed=dusted_sed,
model_mags=model_mags,
t_obs_gyr=t_obs,
formed_mass_msun=formed_mass,
sfr_at_obs_msun_per_yr=gal_sfr_table[-1],
)
[docs]
def predict_mags_jax(
context: DspsContext, wave: jnp.ndarray, dusted_sed: jnp.ndarray, z_obs: jnp.ndarray
) -> jnp.ndarray:
"""Predict configured apparent AB magnitudes with DSPS photometry kernels."""
from dsps import calc_obs_mag
from dsps.cosmology import DEFAULT_COSMOLOGY
filter_arrays = context.jax_filters
if not filter_arrays:
filter_arrays = tuple(
(
jnp.asarray(curve.wave, dtype=jnp.float32),
jnp.asarray(curve.transmission, dtype=jnp.float32),
)
for curve in context.filters.values()
)
mags = []
for filter_wave, filter_transmission in filter_arrays:
mags.append(
calc_obs_mag(
wave,
dusted_sed,
filter_wave,
filter_transmission,
z_obs,
*DEFAULT_COSMOLOGY,
)
)
return jnp.stack(mags)
[docs]
def model_mags_jax(context: DspsContext, params: dict[str, Any]) -> jnp.ndarray:
"""Return only model magnitudes for likelihood/gradient code."""
return run_dsps_model_jax(context, params).model_mags
def _context_ssp_wave(context: DspsContext) -> jnp.ndarray:
if context.ssp_wave_jax is not None:
return context.ssp_wave_jax
return jnp.asarray(context.ssp.ssp_wave, dtype=jnp.float32)
def _context_ssp_lgmet(context: DspsContext) -> jnp.ndarray:
if context.ssp_lgmet_jax is not None:
return context.ssp_lgmet_jax
return jnp.asarray(context.ssp.ssp_lgmet, dtype=jnp.float32)
def _context_ssp_lg_age_gyr(context: DspsContext) -> jnp.ndarray:
if context.ssp_lg_age_gyr_jax is not None:
return context.ssp_lg_age_gyr_jax
return jnp.asarray(context.ssp.ssp_lg_age_gyr, dtype=jnp.float32)
def _context_ssp_flux(context: DspsContext) -> jnp.ndarray:
if context.ssp_flux_jax is not None:
return context.ssp_flux_jax
return jnp.asarray(context.ssp.ssp_flux, dtype=jnp.float32)
_BATCH_PREDICT_CACHE = {}
[docs]
def predict_batch_mags(
context: DspsContext, parameter_names: list[str], parameter_matrix: np.ndarray
) -> np.ndarray:
"""Predict magnitudes for many parameter rows with one JAX-vmapped call."""
cache_key = ("mags", id(context), tuple(parameter_names))
if cache_key not in _BATCH_PREDICT_CACHE:
def single(values):
params = {name: values[index] for index, name in enumerate(parameter_names)}
return model_mags_jax(context, params)
_BATCH_PREDICT_CACHE[cache_key] = jax.jit(jax.vmap(single))
predict = _BATCH_PREDICT_CACHE[cache_key]
return np.asarray(predict(jnp.asarray(parameter_matrix, dtype=jnp.float32)))
[docs]
def derived_quantities_jax(context: DspsContext, params: dict[str, Any]) -> jnp.ndarray:
"""Return derived quantities needed for scientifically comparable reports."""
from dsps.cosmology import DEFAULT_COSMOLOGY, age_at_z
z_obs = jnp.asarray(params["z_obs"], dtype=jnp.float32)
t_obs = jnp.ravel(age_at_z(z_obs, *DEFAULT_COSMOLOGY))[0]
gal_t_table = jnp.linspace(0.05, jnp.maximum(t_obs, 0.06), context.n_sfh_bins)
gal_sfr_table = build_sfh_table_jax(gal_t_table, params)
gal_sfr_table, formed_mass = normalize_sfh_mass_jax(
gal_t_table, gal_sfr_table, params
)
sfr_at_obs = gal_sfr_table[-1]
return jnp.asarray(
[
t_obs,
formed_mass,
jnp.log10(jnp.maximum(formed_mass, 1.0e-300)),
sfr_at_obs,
jnp.log10(jnp.maximum(sfr_at_obs, 1.0e-300)),
]
)
[docs]
def predict_batch_derived(
context: DspsContext, parameter_names: list[str], parameter_matrix: np.ndarray
) -> dict[str, np.ndarray]:
"""Compute derived quantities for many fitted parameter rows."""
cache_key = ("derived", id(context), tuple(parameter_names))
if cache_key not in _BATCH_PREDICT_CACHE:
def single(values):
params = {name: values[index] for index, name in enumerate(parameter_names)}
return derived_quantities_jax(context, params)
_BATCH_PREDICT_CACHE[cache_key] = jax.jit(jax.vmap(single))
predict = _BATCH_PREDICT_CACHE[cache_key]
values = np.asarray(predict(jnp.asarray(parameter_matrix, dtype=jnp.float32)))
return {name: values[:, index] for index, name in enumerate(DERIVED_QUANTITY_NAMES)}
[docs]
def predict_batch_seds(
context: DspsContext, parameter_names: list[str], parameter_matrix: np.ndarray
) -> BatchSedResult:
"""Predict rest SEDs, dusted rest SEDs, magnitudes, and derived quantities.
This is the batch/GPU path used by COSMOS-template comparisons after MAP or
population fits. It avoids one Python DSPS call per galaxy.
"""
cache_key = ("seds", id(context), tuple(parameter_names))
if cache_key not in _BATCH_PREDICT_CACHE:
def single(values):
params = {name: values[index] for index, name in enumerate(parameter_names)}
result = run_dsps_model_jax(context, params)
derived = jnp.asarray(
[
result.t_obs_gyr,
result.formed_mass_msun,
jnp.log10(jnp.maximum(result.formed_mass_msun, 1.0e-300)),
result.sfr_at_obs_msun_per_yr,
jnp.log10(jnp.maximum(result.sfr_at_obs_msun_per_yr, 1.0e-300)),
]
)
return result.rest_sed, result.dusted_rest_sed, result.model_mags, derived
_BATCH_PREDICT_CACHE[cache_key] = jax.jit(jax.vmap(single))
predict = _BATCH_PREDICT_CACHE[cache_key]
rest_sed, dusted_rest_sed, model_mags, derived_values = predict(
jnp.asarray(parameter_matrix, dtype=jnp.float32)
)
derived_array = np.asarray(derived_values)
return BatchSedResult(
parameter_names=list(parameter_names),
parameter_matrix=np.asarray(parameter_matrix, dtype=float),
wave=np.asarray(_context_ssp_wave(context), dtype=float),
rest_sed=np.asarray(rest_sed, dtype=float),
dusted_rest_sed=np.asarray(dusted_rest_sed, dtype=float),
model_mags=np.asarray(model_mags, dtype=float),
derived={
name: derived_array[:, index]
for index, name in enumerate(DERIVED_QUANTITY_NAMES)
},
)
[docs]
def build_lognormal_sfh(
gal_t_table: np.ndarray,
log10_sfr: float,
sfh_t_peak: float,
sfh_tau: float,
) -> np.ndarray:
"""Build a positive SFH in Msun/yr on cosmic-time bins."""
return np.asarray(
build_lognormal_sfh_jax(
gal_t_table,
log10_sfr,
sfh_t_peak,
sfh_tau,
),
dtype=float,
)
[docs]
def build_sfh_table_jax(
gal_t_table: jnp.ndarray, params: dict[str, Any]
) -> jnp.ndarray:
"""Build the simple production SFH table without leaving JAX."""
return build_lognormal_sfh_jax(
gal_t_table=gal_t_table,
log10_sfr=jnp.asarray(params["log10_sfr"], dtype=jnp.float32),
sfh_t_peak=jnp.asarray(params["sfh_t_peak"], dtype=jnp.float32),
sfh_tau=jnp.asarray(params["sfh_tau"], dtype=jnp.float32),
)
[docs]
def build_lognormal_sfh_jax(
gal_t_table: jnp.ndarray,
log10_sfr: jnp.ndarray,
sfh_t_peak: jnp.ndarray,
sfh_tau: jnp.ndarray,
) -> jnp.ndarray:
"""JAX lognormal SFH used by production fits."""
amplitude = 10**log10_sfr
t_peak = jnp.clip(sfh_t_peak, jnp.min(gal_t_table), jnp.max(gal_t_table))
tau = jnp.maximum(sfh_tau, 0.05)
log_t = jnp.log(jnp.clip(gal_t_table, 1.0e-3))
shape = jnp.exp(-0.5 * ((log_t - jnp.log(t_peak)) / tau) ** 2)
shape = jnp.clip(shape, 1.0e-6)
return jnp.clip(amplitude * shape, 1.0e-12, jnp.inf)
[docs]
def normalize_sfh_mass_jax(
gal_t_table: jnp.ndarray, gal_sfr_table: jnp.ndarray, params: dict[str, Any]
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Optionally scale an SFH to a configured formed stellar mass.
Without ``log10_formed_mass_msun`` this preserves the historical behavior,
where ``log10_sfr`` is the SFH amplitude. With it, ``log10_sfr`` only sets
the pre-normalization shape scale and the luminosity amplitude is controlled
by the formed-mass parameter.
"""
formed_mass = jnp.trapezoid(gal_sfr_table, gal_t_table) * 1.0e9
if "log10_formed_mass_msun" not in params:
return gal_sfr_table, formed_mass
target_mass = 10.0 ** jnp.asarray(
params["log10_formed_mass_msun"], dtype=jnp.float32
)
scale = target_mass / jnp.maximum(formed_mass, 1.0e-30)
scaled_sfr = jnp.clip(gal_sfr_table * scale, 1.0e-12, jnp.inf)
scaled_mass = jnp.trapezoid(scaled_sfr, gal_t_table) * 1.0e9
return scaled_sfr, scaled_mass
[docs]
def apply_dust(
wave_angstrom: np.ndarray, rest_sed: np.ndarray, params: dict[str, float]
) -> np.ndarray:
"""Apply the configured attenuation model."""
return np.asarray(apply_dust_jax(wave_angstrom, rest_sed, params), dtype=float)
[docs]
def apply_dust_jax(
wave_angstrom: jnp.ndarray,
rest_sed: jnp.ndarray,
params: dict[str, Any],
cosmos_dust_k_by_code: np.ndarray | None = None,
) -> jnp.ndarray:
"""Apply COSMOS two-component dust when available, else DSPS Salim dust."""
if cosmos_dust_k_by_code is not None:
return apply_cosmos_two_component_dust_jax(
rest_sed, params, cosmos_dust_k_by_code
)
return apply_salim_dust_jax(wave_angstrom, rest_sed, params)
[docs]
def apply_salim_dust_jax(
wave_angstrom: jnp.ndarray, rest_sed: jnp.ndarray, params: dict[str, Any]
) -> jnp.ndarray:
"""Apply DSPS Salim+2018-style attenuation without leaving JAX."""
from dsps.dust.att_curves import _frac_transmission_from_k_lambda, sbl18_k_lambda
av = jnp.maximum(jnp.asarray(params.get("dust_av", 0.0)), 0.0)
wave_micron = jnp.asarray(wave_angstrom) / 10_000.0
k_lambda = sbl18_k_lambda(
wave_micron,
0.0,
jnp.asarray(params.get("dust_slope", -0.7)),
)
transmission = _frac_transmission_from_k_lambda(k_lambda, av)
return jnp.asarray(rest_sed) * transmission
[docs]
def apply_cosmos_two_component_dust_jax(
rest_sed: jnp.ndarray, params: dict[str, Any], cosmos_dust_k_by_code: np.ndarray
) -> jnp.ndarray:
"""Apply the two COSMOS dust curves as a differentiable mixture."""
k_grid = jnp.asarray(cosmos_dust_k_by_code)
n_codes = k_grid.shape[0]
code_1 = jnp.clip(
jnp.rint(jnp.asarray(params.get("cosmos_ext_curve_1", 0.0))).astype(jnp.int32),
0,
n_codes - 1,
)
code_2 = jnp.clip(
jnp.rint(jnp.asarray(params.get("cosmos_ext_curve_2", 0.0))).astype(jnp.int32),
0,
n_codes - 1,
)
ebv_1 = jnp.maximum(jnp.asarray(params.get("cosmos_ebv_1", 0.0)), 0.0)
ebv_2 = jnp.maximum(jnp.asarray(params.get("cosmos_ebv_2", 0.0)), 0.0)
frac_1 = jnp.maximum(jnp.asarray(params.get("cosmos_frac_1", 0.5)), 0.0)
frac_2 = jnp.maximum(jnp.asarray(params.get("cosmos_frac_2", 0.5)), 0.0)
frac_sum = frac_1 + frac_2
frac_1 = jnp.where(frac_sum > 0.0, frac_1 / frac_sum, 0.5)
frac_2 = jnp.where(frac_sum > 0.0, frac_2 / frac_sum, 0.5)
trans_1 = 10.0 ** (-0.4 * ebv_1 * k_grid[code_1])
trans_2 = 10.0 ** (-0.4 * ebv_2 * k_grid[code_2])
return jnp.asarray(rest_sed) * (frac_1 * trans_1 + frac_2 * trans_2)
def _safe_log10(value: float) -> float:
return float(np.log10(value)) if value > 0 else float("nan")
[docs]
def comparison_rows(
observation: GalaxyObservation, result: ModelResult
) -> list[dict[str, float | str]]:
rows = []
for observed in observation.bands:
model = result.photometry[observed.name]
residual = observed.mag_ab - model["model_mag_ab"]
model_flux = float(model["model_flux_fnu_cgs"])
observed_flux = float(observed.flux_fnu_cgs)
flux_error = observed.flux_error_fnu_cgs
if flux_error is None or not np.isfinite(flux_error) or flux_error <= 0:
flux_error = abs(observed_flux) * np.log(10.0) * 0.4 * observed.sigma_mag
flux_ratio = (
model_flux / observed_flux
if observed_flux > 0
else float("nan")
)
chi_flux = (
(model_flux - observed_flux) / flux_error
if flux_error > 0
else float("nan")
)
rows.append(
{
"band": observed.name,
"column": observed.column,
"effective_wavelength_angstrom": model["effective_wavelength_angstrom"],
"observed_flux_fnu_cgs": observed_flux,
"observed_flux_error_fnu_cgs": float(flux_error),
"observed_mag_ab": observed.mag_ab,
"sigma_mag": observed.sigma_mag,
"model_flux_fnu_cgs": model_flux,
"model_mag_ab": model["model_mag_ab"],
"residual_mag_observed_minus_model": residual,
"residual_mag_model_minus_observed": -residual,
"flux_ratio_model_over_observed": flux_ratio,
"fractional_flux_residual_model_minus_observed": flux_ratio - 1.0,
"chi": residual / observed.sigma_mag,
"chi_flux": chi_flux,
"filter_source": model["filter_source"],
}
)
return rows