Source code for euclid_dsps.workflows.core

"""End-to-end workflows used by the CLI."""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd

from ..filters import load_filters
from ..fit import fit_galaxy_batch_adam, fit_one_galaxy, fit_population_batch_adam
from ..io import (
    abmag_to_flux_fnu_cgs,
    build_observation,
    ensure_dir,
    flux_error_to_sigma_mag,
    flux_fnu_cgs_to_abmag,
    iter_catalog_batches,
    load_row_indices,
    microjy_to_abmag,
    microjy_to_flux_fnu_cgs,
    read_catalog,
    required_catalog_columns,
    truth_column_from_spec,
    truth_value_from_spec,
    write_dataframe_outputs,
    write_json,
)
from ..model import (
    BatchSedResult,
    ModelResult,
    load_context,
    parameters_for_row,
    predict_batch_derived,
    predict_batch_mags,
    predict_batch_seds,
    run_dsps_model,
)
from ..nebular import write_nebular_diagnostic_outputs
from ..performance import PerformanceRecorder, write_performance_outputs
from ..photometry import magerr_to_fluxerr_fnu_cgs
from ..reporting import (
    write_batch_outputs,
    write_eda_outputs,
    write_fit_diagnostic_outputs,
    write_fit_outputs,
    write_mcmc_outputs,
    write_population_corner_outputs,
    write_run_outputs,
    write_sed_diagnostic_outputs,
    write_trace_truth_outputs,
    write_workflow_comparison,
)
from ..selection import select_galaxy_row
from ..semantics import is_forward_active, is_inferred

try:
    from tqdm.auto import tqdm
except ImportError:  # pragma: no cover - optional runtime dependency fallback
    tqdm = None


