Source code for euclid_dsps.mcmc

"""MCMC posterior sampling for selected galaxies."""

# ruff: noqa: I001, E402

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from .jax_runtime import configure_jax_runtime

configure_jax_runtime()

import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.distributions import constraints
from numpyro.infer import HMC, MCMC, NUTS
from numpyro.infer.initialization import init_to_value

from .fit import _initial_value
from .io import GalaxyObservation
from .model import (
    DspsContext,
    model_mags_jax,
    predict_batch_derived,
    predict_batch_mags,
)
from .photometry import abmag_to_fnu_cgs_jax, magerr_to_fluxerr_fnu_cgs


[docs] class ScaledBetaDistribution(dist.Distribution): """Beta distribution scaled to a finite interval.""" arg_constraints = { "alpha": constraints.positive, "beta": constraints.positive, "low": constraints.real, "high": constraints.real, } reparametrized_params = ["alpha", "beta"] def __init__( self, alpha: float, beta: float, low: float, high: float, validate_args: bool | None = None, ): self.alpha = jnp.asarray(alpha) self.beta = jnp.asarray(beta) self.low = jnp.asarray(low) self.high = jnp.asarray(high) self._beta = dist.Beta(self.alpha, self.beta) super().__init__(batch_shape=(), event_shape=(), validate_args=validate_args) @constraints.dependent_property def support(self): return constraints.interval(self.low, self.high)
[docs] def sample(self, key, sample_shape=()): unit = self._beta.sample(key, sample_shape) return self.low + (self.high - self.low) * unit
[docs] def log_prob(self, value): unit = (value - self.low) / (self.high - self.low) return self._beta.log_prob(unit) - jnp.log(self.high - self.low)
[docs] @dataclass(frozen=True) class MCMCResult: samples: dict[str, np.ndarray] derived_samples: dict[str, np.ndarray] summary: list[dict[str, float | str]] posterior_model_mags: np.ndarray observed_mag: np.ndarray sigma_mag: np.ndarray observed_flux_fnu_cgs: np.ndarray flux_error_fnu_cgs: np.ndarray band_names: list[str] diagnostics: dict[str, Any]
[docs] def sample_one_galaxy( context: DspsContext, observation: GalaxyObservation, base_params: dict[str, float], fit_config: dict[str, Any], sample_config: dict[str, Any], initial_params: dict[str, float] | None = None, ) -> MCMCResult: """Sample posterior over configured free parameters with NumPyro HMC/NUTS.""" observed_mag = jnp.asarray([band.mag_ab for band in observation.bands], dtype=float) sigma_mag = jnp.asarray([band.sigma_mag for band in observation.bands], dtype=float) observed_flux = jnp.asarray( [band.flux_fnu_cgs for band in observation.bands], dtype=float ) flux_error = jnp.asarray( [ ( band.flux_error_fnu_cgs if band.flux_error_fnu_cgs is not None else magerr_to_fluxerr_fnu_cgs(band.flux_fnu_cgs, band.sigma_mag) ) for band in observation.bands ], dtype=float, ) if str(fit_config.get("likelihood_space", "flux")).lower() == "flux": floor_frac = float(fit_config.get("flux_error_floor_frac", 0.0)) jitter = float(fit_config.get("flux_error_jitter", 0.0)) observed = observed_flux sigma = jnp.sqrt(flux_error**2 + (floor_frac * observed_flux) ** 2 + jitter**2) finite = jnp.isfinite(observed) & jnp.isfinite(sigma) & (sigma > 0) else: observed = observed_mag sigma = sigma_mag finite = jnp.isfinite(observed_mag) & jnp.isfinite(sigma_mag) & (sigma_mag > 0) band_names = [band.name for band in observation.bands] band_offsets = jnp.asarray( fit_config.get("band_calibration_offsets_mag", []), dtype=float ) free = fit_config["free_parameters"] free_names = list(free) priors = sample_config.get("priors", {}) def model(): params = {key: jnp.asarray(value) for key, value in base_params.items()} for name in free_names: prior_spec = priors.get(name, {}) params[name] = numpyro.sample( name, _prior_distribution(name, free[name], prior_spec, base_params), ) model_mag = model_mags_jax(context, params) if band_offsets.size: model_mag = model_mag + band_offsets numpyro.deterministic("model_mag", model_mag) if str(fit_config.get("likelihood_space", "flux")).lower() == "flux": model_obs = abmag_to_fnu_cgs_jax(model_mag) numpyro.deterministic("model_flux_fnu_cgs", model_obs) else: model_obs = model_mag numpyro.sample("obs", dist.Normal(model_obs, sigma).mask(finite), obs=observed) init_params = _initial_params(initial_params, free, free_names) kernel_kwargs = {} if init_params: kernel_kwargs["init_strategy"] = init_to_value(values=init_params) sampler = str(sample_config.get("sampler", "nuts")).lower() kernel = _build_kernel(model, sampler, sample_config, kernel_kwargs) mcmc = MCMC( kernel, num_warmup=int(sample_config.get("num_warmup", 100)), num_samples=int(sample_config.get("num_samples", 200)), num_chains=int(sample_config.get("num_chains", 1)), chain_method=str(sample_config.get("chain_method", "parallel")), progress_bar=bool(sample_config.get("progress_bar", True)), jit_model_args=bool(sample_config.get("jit_model_args", False)), ) mcmc.run( random.PRNGKey(int(sample_config.get("seed", 42))), extra_fields=("diverging", "accept_prob", "num_steps"), ) samples = { name: np.asarray(values) for name, values in mcmc.get_samples().items() if name in free_names } posterior_model_mags = _posterior_model_mags( context, base_params, samples, fit_config ) derived_samples = _posterior_derived(context, base_params, samples) return MCMCResult( samples=samples, derived_samples=derived_samples, summary=_sample_summary(samples), posterior_model_mags=posterior_model_mags, observed_mag=np.asarray(observed_mag), sigma_mag=np.asarray(sigma_mag), observed_flux_fnu_cgs=np.asarray(observed_flux), flux_error_fnu_cgs=np.asarray(flux_error), band_names=band_names, diagnostics=_diagnostics( mcmc, sample_config=sample_config, sampler=sampler, initial_params=init_params, likelihood_space=str(fit_config.get("likelihood_space", "flux")).lower(), ), )
def _build_kernel( model, sampler: str, sample_config: dict[str, Any], kernel_kwargs: dict[str, Any], ): common_kwargs = { "target_accept_prob": float(sample_config.get("target_accept_prob", 0.85)), "dense_mass": bool(sample_config.get("dense_mass", False)), **kernel_kwargs, } step_size = sample_config.get("step_size") if step_size is not None: common_kwargs["step_size"] = float(step_size) if sampler == "nuts": return NUTS( model, max_tree_depth=int(sample_config.get("max_tree_depth", 10)), **common_kwargs, ) if sampler == "hmc": hmc_kwargs = dict(common_kwargs) num_steps = sample_config.get("num_steps") trajectory_length = sample_config.get("trajectory_length") if num_steps is not None: hmc_kwargs["num_steps"] = int(num_steps) hmc_kwargs["trajectory_length"] = None elif trajectory_length is not None: hmc_kwargs["trajectory_length"] = float(trajectory_length) else: hmc_kwargs["num_steps"] = 8 hmc_kwargs["trajectory_length"] = None return HMC(model, **hmc_kwargs) raise ValueError(f"Unsupported MCMC sampler: {sampler}. Use 'nuts' or 'hmc'.") def _initial_params( initial_params: dict[str, float] | None, free: dict[str, Any], free_names: list[str], ) -> dict[str, jnp.ndarray] | None: if not initial_params: return None values = {} for name in free_names: if name not in initial_params or not np.isfinite(initial_params[name]): return None low, high = [float(value) for value in free[name]["bounds"]] eps = max((high - low) * 1.0e-6, 1.0e-8) values[name] = jnp.asarray( np.clip(float(initial_params[name]), low + eps, high - eps) ) return values def _prior_distribution( name: str, fit_spec: dict[str, Any], prior_spec: dict[str, Any], base_params: dict[str, float], ): low, high = [float(value) for value in fit_spec["bounds"]] loc = _prior_location(name, fit_spec, prior_spec, base_params) scale = _prior_scale(name, prior_spec, base_params, max((high - low) / 4.0, 1.0e-3)) prior_type = str(prior_spec.get("type", "truncated_normal")) if prior_type == "uniform": return dist.Uniform(low, high) if prior_type == "normal": return dist.Normal(loc, scale) if prior_type == "truncated_normal": return dist.TruncatedNormal(loc=loc, scale=scale, low=low, high=high) if prior_type == "scaled_beta": alpha = float(prior_spec.get("alpha", 1.0)) beta = float(prior_spec.get("beta", 1.0)) return ScaledBetaDistribution(alpha=alpha, beta=beta, low=low, high=high) raise ValueError(f"Unsupported prior type for {name}: {prior_type}") def _prior_location( name: str, fit_spec: dict[str, Any], prior_spec: dict[str, Any], base_params: dict[str, float], ) -> float: value = prior_spec.get("loc", _initial_value(fit_spec, name, base_params)) if value == "from_base": return float(base_params[name]) return float(value) def _prior_scale( name: str, prior_spec: dict[str, Any], base_params: dict[str, float], fallback: float, ) -> float: value = prior_spec.get("scale", fallback) if value == "from_base": scale_name = str(prior_spec.get("scale_parameter", f"{name}_prior_sigma")) return max(float(base_params.get(scale_name, fallback)), 1.0e-6) return max(float(value), 1.0e-6) def _posterior_model_mags( context: DspsContext, base_params: dict[str, float], samples: dict[str, np.ndarray], fit_config: dict[str, Any], ) -> np.ndarray: parameter_names, matrix = _posterior_parameter_matrix(base_params, samples) mags = predict_batch_mags(context, parameter_names, matrix) offsets = np.asarray(fit_config.get("band_calibration_offsets_mag", []), dtype=float) if offsets.size: mags = mags + offsets return mags def _posterior_derived( context: DspsContext, base_params: dict[str, float], samples: dict[str, np.ndarray] ) -> dict[str, np.ndarray]: parameter_names, matrix = _posterior_parameter_matrix(base_params, samples) return predict_batch_derived(context, parameter_names, matrix) def _posterior_parameter_matrix( base_params: dict[str, float], samples: dict[str, np.ndarray] ) -> tuple[list[str], np.ndarray]: parameter_names = list(base_params) n_samples = len(next(iter(samples.values()))) matrix = np.asarray( [[float(base_params[name]) for name in parameter_names]] * n_samples, dtype=float, ) for name, values in samples.items(): matrix[:, parameter_names.index(name)] = values return parameter_names, matrix def _sample_summary(samples: dict[str, np.ndarray]) -> list[dict[str, float | str]]: rows = [] for name, values in samples.items(): finite = values[np.isfinite(values)] rows.append( { "parameter": name, "mean": float(np.mean(finite)), "std": float(np.std(finite)), "q05": float(np.quantile(finite, 0.05)), "q16": float(np.quantile(finite, 0.16)), "median": float(np.quantile(finite, 0.50)), "q84": float(np.quantile(finite, 0.84)), "q95": float(np.quantile(finite, 0.95)), } ) return rows def _diagnostics( mcmc: MCMC, sample_config: dict[str, Any], sampler: str, initial_params: dict[str, jnp.ndarray] | None = None, likelihood_space: str = "flux", ) -> dict[str, Any]: extra = mcmc.get_extra_fields() diagnostics: dict[str, Any] = {} if "diverging" in extra: diagnostics["n_divergent"] = int(np.asarray(extra["diverging"]).sum()) if "accept_prob" in extra: diagnostics["mean_accept_prob"] = float(np.asarray(extra["accept_prob"]).mean()) if "num_steps" in extra: num_steps = np.asarray(extra["num_steps"]) diagnostics["mean_num_steps"] = float(np.mean(num_steps)) diagnostics["max_num_steps"] = int(np.max(num_steps)) diagnostics["n_samples"] = int(len(next(iter(mcmc.get_samples().values())))) diagnostics["backend"] = f"numpyro_{sampler}" diagnostics["sampler"] = sampler diagnostics["likelihood_space"] = likelihood_space diagnostics["num_warmup"] = int(sample_config.get("num_warmup", 100)) diagnostics["num_chains"] = int(sample_config.get("num_chains", 1)) diagnostics["chain_method"] = str(sample_config.get("chain_method", "parallel")) if sampler == "nuts": diagnostics["max_tree_depth"] = int(sample_config.get("max_tree_depth", 10)) if sampler == "hmc": diagnostics["num_steps"] = int(sample_config.get("num_steps", 8)) diagnostics["device"] = f"{jax.devices()[0].platform}:{jax.devices()[0].id}" diagnostics["initialized_from_map"] = bool(initial_params) if initial_params: diagnostics["initial_parameters"] = { name: float(value) for name, value in initial_params.items() } return diagnostics