Source code for euclid_dsps.fit

"""Single-galaxy fitting helpers."""

# ruff: noqa: I001, E402

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from .jax_runtime import configure_jax_runtime

configure_jax_runtime()

import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.optimize import minimize as jax_minimize

from .io import GalaxyObservation
from .model import DspsContext, ModelResult, model_mags_jax, run_dsps_model
from .photometry import abmag_to_fnu_cgs, abmag_to_fnu_cgs_jax, magerr_to_fluxerr_fnu_cgs


_FIT_DTYPE = jnp.float32


[docs] @dataclass(frozen=True) class FitResult: success: bool message: str best_parameters: dict[str, float] chi2: float n_bands: int trace: list[dict[str, float]] model_result: ModelResult gradient_norm: float
[docs] @dataclass(frozen=True) class BatchFitResult: success: np.ndarray message: str parameter_names: list[str] free_parameter_names: list[str] best_parameter_matrix: np.ndarray chi2: np.ndarray gradient_norm: np.ndarray model_mags: np.ndarray trace: list[dict[str, float]] device: str
[docs] @dataclass(frozen=True) class PopulationFitResult: batch: BatchFitResult hyper_mu: dict[str, float] hyper_sigma: dict[str, float] hyper_relations: list[dict[str, Any]] loss: float
[docs] def fit_one_galaxy( context: DspsContext, observation: GalaxyObservation, base_params: dict[str, float], fit_config: dict[str, Any], ) -> FitResult: """Fit configured DSPS parameters with pure-JAX gradients.""" method = str(fit_config.get("method", "jax_adam")).lower() if method in {"jax_adam", "jax_adam_vmap"}: observed_mag, sigma_mag, _ = _observation_arrays(observation) observed_flux, flux_error = _observation_flux_arrays(observation) batch = fit_galaxy_batch_adam( context=context, base_params_rows=[base_params], observed_mag=np.asarray(observed_mag)[None, :], sigma_mag=np.asarray(sigma_mag)[None, :], fit_config=fit_config, observed_flux=np.asarray(observed_flux)[None, :], flux_error=np.asarray(flux_error)[None, :], ) best_params = { name: float(batch.best_parameter_matrix[0, index]) for index, name in enumerate(batch.parameter_names) } best_model = run_dsps_model(context, best_params) return FitResult( success=bool(batch.success[0]), message=batch.message, best_parameters=best_params, chi2=float(batch.chi2[0]), n_bands=len(observation.bands), trace=batch.trace, model_result=best_model, gradient_norm=float(batch.gradient_norm[0]), ) free = fit_config["free_parameters"] names = list(free) x0 = jnp.asarray( [_initial_value(free[name], name, base_params) for name in names], dtype=float ) bounds = jnp.asarray( [tuple(float(x) for x in free[name]["bounds"]) for name in names], dtype=float ) lower = bounds[:, 0] upper = bounds[:, 1] x0 = jnp.clip(x0, lower, upper) base_parameter_names = list(base_params) prior_arrays = _prepare_prior_arrays( fit_config.get("priors", {}), names, base_parameter_names, np.asarray( [[float(base_params[name]) for name in base_parameter_names]], dtype=float ), np.asarray(bounds, dtype=float), ) prior_arrays_jax = { key: jnp.asarray(value[0]) for key, value in prior_arrays.items() } observed_mag, sigma_mag, _ = _observation_arrays(observation) observed_flux, flux_error = _observation_flux_arrays(observation) observed_fit, sigma_fit, finite_mask = _fit_arrays_for_likelihood( np.asarray(observed_mag)[None, :], np.asarray(sigma_mag)[None, :], fit_config, observed_flux=np.asarray(observed_flux)[None, :], flux_error=np.asarray(flux_error)[None, :], ) observed_fit = jnp.asarray(observed_fit[0], dtype=float) sigma_fit = jnp.asarray(sigma_fit[0], dtype=float) finite_mask = jnp.asarray(finite_mask[0]) maxiter = int(fit_config.get("maxiter", 80)) learning_rate = float(fit_config.get("learning_rate", 0.03)) tolerance = float(fit_config.get("tolerance", 1.0e-5)) patience = int(fit_config.get("patience", 12)) trace: list[dict[str, float]] = [] def unpack_jax(x: jnp.ndarray) -> dict[str, Any]: params: dict[str, Any] = { key: jnp.asarray(value) for key, value in base_params.items() } params.update({name: value for name, value in zip(names, x, strict=True)}) return params def objective(x: jnp.ndarray) -> jnp.ndarray: params = unpack_jax(x) model_mag = model_mags_jax(context, params) model_values = _model_values_for_likelihood(model_mag, fit_config) chi = jnp.where(finite_mask, (observed_fit - model_values) / sigma_fit, 0.0) chi2 = jnp.sum(chi**2) prior = _physical_prior_penalty( x, lower, upper, prior_arrays_jax["prior_gaussian_mask"], prior_arrays_jax["prior_gaussian_loc"], prior_arrays_jax["prior_gaussian_scale"], prior_arrays_jax["prior_beta_mask"], prior_arrays_jax["prior_beta_alpha"], prior_arrays_jax["prior_beta_beta"], ) objective_value = chi2 + float(fit_config.get("prior_weight", 1.0)) * prior return jnp.nan_to_num(objective_value, nan=1.0e30, posinf=1.0e30, neginf=1.0e30) value_and_grad = jax.jit(jax.value_and_grad(objective)) if method == "jax_adam": best_x, state, best_value, best_grad_norm, success, message = _fit_bounded_adam( value_and_grad=value_and_grad, x0=x0, lower=lower, upper=upper, maxiter=maxiter, learning_rate=learning_rate, tolerance=tolerance, patience=patience, names=names, trace=trace, ) elif method == "jax_bfgs": best_x, state, best_value, best_grad_norm, success, message = _fit_bounded_bfgs( objective=objective, value_and_grad=value_and_grad, x0=x0, lower=lower, upper=upper, maxiter=maxiter, tolerance=tolerance, names=names, trace=trace, ) else: raise ValueError(f"Unsupported fit.method: {fit_config.get('method')}") best_params = _unpack_numpy(best_x, names, base_params) best_model = run_dsps_model(context, best_params) return FitResult( success=success, message=message, best_parameters=best_params, chi2=float(best_value), n_bands=len(observation.bands), trace=trace, model_result=best_model, gradient_norm=float(best_grad_norm), )
def _initial_value( spec: dict[str, Any], name: str, base_params: dict[str, float] ) -> float: """Allow YAML configs to use `initial: from_base` for row-dependent values.""" value = spec.get("initial", base_params.get(name, 0.0)) if isinstance(value, str): if value != "from_base": raise ValueError(f"Unsupported initial value for {name}: {value}") value = base_params[name] return float(value) def _observation_arrays( observation: GalaxyObservation, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: observed_mag = jnp.asarray([band.mag_ab for band in observation.bands], dtype=float) sigma_mag = jnp.asarray([band.sigma_mag for band in observation.bands], dtype=float) finite_mask = jnp.isfinite(observed_mag) & jnp.isfinite(sigma_mag) & (sigma_mag > 0) return observed_mag, sigma_mag, finite_mask def _observation_flux_arrays(observation: GalaxyObservation) -> tuple[np.ndarray, np.ndarray]: observed_flux = np.asarray( [band.flux_fnu_cgs for band in observation.bands], dtype=float ) flux_error = np.asarray( [ ( band.flux_error_fnu_cgs if band.flux_error_fnu_cgs is not None else magerr_to_fluxerr_fnu_cgs(band.flux_fnu_cgs, band.sigma_mag) ) for band in observation.bands ], dtype=float, ) return observed_flux, flux_error def _likelihood_space(fit_config: dict[str, Any]) -> str: return str(fit_config.get("likelihood_space", "flux")).lower() def _fit_arrays_for_likelihood( observed_mag: np.ndarray, sigma_mag: np.ndarray, fit_config: dict[str, Any], observed_flux: np.ndarray | None = None, flux_error: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: likelihood_space = _likelihood_space(fit_config) if likelihood_space == "mag": observed = np.asarray(observed_mag, dtype=float) sigma = np.asarray(sigma_mag, dtype=float) elif likelihood_space == "flux": if observed_flux is None: observed = np.asarray(abmag_to_fnu_cgs(observed_mag), dtype=float) else: observed = np.asarray(observed_flux, dtype=float) if flux_error is None: sigma = np.asarray( magerr_to_fluxerr_fnu_cgs(observed, sigma_mag), dtype=float ) else: sigma = np.asarray(flux_error, dtype=float) floor_frac = float(fit_config.get("flux_error_floor_frac", 0.0)) jitter = float(fit_config.get("flux_error_jitter", 0.0)) sigma = np.sqrt(sigma**2 + (floor_frac * observed) ** 2 + jitter**2) else: raise ValueError(f"Unsupported fit.likelihood_space: {likelihood_space}") mask = np.isfinite(observed) & np.isfinite(sigma) & (sigma > 0.0) return observed, sigma, mask def _model_values_for_likelihood(model_mag: jnp.ndarray, fit_config: dict[str, Any]): model_mag = _apply_band_calibration(model_mag, fit_config) if _likelihood_space(fit_config) == "flux": return abmag_to_fnu_cgs_jax(model_mag) return model_mag def _apply_band_calibration(model_mag: jnp.ndarray, fit_config: dict[str, Any]): offsets = fit_config.get("band_calibration_offsets_mag") or [] if not offsets: return model_mag return model_mag + jnp.asarray(offsets, dtype=model_mag.dtype) def _fit_bounded_adam( value_and_grad, x0: jnp.ndarray, lower: jnp.ndarray, upper: jnp.ndarray, maxiter: int, learning_rate: float, tolerance: float, patience: int, names: list[str], trace: list[dict[str, float]], ): beta1 = 0.9 beta2 = 0.999 eps = 1.0e-8 x = x0 m = jnp.zeros_like(x) v = jnp.zeros_like(x) best_x = x best_value = np.inf best_grad_norm = np.inf stalled = 0 for iteration in range(1, maxiter + 1): value, grad = value_and_grad(x) value_f = float(value) grad_norm = float(jnp.linalg.norm(grad)) trace.append(_trace_entry(names, x, value_f, grad_norm, iteration)) if value_f + tolerance < best_value: best_x = x best_value = value_f best_grad_norm = grad_norm stalled = 0 else: stalled += 1 if stalled >= patience: break grad = jnp.nan_to_num(grad, nan=0.0, posinf=0.0, neginf=0.0) m = beta1 * m + (1.0 - beta1) * grad v = beta2 * v + (1.0 - beta2) * (grad**2) m_hat = m / (1.0 - beta1**iteration) v_hat = v / (1.0 - beta2**iteration) x = jnp.clip(x - learning_rate * m_hat / (jnp.sqrt(v_hat) + eps), lower, upper) state = {"iteration": iteration, "x": x} success = bool(np.isfinite(best_value)) and stalled < patience message = ( f"jax_adam converged: chi2 improvement < {tolerance} for {patience} steps" if success else f"jax_adam stopped after {state['iteration']} iterations" ) return best_x, state, best_value, best_grad_norm, success, message def _fit_bounded_bfgs( objective, value_and_grad, x0: jnp.ndarray, lower: jnp.ndarray, upper: jnp.ndarray, maxiter: int, tolerance: float, names: list[str], trace: list[dict[str, float]], ): y0 = _bounded_to_unconstrained(x0, lower, upper) def unconstrained_objective(y: jnp.ndarray) -> jnp.ndarray: return objective(_unconstrained_to_bounded(y, lower, upper)) initial_value, initial_grad = value_and_grad(x0) trace.append( _trace_entry( names, x0, float(initial_value), float(jnp.linalg.norm(initial_grad)), 0 ) ) opt = jax_minimize( jax.jit(unconstrained_objective), y0, method="BFGS", tol=tolerance, options={"maxiter": maxiter}, ) best_x = _unconstrained_to_bounded(opt.x, lower, upper) best_value, best_grad = value_and_grad(best_x) best_value_f = float(best_value) best_grad_norm = float(jnp.linalg.norm(best_grad)) trace.append( _trace_entry( names, best_x, best_value_f, best_grad_norm, int(getattr(opt, "nit", maxiter)), ) ) success = bool(getattr(opt, "success", False)) and np.isfinite(best_value_f) message = ( f"jax_bfgs status={int(getattr(opt, 'status', -1))}, " f"nit={int(getattr(opt, 'nit', -1))}" ) state = {"iteration": int(getattr(opt, "nit", -1)), "x": best_x} return best_x, state, best_value_f, best_grad_norm, success, message def _bounded_to_unconstrained( x: jnp.ndarray, lower: jnp.ndarray, upper: jnp.ndarray ) -> jnp.ndarray: eps = 1.0e-6 scaled = jnp.clip((x - lower) / (upper - lower), eps, 1.0 - eps) return jnp.log(scaled) - jnp.log1p(-scaled) def _unconstrained_to_bounded( y: jnp.ndarray, lower: jnp.ndarray, upper: jnp.ndarray ) -> jnp.ndarray: return lower + (upper - lower) * jax.nn.sigmoid(y) def _trace_entry( names: list[str], x: jnp.ndarray, chi2: float, grad_norm: float, iteration: int ) -> dict[str, float]: entry = { name: float(value) for name, value in zip(names, np.asarray(x), strict=True) } entry["iteration"] = float(iteration) entry["chi2"] = float(chi2) entry["gradient_norm"] = float(grad_norm) return entry def _unpack_numpy( x: jnp.ndarray, names: list[str], base_params: dict[str, float] ) -> dict[str, float]: params = dict(base_params) params.update( {name: float(value) for name, value in zip(names, np.asarray(x), strict=True)} ) return params _OPTIMIZER_CACHE = {}
[docs] def fit_galaxy_batch_adam( context: DspsContext, base_params_rows: list[dict[str, float]], observed_mag: np.ndarray, sigma_mag: np.ndarray, fit_config: dict[str, Any], truth_theta: np.ndarray | None = None, observed_flux: np.ndarray | None = None, flux_error: np.ndarray | None = None, initial_theta: np.ndarray | None = None, ) -> BatchFitResult: """Fit many independent galaxies in one JAX-vmapped Adam run.""" setup = _prepare_batch_fit(base_params_rows, fit_config, initial_theta=initial_theta) truth_theta_arr, has_truth = _prepare_truth_theta(setup["theta0"], truth_theta) theta0 = jnp.asarray(setup["theta0"], dtype=_FIT_DTYPE) base_matrix = jnp.asarray(setup["base_matrix"], dtype=_FIT_DTYPE) observed_values, sigma_values, valid_mask = _fit_arrays_for_likelihood( observed_mag, sigma_mag, fit_config, observed_flux=observed_flux, flux_error=flux_error, ) observed = jnp.asarray(observed_values, dtype=_FIT_DTYPE) sigma = jnp.asarray(sigma_values, dtype=_FIT_DTYPE) mask = jnp.asarray(valid_mask) warm_observed_mag = jnp.asarray(observed_mag, dtype=_FIT_DTYPE) warm_mask = jnp.isfinite(warm_observed_mag) & jnp.isfinite( jnp.asarray(sigma_mag, dtype=_FIT_DTYPE) ) lower = jnp.asarray(setup["lower"], dtype=_FIT_DTYPE) upper = jnp.asarray(setup["upper"], dtype=_FIT_DTYPE) maxiter = int(fit_config.get("maxiter", 80)) learning_rate = float(fit_config.get("learning_rate", 0.03)) prior_weight = float(fit_config.get("prior_weight", 1.0)) band_offsets_mag = tuple( float(value) for value in fit_config.get("band_calibration_offsets_mag", []) ) cache_key = ( "independent", id(context), tuple(setup["parameter_names"]), tuple(setup["free_names"]), maxiter, learning_rate, prior_weight, _likelihood_space(fit_config), band_offsets_mag, ) if cache_key not in _OPTIMIZER_CACHE: _OPTIMIZER_CACHE[cache_key] = _build_independent_adam_optimizer( context=context, parameter_names=setup["parameter_names"], free_indices=setup["free_indices"], maxiter=maxiter, learning_rate=learning_rate, prior_weight=prior_weight, likelihood_space=_likelihood_space(fit_config), band_offsets_mag=band_offsets_mag, ) optimize = _OPTIMIZER_CACHE[cache_key] best_theta, chi2, grad_norm, model_mags, trace_arrays = optimize( theta0, base_matrix, observed, sigma, mask, warm_observed_mag, warm_mask, lower, upper, jnp.asarray(setup["prior_gaussian_mask"]), jnp.asarray(setup["prior_gaussian_loc"], dtype=_FIT_DTYPE), jnp.asarray(setup["prior_gaussian_scale"], dtype=_FIT_DTYPE), jnp.asarray(setup["prior_beta_mask"]), jnp.asarray(setup["prior_beta_alpha"], dtype=_FIT_DTYPE), jnp.asarray(setup["prior_beta_beta"], dtype=_FIT_DTYPE), jnp.asarray(truth_theta_arr, dtype=_FIT_DTYPE), ) best_matrix = _apply_free_values( base_matrix, best_theta, jnp.asarray(setup["free_indices"]) ) trace = _batch_trace_from_arrays( trace_arrays, setup["free_names"], include_truth_metrics=has_truth ) return BatchFitResult( success=np.isfinite(np.asarray(chi2)), message=f"jax_adam_vmap maxiter={maxiter}, device={_jax_device()}", parameter_names=setup["parameter_names"], free_parameter_names=setup["free_names"], best_parameter_matrix=np.asarray(best_matrix), chi2=np.asarray(chi2), gradient_norm=np.asarray(grad_norm), model_mags=np.asarray(model_mags), trace=trace, device=_jax_device(), )
[docs] def fit_population_batch_adam( context: DspsContext, base_params_rows: list[dict[str, float]], observed_mag: np.ndarray, sigma_mag: np.ndarray, fit_config: dict[str, Any], initial_theta: np.ndarray | None = None, truth_theta: np.ndarray | None = None, observed_flux: np.ndarray | None = None, flux_error: np.ndarray | None = None, ) -> PopulationFitResult: """Joint MAP fit with a Gaussian population prior over free parameters.""" setup = _prepare_batch_fit( base_params_rows, fit_config, initial_theta=initial_theta ) truth_theta_arr, has_truth = _prepare_truth_theta(setup["theta0"], truth_theta) theta0 = jnp.asarray(setup["theta0"]) base_matrix = jnp.asarray(setup["base_matrix"]) observed_values, sigma_values, valid_mask = _fit_arrays_for_likelihood( observed_mag, sigma_mag, fit_config, observed_flux=observed_flux, flux_error=flux_error, ) observed = jnp.asarray(observed_values) sigma = jnp.asarray(sigma_values) mask = jnp.asarray(valid_mask) warm_observed_mag = jnp.asarray(observed_mag) warm_mask = jnp.isfinite(warm_observed_mag) & jnp.isfinite(jnp.asarray(sigma_mag)) lower = jnp.asarray(setup["lower"]) upper = jnp.asarray(setup["upper"]) maxiter = int(fit_config.get("maxiter", 80)) learning_rate = float(fit_config.get("learning_rate", 0.03)) pop = fit_config.get("population", {}) sigma_floor = float(pop.get("sigma_floor", 0.03)) prior_weight = float(pop.get("prior_weight", 1.0)) hyper_mu_scale = float(pop.get("hyper_mu_scale", 5.0)) physical_prior_weight = float(fit_config.get("prior_weight", 1.0)) band_offsets_mag = tuple( float(value) for value in fit_config.get("band_calibration_offsets_mag", []) ) cache_key = ( "population", id(context), tuple(setup["parameter_names"]), tuple(setup["free_names"]), maxiter, learning_rate, sigma_floor, prior_weight, hyper_mu_scale, physical_prior_weight, _likelihood_space(fit_config), band_offsets_mag, tuple(setup["population_relation_default_mask"]), tuple(setup["population_relation_target_pos"]), tuple(setup["population_relation_predictor_free_pos"]), tuple(setup["population_relation_predictor_base_index"]), ) if cache_key not in _OPTIMIZER_CACHE: _OPTIMIZER_CACHE[cache_key] = _build_population_adam_optimizer( context=context, parameter_names=setup["parameter_names"], free_indices=setup["free_indices"], maxiter=maxiter, learning_rate=learning_rate, sigma_floor=sigma_floor, prior_weight=prior_weight, hyper_mu_scale=hyper_mu_scale, physical_prior_weight=physical_prior_weight, likelihood_space=_likelihood_space(fit_config), band_offsets_mag=band_offsets_mag, relation_default_mask=jnp.asarray( setup["population_relation_default_mask"] ), relation_target_pos=jnp.asarray(setup["population_relation_target_pos"]), relation_predictor_free_pos=jnp.asarray( setup["population_relation_predictor_free_pos"] ), relation_predictor_base_index=jnp.asarray( setup["population_relation_predictor_base_index"] ), relation_pivot=jnp.asarray(setup["population_relation_pivot"]), relation_intercept0=jnp.asarray(setup["population_relation_intercept0"]), relation_slope0=jnp.asarray(setup["population_relation_slope0"]), relation_sigma0=jnp.asarray(setup["population_relation_sigma0"]), relation_slope_scale=jnp.asarray(setup["population_relation_slope_scale"]), ) optimize = _OPTIMIZER_CACHE[cache_key] ( best_theta, mu, sigma_pop, relation_intercept, relation_slope, relation_sigma, loss, chi2, grad_norm, model_mags, trace_arrays, ) = optimize( theta0, base_matrix, observed, sigma, mask, warm_observed_mag, warm_mask, lower, upper, jnp.asarray(setup["prior_gaussian_mask"]), jnp.asarray(setup["prior_gaussian_loc"]), jnp.asarray(setup["prior_gaussian_scale"]), jnp.asarray(setup["prior_beta_mask"]), jnp.asarray(setup["prior_beta_alpha"]), jnp.asarray(setup["prior_beta_beta"]), jnp.asarray(truth_theta_arr), ) best_matrix = _apply_free_values( base_matrix, best_theta, jnp.asarray(setup["free_indices"]) ) batch = BatchFitResult( success=np.isfinite(np.asarray(chi2)), message=f"jax_population_adam maxiter={maxiter}, device={_jax_device()}", parameter_names=setup["parameter_names"], free_parameter_names=setup["free_names"], best_parameter_matrix=np.asarray(best_matrix), chi2=np.asarray(chi2), gradient_norm=np.asarray(grad_norm), model_mags=np.asarray(model_mags), trace=_batch_trace_from_arrays( trace_arrays, setup["free_names"], include_truth_metrics=has_truth ), device=_jax_device(), ) return PopulationFitResult( batch=batch, hyper_mu={ name: float(value) for name, value in zip(setup["free_names"], np.asarray(mu), strict=True) }, hyper_sigma={ name: float(value) for name, value in zip( setup["free_names"], np.asarray(sigma_pop), strict=True ) }, hyper_relations=_population_relation_hyper_rows( setup, np.asarray(relation_intercept), np.asarray(relation_slope), np.asarray(relation_sigma), ), loss=float(loss), )
def _prepare_batch_fit( base_params_rows: list[dict[str, float]], fit_config: dict[str, Any], initial_theta: np.ndarray | None = None, ) -> dict[str, Any]: if not base_params_rows: raise ValueError("Cannot fit an empty batch") free = fit_config["free_parameters"] free_names = list(free) parameter_names = _parameter_names_for_rows(base_params_rows) missing = [name for name in free_names if name not in parameter_names] if missing: raise ValueError(f"Free parameters missing from base params: {missing}") base_matrix = np.asarray( [ [float(row.get(name, np.nan)) for name in parameter_names] for row in base_params_rows ], dtype=float, ) bounds = np.asarray( [tuple(float(x) for x in free[name]["bounds"]) for name in free_names], dtype=float, ) if initial_theta is None: theta0 = np.asarray( [ [_initial_value(free[name], name, row) for name in free_names] for row in base_params_rows ], dtype=float, ) else: theta0 = np.asarray(initial_theta, dtype=float) if theta0.shape != (len(base_params_rows), len(free_names)): raise ValueError( "initial_theta shape must be " f"{(len(base_params_rows), len(free_names))}, got {theta0.shape}" ) theta0 = np.clip(theta0, bounds[:, 0], bounds[:, 1]) prior_arrays = _prepare_prior_arrays( fit_config.get("priors", {}), free_names, parameter_names, base_matrix, bounds, ) relation_arrays = _prepare_population_relation_arrays( fit_config.get("population", {}), free_names, parameter_names, base_matrix, theta0, ) return { "parameter_names": parameter_names, "free_names": free_names, "free_indices": np.asarray( [parameter_names.index(name) for name in free_names], dtype=np.int32 ), "base_matrix": base_matrix, "theta0": theta0, "lower": bounds[:, 0], "upper": bounds[:, 1], **prior_arrays, **relation_arrays, } def _parameter_names_for_rows(base_params_rows: list[dict[str, float]]) -> list[str]: """Build a stable schema when optional per-row prior/catalog keys are missing.""" names = list(base_params_rows[0]) seen = set(names) for row in base_params_rows[1:]: for name in row: if name not in seen: names.append(name) seen.add(name) return names def _prepare_prior_arrays( priors: dict[str, Any], free_names: list[str], parameter_names: list[str], base_matrix: np.ndarray, bounds: np.ndarray, ) -> dict[str, np.ndarray]: n_rows = base_matrix.shape[0] n_free = len(free_names) gaussian_mask = np.zeros((n_rows, n_free), dtype=bool) gaussian_loc = np.zeros((n_rows, n_free), dtype=float) gaussian_scale = np.ones((n_rows, n_free), dtype=float) beta_mask = np.zeros((n_rows, n_free), dtype=bool) beta_alpha = np.ones((n_rows, n_free), dtype=float) beta_beta = np.ones((n_rows, n_free), dtype=float) for index, name in enumerate(free_names): spec = priors.get(name) if not spec: continue prior_type = str(spec.get("type", "normal")) if prior_type == "uniform": continue if prior_type in {"normal", "truncated_normal"}: loc = spec.get("loc", 0.0) if loc == "from_base": loc_values = base_matrix[:, parameter_names.index(name)] else: loc_values = np.full(n_rows, float(loc), dtype=float) gaussian_mask[:, index] = True gaussian_loc[:, index] = loc_values scale = spec.get("scale", 1.0) if scale == "from_base": scale_name = str(spec.get("scale_parameter", f"{name}_prior_sigma")) if scale_name not in parameter_names: raise ValueError( f"fit.priors.{name}.scale='from_base' needs parameter " f"{scale_name!r} in base parameters" ) scale_values = np.asarray( base_matrix[:, parameter_names.index(scale_name)], dtype=float ) scale_values = np.where( np.isfinite(scale_values) & (scale_values > 0), scale_values, 1.0, ) gaussian_scale[:, index] = np.maximum(scale_values, 1.0e-6) else: gaussian_scale[:, index] = max(float(scale), 1.0e-6) elif prior_type == "scaled_beta": beta_mask[:, index] = True beta_alpha[:, index] = max(float(spec.get("alpha", 1.0)), 1.0e-6) beta_beta[:, index] = max(float(spec.get("beta", 1.0)), 1.0e-6) else: raise ValueError(f"Unsupported fit prior type for {name}: {prior_type}") return { "prior_gaussian_mask": gaussian_mask, "prior_gaussian_loc": gaussian_loc, "prior_gaussian_scale": gaussian_scale, "prior_beta_mask": beta_mask, "prior_beta_alpha": beta_alpha, "prior_beta_beta": beta_beta, } def _prepare_population_relation_arrays( population: dict[str, Any], free_names: list[str], parameter_names: list[str], base_matrix: np.ndarray, theta0: np.ndarray, ) -> dict[str, np.ndarray]: relations = population.get("relations", {}) if isinstance(population, dict) else {} target_pos: list[int] = [] predictor_free_pos: list[int] = [] predictor_base_index: list[int] = [] pivot: list[float] = [] intercept0: list[float] = [] slope0: list[float] = [] sigma0: list[float] = [] slope_scale: list[float] = [] default_mask = np.ones(len(free_names), dtype=float) relation_names: list[tuple[str, str]] = [] if not isinstance(relations, dict): relations = {} for target_name, spec in relations.items(): if not spec or not bool(spec.get("enabled", True)): continue if target_name not in free_names: raise ValueError( f"fit.population.relations target {target_name!r} must be a free parameter" ) predictor_name = str(spec.get("predictor", "")) if predictor_name not in parameter_names: raise ValueError( f"fit.population.relations.{target_name}.predictor={predictor_name!r} " "must exist in model parameters" ) target_index = free_names.index(target_name) target_pos.append(target_index) default_mask[target_index] = 0.0 predictor_base_index.append(parameter_names.index(predictor_name)) predictor_free_pos.append( free_names.index(predictor_name) if predictor_name in free_names else -1 ) predictor_values = base_matrix[:, parameter_names.index(predictor_name)] raw_pivot = spec.get("pivot", "median") pivot.append( float(np.nanmedian(predictor_values)) if raw_pivot == "median" else float(raw_pivot) ) target_values = theta0[:, target_index] intercept0.append( float(spec.get("intercept_initial", np.nanmean(target_values))) ) slope0.append(float(spec.get("slope_initial", 0.0))) sigma0.append( max( float(spec.get("sigma_initial", np.nanstd(target_values) + 0.2)), 1.0e-3 ) ) slope_scale.append(max(float(spec.get("slope_scale", 2.0)), 1.0e-6)) relation_names.append((str(target_name), predictor_name)) return { "population_relation_default_mask": default_mask, "population_relation_target_pos": np.asarray(target_pos, dtype=np.int32), "population_relation_predictor_free_pos": np.asarray( predictor_free_pos, dtype=np.int32 ), "population_relation_predictor_base_index": np.asarray( predictor_base_index, dtype=np.int32 ), "population_relation_pivot": np.asarray(pivot, dtype=float), "population_relation_intercept0": np.asarray(intercept0, dtype=float), "population_relation_slope0": np.asarray(slope0, dtype=float), "population_relation_sigma0": np.asarray(sigma0, dtype=float), "population_relation_slope_scale": np.asarray(slope_scale, dtype=float), "population_relation_names": np.asarray(relation_names, dtype=object), } def _population_relation_hyper_rows( setup: dict[str, Any], intercept: np.ndarray, slope: np.ndarray, sigma: np.ndarray, ) -> list[dict[str, Any]]: names = np.asarray(setup["population_relation_names"], dtype=object) pivots = np.asarray(setup["population_relation_pivot"], dtype=float) rows: list[dict[str, Any]] = [] for index in range(len(intercept)): target, predictor = names[index] rows.append( { "target_parameter": str(target), "predictor_parameter": str(predictor), "pivot": float(pivots[index]), "intercept": float(intercept[index]), "slope": float(slope[index]), "sigma": float(sigma[index]), } ) return rows def _prepare_truth_theta( theta0: np.ndarray, truth_theta: np.ndarray | None ) -> tuple[np.ndarray, bool]: if truth_theta is None: return np.full_like(theta0, np.nan, dtype=float), False truth = np.asarray(truth_theta, dtype=float) if truth.shape != theta0.shape: raise ValueError(f"truth_theta shape must be {theta0.shape}, got {truth.shape}") return truth, bool(np.isfinite(truth).any()) def _build_independent_adam_optimizer( context: DspsContext, parameter_names: list[str], free_indices: np.ndarray, maxiter: int, learning_rate: float, prior_weight: float, likelihood_space: str, band_offsets_mag: tuple[float, ...], ): free_indices_jax = jnp.asarray(free_indices) band_offsets_jax = ( jnp.asarray(band_offsets_mag, dtype=_FIT_DTYPE) if band_offsets_mag else None ) def single_chi2(theta, base, observed, sigma, mask): params = _params_from_vectors(theta, base, parameter_names, free_indices_jax) model_mag = model_mags_jax(context, params) if band_offsets_jax is not None: model_mag = model_mag + band_offsets_jax model_values = ( abmag_to_fnu_cgs_jax(model_mag) if likelihood_space == "flux" else model_mag ) chi = jnp.where(mask, (observed - model_values) / sigma, 0.0) return jnp.nan_to_num(jnp.sum(chi**2), nan=1.0e30, posinf=1.0e30, neginf=1.0e30) def single_objective( theta, base, observed, sigma, mask, lower, upper, prior_gaussian_mask, prior_gaussian_loc, prior_gaussian_scale, prior_beta_mask, prior_beta_alpha, prior_beta_beta, ): chi2 = single_chi2(theta, base, observed, sigma, mask) prior = _physical_prior_penalty( theta, lower, upper, prior_gaussian_mask, prior_gaussian_loc, prior_gaussian_scale, prior_beta_mask, prior_beta_alpha, prior_beta_beta, ) return jnp.nan_to_num( chi2 + prior_weight * prior, nan=1.0e30, posinf=1.0e30, neginf=1.0e30, ) def single_mags(theta, base): params = _params_from_vectors(theta, base, parameter_names, free_indices_jax) model_mag = model_mags_jax(context, params) if band_offsets_jax is not None: model_mag = model_mag + band_offsets_jax return model_mag batch_mags = jax.vmap(single_mags, in_axes=(0, 0)) amplitude_free_pos = _free_position( parameter_names, free_indices, "log10_formed_mass_msun" ) if amplitude_free_pos is None: amplitude_free_pos = _free_position(parameter_names, free_indices, "log10_sfr") @jax.jit def optimize( theta0, base_matrix, observed, sigma, mask, warm_observed_mag, warm_mask, lower, upper, prior_gaussian_mask, prior_gaussian_loc, prior_gaussian_scale, prior_beta_mask, prior_beta_alpha, prior_beta_beta, truth_theta, ): theta0 = _warm_start_amplitude( theta0, base_matrix, warm_observed_mag, warm_mask, lower, upper, batch_mags, amplitude_free_pos, ) y0 = _bounded_to_unconstrained(theta0, lower, upper) m0 = jnp.zeros_like(y0) v0 = jnp.zeros_like(y0) best_objective0 = jnp.full((theta0.shape[0],), jnp.inf) best_grad0 = jnp.full((theta0.shape[0],), jnp.inf) carry0 = (y0, m0, v0, theta0, best_objective0, best_grad0) def single_objective_y( y, base, observed_i, sigma_i, mask_i, prior_gaussian_mask_i, prior_gaussian_loc_i, prior_gaussian_scale_i, prior_beta_mask_i, prior_beta_alpha_i, prior_beta_beta_i, ): theta = _unconstrained_to_bounded(y, lower, upper) return single_objective( theta, base, observed_i, sigma_i, mask_i, lower, upper, prior_gaussian_mask_i, prior_gaussian_loc_i, prior_gaussian_scale_i, prior_beta_mask_i, prior_beta_alpha_i, prior_beta_beta_i, ) batch_objective_y_grad = jax.vmap( jax.value_and_grad(single_objective_y, argnums=0), in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), ) def step(carry, iteration): y, m, v, best_theta, best_objective, best_grad = carry theta = _unconstrained_to_bounded(y, lower, upper) objective, grad = batch_objective_y_grad( y, base_matrix, observed, sigma, mask, prior_gaussian_mask, prior_gaussian_loc, prior_gaussian_scale, prior_beta_mask, prior_beta_alpha, prior_beta_beta, ) grad = jnp.nan_to_num(grad, nan=0.0, posinf=0.0, neginf=0.0) grad_norm = jnp.linalg.norm(grad, axis=1) improved = objective < best_objective best_theta = jnp.where(improved[:, None], theta, best_theta) best_objective = jnp.where(improved, objective, best_objective) best_grad = jnp.where(improved, grad_norm, best_grad) t = iteration + 1.0 m = 0.9 * m + 0.1 * grad v = 0.999 * v + 0.001 * (grad**2) m_hat = m / (1.0 - 0.9**t) v_hat = v / (1.0 - 0.999**t) y = jnp.clip( y - learning_rate * m_hat / (jnp.sqrt(v_hat) + 1.0e-8), -30.0, 30.0 ) chi2 = jax.vmap(single_chi2, in_axes=(0, 0, 0, 0, 0))( theta, base_matrix, observed, sigma, mask ) metrics = jnp.concatenate( [ jnp.asarray( [jnp.nanmean(chi2), jnp.nanmedian(chi2), jnp.nanmean(grad_norm)] ), _truth_metric_vector(theta, truth_theta), ] ) return (y, m, v, best_theta, best_objective, best_grad), metrics (_, _, _, best_theta, _, best_grad), metrics = jax.lax.scan( step, carry0, jnp.arange(maxiter, dtype=jnp.int32), ) best_chi2 = jax.vmap(single_chi2, in_axes=(0, 0, 0, 0, 0))( best_theta, base_matrix, observed, sigma, mask ) model_mags = batch_mags(best_theta, base_matrix) return best_theta, best_chi2, best_grad, model_mags, metrics return optimize def _build_population_adam_optimizer( context: DspsContext, parameter_names: list[str], free_indices: np.ndarray, maxiter: int, learning_rate: float, sigma_floor: float, prior_weight: float, hyper_mu_scale: float, physical_prior_weight: float, likelihood_space: str, band_offsets_mag: tuple[float, ...], relation_default_mask: jnp.ndarray, relation_target_pos: jnp.ndarray, relation_predictor_free_pos: jnp.ndarray, relation_predictor_base_index: jnp.ndarray, relation_pivot: jnp.ndarray, relation_intercept0: jnp.ndarray, relation_slope0: jnp.ndarray, relation_sigma0: jnp.ndarray, relation_slope_scale: jnp.ndarray, ): free_indices_jax = jnp.asarray(free_indices) band_offsets_jax = ( jnp.asarray(band_offsets_mag, dtype=_FIT_DTYPE) if band_offsets_mag else None ) def single_chi2(theta, base, observed, sigma, mask): params = _params_from_vectors(theta, base, parameter_names, free_indices_jax) model_mag = model_mags_jax(context, params) if band_offsets_jax is not None: model_mag = model_mag + band_offsets_jax model_values = ( abmag_to_fnu_cgs_jax(model_mag) if likelihood_space == "flux" else model_mag ) chi = jnp.where(mask, (observed - model_values) / sigma, 0.0) return jnp.nan_to_num(jnp.sum(chi**2), nan=1.0e30, posinf=1.0e30, neginf=1.0e30) batch_chi2_grad = jax.vmap( single_chi2, in_axes=(0, 0, 0, 0, 0), ) def loss( theta, mu, raw_sigma, relation_intercept, relation_slope, relation_raw_sigma, base_matrix, observed, sigma, mask, warm_observed_mag, warm_mask, lower, upper, prior_gaussian_mask, prior_gaussian_loc, prior_gaussian_scale, prior_beta_mask, prior_beta_alpha, prior_beta_beta, ): chi2 = batch_chi2_grad(theta, base_matrix, observed, sigma, mask) sigma_pop = jax.nn.softplus(raw_sigma) + sigma_floor default_prior = 0.5 * jnp.sum( relation_default_mask * (((theta - mu) / sigma_pop) ** 2 + 2.0 * jnp.log(sigma_pop)), axis=1, ) relation_prior = _population_relation_prior_penalty( theta, base_matrix, relation_intercept, relation_slope, relation_raw_sigma, sigma_floor, relation_target_pos, relation_predictor_free_pos, relation_predictor_base_index, relation_pivot, ) physical_prior = _physical_prior_penalty( theta, lower, upper, prior_gaussian_mask, prior_gaussian_loc, prior_gaussian_scale, prior_beta_mask, prior_beta_alpha, prior_beta_beta, ) hyper = 0.5 * jnp.sum(relation_default_mask * (mu / hyper_mu_scale) ** 2) relation_hyper = 0.5 * jnp.sum( (relation_intercept / hyper_mu_scale) ** 2 + (relation_slope / relation_slope_scale) ** 2 ) return ( 0.5 * jnp.sum(chi2) + prior_weight * (jnp.sum(default_prior) + jnp.sum(relation_prior)) + physical_prior_weight * jnp.sum(physical_prior) + hyper + relation_hyper ) value_and_grad = jax.value_and_grad(loss, argnums=(0, 1, 2, 3, 4, 5)) def single_mags(theta, base): params = _params_from_vectors(theta, base, parameter_names, free_indices_jax) model_mag = model_mags_jax(context, params) if band_offsets_jax is not None: model_mag = model_mag + band_offsets_jax return model_mag batch_mags = jax.vmap(single_mags, in_axes=(0, 0)) amplitude_free_pos = _free_position( parameter_names, free_indices, "log10_formed_mass_msun" ) if amplitude_free_pos is None: amplitude_free_pos = _free_position(parameter_names, free_indices, "log10_sfr") @jax.jit def optimize( theta0, base_matrix, observed, sigma, mask, warm_observed_mag, warm_mask, lower, upper, prior_gaussian_mask, prior_gaussian_loc, prior_gaussian_scale, prior_beta_mask, prior_beta_alpha, prior_beta_beta, truth_theta, ): theta0 = _warm_start_amplitude( theta0, base_matrix, warm_observed_mag, warm_mask, lower, upper, batch_mags, amplitude_free_pos, ) y0 = _bounded_to_unconstrained(theta0, lower, upper) mu0 = jnp.nanmean(theta0, axis=0) raw_sigma0 = _softplus_inverse(jnp.nanstd(theta0, axis=0) + 0.2) relation_raw_sigma0 = _softplus_inverse(relation_sigma0) m_y = jnp.zeros_like(y0) v_y = jnp.zeros_like(y0) m_mu = jnp.zeros_like(mu0) v_mu = jnp.zeros_like(mu0) m_sigma = jnp.zeros_like(raw_sigma0) v_sigma = jnp.zeros_like(raw_sigma0) m_relation_intercept = jnp.zeros_like(relation_intercept0) v_relation_intercept = jnp.zeros_like(relation_intercept0) m_relation_slope = jnp.zeros_like(relation_slope0) v_relation_slope = jnp.zeros_like(relation_slope0) m_relation_sigma = jnp.zeros_like(relation_raw_sigma0) v_relation_sigma = jnp.zeros_like(relation_raw_sigma0) best_theta = theta0 best_mu = mu0 best_raw_sigma = raw_sigma0 best_relation_intercept = relation_intercept0 best_relation_slope = relation_slope0 best_relation_raw_sigma = relation_raw_sigma0 best_loss = jnp.inf carry0 = ( y0, mu0, raw_sigma0, relation_intercept0, relation_slope0, relation_raw_sigma0, m_y, v_y, m_mu, v_mu, m_sigma, v_sigma, m_relation_intercept, v_relation_intercept, m_relation_slope, v_relation_slope, m_relation_sigma, v_relation_sigma, best_theta, best_mu, best_raw_sigma, best_relation_intercept, best_relation_slope, best_relation_raw_sigma, best_loss, ) def adam_update(value, grad, m, v, iteration): grad = jnp.nan_to_num(grad, nan=0.0, posinf=0.0, neginf=0.0) t = iteration + 1.0 m = 0.9 * m + 0.1 * grad v = 0.999 * v + 0.001 * (grad**2) m_hat = m / (1.0 - 0.9**t) v_hat = v / (1.0 - 0.999**t) return value - learning_rate * m_hat / (jnp.sqrt(v_hat) + 1.0e-8), m, v def step(carry, iteration): ( y, mu, raw_sigma, relation_intercept, relation_slope, relation_raw_sigma, m_y, v_y, m_mu, v_mu, m_sigma, v_sigma, m_relation_intercept, v_relation_intercept, m_relation_slope, v_relation_slope, m_relation_sigma, v_relation_sigma, best_theta, best_mu, best_raw_sigma, best_relation_intercept, best_relation_slope, best_relation_raw_sigma, best_loss, ) = carry theta = _unconstrained_to_bounded(y, lower, upper) value, grads = value_and_grad( theta, mu, raw_sigma, relation_intercept, relation_slope, relation_raw_sigma, base_matrix, observed, sigma, mask, lower, upper, prior_gaussian_mask, prior_gaussian_loc, prior_gaussian_scale, prior_beta_mask, prior_beta_alpha, prior_beta_beta, ) ( grad_theta_direct, grad_mu, grad_sigma, grad_relation_intercept, grad_relation_slope, grad_relation_sigma, ) = grads _, grad_y = jax.value_and_grad( lambda yy: loss( _unconstrained_to_bounded(yy, lower, upper), mu, raw_sigma, relation_intercept, relation_slope, relation_raw_sigma, base_matrix, observed, sigma, mask, lower, upper, prior_gaussian_mask, prior_gaussian_loc, prior_gaussian_scale, prior_beta_mask, prior_beta_alpha, prior_beta_beta, ) )(y) improved = value < best_loss best_theta = jnp.where(improved, theta, best_theta) best_mu = jnp.where(improved, mu, best_mu) best_raw_sigma = jnp.where(improved, raw_sigma, best_raw_sigma) best_relation_intercept = jnp.where( improved, relation_intercept, best_relation_intercept ) best_relation_slope = jnp.where( improved, relation_slope, best_relation_slope ) best_relation_raw_sigma = jnp.where( improved, relation_raw_sigma, best_relation_raw_sigma ) best_loss = jnp.where(improved, value, best_loss) y, m_y, v_y = adam_update(y, grad_y, m_y, v_y, iteration) mu, m_mu, v_mu = adam_update(mu, grad_mu, m_mu, v_mu, iteration) raw_sigma, m_sigma, v_sigma = adam_update( raw_sigma, grad_sigma, m_sigma, v_sigma, iteration ) relation_intercept, m_relation_intercept, v_relation_intercept = ( adam_update( relation_intercept, grad_relation_intercept, m_relation_intercept, v_relation_intercept, iteration, ) ) relation_slope, m_relation_slope, v_relation_slope = adam_update( relation_slope, grad_relation_slope, m_relation_slope, v_relation_slope, iteration, ) relation_raw_sigma, m_relation_sigma, v_relation_sigma = adam_update( relation_raw_sigma, grad_relation_sigma, m_relation_sigma, v_relation_sigma, iteration, ) y = jnp.clip(y, -30.0, 30.0) raw_sigma = jnp.clip(raw_sigma, -8.0, 4.0) relation_raw_sigma = jnp.clip(relation_raw_sigma, -8.0, 4.0) metrics = jnp.concatenate( [ jnp.asarray( [ value, jnp.nanmean( batch_chi2_grad( theta, base_matrix, observed, sigma, mask ) ), jnp.nanmean(jnp.linalg.norm(grad_theta_direct, axis=1)), ] ), _truth_metric_vector(theta, truth_theta), ] ) return ( y, mu, raw_sigma, relation_intercept, relation_slope, relation_raw_sigma, m_y, v_y, m_mu, v_mu, m_sigma, v_sigma, m_relation_intercept, v_relation_intercept, m_relation_slope, v_relation_slope, m_relation_sigma, v_relation_sigma, best_theta, best_mu, best_raw_sigma, best_relation_intercept, best_relation_slope, best_relation_raw_sigma, best_loss, ), metrics ( ( *_, best_theta, best_mu, best_raw_sigma, best_relation_intercept, best_relation_slope, best_relation_raw_sigma, best_loss, ), metrics, ) = jax.lax.scan( step, carry0, jnp.arange(maxiter, dtype=jnp.int32), ) sigma_pop = jax.nn.softplus(best_raw_sigma) + sigma_floor relation_sigma = jax.nn.softplus(best_relation_raw_sigma) + sigma_floor chi2, grad = jax.vmap( jax.value_and_grad(single_chi2, argnums=0), in_axes=(0, 0, 0, 0, 0), )(best_theta, base_matrix, observed, sigma, mask) model_mags = batch_mags(best_theta, base_matrix) return ( best_theta, best_mu, sigma_pop, best_relation_intercept, best_relation_slope, relation_sigma, best_loss, chi2, jnp.linalg.norm(grad, axis=1), model_mags, metrics, ) return optimize def _params_from_vectors( theta, base, parameter_names: list[str], free_indices: jnp.ndarray ) -> dict[str, Any]: values = _apply_free_values(base, theta, free_indices) return {name: values[index] for index, name in enumerate(parameter_names)} def _physical_prior_penalty( theta, lower, upper, gaussian_mask, gaussian_loc, gaussian_scale, beta_mask, beta_alpha, beta_beta, ): gaussian_scale = jnp.maximum(gaussian_scale, 1.0e-6) gaussian = 0.5 * ((theta - gaussian_loc) / gaussian_scale) ** 2 + jnp.log( gaussian_scale ) gaussian = jnp.where(gaussian_mask, gaussian, 0.0) scaled = jnp.clip((theta - lower) / (upper - lower), 1.0e-6, 1.0 - 1.0e-6) beta = -( (beta_alpha - 1.0) * jnp.log(scaled) + (beta_beta - 1.0) * jnp.log1p(-scaled) ) beta = jnp.where(beta_mask, beta, 0.0) return jnp.sum(gaussian + beta, axis=-1) def _population_relation_prior_penalty( theta, base_matrix, relation_intercept, relation_slope, relation_raw_sigma, sigma_floor, target_pos, predictor_free_pos, predictor_base_index, pivot, ): if relation_intercept.shape[0] == 0: return jnp.zeros(theta.shape[0]) target_values = theta[:, target_pos] free_predictor = jnp.take(theta, jnp.maximum(predictor_free_pos, 0), axis=1) base_predictor = jnp.take(base_matrix, jnp.maximum(predictor_base_index, 0), axis=1) predictor = jnp.where( predictor_free_pos[None, :] >= 0, free_predictor, base_predictor ) loc = relation_intercept[None, :] + relation_slope[None, :] * ( predictor - pivot[None, :] ) sigma_rel = jax.nn.softplus(relation_raw_sigma) + sigma_floor penalty = 0.5 * ((target_values - loc) / sigma_rel[None, :]) ** 2 + jnp.log( sigma_rel[None, :] ) return jnp.sum(penalty, axis=1) def _apply_free_values(base, theta, free_indices): return base.at[..., free_indices].set(theta) def _free_position( parameter_names: list[str], free_indices: np.ndarray, name: str ) -> int | None: if name not in parameter_names: return None parameter_index = parameter_names.index(name) positions = np.flatnonzero(free_indices == parameter_index) return int(positions[0]) if len(positions) else None def _warm_start_amplitude( theta0, base_matrix, observed, mask, lower, upper, batch_mags, free_pos: int | None ): if free_pos is None: return theta0 model_mag = batch_mags(theta0, base_matrix) delta_mag = jnp.where(mask, model_mag - observed, jnp.nan) delta_log10_sfr = ( jnp.nanmedian(delta_mag, axis=1) / jnp.asarray(2.5, dtype=theta0.dtype) ).astype(theta0.dtype) warmed = theta0.at[:, free_pos].set( theta0[:, free_pos] + jnp.nan_to_num(delta_log10_sfr) ) return jnp.clip(warmed, lower, upper) def _softplus_inverse(value): value = jnp.maximum(value, 1.0e-6) return jnp.log(jnp.expm1(value)) def _truth_metric_vector(theta, truth_theta): diff = theta - truth_theta diff = jnp.where(jnp.isfinite(truth_theta), diff, jnp.nan) sq = diff**2 mse = jnp.nanmean(sq) mae = jnp.nanmean(jnp.abs(diff)) per_parameter_mse = jnp.nanmean(sq, axis=0) return jnp.concatenate([jnp.asarray([mse, jnp.sqrt(mse), mae]), per_parameter_mse]) def _batch_trace_from_arrays( metrics, free_names: list[str] | None = None, include_truth_metrics: bool = False, ) -> list[dict[str, float]]: arr = np.asarray(metrics) rows = [] for index, row in enumerate(arr): entry = { "iteration": float(index + 1), "mean_chi2_or_loss": float(row[0]), "median_chi2": float(row[1]), "mean_gradient_norm": float(row[2]), } if include_truth_metrics and len(row) >= 6: entry.update( { "truth_mse": float(row[3]), "truth_rmse": float(row[4]), "truth_mae": float(row[5]), } ) for offset, name in enumerate(free_names or []): metric_index = 6 + offset if metric_index < len(row): entry[f"truth_mse_{name}"] = float(row[metric_index]) rows.append(entry) return rows def _jax_device() -> str: device = jax.devices()[0] return f"{device.platform}:{device.id}"