[docs] def run_eda(config: dict[str, Any], out_dir: str | Path) -> None: columns = required_catalog_columns(config) import pyarrow.parquet as pq available_columns = pq.ParquetFile(config["catalog_path"]).schema.names for p in ["metallicity_true", "sfr_true", "log_sfr_true", "dust_ebv_true"]: if p in available_columns and p not in columns: columns.append(p) df = read_catalog( config["catalog_path"], columns=columns, nrows=config["eda"].get("nrows") ) write_eda_outputs( df, config["bands"], out_dir, redshift_config=config.get("redshift") )
[docs] def prepare_one(config: dict[str, Any]): columns = required_catalog_columns(config) df = read_catalog( config["catalog_path"], columns=columns, nrows=config["selection"].get("read_nrows"), ) band_columns = [band["column"] for band in config["bands"]] row_index, row = select_galaxy_row( df, band_columns=band_columns, index=config["selection"].get("index"), require_positive_flux=bool( config["selection"].get("require_positive_flux", True) ), nondetection_policy=config["selection"].get("nondetection_policy"), sort_by_flux=config["selection"].get("sort_by_flux"), ) observation = build_observation(row_index, row, config["bands"]) filters = load_filters(config["bands"]) context = load_context( config["ssp_path"], filters, n_sfh_bins=int(config["model"].get("n_sfh_bins", 96)), cosmos_config=config.get("cosmos_sed"), nebular_emission=config.get("nebular_emission", "ssp_flux"), ) params = parameters_for_row( config["model"]["fixed_parameters"], config["model"].get("parameter_columns", {}), row.to_dict(), config.get("redshift", {}), ) return context, observation, params
[docs] def run_one(config: dict[str, Any], out_dir: str | Path) -> pd.DataFrame: out = ensure_dir(out_dir) context, observation, params = prepare_one(config) result = run_dsps_model(context, params) ground_truth_sed, ground_truth_status = _ground_truth_sed_for_row( observation.row, observation.row_index, context.filters, config ) comparison = write_run_outputs( observation, result, out, ground_truth_sed=ground_truth_sed, include_filters=_plot_filters(config), ) write_json( out / "run_summary.json", { "row_index": observation.row_index, "n_bands": len(observation.bands), "ground_truth_sed_status": ground_truth_status, **_row_context(observation.row, params, config), }, ) write_json(out / "normalized_config.json", config) return comparison
[docs] def fit_one(config: dict[str, Any], out_dir: str | Path) -> None: out = ensure_dir(out_dir) context, observation, params = prepare_one(config) fit_result = fit_one_galaxy(context, observation, params, config["fit"]) ground_truth_sed, ground_truth_status = _ground_truth_sed_for_row( observation.row, observation.row_index, context.filters, config ) write_run_outputs( observation, fit_result.model_result, out, ground_truth_sed=ground_truth_sed, include_filters=_plot_filters(config), ) write_fit_outputs(fit_result, out) write_json( out / "sed_diagnostic_summary.json", {"ground_truth_sed_status": ground_truth_status}, ) write_json(out / "normalized_config.json", config)
[docs] def sample_one(config: dict[str, Any], out_dir: str | Path) -> None: from ..mcmc import sample_one_galaxy out = ensure_dir(out_dir) context, observation, params = prepare_one(config) map_result = _map_start(context, observation, params, config) initial_params = map_result.best_parameters if map_result is not None else None mcmc_result = sample_one_galaxy( context, observation, params, config["fit"], config["sample"], initial_params=initial_params, ) context_values = _row_context(observation.row, params, config) write_json( out / "sampled_galaxy.json", { "row_index": observation.row_index, "n_bands": len(observation.bands), "hmc_initialized_from_map": map_result is not None, "map_chi2": float(map_result.chi2) if map_result is not None else None, **context_values, }, ) write_mcmc_outputs( mcmc_result, out, truth_values=_posterior_truth_values(context_values), ) write_json(out / "normalized_config.json", config)
[docs] def sample_batch( config: dict[str, Any], out_dir: str | Path, limit: int | None = 5, batch_size: int = 1, row_indices_file: str | None = None, ) -> None: """Sample independent galaxy posteriors with NumPyro NUTS. This intentionally runs one galaxy at a time; HMC/NUTS is for posterior density checks on small subsets, while Adam/MAP remains the full-catalog path. """ from ..mcmc import sample_one_galaxy if limit is None and not row_indices_file: raise ValueError( "Bayesian batch sampling requires --limit; full-catalog HMC is not practical here." ) out = ensure_dir(out_dir) row_indices = _load_row_indices_set(row_indices_file) columns = required_catalog_columns(config) filters = load_filters(config["bands"]) context = load_context( config["ssp_path"], filters, n_sfh_bins=int(config["model"].get("n_sfh_bins", 96)), cosmos_config=config.get("cosmos_sed"), nebular_emission=config.get("nebular_emission", "ssp_flux"), ) summary_rows = [] predictive_rows = [] diagnostic_rows = [] sample_rows = [] save_samples = bool(config["sample"].get("save_samples", True)) total = _progress_total(config["catalog_path"], limit, row_indices) with _make_progress_bar( total=total, desc="sample-batch", unit="galaxy" ) as progress: for batch in iter_catalog_batches( config["catalog_path"], columns=columns, batch_size=batch_size, limit=limit, row_indices=row_indices, ): for row_index, row in batch.iterrows(): observation = build_observation(int(row_index), row, config["bands"]) params = parameters_for_row( config["model"]["fixed_parameters"], config["model"].get("parameter_columns", {}), row.to_dict(), config.get("redshift", {}), ) map_result = _map_start(context, observation, params, config) initial_params = ( map_result.best_parameters if map_result is not None else None ) result = sample_one_galaxy( context, observation, params, config["fit"], config["sample"], initial_params=initial_params, ) context_values = _row_context(row.to_dict(), params, config) if map_result is not None: context_values = { **context_values, "hmc_initialized_from_map": True, "map_chi2": float(map_result.chi2), } else: context_values = { **context_values, "hmc_initialized_from_map": False, "map_chi2": None, } for item in result.summary: summary_rows.append( {"row_index": int(row_index), **item, **context_values} ) predictive_rows.extend( _posterior_predictive_rows(int(row_index), result, context_values) ) diagnostic_rows.append( { "row_index": int(row_index), **result.diagnostics, **context_values, } ) if save_samples: sample_rows.extend(_posterior_sample_rows(int(row_index), result)) _update_progress(progress, row_index=int(row_index)) pd.DataFrame(summary_rows).to_csv(out / "batch_posterior_summary.csv", index=False) pd.DataFrame(predictive_rows).to_csv( out / "batch_posterior_predictive.csv", index=False ) pd.DataFrame(diagnostic_rows).to_csv( out / "batch_mcmc_diagnostics.csv", index=False ) if sample_rows: pd.DataFrame(sample_rows).to_csv( out / "batch_posterior_samples.csv", index=False ) write_json(out / "normalized_config.json", config) from ..reporting import write_mcmc_batch_outputs write_mcmc_batch_outputs( pd.DataFrame(summary_rows), pd.DataFrame(predictive_rows), pd.DataFrame(diagnostic_rows), out, )
[docs] def fit_workflow( config: dict[str, Any], out_dir: str | Path, limit: int | None = 1000, batch_size: int = 64, hmc_n: int = 20, hmc_batch_size: int = 1, population_batch_size: int | None = None, hmc_select: str = "stratified", seed: int = 42, ) -> None: """Run MAP batch, HMC subset, population MAP, and comparison reports.""" out = ensure_dir(out_dir) map_out = ensure_dir(out / "map") hmc_out = ensure_dir(out / "hmc") population_out = ensure_dir(out / "population") comparison_out = ensure_dir(out / "comparison") fit_batch(config, map_out, limit=limit, batch_size=batch_size) map_fits = _read_table(map_out / "batch_fit_results.csv") selected_rows = _select_hmc_row_indices( map_fits, n=hmc_n, method=hmc_select, seed=seed ) hmc_rows_path = out / "hmc_row_indices.txt" hmc_rows_path.write_text( "\n".join(str(row) for row in selected_rows) + "\n", encoding="utf-8" ) if selected_rows: sample_batch( config, hmc_out, limit=None, batch_size=hmc_batch_size, row_indices_file=str(hmc_rows_path), ) fit_population( config, population_out, limit=limit, batch_size=population_batch_size or batch_size, map_init_file=str(map_out / "batch_fit_results.csv"), ) population_fits = _read_table(population_out / "population_fit_results.csv") hmc_summary = _read_optional_csv(hmc_out / "batch_posterior_summary.csv") hmc_diagnostics = _read_optional_csv(hmc_out / "batch_mcmc_diagnostics.csv") hmc_samples = _read_optional_csv(hmc_out / "batch_posterior_samples.csv") write_workflow_comparison( map_fits=map_fits, population_fits=population_fits, hmc_summary=hmc_summary, hmc_diagnostics=hmc_diagnostics, hmc_samples=hmc_samples, free_parameters=list(config["fit"]["free_parameters"]), out_dir=comparison_out, ) write_json( out / "fit_workflow_summary.json", { "limit": limit, "batch_size": batch_size, "hmc_n": len(selected_rows), "hmc_selection": hmc_select, "hmc_row_indices_file": str(hmc_rows_path), "map_out": str(map_out), "hmc_out": str(hmc_out), "population_out": str(population_out), "comparison_out": str(comparison_out), }, )
def _select_hmc_row_indices( map_fits: pd.DataFrame, n: int, method: str, seed: int ) -> list[int]: if n <= 0 or map_fits.empty: return [] work = map_fits.drop_duplicates("row_index").copy() work = work[pd.to_numeric(work.get("reduced_chi2"), errors="coerce").notna()] if work.empty: return [] n = min(int(n), len(work)) method = method.lower() if method == "random": selected = work.sample(n=n, random_state=seed) elif method == "best": selected = work.nsmallest(n, "reduced_chi2") elif method == "worst": selected = work.nlargest(n, "reduced_chi2") elif method == "stratified": sorted_work = work.sort_values("reduced_chi2").reset_index(drop=True) positions = np.rint(np.linspace(0, len(sorted_work) - 1, n)).astype(int) selected = sorted_work.iloc[np.unique(positions)] if len(selected) < n: missing = sorted_work[~sorted_work["row_index"].isin(selected["row_index"])] selected = pd.concat( [selected, missing.head(n - len(selected))], ignore_index=True ) else: raise ValueError(f"Unsupported HMC selection method: {method}") return sorted(int(row) for row in selected["row_index"].tolist()) def _read_optional_csv(path: Path) -> pd.DataFrame: if path.exists(): return pd.read_csv(path) parquet = path.with_suffix(".parquet") if parquet.exists(): return pd.read_parquet(parquet) return pd.DataFrame() def _read_table(path: Path) -> pd.DataFrame: if path.exists(): return pd.read_csv(path) parquet = path.with_suffix(".parquet") if parquet.exists(): return pd.read_parquet(parquet) raise FileNotFoundError(path) def _reporting_level(config: dict[str, Any]) -> str: return str((config.get("reporting", {}) or {}).get("level", "full")).lower() def _verbose_benchmark(config: dict[str, Any]) -> bool: return bool((config.get("output", {}) or {}).get("verbose_benchmark", False))
[docs] def report_workflow(config: dict[str, Any], run_dir: str | Path) -> None: """Regenerate workflow comparison reports from existing workflow outputs.""" root = Path(run_dir) map_fits = _read_table(root / "map" / "batch_fit_results.csv") population_fits = _read_table(root / "population" / "population_fit_results.csv") hmc_summary = _read_optional_csv(root / "hmc" / "batch_posterior_summary.csv") hmc_diagnostics = _read_optional_csv(root / "hmc" / "batch_mcmc_diagnostics.csv") hmc_samples = _read_optional_csv(root / "hmc" / "batch_posterior_samples.csv") write_workflow_comparison( map_fits=map_fits, population_fits=population_fits, hmc_summary=hmc_summary, hmc_diagnostics=hmc_diagnostics, hmc_samples=hmc_samples, free_parameters=list(config["fit"]["free_parameters"]), out_dir=root / "comparison", )
def _map_start(context, observation, params: dict[str, float], config: dict[str, Any]): if not bool(config.get("sample", {}).get("init_from_map", True)): return None return fit_one_galaxy(context, observation, params, config["fit"])
[docs] def run_batch( config: dict[str, Any], out_dir: str | Path, limit: int | None = None, batch_size: int = 10_000, row_indices_file: str | None = None, ) -> None: """Run the same configured DSPS model over many catalog rows. This is intentionally conservative: it writes a flat comparison table and supports per-row physical parameters through `model.parameter_columns`. """ out = ensure_dir(out_dir) perf = PerformanceRecorder(_verbose_benchmark(config)) reporting_level = _reporting_level(config) row_indices = _load_row_indices_set(row_indices_file) columns = required_catalog_columns(config) filters = load_filters(config["bands"]) context = load_context( config["ssp_path"], filters, n_sfh_bins=int(config["model"].get("n_sfh_bins", 96)), cosmos_config=config.get("cosmos_sed"), nebular_emission=config.get("nebular_emission", "ssp_flux"), ) perf.mark("load_context", n_bands=len(config["bands"])) rows = [] sed_manifest_rows = [] sed_samples_written = 0 sed_sample_limit = _sed_sample_limit(config) total = _progress_total(config["catalog_path"], limit, row_indices) chunk_index = 0 with _make_progress_bar(total=total, desc="check-batch", unit="galaxy") as progress: for batch in iter_catalog_batches( config["catalog_path"], columns=columns, batch_size=batch_size, limit=limit, row_indices=row_indices, ): perf.mark("read_chunk", chunk_index=chunk_index, n_rows=len(batch)) rows.extend(_forward_dataframe_batch(context, batch, config)) perf.mark("forward_chunk", chunk_index=chunk_index, n_rows=len(batch)) if sed_samples_written < sed_sample_limit: base_rows = [ parameters_for_row( config["model"]["fixed_parameters"], config["model"].get("parameter_columns", {}), row.to_dict(), config.get("redshift", {}), ) for _, row in batch.iterrows() ] new_rows = _write_sed_samples( context, batch, base_rows, config, out, mode="forward", chunk_index=chunk_index, remaining=sed_sample_limit - sed_samples_written, ) sed_manifest_rows.extend(new_rows) sed_samples_written += len(new_rows) perf.mark( "write_sed_samples", chunk_index=chunk_index, n_rows=len(new_rows), ) _update_progress( progress, row_index=int(batch.index[-1]), amount=len(batch) ) chunk_index += 1 comparison = pd.DataFrame(rows) write_dataframe_outputs(comparison, out, "batch_photometry_comparison", config) perf.mark("write_primary_outputs", n_rows=len(comparison)) write_batch_outputs( comparison, out, label="batch", reporting_level=reporting_level, config=config ) if sed_manifest_rows: pd.DataFrame(sed_manifest_rows).to_csv( out / "sed_diagnostics_manifest.csv", index=False ) perf.mark("write_reports") write_json(out / "normalized_config.json", config) write_json( out / "batch_run_config.json", { "rows_written": len(rows), "limit": limit, "batch_size": batch_size, "row_indices_file": row_indices_file, }, ) write_performance_outputs(perf.rows, out, "batch")
[docs] def fit_batch( config: dict[str, Any], out_dir: str | Path, limit: int | None = 25, batch_size: int = 1000, row_indices_file: str | None = None, ) -> None: """Fit the configured free parameters for many rows. The default path optimizes each parquet chunk with one JAX-vmapped Adam run. """ out = ensure_dir(out_dir) perf = PerformanceRecorder(_verbose_benchmark(config)) reporting_level = _reporting_level(config) row_indices = _load_row_indices_set(row_indices_file) columns = required_catalog_columns(config) filters = load_filters(config["bands"]) context = load_context( config["ssp_path"], filters, n_sfh_bins=int(config["model"].get("n_sfh_bins", 96)), cosmos_config=config.get("cosmos_sed"), nebular_emission=config.get("nebular_emission", "ssp_flux"), ) perf.mark("load_context", n_bands=len(config["bands"])) comparison_rows = [] fit_rows = [] trace_rows = [] sed_sample_limit = _sed_sample_limit(config) total = _progress_total(config["catalog_path"], limit, row_indices) chunk_index = 0 with _make_progress_bar(total=total, desc="fit", unit="galaxy") as progress: for batch in iter_catalog_batches( config["catalog_path"], columns=columns, batch_size=batch_size, limit=limit, row_indices=row_indices, ): perf.mark("read_chunk", chunk_index=chunk_index, n_rows=len(batch)) batch_result = _fit_dataframe_batch( context, batch, config, chunk_index=chunk_index, perf=perf ) fit_rows.extend(batch_result["fit_rows"]) comparison_rows.extend(batch_result["comparison_rows"]) trace_rows.extend(batch_result["trace_rows"]) _write_fit_chunk_checkpoint(out, batch_result, config, chunk_index) perf.mark("fit_chunk", chunk_index=chunk_index, n_rows=len(batch)) _update_progress( progress, row_index=int(batch.index[-1]), amount=len(batch) ) chunk_index += 1 fits = pd.DataFrame(fit_rows) comparison = pd.DataFrame(comparison_rows) write_dataframe_outputs(fits, out, "batch_fit_results", config) write_dataframe_outputs(comparison, out, "batch_fit_photometry_comparison", config) perf.mark("write_primary_outputs", n_rows=len(fits)) sed_manifest_rows = _write_worst_fit_sed_samples( context, fits, comparison, config, out, limit=sed_sample_limit, ) if sed_manifest_rows: perf.mark("write_sed_samples", n_rows=len(sed_manifest_rows)) if trace_rows: trace = pd.DataFrame(trace_rows) write_dataframe_outputs(trace, out, "batch_fit_trace", config) write_trace_truth_outputs( trace, out, label="batch_fit", make_plots=reporting_level == "full" ) write_batch_outputs( comparison, out, label="batch_fit", reporting_level=reporting_level, config=config, ) if sed_manifest_rows: pd.DataFrame(sed_manifest_rows).to_csv( out / "sed_diagnostics_manifest.csv", index=False ) write_fit_diagnostic_outputs(fits, comparison, config, out, label="batch_fit") write_nebular_diagnostic_outputs(context, fits, out, label="batch_fit") perf.mark("write_reports") write_json(out / "normalized_config.json", config) write_json( out / "batch_fit_run_config.json", { "rows_written": len(comparison_rows), "limit": limit, "batch_size": batch_size, "row_indices_file": row_indices_file, }, ) write_performance_outputs(perf.rows, out, "batch_fit")
def _write_fit_chunk_checkpoint( out: Path, batch_result: dict[str, list[dict[str, Any]]], config: dict[str, Any], chunk_index: int, ) -> None: chunk_dir = ensure_dir(out / "_chunks") suffix = f"chunk_{chunk_index:06d}" if batch_result["fit_rows"]: write_dataframe_outputs( pd.DataFrame(batch_result["fit_rows"]), chunk_dir, f"batch_fit_results_{suffix}", config, ) if batch_result["comparison_rows"]: write_dataframe_outputs( pd.DataFrame(batch_result["comparison_rows"]), chunk_dir, f"batch_fit_photometry_comparison_{suffix}", config, ) if batch_result["trace_rows"]: write_dataframe_outputs( pd.DataFrame(batch_result["trace_rows"]), chunk_dir, f"batch_fit_trace_{suffix}", config, ) _COSMOS_RESOURCE_CACHE: dict[tuple[Any, ...], Any] = {} def _sed_sample_limit(config: dict[str, Any]) -> int: reporting = config.get("reporting", {}) or {} return max(int(reporting.get("save_sed_samples", 0) or 0), 0) def _plot_filters(config: dict[str, Any]) -> bool: reporting = config.get("reporting", {}) or {} return bool(reporting.get("plot_filters", True)) def _plot_ground_truth(config: dict[str, Any]) -> bool: reporting = config.get("reporting", {}) or {} return bool(reporting.get("plot_ground_truth", False)) def _write_worst_fit_sed_samples( context, fits: pd.DataFrame, comparison: pd.DataFrame, config: dict[str, Any], out: Path, *, limit: int, ) -> list[dict[str, Any]]: if limit <= 0 or fits.empty: return [] worst_limit = max(1, (limit + 1) // 2) best_limit = max(0, limit - worst_limit) worst = _ranked_fit_rows(fits, comparison, worst_limit, worst=True) used = set(int(value) for value in worst.get("row_index", [])) best = _ranked_fit_rows( fits, comparison, best_limit + len(used), worst=False, ) if not best.empty and used: best = best[~best["row_index"].astype(int).isin(used)].head(best_limit) manifest = [] manifest.extend( _write_selected_fit_sed_samples( context, worst, config, out, mode="fit_worst", limit=worst_limit, selection_reason="worst_fit", ) ) manifest.extend( _write_selected_fit_sed_samples( context, best, config, out, mode="fit_best", limit=best_limit, selection_reason="best_fit", ) ) return manifest def _write_selected_fit_sed_samples( context, selected: pd.DataFrame, config: dict[str, Any], out: Path, *, mode: str, limit: int, selection_reason: str, ) -> list[dict[str, Any]]: if limit <= 0: return [] if selected.empty: return [] row_indices = [int(value) for value in selected["row_index"]] wanted = set(row_indices) rows_by_index: dict[int, pd.Series] = {} for batch in iter_catalog_batches( config["catalog_path"], columns=required_catalog_columns(config), batch_size=max(1024, len(wanted)), row_indices=wanted, ): for row_index, row in batch.iterrows(): rows_by_index[int(row_index)] = row if len(rows_by_index) == len(wanted): break rows = [] fit_rows = [] for _, fit_row in selected.iterrows(): row_index = int(fit_row["row_index"]) row = rows_by_index.get(row_index) if row is None: continue rows.append(row) fit_rows.append(fit_row.to_dict()) if not rows: return [] batch = pd.DataFrame(rows) batch.index = [int(row["row_index"]) for row in fit_rows] manifest = _write_fit_sed_samples( context, batch, fit_rows, config, out, mode=mode, chunk_index=0, remaining=limit, ) ranks = { int(row["row_index"]): int(rank) for rank, row in enumerate(selected.to_dict("records"), start=1) } scores = { int(row["row_index"]): float(row["sed_diagnostic_score"]) for row in selected.to_dict("records") } for row in manifest: row_index = int(row.get("row_index", -1)) row["selection_reason"] = selection_reason row["selection_rank"] = ranks.get(row_index) row["selection_score"] = scores.get(row_index) return manifest def _ranked_fit_rows( fits: pd.DataFrame, comparison: pd.DataFrame, limit: int, *, worst: bool ) -> pd.DataFrame: if limit <= 0: return pd.DataFrame() work = fits.copy() if "row_index" not in work: return pd.DataFrame() work["row_index"] = work["row_index"].astype(int) if not comparison.empty and "row_index" in comparison: residual_col = "residual_mag_model_minus_observed" if residual_col in comparison: residual = ( comparison[["row_index", residual_col]] .replace([np.inf, -np.inf], np.nan) .dropna() ) by_row = ( residual.assign(abs_residual=residual[residual_col].abs()) .groupby("row_index", as_index=False)["abs_residual"] .median() .rename(columns={"abs_residual": "median_abs_mag_residual"}) ) work = work.merge(by_row, on="row_index", how="left") if "median_abs_mag_residual" in work: score = work["median_abs_mag_residual"] elif "reduced_chi2" in work: score = work["reduced_chi2"] else: score = pd.Series(np.zeros(len(work)), index=work.index) fill_value = -np.inf if worst else np.inf work["sed_diagnostic_score"] = score.replace([np.inf, -np.inf], np.nan).fillna( fill_value ) if "reduced_chi2" not in work: work["reduced_chi2"] = work["sed_diagnostic_score"] return ( work.sort_values( ["sed_diagnostic_score", "reduced_chi2"], ascending=[not worst, not worst], na_position="last", ) .drop_duplicates("row_index") .head(limit) ) def _write_fit_sed_samples( context, batch: pd.DataFrame, fit_rows: list[dict[str, Any]], config: dict[str, Any], out: Path, *, mode: str, chunk_index: int, remaining: int, ) -> list[dict[str, Any]]: fit_by_row = {int(row["row_index"]): row for row in fit_rows} parameter_rows = [] selected = [] for row_index, row in batch.iterrows(): if len(selected) >= remaining: break fit_row = fit_by_row.get(int(row_index)) if fit_row is None: continue params = parameters_for_row( config["model"]["fixed_parameters"], config["model"].get("parameter_columns", {}), row.to_dict(), config.get("redshift", {}), ) for name in list(params): fit_key = f"fit_{name}" if fit_key in fit_row and pd.notna(fit_row[fit_key]): params[name] = float(fit_row[fit_key]) selected.append((row_index, row)) parameter_rows.append(params) if not selected: return [] selected_batch = pd.DataFrame([row for _, row in selected]) selected_batch.index = [int(row_index) for row_index, _ in selected] return _write_sed_samples( context, selected_batch, parameter_rows, config, out, mode=mode, chunk_index=chunk_index, remaining=remaining, ) def _write_sed_samples( context, batch: pd.DataFrame, parameter_rows: list[dict[str, Any]], config: dict[str, Any], out: Path, *, mode: str, chunk_index: int, remaining: int, ) -> list[dict[str, Any]]: if remaining <= 0 or batch.empty or not parameter_rows: return [] n = min(int(remaining), len(batch), len(parameter_rows)) batch = batch.head(n) parameter_rows = parameter_rows[:n] parameter_names = _parameter_names_for_sed_rows(parameter_rows) parameter_matrix = pd.DataFrame(parameter_rows, columns=parameter_names).to_numpy( dtype=float ) sed_result = predict_batch_seds(context, parameter_names, parameter_matrix) sed_dir = ensure_dir(out / "sed_diagnostics") manifest_rows = [] for local_index, (row_index, row) in enumerate(batch.iterrows()): row_dict = row.to_dict() observation = build_observation(int(row_index), row, config["bands"]) model_result = _model_result_from_sed_batch( sed_result, local_index, context, config ) ground_truth_sed, ground_truth_status = _ground_truth_sed_for_row( row_dict, int(row_index), context.filters, config ) stem = f"{mode}_chunk_{chunk_index:06d}_row_{int(row_index):08d}" output = write_sed_diagnostic_outputs( observation, model_result, sed_dir, stem=stem, ground_truth_sed=ground_truth_sed, include_filters=_plot_filters(config), ) manifest_rows.append( { **output, "mode": mode, "chunk_index": int(chunk_index), "ground_truth_sed_status": ground_truth_status, } ) return manifest_rows def _parameter_names_for_sed_rows(parameter_rows: list[dict[str, Any]]) -> list[str]: names: list[str] = [] for row in parameter_rows: for name in row: if name not in names: names.append(name) return names def _model_result_from_sed_batch( batch_result: BatchSedResult, local_index: int, context, config: dict[str, Any], ) -> ModelResult: params = { name: float(batch_result.parameter_matrix[local_index, index]) for index, name in enumerate(batch_result.parameter_names) } derived = { name: float(values[local_index]) for name, values in batch_result.derived.items() } photometry = {} for band_index, band in enumerate(config["bands"]): name = band["name"] curve = context.filters[name] mag = float(batch_result.model_mags[local_index, band_index]) photometry[name] = { "model_mag_ab": mag, "model_flux_fnu_cgs": abmag_to_flux_fnu_cgs(mag), "filter_source": curve.source, "effective_wavelength_angstrom": curve.effective_wavelength, "filter_wave_angstrom": curve.wave, "filter_transmission": curve.transmission, } return ModelResult( parameters=params, derived=derived, wave=batch_result.wave, rest_sed=batch_result.rest_sed[local_index], dusted_rest_sed=batch_result.dusted_rest_sed[local_index], photometry=photometry, ) def _ground_truth_sed_for_row( row: dict[str, Any], row_index: int, filters: Any, config: dict[str, Any], ) -> tuple[pd.DataFrame | None, str]: if not _plot_ground_truth(config): return None, "disabled" try: from ..cosmos import ( MissingCosmosColumnsError, MissingCosmosResourceError, flambda_10pc_to_lnu_lsun, load_cosmos_sed_resources, reconstruct_cosmos_proxy_sed, ) except Exception as exc: # pragma: no cover - defensive optional import path return None, f"import_failed:{type(exc).__name__}" cosmos_config = config.get("cosmos_sed", {}) or {} cache_key = ( cosmos_config.get("value_added_data_dir"), cosmos_config.get("lephare_data_dir"), cosmos_config.get("template_subdir"), cosmos_config.get("template_list"), ) try: if cache_key not in _COSMOS_RESOURCE_CACHE: _COSMOS_RESOURCE_CACHE[cache_key] = load_cosmos_sed_resources( cosmos_config ) cosmos_result = reconstruct_cosmos_proxy_sed( row, row_index, _COSMOS_RESOURCE_CACHE[cache_key], filters, config["bands"], cosmos_config, ) except MissingCosmosColumnsError as exc: return None, f"missing_columns:{exc}" except MissingCosmosResourceError as exc: return None, f"missing_resource:{exc}" except Exception as exc: # pragma: no cover - report unexpected data issue return None, f"failed:{type(exc).__name__}:{exc}" rel_residual = np.asarray( list(cosmos_result.relative_residuals_vs_catalog_abs.values()), dtype=float ) finite_rel = rel_residual[np.isfinite(rel_residual)] norm_bands = sorted(cosmos_result.synthetic_abs_fluxes_after) return ( pd.DataFrame( { "wave_angstrom": cosmos_result.wave_angstrom, "ground_truth_lnu_lsun_per_hz": flambda_10pc_to_lnu_lsun( cosmos_result.wave_angstrom, cosmos_result.flambda_scaled ), "ground_truth_unscaled_lnu_lsun_per_hz": flambda_10pc_to_lnu_lsun( cosmos_result.wave_angstrom, cosmos_result.flambda_unscaled ), "ground_truth_label": "COSMOS proxy", "ground_truth_scale_factor": cosmos_result.alpha, "ground_truth_normalization_bands": ",".join(norm_bands), "ground_truth_norm_median_abs_rel_residual": ( float(np.nanmedian(np.abs(finite_rel))) if finite_rel.size else float("nan") ), "ground_truth_norm_max_abs_rel_residual": ( float(np.nanmax(np.abs(finite_rel))) if finite_rel.size else float("nan") ), } ), "ok", )
[docs] def fit_population( config: dict[str, Any], out_dir: str | Path, limit: int | None = 25, batch_size: int = 256, row_indices_file: str | None = None, map_init_file: str | None = None, ) -> None: """Fit chunked hierarchical population MAP models with JAX-vmapped Adam.""" out = ensure_dir(out_dir) perf = PerformanceRecorder(_verbose_benchmark(config)) reporting_level = _reporting_level(config) row_indices = _load_row_indices_set(row_indices_file) map_init = ( pd.read_csv(map_init_file).set_index("row_index") if map_init_file else None ) columns = required_catalog_columns(config) filters = load_filters(config["bands"]) context = load_context( config["ssp_path"], filters, n_sfh_bins=int(config["model"].get("n_sfh_bins", 96)), cosmos_config=config.get("cosmos_sed"), nebular_emission=config.get("nebular_emission", "ssp_flux"), ) perf.mark("load_context", n_bands=len(config["bands"])) comparison_rows = [] fit_rows = [] hyper_rows = [] trace_rows = [] sed_manifest_rows = [] sed_samples_written = 0 sed_sample_limit = _sed_sample_limit(config) total = _progress_total(config["catalog_path"], limit, row_indices) with _make_progress_bar( total=total, desc="population-fit", unit="galaxy" ) as progress: chunk_index = 0 for batch in iter_catalog_batches( config["catalog_path"], columns=columns, batch_size=batch_size, limit=limit, row_indices=row_indices, ): perf.mark("read_chunk", chunk_index=chunk_index, n_rows=len(batch)) batch_result = _fit_dataframe_batch( context, batch, config, population=True, chunk_index=chunk_index, map_init=map_init, perf=perf, ) fit_rows.extend(batch_result["fit_rows"]) comparison_rows.extend(batch_result["comparison_rows"]) hyper_rows.extend(batch_result["hyper_rows"]) trace_rows.extend(batch_result["trace_rows"]) if sed_samples_written < sed_sample_limit: new_rows = _write_fit_sed_samples( context, batch, batch_result["fit_rows"], config, out, mode="population_fit", chunk_index=chunk_index, remaining=sed_sample_limit - sed_samples_written, ) sed_manifest_rows.extend(new_rows) sed_samples_written += len(new_rows) perf.mark( "write_sed_samples", chunk_index=chunk_index, n_rows=len(new_rows), ) perf.mark( "fit_population_chunk", chunk_index=chunk_index, n_rows=len(batch) ) _update_progress( progress, row_index=int(batch.index[-1]), amount=len(batch) ) chunk_index += 1 fits = pd.DataFrame(fit_rows) comparison = pd.DataFrame(comparison_rows) write_dataframe_outputs(fits, out, "population_fit_results", config) write_dataframe_outputs( comparison, out, "population_fit_photometry_comparison", config ) hyper_frame = pd.DataFrame(hyper_rows) if hyper_rows: write_dataframe_outputs(hyper_frame, out, "population_hyperparameters", config) perf.mark("write_primary_outputs", n_rows=len(fits)) if trace_rows: trace = pd.DataFrame(trace_rows) write_dataframe_outputs(trace, out, "population_fit_trace", config) write_trace_truth_outputs( trace, out, label="population_fit", make_plots=reporting_level == "full", ) write_batch_outputs( comparison, out, label="population_fit", reporting_level=reporting_level, config=config, ) if sed_manifest_rows: pd.DataFrame(sed_manifest_rows).to_csv( out / "sed_diagnostics_manifest.csv", index=False ) if reporting_level == "full": write_population_corner_outputs( fits, list(config["fit"]["free_parameters"]), out, config=config ) write_fit_diagnostic_outputs( fits, comparison, config, out, label="population_fit", hyperparameters=hyper_frame, ) write_nebular_diagnostic_outputs(context, fits, out, label="population_fit") perf.mark("write_reports") write_json(out / "normalized_config.json", config) write_json( out / "population_fit_run_config.json", { "rows_written": len(comparison_rows), "limit": limit, "batch_size": batch_size, "row_indices_file": row_indices_file, "map_init_file": map_init_file, }, ) write_performance_outputs(perf.rows, out, "population_fit")
def _comparison_for_batch(observation, result, params, config): from ..model import comparison_rows context_values = _row_context(observation.row, params, config) for row in comparison_rows(observation, result): row["row_index"] = observation.row_index row.update(context_values) yield row def _posterior_predictive_rows( row_index: int, result, context_values: dict[str, Any] ) -> list[dict[str, Any]]: rows = [] for band_index, band in enumerate(result.band_names): values = result.posterior_model_mags[:, band_index] obs = float(result.observed_mag[band_index]) med = float(pd.Series(values).quantile(0.50)) rows.append( { "row_index": row_index, "band": band, "observed_mag_ab": obs, "sigma_mag": float(result.sigma_mag[band_index]), "model_mag_q05": float(pd.Series(values).quantile(0.05)), "model_mag_q16": float(pd.Series(values).quantile(0.16)), "model_mag_median": med, "model_mag_q84": float(pd.Series(values).quantile(0.84)), "model_mag_q95": float(pd.Series(values).quantile(0.95)), "residual_mag_median_model_minus_observed": med - obs, **context_values, } ) return rows def _posterior_sample_rows(row_index: int, result) -> list[dict[str, Any]]: names = list(result.samples) n_samples = len(result.samples[names[0]]) if names else 0 rows = [] for sample_index in range(n_samples): rows.append( { "row_index": row_index, "sample_index": sample_index, **{name: float(result.samples[name][sample_index]) for name in names}, } ) return rows def _posterior_truth_values(context_values: dict[str, Any]) -> dict[str, Any]: return { key: value for key, value in context_values.items() if key.startswith("truth_") or key.startswith("truth_source_") or key.startswith("truth_kind_") or key in {"redshift_truth", "redshift_truth_source", "z_obs", "z_obs_source"} } def _row_context( row: dict[str, Any], params: dict[str, float], config: dict[str, Any] ) -> dict[str, float | str]: values: dict[str, float | str] = {} values["z_obs"] = float(params["z_obs"]) values["dust_parameter_active"] = bool(is_forward_active(config, "dust_av")) values["dust_parameter_inferred"] = bool(is_inferred(config, "dust_av")) values["dust_model"] = str(config.get("dust_model", "salim")) redshift = config.get("redshift", {}) redshift_initial = str(redshift.get("initial", "catalog_column")) if redshift_initial == "random_uniform": z_source = "random_uniform" elif redshift_initial == "fixed": z_source = "fixed_value" else: z_source = truth_column_from_spec(redshift.get("column")) or "fixed_value" values["z_obs_source"] = z_source truth_col = redshift.get("truth_column") or config.get("truth", {}).get( "redshift_column" ) truth_col = truth_column_from_spec(truth_col) if truth_col and truth_col in row and pd.notna(row[truth_col]): values["redshift_truth_source"] = truth_col values["redshift_truth"] = float(row[truth_col]) values["delta_z_obs_minus_truth"] = values["z_obs"] - values["redshift_truth"] for key, value in params.items(): values[f"param_{key}"] = float(value) for truth_name, column in ( config.get("truth", {}).get("parameter_columns") or {} ).items(): truth_value = truth_value_from_spec(row, column) if truth_value is not None: truth_column = truth_column_from_spec(column) if truth_column: values[f"truth_source_{truth_name}"] = truth_column values[f"truth_kind_{truth_name}"] = _truth_kind(truth_name, column) values[f"truth_{truth_name}"] = float(truth_value) param_key = f"param_{truth_name}" if param_key in values: values[f"delta_{truth_name}"] = ( values[param_key] - values[f"truth_{truth_name}"] ) for column in config.get("extra_columns", []): if column in row and pd.notna(row[column]): values[f"catalog_{column}"] = float(row[column]) return values def _truth_kind(truth_name: str, spec: Any) -> str: if truth_name in {"dust_av", "log10_metallicity"}: return "proxy" if isinstance(spec, dict) and ( "scale" in spec or "offset" in spec or spec.get("transform") not in {None, "linear"} ): return "proxy" return "direct" def _fit_dataframe_batch( context, batch: pd.DataFrame, config: dict[str, Any], population: bool = False, chunk_index: int = 0, map_init: pd.DataFrame | None = None, perf: PerformanceRecorder | None = None, ) -> dict[str, list[dict[str, Any]]]: observed_mag, observed_flux, sigma_mag = _photometry_arrays(batch, config["bands"]) flux_error = _flux_error_arrays(batch, config["bands"], observed_flux, sigma_mag) valid_band_mask = _valid_band_mask( observed_mag, observed_flux, sigma_mag, flux_error, config["fit"] ) if perf is not None: perf.mark("prepare_photometry", chunk_index=chunk_index, n_rows=len(batch)) base_rows = [ parameters_for_row( config["model"]["fixed_parameters"], config["model"].get("parameter_columns", {}), row.to_dict(), config.get("redshift", {}), ) for _, row in batch.iterrows() ] if perf is not None: perf.mark("prepare_parameters", chunk_index=chunk_index, n_rows=len(batch)) if population: truth_theta = _truth_parameter_matrix( batch, config, list(config["fit"]["free_parameters"]) ) if perf is not None: perf.mark("prepare_truth", chunk_index=chunk_index, n_rows=len(batch)) pop_result = fit_population_batch_adam( context, base_rows, observed_mag, sigma_mag, config["fit"], initial_theta=_initial_theta_from_map(batch, config, map_init), truth_theta=truth_theta, observed_flux=observed_flux, flux_error=flux_error, ) if perf is not None: perf.mark("jax_optimize", chunk_index=chunk_index, n_rows=len(batch)) fit_result = pop_result.batch hyper_rows = [ { "chunk_index": chunk_index, "n_galaxies": len(batch), "kind": "gaussian", "parameter": name, "population_mu": pop_result.hyper_mu[name], "population_sigma": pop_result.hyper_sigma[name], "loss": pop_result.loss, "device": fit_result.device, } for name in fit_result.free_parameter_names ] for relation in pop_result.hyper_relations: hyper_rows.append( { "chunk_index": chunk_index, "n_galaxies": len(batch), "kind": "relation", "parameter": relation["target_parameter"], "target_parameter": relation["target_parameter"], "predictor_parameter": relation["predictor_parameter"], "population_pivot": relation["pivot"], "population_intercept": relation["intercept"], "population_slope": relation["slope"], "population_sigma": relation["sigma"], "loss": pop_result.loss, "device": fit_result.device, } ) else: truth_theta = _truth_parameter_matrix( batch, config, list(config["fit"]["free_parameters"]) ) if perf is not None: perf.mark("prepare_truth", chunk_index=chunk_index, n_rows=len(batch)) fit_result = fit_galaxy_batch_adam( context, base_rows, observed_mag, sigma_mag, config["fit"], truth_theta=truth_theta, observed_flux=observed_flux, flux_error=flux_error, ) if perf is not None: perf.mark("jax_optimize", chunk_index=chunk_index, n_rows=len(batch)) hyper_rows = [] fit_rows = [] comparison_rows = [] band_names = [band["name"] for band in config["bands"]] filter_curves = [context.filters[name] for name in band_names] param_matrix = fit_result.best_parameter_matrix derived = predict_batch_derived(context, fit_result.parameter_names, param_matrix) if perf is not None: perf.mark("derive_quantities", chunk_index=chunk_index, n_rows=len(batch)) for local_index, (row_index, row) in enumerate(batch.iterrows()): params = { name: float(param_matrix[local_index, param_index]) for param_index, name in enumerate(fit_result.parameter_names) } derived_values = { f"fit_{name}": float(values[local_index]) for name, values in derived.items() } context_values = { **_row_context(row.to_dict(), params, config), **{f"fit_{key}": value for key, value in params.items()}, **derived_values, } if "z_obs" in fit_result.free_parameter_names: context_values["z_obs_source"] = "DSPS fit" context_values["redshift_initial_mode"] = str( config.get("redshift", {}).get("initial", "catalog_column") ) context_values["redshift_prior_mode"] = str( (config.get("redshift", {}).get("prior_z") or {}).get("mode", "none") ) z_initial = ( _initial_value( config["fit"]["free_parameters"]["z_obs"], "z_obs", base_rows[local_index], ) if "z_obs" in config["fit"]["free_parameters"] else base_rows[local_index].get("z_obs", np.nan) ) n_bands = len(config["bands"]) n_valid_bands = int(np.asarray(valid_band_mask[local_index]).sum()) flux_values = np.asarray(observed_flux[local_index], dtype=float) flux_errors = np.asarray(flux_error[local_index], dtype=float) finite_flux_error = np.isfinite(flux_errors) & (flux_errors > 0.0) n_nondetected_bands = int( (np.isfinite(flux_values) & (flux_values <= 0.0) & finite_flux_error).sum() ) n_masked_bands = int(n_bands - n_valid_bands) n_free_effective = len(config["fit"]["free_parameters"]) dof = max(n_valid_bands - n_free_effective, 1) chi2_value = float(fit_result.chi2[local_index]) fit_rows.append( { "row_index": int(row_index), "chunk_index": int(chunk_index), "success": bool(fit_result.success[local_index]), "message": fit_result.message, "chi2": chi2_value, "chi2_per_band": chi2_value / max(n_valid_bands, 1), "reduced_chi2": chi2_value / dof, "reduced_chi2_dof": chi2_value / dof, "gradient_norm": float(fit_result.gradient_norm[local_index]), "n_bands": n_bands, "n_valid_bands": n_valid_bands, "n_masked_bands": n_masked_bands, "n_nondetected_bands": n_nondetected_bands, "n_upper_limit_bands": 0, "nondetection_policy": str( config.get("selection", {}).get("nondetection_policy", "drop") ), "n_free_effective": n_free_effective, "dof": dof, "z_initial": float(z_initial), "device": fit_result.device, **{f"fit_{key}": value for key, value in params.items()}, **derived_values, **context_values, } ) for band_index, band in enumerate(config["bands"]): model_mag = float(fit_result.model_mags[local_index, band_index]) obs_mag = float(observed_mag[local_index, band_index]) obs_flux = float(observed_flux[local_index, band_index]) obs_flux_error = float(flux_error[local_index, band_index]) sigma = float(sigma_mag[local_index, band_index]) model_flux = abmag_to_flux_fnu_cgs(model_mag) flux_ratio = model_flux / obs_flux if obs_flux > 0 else float("nan") residual = obs_mag - model_mag chi_flux = ( (model_flux - obs_flux) / obs_flux_error if obs_flux_error > 0 else float("nan") ) comparison_rows.append( { "row_index": int(row_index), "band": band["name"], "column": band["column"], "effective_wavelength_angstrom": filter_curves[ band_index ].effective_wavelength, "observed_flux_fnu_cgs": obs_flux, "observed_flux_error_fnu_cgs": obs_flux_error, "band_used_in_likelihood": bool( valid_band_mask[local_index, band_index] ), "band_is_nondetection": bool( np.isfinite(obs_flux) and obs_flux <= 0.0 and obs_flux_error > 0.0 ), "n_valid_bands": n_valid_bands, "n_masked_bands": n_masked_bands, "n_nondetected_bands": n_nondetected_bands, "n_upper_limit_bands": 0, "n_free_effective": n_free_effective, "dof": dof, "observed_mag_ab": obs_mag, "sigma_mag": sigma, "model_flux_fnu_cgs": model_flux, "model_mag_ab": model_mag, "residual_mag_observed_minus_model": residual, "residual_mag_model_minus_observed": -residual, "flux_ratio_model_over_observed": flux_ratio, "fractional_flux_residual_model_minus_observed": flux_ratio - 1.0, "chi": residual / sigma if sigma > 0 else float("nan"), "chi_flux": chi_flux, "filter_source": filter_curves[band_index].source, **context_values, } ) trace_rows = [ { "chunk_index": chunk_index, **entry, } for entry in fit_result.trace ] if perf is not None: perf.mark("materialize_rows", chunk_index=chunk_index, n_rows=len(batch)) return { "fit_rows": fit_rows, "comparison_rows": comparison_rows, "hyper_rows": hyper_rows, "trace_rows": trace_rows, } def _truth_parameter_matrix( batch: pd.DataFrame, config: dict[str, Any], free_names: list[str] ) -> np.ndarray | None: redshift = config.get("redshift", {}) redshift_truth_spec = redshift.get("truth_column") or config.get("truth", {}).get( "redshift_column" ) parameter_specs = config.get("truth", {}).get("parameter_columns") or {} rows = [] for _, row in batch.iterrows(): values = [] row_dict = row.to_dict() for name in free_names: spec = redshift_truth_spec if name == "z_obs" else parameter_specs.get(name) truth_value = truth_value_from_spec(row_dict, spec) if spec else None values.append(float(truth_value) if truth_value is not None else np.nan) rows.append(values) truth = np.asarray(rows, dtype=float) return truth if np.isfinite(truth).any() else None def _initial_theta_from_map( batch: pd.DataFrame, config: dict[str, Any], map_init: pd.DataFrame | None ) -> np.ndarray | None: if map_init is None: return None free_names = list(config["fit"]["free_parameters"]) rows = [] for row_index in batch.index: if row_index not in map_init.index: return None item = map_init.loc[row_index] if isinstance(item, pd.DataFrame): item = item.iloc[0] values = [] for name in free_names: col = f"fit_{name}" if col not in item or not np.isfinite(item[col]): return None values.append(float(item[col])) rows.append(values) return np.asarray(rows, dtype=float) def _initial_value( spec: dict[str, Any], name: str, base_params: dict[str, float] ) -> float: value = spec.get("initial", base_params.get(name, 0.0)) if isinstance(value, str): value = base_params[name] if value == "from_base" else value return float(value) def _forward_dataframe_batch( context, batch: pd.DataFrame, config: dict[str, Any] ) -> list[dict[str, Any]]: observed_mag, observed_flux, sigma_mag = _photometry_arrays(batch, config["bands"]) base_rows = [ parameters_for_row( config["model"]["fixed_parameters"], config["model"].get("parameter_columns", {}), row.to_dict(), config.get("redshift", {}), ) for _, row in batch.iterrows() ] parameter_names = list(base_rows[0]) parameter_matrix = pd.DataFrame(base_rows, columns=parameter_names).to_numpy( dtype=float ) model_mags = predict_batch_mags(context, parameter_names, parameter_matrix) comparison_rows = [] band_names = [band["name"] for band in config["bands"]] filter_curves = [context.filters[name] for name in band_names] for local_index, (row_index, row) in enumerate(batch.iterrows()): params = { name: float(parameter_matrix[local_index, param_index]) for param_index, name in enumerate(parameter_names) } context_values = _row_context(row.to_dict(), params, config) for band_index, band in enumerate(config["bands"]): model_mag = float(model_mags[local_index, band_index]) obs_mag = float(observed_mag[local_index, band_index]) obs_flux = float(observed_flux[local_index, band_index]) sigma = float(sigma_mag[local_index, band_index]) model_flux = abmag_to_flux_fnu_cgs(model_mag) flux_ratio = model_flux / obs_flux if obs_flux > 0 else float("nan") residual = obs_mag - model_mag comparison_rows.append( { "row_index": int(row_index), "band": band["name"], "column": band["column"], "effective_wavelength_angstrom": filter_curves[ band_index ].effective_wavelength, "observed_flux_fnu_cgs": obs_flux, "observed_mag_ab": obs_mag, "sigma_mag": sigma, "model_flux_fnu_cgs": model_flux, "model_mag_ab": model_mag, "residual_mag_observed_minus_model": residual, "residual_mag_model_minus_observed": -residual, "flux_ratio_model_over_observed": flux_ratio, "fractional_flux_residual_model_minus_observed": flux_ratio - 1.0, "chi": residual / sigma if sigma > 0 else float("nan"), "filter_source": filter_curves[band_index].source, **context_values, } ) return comparison_rows def _photometry_arrays( batch: pd.DataFrame, band_configs: list[dict[str, Any]] ) -> tuple[Any, Any, Any]: mag_columns = [] flux_columns = [] sigma_columns = [] for band in band_configs: values = batch[band["column"]].astype(float).to_numpy() units = band.get("units", "fnu_cgs") if units == "fnu_cgs": flux = values mag = [flux_fnu_cgs_to_abmag(value) for value in values] elif units == "abmag": mag = values flux = [abmag_to_flux_fnu_cgs(value) for value in values] elif units in {"microjy", "ujy"}: mag = [microjy_to_abmag(value) for value in values] flux = [microjy_to_flux_fnu_cgs(value) for value in values] else: raise ValueError( f"Unsupported photometry units for {band['name']}: {units}" ) mag_columns.append(mag) flux_columns.append(flux) sigma_columns.append(_sigma_mag_array(batch, band, flux, units)) return ( pd.DataFrame(mag_columns).transpose().to_numpy(dtype=float), pd.DataFrame(flux_columns).transpose().to_numpy(dtype=float), pd.DataFrame(sigma_columns).transpose().to_numpy(dtype=float), ) def _sigma_mag_array( batch: pd.DataFrame, band: dict[str, Any], flux: Any, units: str ) -> list[float]: fallback = float(band.get("sigma_mag", 0.05)) error_column = band.get("error_column") if not error_column or error_column not in batch: return [fallback] * len(batch) raw_errors = batch[str(error_column)].astype(float).to_numpy() error_units = band.get("error_units", units) floor = band.get("sigma_mag_floor") ceiling = band.get("sigma_mag_ceiling") sigma = [] flux_values = np.asarray(flux, dtype=float) for flux_value, raw_error in zip(flux_values, raw_errors, strict=True): value = float("nan") if np.isfinite(raw_error) and raw_error > 0.0: if error_units == "abmag": value = float(raw_error) if floor is not None and np.isfinite(floor): value = max(value, float(floor)) if ceiling is not None and np.isfinite(ceiling): value = min(value, float(ceiling)) elif error_units == "fnu_cgs": value = flux_error_to_sigma_mag( float(flux_value), float(raw_error), floor=floor, ceiling=ceiling ) elif error_units in {"microjy", "ujy"}: value = flux_error_to_sigma_mag( float(flux_value), microjy_to_flux_fnu_cgs(float(raw_error)), floor=floor, ceiling=ceiling, ) else: raise ValueError( f"Unsupported photometry error units for {band['name']}: {error_units}" ) sigma.append(value if np.isfinite(value) and value > 0.0 else fallback) return sigma def _flux_error_arrays( batch: pd.DataFrame, band_configs: list[dict[str, Any]], observed_flux: np.ndarray, sigma_mag: np.ndarray, ) -> np.ndarray: columns = [] for band_index, band in enumerate(band_configs): flux = np.asarray(observed_flux[:, band_index], dtype=float) fallback = np.asarray( magerr_to_fluxerr_fnu_cgs(flux, sigma_mag[:, band_index]), dtype=float ) error_column = band.get("error_column") if not error_column or error_column not in batch: columns.append(fallback) continue raw_errors = batch[str(error_column)].astype(float).to_numpy() error_units = band.get("error_units", band.get("units", "fnu_cgs")) if error_units == "fnu_cgs": errors = raw_errors elif error_units in {"microjy", "ujy"}: errors = np.asarray( [microjy_to_flux_fnu_cgs(value) for value in raw_errors], dtype=float ) elif error_units == "abmag": errors = np.asarray( magerr_to_fluxerr_fnu_cgs(flux, raw_errors), dtype=float ) else: raise ValueError( f"Unsupported photometry error units for {band['name']}: {error_units}" ) errors = np.where(np.isfinite(errors) & (errors > 0.0), errors, fallback) columns.append(errors) return pd.DataFrame(columns).transpose().to_numpy(dtype=float) def _valid_band_mask( observed_mag: np.ndarray, observed_flux: np.ndarray, sigma_mag: np.ndarray, flux_error: np.ndarray, fit_config: dict[str, Any], ) -> np.ndarray: if str(fit_config.get("likelihood_space", "flux")).lower() == "flux": sigma = np.asarray(flux_error, dtype=float) observed = np.asarray(observed_flux, dtype=float) else: sigma = np.asarray(sigma_mag, dtype=float) observed = np.asarray(observed_mag, dtype=float) return np.isfinite(observed) & np.isfinite(sigma) & (sigma > 0.0) def _load_row_indices_set(row_indices_file: str | None) -> set[int] | None: if not row_indices_file: return None return set(load_row_indices(row_indices_file)) def _progress_total( catalog_path: str | Path, limit: int | None, row_indices: set[int] | None = None ) -> int | None: if row_indices is not None: if limit is None: return len(row_indices) return min(int(limit), len(row_indices)) if limit is not None: return int(limit) try: import pyarrow.parquet as pq return int(pq.ParquetFile(catalog_path).metadata.num_rows) except Exception: return None class _NullProgress: def update(self, _: int = 1) -> None: return None def set_postfix_str(self, _: str, refresh: bool = False) -> None: return None def __enter__(self) -> _NullProgress: return self def __exit__(self, exc_type, exc, tb) -> None: return None def _make_progress_bar(total: int | None, desc: str, unit: str) -> Any: if tqdm is None: return _NullProgress() return tqdm( total=total, desc=desc, unit=unit, dynamic_ncols=True, mininterval=0.2, smoothing=0.05, ) def _update_progress(progress: Any, row_index: int, amount: int = 1) -> None: progress.update(amount) progress.set_postfix_str(f"row={row_index}", refresh=False)