"""EDA and run reporting."""
from __future__ import annotations
import os
import tempfile
from pathlib import Path
from typing import Any
os.environ.setdefault("MPLCONFIGDIR", str(Path(tempfile.gettempdir()) / "matplotlib"))
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from cycler import cycler
from ..io import GalaxyObservation, ensure_dir, write_json
from ..model import ModelResult, comparison_rows
from ..semantics import active_parameters, is_comparable_fit_parameter
configure_plot_style()
[docs]
def write_eda_outputs(
df: pd.DataFrame,
band_configs: list[dict[str, Any]],
out_dir: str | Path,
redshift_config: dict[str, Any] | None = None,
) -> None:
out = ensure_dir(out_dir)
schema = [{"name": col, "dtype": str(dtype)} for col, dtype in df.dtypes.items()]
write_json(out / "catalog_schema.json", schema)
df.describe(include="all").transpose().to_csv(out / "catalog_stats.csv")
missing = df.isna().sum().rename("missing_count").to_frame()
missing["missing_fraction"] = missing["missing_count"] / max(len(df), 1)
missing.to_csv(out / "missing_values.csv")
band_columns = [
band["column"] for band in band_configs if band["column"] in df.columns
]
if band_columns:
plot_flux_distributions(df, band_columns, out / "flux_distributions.png")
plot_color_distributions(df, band_columns, out / "color_distributions.png")
if redshift_config:
plot_redshift_distributions(
df, redshift_config, out / "redshift_diagnostics.png"
)
plot_physical_parameters_distributions(df, out / "physical_parameters.png")
[docs]
def write_run_outputs(
observation: GalaxyObservation,
result: ModelResult,
out_dir: str | Path,
*,
ground_truth_sed: Any | None = None,
include_filters: bool = True,
) -> pd.DataFrame:
out = ensure_dir(out_dir)
write_json(out / "selected_galaxy.json", observation)
write_json(
out / "model_parameters.json",
{"parameters": result.parameters, "derived": result.derived},
)
sed = pd.DataFrame(
{
"wave_angstrom": result.wave,
"rest_sed_lsun_per_hz": result.rest_sed,
"dusted_rest_sed_lsun_per_hz": result.dusted_rest_sed,
}
)
sed.to_csv(out / "sed.csv", index=False)
comparison = pd.DataFrame(comparison_rows(observation, result))
comparison.to_csv(out / "photometry_comparison.csv", index=False)
plot_sed(result, out / "sed.png")
write_sed_diagnostic_outputs(
observation,
result,
out,
stem="sed_diagnostic",
ground_truth_sed=ground_truth_sed,
include_filters=include_filters,
)
plot_photometry_comparison(comparison, out / "photometry_comparison.png")
return comparison
[docs]
def write_sed_diagnostic_outputs(
observation: GalaxyObservation,
result: ModelResult,
out_dir: str | Path,
*,
stem: str,
ground_truth_sed: Any | None = None,
include_filters: bool = True,
) -> dict[str, Any]:
"""Write one rich SED diagnostic: DSPS SED, optional COSMOS proxy, filters, photometry."""
out = ensure_dir(out_dir)
sed = pd.DataFrame(
{
"wave_angstrom": result.wave,
"rest_sed_lsun_per_hz": result.rest_sed,
"dusted_rest_sed_lsun_per_hz": result.dusted_rest_sed,
}
)
sed_path = out / f"{stem}_dsps_sed.csv"
sed.to_csv(sed_path, index=False)
truth_path = None
truth_metadata: dict[str, Any] = {}
if ground_truth_sed is not None:
truth = _ground_truth_sed_frame(ground_truth_sed)
if truth is not None and not truth.empty:
truth_path = out / f"{stem}_ground_truth_sed.csv"
truth.to_csv(truth_path, index=False)
for key in (
"ground_truth_scale_factor",
"ground_truth_normalization_bands",
"ground_truth_norm_median_abs_rel_residual",
"ground_truth_norm_max_abs_rel_residual",
):
if key in truth:
value = truth[key].iloc[0]
truth_metadata[key] = (
float(value)
if isinstance(value, (int, float, np.number))
else value
)
comparison = pd.DataFrame(comparison_rows(observation, result))
phot_path = out / f"{stem}_photometry.csv"
comparison.to_csv(phot_path, index=False)
plot_path = out / f"{stem}.png"
plot_sed_diagnostic(
result,
plot_path,
observation=observation,
ground_truth_sed=ground_truth_sed,
include_filters=include_filters,
)
return {
"row_index": int(observation.row_index),
"plot": plot_path.name,
"dsps_sed": sed_path.name,
"photometry": phot_path.name,
"ground_truth_sed": truth_path.name if truth_path else None,
"has_ground_truth_sed": truth_path is not None,
**truth_metadata,
}
[docs]
def write_fit_outputs(fit_result: Any, out_dir: str | Path) -> None:
out = ensure_dir(out_dir)
write_json(
out / "fit_result.json",
{
"success": fit_result.success,
"message": fit_result.message,
"best_parameters": fit_result.best_parameters,
"chi2": fit_result.chi2,
"n_bands": fit_result.n_bands,
"gradient_norm": fit_result.gradient_norm,
},
)
pd.DataFrame(fit_result.trace).to_csv(out / "fit_trace.csv", index=False)
plot_fit_trace(pd.DataFrame(fit_result.trace), out / "fit_trace.png")
[docs]
def write_mcmc_outputs(
mcmc_result: Any,
out_dir: str | Path,
truth_values: dict[str, Any] | None = None,
) -> None:
out = ensure_dir(out_dir)
samples = pd.DataFrame(mcmc_result.samples)
samples.to_csv(out / "posterior_samples.csv", index=False)
derived = pd.DataFrame(getattr(mcmc_result, "derived_samples", {}))
if not derived.empty:
derived.to_csv(out / "posterior_derived_samples.csv", index=False)
pd.DataFrame(mcmc_result.summary).to_csv(out / "posterior_summary.csv", index=False)
write_json(out / "mcmc_diagnostics.json", mcmc_result.diagnostics)
if truth_values:
write_json(out / "posterior_truth_values.json", truth_values)
write_posterior_predictive(mcmc_result, out / "posterior_predictive_photometry.csv")
plot_mcmc_traces(samples, out / "posterior_trace.png")
plot_corner(samples, out / "posterior_corner.png")
comparable = posterior_comparable_frame(samples, derived, truth_values or {})
if not comparable.empty:
comparable.to_csv(out / "posterior_comparable_samples.csv", index=False)
plot_corner_with_truth(
comparable,
truth_values or {},
out / "posterior_corner_with_truth.png",
)
plot_posterior_predictive(mcmc_result, out / "posterior_predictive_photometry.png")
[docs]
def write_mcmc_batch_outputs(
summary: pd.DataFrame,
predictive: pd.DataFrame,
diagnostics: pd.DataFrame,
out_dir: str | Path,
) -> None:
out = ensure_dir(out_dir)
summary_payload = {
"n_galaxies": (
int(summary["row_index"].nunique()) if "row_index" in summary else 0
),
"n_parameters": (
int(summary["parameter"].nunique()) if "parameter" in summary else 0
),
}
if "n_divergent" in diagnostics:
summary_payload["n_divergent"] = int(diagnostics["n_divergent"].sum())
if "residual_mag_median_model_minus_observed" in predictive:
summary_payload["median_abs_posterior_residual_mag"] = float(
predictive["residual_mag_median_model_minus_observed"].abs().median()
)
write_json(out / "batch_mcmc_summary.json", summary_payload)
plot_batch_posterior_intervals(
summary, out / "batch_posterior_parameter_intervals.png"
)
plot_batch_posterior_predictive(predictive, out / "batch_posterior_predictive.png")
plot_batch_mcmc_diagnostics(diagnostics, out / "batch_mcmc_diagnostics.png")
[docs]
def write_population_corner_outputs(
fits: pd.DataFrame,
free_parameters: list[str],
out_dir: str | Path,
config: dict[str, Any] | None = None,
) -> None:
"""Write population-level MAP point-estimate distributions."""
out = ensure_dir(out_dir)
params = _fit_parameter_frame(fits, free_parameters)
if params.empty:
return
paired_params, paired_truth = paired_fit_truth_frames(fits, config=config)
params.to_csv(out / "population_map_parameters.csv", index=False)
params.describe(percentiles=[0.05, 0.16, 0.5, 0.84, 0.95]).transpose().to_csv(
out / "population_map_parameter_summary.csv"
)
plot_corner(params, out / "population_corner_parameters.png")
plot_population_parameter_histograms(
params, out / "population_parameter_distributions.png"
)
if not paired_truth.empty:
paired_truth.to_csv(out / "population_truth_parameters.csv", index=False)
paired_truth.describe(
percentiles=[0.05, 0.16, 0.5, 0.84, 0.95]
).transpose().to_csv(out / "population_truth_parameter_summary.csv")
metrics = parameter_truth_metrics(fits, config=config)
if not metrics.empty:
metrics.to_csv(out / "population_parameter_truth_metrics.csv", index=False)
plot_corner_overlay(
paired_params,
paired_truth,
out / "population_corner_parameters_with_truth.png",
)
plot_population_parameter_histograms(
paired_params,
out / "population_parameter_distributions_with_truth.png",
truth=paired_truth,
)
[docs]
def plot_population_parameter_histograms(
params: pd.DataFrame, path: str | Path, truth: pd.DataFrame | None = None
) -> None:
if params.empty:
return
columns = list(params.columns)
fig, axes = plt.subplots(len(columns), 1, figsize=(8, max(2.2 * len(columns), 3)))
axes = np.atleast_1d(axes)
for ax, column in zip(axes, columns, strict=True):
values = params[column].replace([np.inf, -np.inf], np.nan).dropna()
truth_values = (
truth[column].replace([np.inf, -np.inf], np.nan).dropna()
if truth is not None and column in truth
else pd.Series(dtype=float)
)
bins = 50
if not values.empty and not truth_values.empty:
combined = pd.concat([values, truth_values])
lo, hi = float(combined.min()), float(combined.max())
if lo < hi:
bins = np.linspace(lo, hi, 51)
if not values.empty:
ax.hist(
values,
bins=bins,
histtype="stepfilled",
alpha=0.55,
color="#8fbcd4",
label="inferred",
)
ax.axvline(values.median(), color="black", lw=1, label="median")
if not truth_values.empty:
ax.hist(
truth_values,
bins=bins,
histtype="step",
lw=1.7,
color="#e6a0a8",
label=_truth_legend_label(column),
)
ax.axvline(truth_values.median(), color="#9f5360", lw=1, alpha=0.8)
ax.set_xlabel(_parameter_display_label(column))
ax.set_ylabel("galaxies")
ax.grid(alpha=0.2)
ax.legend(fontsize=8)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def parameter_truth_metrics(
frame: pd.DataFrame, config: dict[str, Any] | None = None
) -> pd.DataFrame:
"""Summarize paired inferred-vs-truth parameter errors."""
rows = []
pairs: list[tuple[str, str, str]] = []
if {"z_obs", "redshift_truth"}.issubset(frame.columns):
pairs.append(("z_obs", "z_obs", "redshift_truth"))
for col in frame.columns:
if col.startswith("truth_"):
parameter = col[6:]
fit_col = f"fit_{parameter}"
if fit_col in frame.columns and is_comparable_fit_parameter(
config, parameter
):
pairs.append((parameter, fit_col, col))
for parameter, fit_col, truth_col in pairs:
work = frame[[fit_col, truth_col]].replace([np.inf, -np.inf], np.nan).dropna()
if work.empty:
continue
delta = work[fit_col] - work[truth_col]
corr = (
float(work[fit_col].corr(work[truth_col]))
if len(work) > 1
and work[fit_col].nunique() > 1
and work[truth_col].nunique() > 1
else float("nan")
)
rows.append(
{
"parameter": parameter,
"comparison_kind": _truth_kind_from_frame(frame, parameter),
"fit_column": fit_col,
"truth_column": truth_col,
"n": int(len(work)),
"bias_mean": float(delta.mean()),
"bias_median": float(delta.median()),
"mae": float(delta.abs().mean()),
"median_abs_error": float(delta.abs().median()),
"rmse": float(np.sqrt(np.mean(delta**2))),
"std_delta": float(delta.std()),
"correlation": corr,
"truth_median": float(work[truth_col].median()),
"inferred_median": float(work[fit_col].median()),
}
)
return pd.DataFrame(rows)
[docs]
def write_trace_truth_outputs(
trace: pd.DataFrame, out_dir: str | Path, label: str, make_plots: bool = True
) -> None:
if trace.empty or "truth_mse" not in trace:
return
metric_values = trace["truth_mse"].replace([np.inf, -np.inf], np.nan).dropna()
if metric_values.empty:
return
out = ensure_dir(out_dir)
trace_truth_summary(trace).to_csv(
out / f"{label}_trace_truth_summary.csv", index=False
)
if make_plots:
plot_trace_truth_metrics(trace, out / f"{label}_trace_truth.png")
[docs]
def trace_truth_summary(trace: pd.DataFrame) -> pd.DataFrame:
metric_columns = [
col
for col in trace.columns
if col in {"truth_mse", "truth_rmse", "truth_mae"}
or col.startswith("truth_mse_")
]
if not metric_columns:
return pd.DataFrame()
group_names = ["chunk_index"] if "chunk_index" in trace.columns else []
grouped = (
trace.groupby(group_names, dropna=False) if group_names else [(None, trace)]
)
rows = []
for key, group in grouped:
row: dict[str, float | int] = {}
if group_names:
if isinstance(key, tuple):
key = key[0]
row["chunk_index"] = int(key)
group = group.sort_values("iteration") if "iteration" in group else group
for col in metric_columns:
values = group[col].replace([np.inf, -np.inf], np.nan).dropna()
if values.empty:
continue
row[f"final_{col}"] = float(values.iloc[-1])
row[f"min_{col}"] = float(values.min())
row[f"final_minus_min_{col}"] = float(values.iloc[-1] - values.min())
if row:
rows.append(row)
return pd.DataFrame(rows)
[docs]
def plot_trace_truth_metrics(trace: pd.DataFrame, path: str | Path) -> None:
if trace.empty or "truth_rmse" not in trace:
return
truth = trace["truth_rmse"].replace([np.inf, -np.inf], np.nan).dropna()
if truth.empty:
return
loss_col = (
"mean_chi2_or_loss"
if "mean_chi2_or_loss" in trace
else "chi2" if "chi2" in trace else None
)
param_cols = [col for col in trace.columns if col.startswith("truth_mse_")]
ncols = 3 if param_cols else 2
fig, axes = plt.subplots(1, ncols, figsize=(5 * ncols, 4))
axes = np.atleast_1d(axes)
if loss_col:
_plot_trace_metric(trace, loss_col, axes[0], label=loss_col, logy=True)
axes[0].set_ylabel(loss_col)
else:
axes[0].axis("off")
_plot_trace_metric(
trace,
"truth_rmse",
axes[1],
label="configured truth/proxy RMSE",
logy=True,
)
axes[1].set_ylabel("diagnostic RMSE")
axes[1].set_title("not optimized")
if param_cols:
grouped = _trace_group_mean(trace, param_cols)
x = grouped["iteration"] if "iteration" in grouped else np.arange(len(grouped))
for col in param_cols:
values = np.sqrt(
grouped[col].replace([np.inf, -np.inf], np.nan).to_numpy(dtype=float)
)
if np.isfinite(values).any():
parameter = col.removeprefix("truth_mse_")
axes[2].plot(
x,
values,
lw=1.4,
label=_parameter_display_label(parameter),
)
axes[2].set_xlabel("iteration")
axes[2].set_ylabel("per-parameter diagnostic RMSE")
axes[2].set_title("photometry objective can diverge from truth")
axes[2].set_yscale("log")
axes[2].grid(alpha=0.25)
axes[2].legend(fontsize=8)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _plot_trace_metric(
trace: pd.DataFrame, column: str, ax: plt.Axes, label: str, logy: bool = False
) -> None:
if "chunk_index" in trace:
for _, group in trace.groupby("chunk_index"):
group = group.sort_values("iteration")
ax.plot(
group["iteration"],
group[column],
color="#8fbcd4",
alpha=0.18,
lw=0.8,
)
grouped = _trace_group_mean(trace, [column])
x = grouped["iteration"] if "iteration" in grouped else np.arange(len(grouped))
ax.plot(x, grouped[column], color="black", lw=1.6, label=label)
ax.set_xlabel("iteration")
if logy and (grouped[column].dropna() > 0).any():
ax.set_yscale("log")
ax.grid(alpha=0.25)
ax.legend(fontsize=8)
def _trace_group_mean(trace: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
available = [col for col in columns if col in trace]
if "iteration" not in trace:
return trace[available].reset_index(drop=True)
return (
trace[["iteration", *available]]
.replace([np.inf, -np.inf], np.nan)
.groupby("iteration", as_index=False)
.mean(numeric_only=True)
)
[docs]
def write_batch_outputs(
comparison: pd.DataFrame,
out_dir: str | Path,
label: str = "batch",
reporting_level: str = "full",
config: dict[str, Any] | None = None,
) -> None:
"""Write aggregate tables and plots for multi-galaxy runs."""
out = ensure_dir(out_dir)
error_rows = (
comparison[comparison["error"].notna()]
if "error" in comparison
else pd.DataFrame()
)
valid = (
comparison[comparison["band"].notna()].copy()
if "band" in comparison
else pd.DataFrame()
)
if not error_rows.empty:
error_rows.to_csv(out / f"{label}_errors.csv", index=False)
if valid.empty:
write_json(
out / f"{label}_summary.json",
{"n_valid_rows": 0, "n_errors": int(len(error_rows))},
)
return
by_band = summarize_by_band(valid)
by_row = summarize_by_row(valid)
by_band.to_csv(out / f"{label}_summary_by_band.csv")
by_row.to_csv(out / f"{label}_summary_by_galaxy.csv")
summary = {
"n_valid_comparisons": int(len(valid)),
"n_valid_galaxies": int(by_row.shape[0]),
"n_error_rows": int(len(error_rows)),
"median_chi2": float(by_row["chi2"].median()),
"median_reduced_chi2": float(by_row["reduced_chi2"].median()),
"median_abs_residual_mag": float(
valid["residual_mag_model_minus_observed"].abs().median()
),
}
if "delta_z_obs_minus_truth" in by_row:
dz = by_row["delta_z_obs_minus_truth"].dropna()
if not dz.empty:
summary["median_delta_z_obs_minus_truth"] = float(dz.median())
summary["mad_delta_z_obs_minus_truth"] = float(
(dz - dz.median()).abs().median()
)
truth_metrics = parameter_truth_metrics(by_row.reset_index(), config=config)
if not truth_metrics.empty:
truth_metrics.to_csv(out / f"{label}_truth_metrics.csv", index=False)
residual_properties = residuals_by_property(by_row.reset_index())
if not residual_properties.empty:
residual_properties.to_csv(
out / f"{label}_residuals_by_property.csv", index=False
)
attractors = redshift_attractor_summary(by_row.reset_index())
if not attractors.empty:
attractors.to_csv(out / f"{label}_redshift_attractors.csv", index=False)
write_json(out / f"{label}_summary.json", summary)
plot_batch_dashboard(valid, by_row, out / f"{label}_dashboard.png")
plot_batch_residuals_by_band(valid, out / f"{label}_residuals_by_band.png")
plot_batch_observed_vs_model(valid, out / f"{label}_observed_vs_model.png")
plot_batch_redshift_truth(by_row, out / f"{label}_redshift_truth.png")
plot_redshift_attractors(
by_row.reset_index(), attractors, out / f"{label}_redshift_attractors.png"
)
plot_batch_parameter_truth(by_row, out / f"{label}_parameter_truth.png", config)
plot_residuals_by_property(
by_row.reset_index(), out / f"{label}_residuals_by_property.png"
)
plot_color_redshift_diagnostics(valid, by_row, out / f"{label}_color_redshift.png")
plot_physical_population_diagnostics(
by_row.reset_index(), out / f"{label}_physical_population.png"
)
plot_population_bias_heatmap(
by_row.reset_index(), out / f"{label}_bias_heatmap.png"
)
[docs]
def write_fit_diagnostic_outputs(
fits: pd.DataFrame,
comparison: pd.DataFrame,
config: dict[str, Any],
out_dir: str | Path,
label: str = "batch_fit",
hyperparameters: pd.DataFrame | None = None,
) -> None:
"""Write fit audit tables that protect scientific interpretation."""
out = ensure_dir(out_dir)
audit = fit_parameter_audit(fits, config)
if not audit.empty:
audit.to_csv(out / f"{label}_parameter_audit.csv", index=False)
components = fit_objective_components(fits, comparison, config, hyperparameters)
if not components.empty:
components.to_csv(out / f"{label}_objective_components.csv", index=False)
[docs]
def fit_parameter_audit(fits: pd.DataFrame, config: dict[str, Any]) -> pd.DataFrame:
"""Summarize whether reported fit columns were truly inferred."""
if fits.empty:
return pd.DataFrame()
free = config.get("fit", {}).get("free_parameters", {}) or {}
fixed = config.get("model", {}).get("fixed_parameters", {}) or {}
injected = config.get("model", {}).get("parameter_columns", {}) or {}
active = set(active_parameters(config))
derived = {
"t_obs_gyr",
"formed_mass_msun",
"log10_formed_mass_msun",
"sfr_at_obs_msun_per_yr",
"log10_sfr_at_obs",
}
rows = []
for column in sorted(c for c in fits.columns if c.startswith("fit_")):
name = column[4:]
values = pd.to_numeric(fits[column], errors="coerce").replace(
[np.inf, -np.inf], np.nan
)
finite = values.dropna()
source = _fit_parameter_source(name, free, fixed, injected, derived, active)
warning_flags = _fit_parameter_warning_flags(name, source, finite, free)
row: dict[str, Any] = {
"parameter": name,
"fit_column": column,
"source": source,
"is_free": bool(name in free),
"is_fixed": bool(name in fixed and name not in free),
"is_row_injected": bool(name in injected and name not in free),
"is_derived": bool(name in derived),
"active_in_forward_model": bool(name in active),
"n": int(len(finite)),
"n_unique": int(finite.nunique()) if not finite.empty else 0,
"warning_flags": ",".join(warning_flags),
}
if not finite.empty:
row.update(
{
"min": float(finite.min()),
"p16": float(finite.quantile(0.16)),
"median": float(finite.median()),
"p84": float(finite.quantile(0.84)),
"max": float(finite.max()),
"std": float(finite.std()) if len(finite) > 1 else 0.0,
}
)
if name in free:
spec = free[name] or {}
row["initial"] = spec.get("initial")
bounds = spec.get("bounds")
if isinstance(bounds, (list, tuple)) and len(bounds) == 2:
low = float(bounds[0])
high = float(bounds[1])
span = high - low
row["lower_bound"] = low
row["upper_bound"] = high
if not finite.empty and span > 0:
row["fraction_near_lower_1pct"] = float(
((finite - low).abs() <= 0.01 * span).mean()
)
row["fraction_near_upper_1pct"] = float(
((high - finite).abs() <= 0.01 * span).mean()
)
if name in injected:
row["source_column"] = injected[name]
rows.append(row)
return pd.DataFrame(rows)
[docs]
def fit_objective_components(
fits: pd.DataFrame,
comparison: pd.DataFrame,
config: dict[str, Any],
hyperparameters: pd.DataFrame | None = None,
) -> pd.DataFrame:
"""Post-hoc objective decomposition from saved MAP rows."""
if fits.empty:
return pd.DataFrame()
rows = fits.copy()
out = pd.DataFrame({"row_index": rows["row_index"].astype(int)})
if "chunk_index" in rows:
out["chunk_index"] = rows["chunk_index"]
out["photometric_chi2"] = _photometric_chi2_by_row(rows, comparison)
physical = _physical_prior_components(rows, config)
for name, values in physical.items():
out[name] = values
if hyperparameters is not None and not hyperparameters.empty:
population = _population_prior_components(rows, hyperparameters)
for name, values in population.items():
out[name] = values
for column in [
"physical_gaussian_prior_penalty",
"physical_beta_prior_penalty",
"population_gaussian_prior_penalty",
"population_relation_prior_penalty",
]:
if column not in out:
out[column] = 0.0
out["physical_prior_penalty"] = (
out["physical_gaussian_prior_penalty"]
+ out["physical_beta_prior_penalty"]
)
out["population_prior_penalty"] = (
out["population_gaussian_prior_penalty"]
+ out["population_relation_prior_penalty"]
)
out["approx_objective"] = (
0.5 * out["photometric_chi2"]
+ out["physical_prior_penalty"]
+ out["population_prior_penalty"]
)
return out
def _fit_parameter_source(
name: str,
free: dict[str, Any],
fixed: dict[str, Any],
injected: dict[str, Any],
derived: set[str],
active: set[str],
) -> str:
if name in free:
return "free"
if name in derived:
return "derived"
if name in injected:
return "row_injected" if name in active else "inactive_row_injected"
if name in fixed:
return "fixed" if name in active else "inactive_fixed"
if name.endswith("_prior_sigma"):
return "prior_context"
return "reported_not_configured"
def _fit_parameter_warning_flags(
name: str, source: str, finite: pd.Series, free: dict[str, Any]
) -> list[str]:
flags: list[str] = []
if source != "free" and source not in {"derived"}:
flags.append("not_inferred_column")
if source == "free" and len(finite) > 1 and finite.nunique() <= 1:
flags.append("constant_free_parameter")
spec = free.get(name) or {}
bounds = spec.get("bounds")
if isinstance(bounds, (list, tuple)) and len(bounds) == 2 and not finite.empty:
low = float(bounds[0])
high = float(bounds[1])
span = high - low
if span > 0:
if ((finite - low).abs() <= 0.01 * span).mean() > 0.1:
flags.append("near_lower_bound_population")
if ((high - finite).abs() <= 0.01 * span).mean() > 0.1:
flags.append("near_upper_bound_population")
return flags
def _photometric_chi2_by_row(fits: pd.DataFrame, comparison: pd.DataFrame) -> pd.Series:
chi_column = "chi_flux" if "chi_flux" in comparison.columns else "chi"
if not comparison.empty and {"row_index", chi_column}.issubset(comparison.columns):
grouped = comparison.assign(
_chi2=pd.to_numeric(comparison[chi_column], errors="coerce") ** 2
)
chi2 = grouped.groupby("row_index")["_chi2"].sum()
return fits["row_index"].map(chi2).fillna(fits.get("chi2", 0.0)).astype(float)
if "chi2" in fits:
return pd.to_numeric(fits["chi2"], errors="coerce").fillna(0.0)
return pd.Series(np.zeros(len(fits)), index=fits.index)
def _physical_prior_components(
fits: pd.DataFrame, config: dict[str, Any]
) -> dict[str, pd.Series]:
priors = config.get("fit", {}).get("priors", {}) or {}
free = config.get("fit", {}).get("free_parameters", {}) or {}
gaussian = pd.Series(np.zeros(len(fits)), index=fits.index, dtype=float)
beta = pd.Series(np.zeros(len(fits)), index=fits.index, dtype=float)
for name, spec in priors.items():
if name not in free:
continue
values = _fit_values(fits, name)
if values is None:
continue
prior_type = str((spec or {}).get("type", "normal"))
if prior_type in {"normal", "truncated_normal"}:
loc = _prior_loc(fits, name, spec)
scale = _prior_scale(fits, name, spec)
gaussian += 0.5 * ((values - loc) / scale) ** 2 + np.log(scale)
elif prior_type == "scaled_beta":
bounds = free.get(name, {}).get("bounds", [0.0, 1.0])
low, high = float(bounds[0]), float(bounds[1])
scaled = ((values - low) / max(high - low, 1.0e-12)).clip(
1.0e-6, 1 - 1.0e-6
)
alpha = max(float((spec or {}).get("alpha", 1.0)), 1.0e-6)
beta_param = max(float((spec or {}).get("beta", 1.0)), 1.0e-6)
beta += -(
(alpha - 1.0) * np.log(scaled) + (beta_param - 1.0) * np.log1p(-scaled)
)
return {
"physical_gaussian_prior_penalty": gaussian,
"physical_beta_prior_penalty": beta,
}
def _population_prior_components(
fits: pd.DataFrame, hyperparameters: pd.DataFrame
) -> dict[str, pd.Series]:
gaussian = pd.Series(np.zeros(len(fits)), index=fits.index, dtype=float)
relation = pd.Series(np.zeros(len(fits)), index=fits.index, dtype=float)
relation_targets = set(
hyperparameters.loc[
hyperparameters.get("kind", pd.Series(dtype=str)).eq("relation"),
"target_parameter",
]
.dropna()
.astype(str)
)
chunk_values = (
fits["chunk_index"] if "chunk_index" in fits else pd.Series(0, index=fits.index)
)
for _, row in hyperparameters.iterrows():
chunk_mask = chunk_values.eq(row.get("chunk_index", 0))
if not chunk_mask.any():
continue
kind = row.get("kind")
if kind == "gaussian":
parameter = str(row.get("parameter"))
if parameter in relation_targets:
continue
values = _fit_values(fits.loc[chunk_mask], parameter)
if values is None:
continue
mu = float(row.get("population_mu", 0.0))
sigma = max(float(row.get("population_sigma", 1.0)), 1.0e-6)
gaussian.loc[chunk_mask] += 0.5 * ((values - mu) / sigma) ** 2 + np.log(
sigma
)
elif kind == "relation":
target = str(row.get("target_parameter"))
predictor = str(row.get("predictor_parameter"))
target_values = _fit_values(fits.loc[chunk_mask], target)
predictor_values = _fit_values(fits.loc[chunk_mask], predictor)
if target_values is None or predictor_values is None:
continue
pivot = float(row.get("population_pivot", 0.0))
intercept = float(row.get("population_intercept", 0.0))
slope = float(row.get("population_slope", 0.0))
sigma = max(float(row.get("population_sigma", 1.0)), 1.0e-6)
loc = intercept + slope * (predictor_values - pivot)
relation.loc[chunk_mask] += 0.5 * (
(target_values - loc) / sigma
) ** 2 + np.log(sigma)
return {
"population_gaussian_prior_penalty": gaussian,
"population_relation_prior_penalty": relation,
}
def _fit_values(fits: pd.DataFrame, parameter: str) -> pd.Series | None:
for column in (f"fit_{parameter}", f"param_{parameter}", parameter):
if column in fits:
return pd.to_numeric(fits[column], errors="coerce").replace(
[np.inf, -np.inf], np.nan
)
return None
def _prior_loc(fits: pd.DataFrame, name: str, spec: dict[str, Any]) -> pd.Series:
loc = spec.get("loc", 0.0)
if loc == "from_base":
values = _fit_values(fits, name)
return (
values
if values is not None
else pd.Series(np.zeros(len(fits)), index=fits.index)
)
return pd.Series(float(loc), index=fits.index)
def _prior_scale(fits: pd.DataFrame, name: str, spec: dict[str, Any]) -> pd.Series:
scale = spec.get("scale", 1.0)
if scale == "from_base":
scale_name = str(spec.get("scale_parameter", f"{name}_prior_sigma"))
values = _fit_values(fits, scale_name)
if values is None:
return pd.Series(np.ones(len(fits)), index=fits.index)
return values.fillna(1.0).clip(lower=1.0e-6)
return pd.Series(max(float(scale), 1.0e-6), index=fits.index)
[docs]
def plot_population_bias_heatmap(by_row: pd.DataFrame, path: str | Path) -> None:
"""Plot a heatmap of reduced chi2 in the Redshift-Mass plane."""
z_col = "redshift_truth" if "redshift_truth" in by_row else "z_obs"
m_col = (
"catalog_log_stellar_mass"
if "catalog_log_stellar_mass" in by_row
else "fit_log10_formed_mass_msun"
)
if z_col not in by_row or m_col not in by_row or "reduced_chi2" not in by_row:
return
work = (
by_row[[z_col, m_col, "reduced_chi2"]]
.replace([np.inf, -np.inf], np.nan)
.dropna()
)
if len(work) < 10:
return
try:
fig, ax = plt.subplots(figsize=(8, 6))
# Create 2D bins
z_bins = np.linspace(work[z_col].min(), work[z_col].max(), 12)
m_bins = np.linspace(work[m_col].min(), work[m_col].max(), 12)
stats = (
pd.DataFrame(
{
"z_bin": pd.cut(work[z_col], bins=z_bins),
"m_bin": pd.cut(work[m_col], bins=m_bins),
"chi2": work["reduced_chi2"],
}
)
.groupby(["z_bin", "m_bin"], observed=True)["chi2"]
.median()
.unstack()
)
im = ax.pcolormesh(
m_bins,
z_bins,
stats.to_numpy(),
cmap="viridis",
shading="flat",
norm=matplotlib.colors.LogNorm(vmin=0.1, vmax=10.0),
)
fig.colorbar(im, ax=ax, label="median reduced chi2")
ax.set_xlabel(_parameter_display_label(m_col))
ax.set_ylabel(_parameter_display_label(z_col))
ax.set_title("Population Fit Quality Map")
except Exception:
if "fig" in locals():
plt.close(fig)
return
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_color_redshift_diagnostics(
comparison: pd.DataFrame, by_row: pd.DataFrame, path: str | Path
) -> None:
"""Plot broad-band color-redshift diagnostics, POP-COSMOS style."""
z_col = "redshift_truth" if "redshift_truth" in by_row else "z_obs"
if comparison.empty or z_col not in by_row or "observed_mag_ab" not in comparison:
return
mags = comparison.pivot_table(
index="row_index", columns="band", values="observed_mag_ab", aggfunc="median"
)
z = by_row[z_col].reindex(mags.index)
color_pairs = [
("lsst_u", "lsst_g"),
("lsst_g", "lsst_r"),
("lsst_r", "lsst_i"),
("lsst_i", "lsst_z"),
("lsst_z", "lsst_y"),
("euclid_nisp_y", "euclid_nisp_j"),
("euclid_nisp_j", "euclid_nisp_h"),
]
available = [(a, b) for a, b in color_pairs if a in mags and b in mags]
if not available:
return
n_cols = min(4, len(available))
n_rows = int(np.ceil(len(available) / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(3.0 * n_cols, 2.2 * n_rows))
axes = np.atleast_1d(axes).ravel()
visible = 0
for ax, (blue, red) in zip(axes, available, strict=False):
color = (mags[blue] - mags[red]).replace([np.inf, -np.inf], np.nan)
work = pd.DataFrame({"z": z, "color": color}).dropna()
if len(work) < 5:
ax.set_visible(False)
continue
visible += 1
if len(work) > 250:
ax.hexbin(work["z"], work["color"], gridsize=36, mincnt=1, cmap="viridis")
else:
ax.scatter(work["z"], work["color"], s=8, alpha=0.5)
ax.set_xlabel(_parameter_display_label(z_col))
ax.set_ylabel(f"{blue}-{red} [mag]")
for ax in axes[len(available) :]:
ax.set_visible(False)
if visible == 0:
plt.close(fig)
return
fig.suptitle("Color-redshift diagnostics", y=0.995)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_physical_population_diagnostics(by_row: pd.DataFrame, path: str | Path) -> None:
"""Plot fitted/proxy physical relations in redshift bins."""
if by_row.empty:
return
z_col = "redshift_truth" if "redshift_truth" in by_row else "z_obs"
mass_col = _first_present(
by_row,
["fit_log10_formed_mass_msun", "catalog_log_stellar_mass"],
)
sfr_col = _first_present(by_row, ["fit_log10_sfr_at_obs", "catalog_log_sfr"])
metallicity_col = _first_present(
by_row, ["fit_log10_metallicity", "catalog_log10_metallicity_proxy"]
)
dust_col = _first_present(
by_row, ["fit_dust_av", "catalog_dust_av_proxy", "catalog_dust_ebv_proxy"]
)
if z_col not in by_row or mass_col is None:
return
rows = [
("log SFR", sfr_col),
("log sSFR", "__ssfr__" if sfr_col else None),
("log metallicity", metallicity_col),
("dust proxy/Av", dust_col),
]
rows = [(label, col) for label, col in rows if col is not None]
if not rows:
return
work = by_row.copy()
if "__ssfr__" in [col for _, col in rows]:
work["__ssfr__"] = pd.to_numeric(work[sfr_col], errors="coerce") - pd.to_numeric(
work[mass_col], errors="coerce"
)
z_values = pd.to_numeric(work[z_col], errors="coerce").replace(
[np.inf, -np.inf], np.nan
)
finite_z = z_values.dropna()
if len(finite_z) < 8:
return
quantiles = np.linspace(0.0, 1.0, min(5, len(finite_z)) + 1)
edges = np.unique(np.quantile(finite_z, quantiles))
if len(edges) < 3:
return
n_cols = len(edges) - 1
fig, axes = plt.subplots(
len(rows),
n_cols,
figsize=(2.7 * n_cols, 2.25 * len(rows)),
squeeze=False,
)
visible = 0
for row_index, (ylabel, y_col) in enumerate(rows):
for col_index in range(n_cols):
lo, hi = edges[col_index], edges[col_index + 1]
mask = (z_values >= lo) & (z_values <= hi if col_index == n_cols - 1 else z_values < hi)
subset = work.loc[mask, [mass_col, y_col]].replace([np.inf, -np.inf], np.nan).dropna()
ax = axes[row_index, col_index]
if len(subset) < 3:
ax.set_visible(False)
continue
visible += 1
ax.scatter(subset[mass_col], subset[y_col], s=7, alpha=0.35)
if row_index == 0:
ax.set_title(f"{lo:.2f} <= z < {hi:.2f}", fontsize=8)
if row_index == len(rows) - 1:
ax.set_xlabel(_parameter_display_label(mass_col))
if col_index == 0:
ax.set_ylabel(ylabel)
if visible == 0:
plt.close(fig)
return
fig.suptitle("Physical population diagnostics", y=0.995)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _first_present(frame: pd.DataFrame, columns: list[str]) -> str | None:
for column in columns:
if column in frame:
return column
return None
[docs]
def summarize_by_band(valid: pd.DataFrame) -> pd.DataFrame:
return valid.groupby("band").agg(
n=("row_index", "count"),
effective_wavelength_angstrom=("effective_wavelength_angstrom", "median"),
mean_residual_mag=("residual_mag_model_minus_observed", "mean"),
median_residual_mag=("residual_mag_model_minus_observed", "median"),
std_residual_mag=("residual_mag_model_minus_observed", "std"),
rms_residual_mag=(
"residual_mag_model_minus_observed",
lambda x: float(np.sqrt(np.nanmean(x**2))),
),
mean_abs_residual_mag=(
"residual_mag_model_minus_observed",
lambda x: float(np.nanmean(np.abs(x))),
),
median_flux_ratio=("flux_ratio_model_over_observed", "median"),
mean_chi=("chi", "mean"),
)
[docs]
def summarize_by_row(valid: pd.DataFrame) -> pd.DataFrame:
context_columns = [
col
for col in valid.columns
if col in {"z_obs", "redshift_truth", "delta_z_obs_minus_truth"}
or col in {"z_obs_source", "redshift_truth_source"}
or col in {"n_valid_bands", "n_free_effective", "dof"}
or col in {"redshift_initial_mode", "redshift_prior_mode"}
or col in {"dust_parameter_active", "dust_parameter_inferred", "dust_model"}
or col.startswith("param_")
or col.startswith("fit_")
or col.startswith("truth_")
or col.startswith("delta_")
or col.startswith("catalog_")
]
chi_column = "chi_flux" if "chi_flux" in valid.columns else "chi"
aggregations: dict[str, tuple[str, Any]] = {
"n_bands": ("band", "count"),
"chi2": (chi_column, lambda x: float(np.nansum(x**2))),
"mean_residual_mag": ("residual_mag_model_minus_observed", "mean"),
"median_residual_mag": ("residual_mag_model_minus_observed", "median"),
"rms_residual_mag": (
"residual_mag_model_minus_observed",
lambda x: float(np.sqrt(np.nanmean(x**2))),
),
"mean_abs_residual_mag": (
"residual_mag_model_minus_observed",
lambda x: float(np.nanmean(np.abs(x))),
),
}
for col in context_columns:
aggregations[col] = (col, "first")
by_row = valid.groupby("row_index").agg(**aggregations)
if "n_valid_bands" not in by_row:
by_row["n_valid_bands"] = by_row["n_bands"]
if "n_free_effective" not in by_row:
by_row["n_free_effective"] = 0
if "dof" not in by_row:
by_row["dof"] = (by_row["n_valid_bands"] - by_row["n_free_effective"]).clip(
lower=1
)
derived_columns = {
"chi2_per_band": by_row["chi2"] / by_row["n_valid_bands"].clip(lower=1),
"reduced_chi2": by_row["chi2"] / by_row["dof"].clip(lower=1),
"reduced_chi2_dof": by_row["chi2"] / by_row["dof"].clip(lower=1),
}
for col in list(by_row.columns):
if col.startswith("truth_"):
param_name = col[6:]
fit_col = f"fit_{param_name}"
if fit_col in by_row.columns:
derived_columns[f"delta_fit_{param_name}_minus_truth"] = (
by_row[fit_col] - by_row[col]
)
by_row = pd.concat([by_row, pd.DataFrame(derived_columns)], axis=1).copy()
return by_row
[docs]
def residuals_by_property(by_row: pd.DataFrame) -> pd.DataFrame:
"""Summarize row residuals against catalog/fit properties."""
rows: list[dict[str, Any]] = []
property_specs = [
("redshift_truth", "redshift"),
("z_obs", "redshift"),
("z_true_gal", "truth redshift gal"),
("catalog_log_stellar_mass", "catalog log stellar mass"),
("fit_log10_formed_mass_msun", "fit log formed mass"),
("fit_log10_sfr_at_obs", "fit log sfr at obs"),
("catalog_color_kind", "color_kind"),
]
for column, label in property_specs:
if column not in by_row:
continue
values = by_row[column]
if label == "color_kind":
groups = values.astype("string")
else:
numeric = pd.to_numeric(values, errors="coerce")
finite = numeric.replace([np.inf, -np.inf], np.nan).dropna()
if finite.nunique() < 2:
continue
try:
# Use a mix of fixed and quantile bins for better scientific grouping
if label == "redshift":
groups = pd.cut(numeric, bins=[0, 0.5, 1.0, 1.5, 2.5, 4.0, 6.0])
else:
groups = pd.qcut(
numeric, q=min(6, finite.nunique()), duplicates="drop"
)
except ValueError:
continue
work = by_row.assign(_group=groups)
for group, subset in work.groupby("_group", dropna=True, observed=False):
row = _residual_property_row(subset, label, column, str(group))
if row:
rows.append(row)
return pd.DataFrame(rows)
[docs]
def redshift_attractor_summary(
by_row: pd.DataFrame,
bin_width: float = 0.05,
min_count: int = 5,
max_modes: int = 30,
) -> pd.DataFrame:
"""Summarize repeated fitted-redshift modes from MAP output."""
if by_row.empty or "z_obs" not in by_row:
return pd.DataFrame()
work = by_row.copy()
z_fit = pd.to_numeric(work["z_obs"], errors="coerce")
work = work.loc[np.isfinite(z_fit)].copy()
if work.empty:
return pd.DataFrame()
z_fit = z_fit.loc[work.index]
work["_z_fit_bin"] = (z_fit / bin_width).round() * bin_width
rows: list[dict[str, Any]] = []
total = len(work)
for z_bin, group in work.groupby("_z_fit_bin"):
if len(group) < min_count:
continue
group_z = pd.to_numeric(group["z_obs"], errors="coerce")
row: dict[str, Any] = {
"z_fit_bin": float(z_bin),
"n_galaxies": int(len(group)),
"fraction": float(len(group) / total),
"z_fit_median": float(group_z.median()),
"z_fit_min": float(group_z.min()),
"z_fit_max": float(group_z.max()),
}
if "redshift_truth" in group:
truth = pd.to_numeric(group["redshift_truth"], errors="coerce")
truth = truth[np.isfinite(truth)]
if not truth.empty:
row.update(
{
"z_truth_median": float(truth.median()),
"z_truth_min": float(truth.min()),
"z_truth_max": float(truth.max()),
}
)
if "delta_z_obs_minus_truth" in group:
dz = pd.to_numeric(group["delta_z_obs_minus_truth"], errors="coerce")
finite_dz = dz[np.isfinite(dz)]
if not finite_dz.empty:
row["delta_z_median"] = float(finite_dz.median())
row["delta_z_mad"] = float(
(finite_dz - finite_dz.median()).abs().median()
)
if "redshift_truth" in group:
truth = pd.to_numeric(group["redshift_truth"], errors="coerce")
good = np.isfinite(dz) & np.isfinite(truth)
if good.any():
catastrophic = np.abs(dz[good]) > 0.15 * (1.0 + truth[good])
row["catastrophic_fraction_0p15_1pz"] = float(
catastrophic.mean()
)
for metric in ("reduced_chi2", "chi2_per_band", "mean_residual_mag"):
if metric in group:
values = pd.to_numeric(group[metric], errors="coerce")
values = values[np.isfinite(values)]
if not values.empty:
row[f"{metric}_median"] = float(values.median())
rows.append(row)
if not rows:
return pd.DataFrame()
return (
pd.DataFrame(rows)
.sort_values(["n_galaxies", "z_fit_bin"], ascending=[False, True])
.head(max_modes)
.reset_index(drop=True)
)
def _residual_property_row(
subset: pd.DataFrame, property_name: str, source_column: str, group_label: str
) -> dict[str, Any]:
residual = subset["mean_residual_mag"].replace([np.inf, -np.inf], np.nan).dropna()
reduced = subset["reduced_chi2"].replace([np.inf, -np.inf], np.nan).dropna()
if residual.empty:
return {}
row: dict[str, Any] = {
"property": property_name,
"source_column": source_column,
"group": group_label,
"n_galaxies": int(len(residual)),
"mean_residual_mag": float(residual.mean()),
"median_residual_mag": float(residual.median()),
"mean_abs_residual_mag": float(residual.abs().mean()),
"median_abs_residual_mag": float(residual.abs().median()),
}
if not reduced.empty:
row["median_reduced_chi2"] = float(reduced.median())
return row
[docs]
def plot_flux_distributions(
df: pd.DataFrame, columns: list[str], path: str | Path
) -> None:
fig, ax = plt.subplots(figsize=(8, 5))
for col in columns:
values = df[col].to_numpy(dtype=float)
values = values[np.isfinite(values) & (values > 0)]
if values.size:
ax.hist(np.log10(values), bins=60, histtype="step", lw=1.4, label=col)
ax.set_xlabel("log10 flux [Fnu cgs]")
ax.set_ylabel("count")
ax.legend(fontsize=8, ncol=2)
ax.grid(alpha=0.2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_color_distributions(
df: pd.DataFrame, columns: list[str], path: str | Path
) -> None:
if len(columns) < 2:
return
fig, ax = plt.subplots(figsize=(8, 5))
for left, right in zip(columns[:-1], columns[1:], strict=True):
a = df[left].to_numpy(dtype=float)
b = df[right].to_numpy(dtype=float)
mask = np.isfinite(a) & np.isfinite(b) & (a > 0) & (b > 0)
if mask.any():
color = -2.5 * np.log10(a[mask] / b[mask])
ax.hist(color, bins=60, histtype="step", lw=1.2, label=f"{left}-{right}")
ax.set_xlabel("AB color [mag]")
ax.set_ylabel("count")
ax.legend(fontsize=8)
ax.grid(alpha=0.2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_redshift_distributions(
df: pd.DataFrame, redshift_config: dict[str, Any], path: str | Path
) -> None:
z_col = redshift_config.get("column")
truth_col = redshift_config.get("truth_column")
if not z_col or z_col not in df:
return
fig, axes = plt.subplots(1, 2 if truth_col in df else 1, figsize=(10, 4))
axes = np.atleast_1d(axes)
z = df[z_col].to_numpy(dtype=float)
z = z[np.isfinite(z)]
axes[0].hist(z, bins=70, histtype="stepfilled", alpha=0.65, label=z_col)
if truth_col in df:
zt = df[truth_col].to_numpy(dtype=float)
zt = zt[np.isfinite(zt)]
axes[0].hist(zt, bins=70, histtype="step", lw=1.4, label=truth_col)
axes[0].set_xlabel("redshift")
axes[0].set_ylabel("count")
axes[0].legend(fontsize=8)
axes[0].grid(alpha=0.2)
if truth_col in df:
work = df[[z_col, truth_col]].dropna()
delta = work[z_col] - work[truth_col]
axes[1].hist(delta, bins=80, histtype="stepfilled", alpha=0.7)
axes[1].axvline(0, color="black", lw=1)
axes[1].set_xlabel(f"{z_col} - {truth_col}")
axes[1].set_ylabel("count")
axes[1].grid(alpha=0.2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_physical_parameters_distributions(df: pd.DataFrame, path: str | Path) -> None:
params_map = {
"log10_metallicity_true": "Catalog proxy log10(Z)",
"sfr_true": "Catalog truth SFR [Msun/yr]",
"log_sfr_true": "Catalog truth log10(SFR)",
"dust_ebv_true": "Catalog proxy E(B-V)",
}
valid_params = [p for p in params_map if p in df.columns]
if not valid_params:
return
fig, axes = plt.subplots(1, len(valid_params), figsize=(4 * len(valid_params), 4))
axes = np.atleast_1d(axes)
for ax, param in zip(axes, valid_params, strict=True):
val = df[param].to_numpy(dtype=float)
val = val[np.isfinite(val)]
ax.hist(val, bins=70, histtype="stepfilled", alpha=0.7)
ax.set_xlabel(params_map[param])
ax.set_ylabel("galaxies")
ax.grid(alpha=0.2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_sed(result: ModelResult, path: str | Path) -> None:
fig, ax = plt.subplots(figsize=(8, 5))
mask = (
np.isfinite(result.wave)
& np.isfinite(result.dusted_rest_sed)
& (result.wave >= 800)
& (result.wave <= 30_000)
& (result.rest_sed > 0)
& (result.dusted_rest_sed > 0)
)
ax.plot(
result.wave[mask],
result.rest_sed[mask],
label="Intrinsic rest SED",
lw=1.1,
alpha=0.65,
)
ax.plot(
result.wave[mask],
result.dusted_rest_sed[mask],
label="Dust-attenuated rest SED",
lw=1.4,
)
z_obs = result.parameters.get("z_obs", np.nan)
ymin, ymax = _positive_axis_limits(result.dusted_rest_sed[mask])
for band, values in result.photometry.items():
wave_filter_obs = np.asarray(
values.get("filter_wave_angstrom", []), dtype=float
)
transmission = np.asarray(values.get("filter_transmission", []), dtype=float)
passband_mask = (
np.isfinite(wave_filter_obs)
& np.isfinite(transmission)
& (transmission > 0)
)
if passband_mask.any() and np.isfinite(z_obs):
wave_filter_rest = wave_filter_obs[passband_mask] / (1.0 + z_obs)
scaled = ymin * (ymax / ymin) ** (
0.06
+ 0.16
* transmission[passband_mask]
/ np.nanmax(transmission[passband_mask])
)
ax.fill_between(
wave_filter_rest,
ymin,
scaled,
alpha=0.16,
lw=0,
label=f"{band.replace('euclid_', '')} passband",
)
wave_rest = values["effective_wavelength_angstrom"] / (1.0 + z_obs)
if np.isfinite(wave_rest) and 800 <= wave_rest <= 30_000:
ax.axvline(wave_rest, color="black", lw=0.7, alpha=0.18)
ax.text(
wave_rest,
0.98,
band.replace("euclid_", ""),
rotation=90,
va="top",
ha="right",
transform=ax.get_xaxis_transform(),
fontsize=7,
)
ax.set_xlabel("rest-frame wavelength [Angstrom]")
ax.set_ylabel("Lsun / Hz")
ax.set_yscale("log")
ax.set_title(f"DSPS rest SED, z={z_obs:.3f}")
ax.legend(fontsize=8)
ax.grid(alpha=0.2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_sed_diagnostic(
result: ModelResult,
path: str | Path,
*,
observation: GalaxyObservation | None = None,
ground_truth_sed: Any | None = None,
include_filters: bool = True,
) -> None:
"""Plot DSPS SED, COSMOS proxy SED, and model-anchored photometry residuals."""
fig, axes = plt.subplots(
3,
1,
figsize=(8.6, 8.2),
gridspec_kw={"height_ratios": [2.2, 1.0, 0.72]},
sharex=False,
)
ax_sed, ax_phot, ax_flux = axes
mask = _sed_plot_mask(result.wave, result.dusted_rest_sed)
if mask.any():
ax_sed.plot(
result.wave[mask],
result.rest_sed[mask],
label="DSPS intrinsic",
lw=1.0,
alpha=0.6,
)
ax_sed.plot(
result.wave[mask],
result.dusted_rest_sed[mask],
label="DSPS dusted",
lw=1.35,
)
truth = _ground_truth_sed_frame(ground_truth_sed)
if truth is not None and not truth.empty:
truth_wave = truth["wave_angstrom"].to_numpy(dtype=float)
truth_scaled = truth["ground_truth_lnu_lsun_per_hz"].to_numpy(dtype=float)
truth_mask = _sed_plot_mask(
truth_wave,
truth_scaled,
)
if truth_mask.any():
label = f"{truth['ground_truth_label'].iloc[0]} scaled"
if "ground_truth_scale_factor" in truth:
scale = pd.to_numeric(
truth["ground_truth_scale_factor"], errors="coerce"
).iloc[0]
if np.isfinite(scale):
label = f"{label} (alpha={scale:.3g})"
ax_sed.plot(
truth_wave[truth_mask],
truth_scaled[truth_mask],
label=label,
lw=1.2,
ls="--",
alpha=0.9,
)
if "ground_truth_unscaled_lnu_lsun_per_hz_display" in truth:
unscaled = truth[
"ground_truth_unscaled_lnu_lsun_per_hz_display"
].to_numpy(dtype=float)
unscaled_mask = _sed_plot_mask(truth_wave, unscaled)
if unscaled_mask.any():
ax_sed.plot(
truth_wave[unscaled_mask],
unscaled[unscaled_mask],
label=f"{truth['ground_truth_label'].iloc[0]} unscaled shape",
lw=0.9,
ls=":",
alpha=0.85,
color="#5A8F6B",
)
z_obs = result.parameters.get("z_obs", np.nan)
if include_filters and mask.any():
_plot_rest_frame_filters(ax_sed, result, mask)
if observation is not None:
comparison = pd.DataFrame(comparison_rows(observation, result)).sort_values(
"effective_wavelength_angstrom"
)
if not comparison.empty:
_plot_sed_photometry_constraints(ax_sed, result, comparison, z_obs)
x_obs = comparison["effective_wavelength_angstrom"].to_numpy(dtype=float)
if np.isfinite(z_obs):
x_obs = x_obs / (1.0 + float(z_obs))
obs_mag = comparison["observed_mag_ab"].to_numpy(dtype=float)
model_mag = comparison["model_mag_ab"].to_numpy(dtype=float)
sigma = comparison["sigma_mag"].to_numpy(dtype=float)
ax_phot.errorbar(
x_obs,
obs_mag,
yerr=sigma,
fmt="o",
ms=4.2,
lw=0.8,
capsize=2.0,
label="observed",
)
ax_phot.plot(x_obs, model_mag, "s", ms=4.0, label="DSPS model")
for _, item in comparison.iterrows():
x_val = float(item["effective_wavelength_angstrom"])
if np.isfinite(z_obs):
x_val = x_val / (1.0 + float(z_obs))
ax_phot.text(
x_val,
float(item["model_mag_ab"]),
str(item["band"]).replace("euclid_", ""),
fontsize=7,
rotation=35,
ha="left",
va="bottom",
)
chi_flux = (
pd.to_numeric(comparison["chi_flux"], errors="coerce").to_numpy(
dtype=float
)
if "chi_flux" in comparison
else np.full(len(comparison), np.nan)
)
ax_flux.axhline(0.0, color="0.2", lw=0.9)
ax_flux.scatter(
x_obs,
-chi_flux,
s=18,
color="#1F4E79",
alpha=0.8,
label="flux residual",
)
ax_flux.plot(
x_obs,
-chi_flux,
color="#B85C38",
lw=0.8,
alpha=0.8,
)
ax_flux.set_ylabel(r"$(F_{obs}-F_{model})/\sigma_F$")
ax_flux.grid(alpha=0.2)
ax_sed.set_xscale("log")
ax_sed.set_yscale("log")
ax_sed.set_ylabel("rest Lsun / Hz")
ax_sed.set_title(f"SED diagnostic, z={z_obs:.3f}")
ax_sed.text(
0.01,
0.02,
"Photometry markers are model-anchored ratios, not an independent rest-frame spectrum.",
transform=ax_sed.transAxes,
fontsize=7,
va="bottom",
ha="left",
bbox={"facecolor": "white", "edgecolor": "none", "alpha": 0.65, "pad": 2},
)
ax_sed.legend(fontsize=8, loc="best")
ax_phot.set_xscale("log")
ax_phot.invert_yaxis()
ax_phot.set_ylabel("AB mag")
ax_phot.legend(fontsize=8, loc="best")
ax_flux.set_xscale("log")
ax_flux.set_xlabel("rest-frame wavelength [Angstrom]")
fig.tight_layout()
fig.savefig(path, dpi=180)
plt.close(fig)
def _sed_plot_mask(wave: np.ndarray, values: np.ndarray) -> np.ndarray:
return (
np.isfinite(wave)
& np.isfinite(values)
& (wave >= 800.0)
& (wave <= 30_000.0)
& (values > 0.0)
)
def _plot_rest_frame_filters(
ax: plt.Axes, result: ModelResult, sed_mask: np.ndarray
) -> None:
z_obs = result.parameters.get("z_obs", np.nan)
if not np.isfinite(z_obs) or not sed_mask.any():
return
palette = [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#d62728",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf",
]
base = 0.035
height = 0.055
xaxis_transform = ax.get_xaxis_transform()
for index, (band, values) in enumerate(result.photometry.items()):
wave_filter_obs = np.asarray(
values.get("filter_wave_angstrom", []), dtype=float
)
transmission = np.asarray(values.get("filter_transmission", []), dtype=float)
passband_mask = (
np.isfinite(wave_filter_obs)
& np.isfinite(transmission)
& (transmission > 0)
)
if not passband_mask.any():
continue
rest_wave = wave_filter_obs[passband_mask] / (1.0 + float(z_obs))
color = palette[index % len(palette)]
norm_trans = transmission[passband_mask] / np.nanmax(transmission[passband_mask])
scaled = base + height * norm_trans
ax.fill_between(
rest_wave,
base,
scaled,
transform=xaxis_transform,
alpha=0.34,
color=color,
lw=0,
label=f"{band.replace('euclid_', '')} filter",
zorder=0,
clip_on=True,
)
ax.plot(
rest_wave,
scaled,
color=color,
lw=0.8,
alpha=0.9,
zorder=1,
transform=xaxis_transform,
clip_on=True,
)
def _plot_sed_photometry_constraints(
ax: plt.Axes, result: ModelResult, comparison: pd.DataFrame, z_obs: float
) -> None:
if not np.isfinite(z_obs):
return
rest_wave = comparison["effective_wavelength_angstrom"].to_numpy(dtype=float) / (
1.0 + float(z_obs)
)
model_flux = comparison["model_flux_fnu_cgs"].to_numpy(dtype=float)
obs_flux = comparison["observed_flux_fnu_cgs"].to_numpy(dtype=float)
sigma_mag = comparison["sigma_mag"].to_numpy(dtype=float)
y_model = np.interp(rest_wave, result.wave, result.dusted_rest_sed)
ratio = np.divide(
obs_flux,
model_flux,
out=np.full_like(obs_flux, np.nan),
where=np.isfinite(model_flux) & (model_flux > 0),
)
y_obs = y_model * ratio
yerr = y_obs * np.log(10.0) * 0.4 * sigma_mag
valid = (
np.isfinite(rest_wave)
& np.isfinite(y_model)
& np.isfinite(y_obs)
& (rest_wave > 0)
& (y_model > 0)
& (y_obs > 0)
)
if not valid.any():
return
ax.errorbar(
rest_wave[valid],
y_obs[valid],
yerr=yerr[valid],
fmt="o",
ms=4.0,
color="#1F4E79",
ecolor="#1F4E79",
elinewidth=0.75,
capsize=1.5,
label="model-anchored phot. ratio",
zorder=5,
)
ax.scatter(
rest_wave[valid],
y_model[valid],
marker="s",
s=18,
color="#B85C38",
label="DSPS band model",
zorder=5,
)
def _ground_truth_sed_frame(ground_truth_sed: Any | None) -> pd.DataFrame | None:
if ground_truth_sed is None:
return None
if isinstance(ground_truth_sed, pd.DataFrame):
frame = ground_truth_sed.copy()
else:
frame = pd.DataFrame(ground_truth_sed)
if frame.empty:
return None
if "wave_angstrom" not in frame:
return None
value_column = None
for candidate in (
"ground_truth_lnu_lsun_per_hz",
"lnu_lsun_per_hz_rest_proxy_scaled",
"cosmos_proxy_lnu_lsun_per_hz",
):
if candidate in frame:
value_column = candidate
break
if value_column is None:
return None
label = (
str(frame["ground_truth_label"].iloc[0])
if "ground_truth_label" in frame
else "COSMOS proxy"
)
output = pd.DataFrame(
{
"wave_angstrom": frame["wave_angstrom"].to_numpy(dtype=float),
"ground_truth_lnu_lsun_per_hz": frame[value_column].to_numpy(dtype=float),
"ground_truth_label": label,
}
)
if "ground_truth_unscaled_lnu_lsun_per_hz" in frame:
raw = frame["ground_truth_unscaled_lnu_lsun_per_hz"].to_numpy(dtype=float)
scaled = output["ground_truth_lnu_lsun_per_hz"].to_numpy(dtype=float)
good = np.isfinite(raw) & np.isfinite(scaled) & (raw > 0) & (scaled > 0)
display = np.full_like(raw, np.nan, dtype=float)
if good.any():
factor = np.nanmedian(scaled[good]) / np.nanmedian(raw[good])
display = raw * factor
output["ground_truth_unscaled_lnu_lsun_per_hz"] = raw
output["ground_truth_unscaled_lnu_lsun_per_hz_display"] = display
for column in (
"ground_truth_scale_factor",
"ground_truth_normalization_bands",
"ground_truth_norm_median_abs_rel_residual",
"ground_truth_norm_max_abs_rel_residual",
):
if column in frame:
output[column] = frame[column].iloc[0]
return output
def _positive_axis_limits(values: np.ndarray) -> tuple[float, float]:
finite = values[np.isfinite(values) & (values > 0)]
if finite.size == 0:
return 1.0, 10.0
return float(np.nanpercentile(finite, 1)), float(np.nanpercentile(finite, 99.5))
[docs]
def plot_photometry_comparison(comparison: pd.DataFrame, path: str | Path) -> None:
work = comparison.sort_values("effective_wavelength_angstrom").reset_index(
drop=True
)
x = work["effective_wavelength_angstrom"].to_numpy(dtype=float) / 10_000.0
fig, (ax_mag, ax_resid) = plt.subplots(
2,
1,
figsize=(8, 6),
sharex=True,
gridspec_kw={"height_ratios": [2.2, 1.0]},
)
ax_mag.errorbar(
x,
work["observed_mag_ab"],
yerr=work["sigma_mag"],
fmt="o",
ms=6,
capsize=3,
label="Simulated catalog",
)
ax_mag.plot(x, work["model_mag_ab"], marker="s", lw=1.4, label="DSPS model")
for xi, yi, label in zip(x, work["observed_mag_ab"], work["band"], strict=True):
ax_mag.annotate(
label.replace("euclid_", ""),
(xi, yi),
textcoords="offset points",
xytext=(3, 5),
fontsize=8,
)
ax_mag.set_ylabel("AB magnitude")
ax_mag.invert_yaxis()
ax_mag.legend(fontsize=8)
ax_mag.grid(alpha=0.25)
ax_resid.axhline(0, color="black", lw=1)
ax_resid.errorbar(
x,
work["residual_mag_model_minus_observed"],
yerr=work["sigma_mag"],
fmt="o",
capsize=3,
)
ax_resid.set_xlabel("observed-frame effective wavelength [micron]")
ax_resid.set_ylabel("model - obs [mag]")
ax_resid.grid(alpha=0.25)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_fit_trace(trace: pd.DataFrame, path: str | Path) -> None:
if trace.empty:
return
y_col = (
"chi2"
if "chi2" in trace
else "mean_chi2_or_loss" if "mean_chi2_or_loss" in trace else None
)
if y_col is None:
return
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(np.arange(len(trace)), trace[y_col], lw=1)
ax.set_xlabel("evaluation")
ax.set_ylabel(y_col)
ax.set_yscale("log")
ax.grid(alpha=0.25)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def write_posterior_predictive(mcmc_result: Any, path: str | Path) -> None:
rows = []
mags = np.asarray(mcmc_result.posterior_model_mags)
for band_index, band in enumerate(mcmc_result.band_names):
values = mags[:, band_index]
rows.append(
{
"band": band,
"observed_mag_ab": float(mcmc_result.observed_mag[band_index]),
"sigma_mag": float(mcmc_result.sigma_mag[band_index]),
"model_mag_q05": float(np.quantile(values, 0.05)),
"model_mag_q16": float(np.quantile(values, 0.16)),
"model_mag_median": float(np.quantile(values, 0.50)),
"model_mag_q84": float(np.quantile(values, 0.84)),
"model_mag_q95": float(np.quantile(values, 0.95)),
}
)
pd.DataFrame(rows).to_csv(path, index=False)
[docs]
def plot_mcmc_traces(samples: pd.DataFrame, path: str | Path) -> None:
if samples.empty:
return
fig, axes = plt.subplots(
len(samples.columns),
1,
figsize=(8, max(2.2 * len(samples.columns), 3)),
sharex=True,
)
axes = np.atleast_1d(axes)
for ax, col in zip(axes, samples.columns, strict=True):
ax.plot(samples[col].to_numpy(dtype=float), lw=0.8)
ax.set_ylabel(col)
ax.grid(alpha=0.2)
axes[-1].set_xlabel("posterior sample")
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_corner(samples: pd.DataFrame, path: str | Path) -> None:
if samples.empty:
return
varying = samples.loc[:, samples.nunique(dropna=True) > 1]
if varying.shape[1] < 2:
return
if varying.shape[0] <= varying.shape[1]:
return
try:
import corner
except ImportError:
return
try:
fig = corner.corner(
varying.to_numpy(dtype=float),
labels=list(varying.columns),
show_titles=True,
)
except (AssertionError, ValueError):
return
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def posterior_comparable_frame(
samples: pd.DataFrame,
derived: pd.DataFrame,
truth_values: dict[str, Any],
) -> pd.DataFrame:
"""Build posterior columns that have like-for-like truth/proxy values."""
data = {}
if "truth_log10_sfr_at_obs" in truth_values and "log10_sfr_at_obs" in derived:
data["log10_sfr_at_obs"] = derived["log10_sfr_at_obs"].to_numpy(dtype=float)
if (
"truth_log10_formed_mass_msun" in truth_values
and "log10_formed_mass_msun" in derived
):
data["log10_formed_mass_msun"] = derived["log10_formed_mass_msun"].to_numpy(
dtype=float
)
if "truth_log10_metallicity" in truth_values and "log10_metallicity" in samples:
data["log10_metallicity"] = samples["log10_metallicity"].to_numpy(dtype=float)
for column in samples.columns:
truth_key = f"truth_{column}"
if truth_key in truth_values and column not in data:
data[column] = samples[column].to_numpy(dtype=float)
return pd.DataFrame(data)
[docs]
def plot_corner_with_truth(
samples: pd.DataFrame,
truth_values: dict[str, Any],
path: str | Path,
) -> None:
if samples.empty:
return
varying = samples.loc[:, samples.nunique(dropna=True) > 1]
if varying.shape[1] < 2 or varying.shape[0] <= varying.shape[1]:
return
truth_by_column = {}
for column in varying.columns:
value = truth_values.get(f"truth_{column}")
if value is not None and np.isfinite(value):
truth_by_column[column] = float(value)
ranges = [
_corner_range_with_truth(
varying[column].to_numpy(dtype=float), truth_by_column.get(column)
)
for column in varying.columns
]
try:
import corner
except ImportError:
return
try:
fig = corner.corner(
varying.to_numpy(dtype=float),
labels=[_posterior_display_label(name) for name in varying.columns],
show_titles=True,
range=ranges,
title_kwargs={"fontsize": 11},
)
except (AssertionError, ValueError):
return
_annotate_corner_truth(fig, list(varying.columns), truth_by_column, truth_values)
fig.suptitle(
"Posterior with catalog truth/proxy markers",
y=0.995,
fontsize=13,
)
fig.subplots_adjust(top=0.92)
fig.savefig(path, dpi=170)
plt.close(fig)
def _corner_range_with_truth(
values: np.ndarray, truth_value: float | None
) -> tuple[float, float]:
finite = values[np.isfinite(values)]
if finite.size == 0:
return (0.0, 1.0)
lo = float(np.min(finite))
hi = float(np.max(finite))
if truth_value is not None and np.isfinite(truth_value):
lo = min(lo, float(truth_value))
hi = max(hi, float(truth_value))
if not np.isfinite(lo) or not np.isfinite(hi):
return (0.0, 1.0)
if lo == hi:
pad = max(abs(lo) * 0.05, 1.0e-3)
else:
pad = max((hi - lo) * 0.08, 1.0e-3)
return (lo - pad, hi + pad)
def _annotate_corner_truth(
fig: Any,
columns: list[str],
truth_by_column: dict[str, float],
truth_values: dict[str, Any],
) -> None:
n_dim = len(columns)
axes = np.asarray(fig.axes).reshape((n_dim, n_dim))
truth_color = "#d94f70"
for row, y_name in enumerate(columns):
y_truth = truth_by_column.get(y_name)
for col, x_name in enumerate(columns):
if row < col:
continue
ax = axes[row, col]
x_truth = truth_by_column.get(x_name)
if x_truth is not None:
ax.axvline(x_truth, color=truth_color, lw=2.0, ls="--", zorder=8)
if row > col and y_truth is not None:
ax.axhline(y_truth, color=truth_color, lw=2.0, ls="--", zorder=8)
if row == col and x_truth is not None:
ax.text(
0.98,
0.86,
_truth_axis_label(x_name, x_truth, truth_values),
transform=ax.transAxes,
ha="right",
va="top",
fontsize=8.5,
color=truth_color,
bbox={
"facecolor": "white",
"edgecolor": truth_color,
"alpha": 0.85,
"pad": 2,
},
)
def _truth_axis_label(
parameter: str, value: float, truth_values: dict[str, Any]
) -> str:
source = truth_values.get(f"truth_source_{parameter}")
kind = truth_values.get(f"truth_kind_{parameter}", "direct")
prefix = "proxy" if kind == "proxy" else "truth"
if source:
return f"{prefix} = {value:.3g} ({source})"
return f"{prefix} = {value:.3g}"
def _posterior_display_label(parameter: str) -> str:
labels = {
"z_obs": "DSPS redshift",
"log10_sfr_at_obs": "DSPS log10 SFR(t_obs)",
"log10_sfr": "DSPS log10 SFH amplitude",
"log10_formed_mass_msun": "DSPS log10 formed mass",
"dust_av": "DSPS dust A_V",
"log10_metallicity": "DSPS log10 stellar Z",
}
return labels.get(parameter, parameter)
[docs]
def plot_posterior_predictive(mcmc_result: Any, path: str | Path) -> None:
mags = np.asarray(mcmc_result.posterior_model_mags)
x = np.arange(len(mcmc_result.band_names))
med = np.quantile(mags, 0.50, axis=0)
lo = np.quantile(mags, 0.16, axis=0)
hi = np.quantile(mags, 0.84, axis=0)
obs = np.asarray(mcmc_result.observed_mag, dtype=float)
sig = np.asarray(mcmc_result.sigma_mag, dtype=float)
fig, ax = plt.subplots(figsize=(8, 4.8))
ax.errorbar(x, obs, yerr=sig, fmt="o", capsize=3, label="catalog")
ax.errorbar(
x, med, yerr=[med - lo, hi - med], fmt="s", capsize=3, label="posterior model"
)
ax.set_xticks(x)
ax.set_xticklabels(
[name.replace("euclid_", "") for name in mcmc_result.band_names], rotation=20
)
ax.set_ylabel("AB magnitude")
ax.invert_yaxis()
ax.grid(alpha=0.25)
ax.legend(fontsize=8)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_batch_posterior_intervals(summary: pd.DataFrame, path: str | Path) -> None:
if summary.empty or not {"row_index", "parameter", "median", "q16", "q84"}.issubset(
summary.columns
):
return
parameters = list(summary["parameter"].drop_duplicates())
fig, axes = plt.subplots(
len(parameters), 1, figsize=(9, max(2.4 * len(parameters), 3)), sharex=True
)
axes = np.atleast_1d(axes)
for ax, parameter in zip(axes, parameters, strict=True):
work = summary[summary["parameter"] == parameter].sort_values("row_index")
x = work["row_index"].to_numpy(dtype=float)
median = work["median"].to_numpy(dtype=float)
q16 = work["q16"].to_numpy(dtype=float)
q84 = work["q84"].to_numpy(dtype=float)
ax.errorbar(
x, median, yerr=[median - q16, q84 - median], fmt="o", ms=4, capsize=2
)
ax.set_ylabel(parameter)
ax.grid(alpha=0.2)
axes[-1].set_xlabel("catalog row index")
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_batch_posterior_predictive(predictive: pd.DataFrame, path: str | Path) -> None:
if predictive.empty or "residual_mag_median_model_minus_observed" not in predictive:
return
fig, ax = plt.subplots(figsize=(9, 5))
for band in ordered_bands(predictive):
work = predictive[predictive["band"] == band]
ax.scatter(
work["observed_mag_ab"],
work["residual_mag_median_model_minus_observed"],
s=14,
alpha=0.65,
label=band.replace("euclid_", ""),
)
ax.axhline(0, color="black", lw=1)
ax.set_xlabel("observed AB magnitude")
ax.set_ylabel("posterior median model - observed [mag]")
ax.grid(alpha=0.2)
ax.legend(fontsize=8)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_batch_mcmc_diagnostics(diagnostics: pd.DataFrame, path: str | Path) -> None:
if diagnostics.empty:
return
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
if "n_divergent" in diagnostics:
axes[0].bar(diagnostics["row_index"].astype(str), diagnostics["n_divergent"])
axes[0].set_ylabel("divergences")
else:
axes[0].axis("off")
if "mean_accept_prob" in diagnostics:
axes[1].bar(
diagnostics["row_index"].astype(str), diagnostics["mean_accept_prob"]
)
axes[1].axhline(0.8, color="black", lw=1, alpha=0.4)
axes[1].set_ylabel("mean accept prob")
axes[1].set_ylim(0, 1)
else:
axes[1].axis("off")
for ax in axes:
ax.set_xlabel("catalog row index")
ax.grid(alpha=0.2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def write_workflow_comparison(
map_fits: pd.DataFrame,
population_fits: pd.DataFrame,
hmc_summary: pd.DataFrame,
hmc_diagnostics: pd.DataFrame,
free_parameters: list[str],
out_dir: str | Path,
hmc_samples: pd.DataFrame | None = None,
) -> None:
out = ensure_dir(out_dir)
map_pop_params = workflow_parameter_comparison(
map_fits, population_fits, free_parameters
)
map_pop_fits = workflow_fit_comparison(map_fits, population_fits)
hmc_compare = workflow_hmc_comparison(
map_fits, population_fits, hmc_summary, free_parameters
)
map_pop_params.to_csv(out / "map_vs_population_parameters.csv", index=False)
map_pop_fits.to_csv(out / "map_vs_population_fit_quality.csv", index=False)
if not hmc_compare.empty:
hmc_compare.to_csv(out / "hmc_vs_map_population.csv", index=False)
if not hmc_diagnostics.empty:
hmc_diagnostics.to_csv(out / "hmc_diagnostics.csv", index=False)
summary = {
"n_map_galaxies": (
int(map_fits["row_index"].nunique()) if "row_index" in map_fits else 0
),
"n_population_galaxies": (
int(population_fits["row_index"].nunique())
if "row_index" in population_fits
else 0
),
"n_hmc_galaxies": (
int(hmc_summary["row_index"].nunique()) if "row_index" in hmc_summary else 0
),
}
if "delta_chi2_population_minus_map" in map_pop_fits:
summary["median_delta_chi2_population_minus_map"] = float(
map_pop_fits["delta_chi2_population_minus_map"].median()
)
if "n_divergent" in hmc_diagnostics:
summary["hmc_total_divergences"] = int(hmc_diagnostics["n_divergent"].sum())
write_json(out / "workflow_comparison_summary.json", summary)
plot_map_population_parameters(
map_pop_params, out / "map_vs_population_parameters.png"
)
plot_map_population_chi2(map_pop_fits, out / "map_vs_population_chi2.png")
plot_hmc_map_population(hmc_compare, out / "hmc_vs_map_population.png")
plot_workflow_parameter_corners(
map_fits,
population_fits,
hmc_summary,
hmc_samples,
free_parameters,
out,
)
[docs]
def workflow_parameter_comparison(
map_fits: pd.DataFrame, population_fits: pd.DataFrame, free_parameters: list[str]
) -> pd.DataFrame:
rows = []
merged = map_fits.merge(
population_fits, on="row_index", suffixes=("_map", "_population")
)
for _, row in merged.iterrows():
for parameter in free_parameters:
map_col = f"fit_{parameter}_map"
pop_col = f"fit_{parameter}_population"
if map_col in row and pop_col in row:
rows.append(
{
"row_index": int(row["row_index"]),
"parameter": parameter,
"map_value": float(row[map_col]),
"population_value": float(row[pop_col]),
"delta_population_minus_map": float(
row[pop_col] - row[map_col]
),
}
)
return pd.DataFrame(rows)
[docs]
def workflow_fit_comparison(
map_fits: pd.DataFrame, population_fits: pd.DataFrame
) -> pd.DataFrame:
cols = ["row_index", "chi2", "reduced_chi2", "gradient_norm"]
left = map_fits[[col for col in cols if col in map_fits]].copy()
right = population_fits[[col for col in cols if col in population_fits]].copy()
merged = left.merge(right, on="row_index", suffixes=("_map", "_population"))
if {"chi2_map", "chi2_population"}.issubset(merged.columns):
merged["delta_chi2_population_minus_map"] = (
merged["chi2_population"] - merged["chi2_map"]
)
if {"reduced_chi2_map", "reduced_chi2_population"}.issubset(merged.columns):
merged["delta_reduced_chi2_population_minus_map"] = (
merged["reduced_chi2_population"] - merged["reduced_chi2_map"]
)
return merged
[docs]
def workflow_hmc_comparison(
map_fits: pd.DataFrame,
population_fits: pd.DataFrame,
hmc_summary: pd.DataFrame,
free_parameters: list[str],
) -> pd.DataFrame:
if hmc_summary.empty:
return pd.DataFrame()
rows = []
map_by_row = map_fits.drop_duplicates("row_index").set_index("row_index")
pop_by_row = population_fits.drop_duplicates("row_index").set_index("row_index")
for _, row in hmc_summary.iterrows():
parameter = row.get("parameter")
row_index = int(row["row_index"])
if parameter not in free_parameters or row_index not in map_by_row.index:
continue
map_value = float(map_by_row.loc[row_index, f"fit_{parameter}"])
population_value = (
float(pop_by_row.loc[row_index, f"fit_{parameter}"])
if row_index in pop_by_row.index
else np.nan
)
median = float(row["median"])
rows.append(
{
"row_index": row_index,
"parameter": parameter,
"hmc_q16": float(row["q16"]),
"hmc_median": median,
"hmc_q84": float(row["q84"]),
"map_value": map_value,
"population_value": population_value,
"delta_map_minus_hmc_median": map_value - median,
"delta_population_minus_hmc_median": population_value - median,
}
)
return pd.DataFrame(rows)
[docs]
def plot_workflow_parameter_corners(
map_fits: pd.DataFrame,
population_fits: pd.DataFrame,
hmc_summary: pd.DataFrame,
hmc_samples: pd.DataFrame | None,
free_parameters: list[str],
out: Path,
) -> None:
plot_corner(
_fit_parameter_frame(map_fits, free_parameters),
out / "corner_map_parameters.png",
)
plot_corner(
_fit_parameter_frame(population_fits, free_parameters),
out / "corner_population_parameters.png",
)
if not hmc_summary.empty:
hmc_medians = hmc_summary.pivot(
index="row_index", columns="parameter", values="median"
)
hmc_medians = hmc_medians[
[name for name in free_parameters if name in hmc_medians.columns]
]
plot_corner(
hmc_medians.reset_index(drop=True), out / "corner_hmc_posterior_medians.png"
)
if hmc_samples is None or hmc_samples.empty:
return
sample_columns = [name for name in free_parameters if name in hmc_samples.columns]
if not sample_columns:
return
plot_corner(hmc_samples[sample_columns], out / "corner_hmc_all_samples.png")
individual_dir = ensure_dir(out / "hmc_individual_corners")
for row_index, group in hmc_samples.groupby("row_index"):
plot_corner(
group[sample_columns],
individual_dir / f"corner_hmc_row_{int(row_index)}.png",
)
def _fit_parameter_frame(
fits: pd.DataFrame, free_parameters: list[str]
) -> pd.DataFrame:
data = {}
for parameter in free_parameters:
col = f"fit_{parameter}"
if col in fits:
data[parameter] = fits[col].to_numpy(dtype=float)
return pd.DataFrame(data)
def _truth_parameter_frame(
fits: pd.DataFrame, free_parameters: list[str]
) -> pd.DataFrame:
data = {}
for parameter in free_parameters:
col = f"truth_{parameter}"
if col in fits:
data[parameter] = fits[col].to_numpy(dtype=float)
return pd.DataFrame(data)
[docs]
def paired_fit_truth_frames(
fits: pd.DataFrame, config: dict[str, Any] | None = None
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Return aligned frames for parameters with both ``fit_`` and ``truth_`` columns."""
inferred = {}
truth = {}
for col in fits.columns:
if not col.startswith("truth_"):
continue
parameter = col[6:]
if not is_comparable_fit_parameter(config, parameter):
continue
fit_col = f"fit_{parameter}"
if fit_col in fits:
inferred[parameter] = fits[fit_col].to_numpy(dtype=float)
truth[parameter] = fits[col].to_numpy(dtype=float)
return pd.DataFrame(inferred), pd.DataFrame(truth)
[docs]
def plot_corner_overlay(
inferred: pd.DataFrame, truth: pd.DataFrame, path: str | Path
) -> None:
common = [col for col in inferred.columns if col in truth.columns]
if len(common) < 2:
return
inferred = inferred[common].replace([np.inf, -np.inf], np.nan)
truth = truth[common].replace([np.inf, -np.inf], np.nan)
valid = inferred.notna().all(axis=1) & truth.notna().all(axis=1)
inferred = inferred.loc[valid]
truth = truth.loc[valid]
if inferred.empty or truth.empty:
return
varying = [
col
for col in common
if pd.concat([inferred[col], truth[col]]).nunique(dropna=True) > 1
]
if len(varying) < 2 or len(inferred) <= len(varying) or len(truth) <= len(varying):
return
inferred = inferred[varying]
truth = truth[varying]
ranges = []
for col in varying:
combined = pd.concat([inferred[col], truth[col]])
lo, hi = float(combined.min()), float(combined.max())
pad = 0.03 * (hi - lo)
ranges.append((lo - pad, hi + pad))
try:
import corner
except ImportError:
return
inferred_color = "#8fbcd4"
truth_color = "#e6a0a8"
try:
fig = corner.corner(
inferred.to_numpy(dtype=float),
labels=[_parameter_display_label(name) for name in varying],
range=ranges,
show_titles=True,
color=inferred_color,
plot_datapoints=True,
fill_contours=True,
hist_kwargs={
"density": True,
"color": inferred_color,
"alpha": 0.55,
},
contour_kwargs={"colors": inferred_color, "linewidths": 1.0},
)
corner.corner(
truth.to_numpy(dtype=float),
labels=[_parameter_display_label(name) for name in varying],
fig=fig,
range=ranges,
show_titles=False,
color=truth_color,
plot_datapoints=False,
fill_contours=False,
hist_kwargs={
"density": True,
"histtype": "step",
"linewidth": 1.6,
"color": truth_color,
},
contour_kwargs={"colors": truth_color, "linewidths": 1.4},
)
except (AssertionError, ValueError):
return
from matplotlib.lines import Line2D
fig.legend(
handles=[
Line2D([0], [0], color=inferred_color, lw=3, label="inferred"),
Line2D([0], [0], color=truth_color, lw=3, label="truth/proxy"),
],
loc="upper right",
frameon=False,
fontsize=10,
)
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_map_population_parameters(comparison: pd.DataFrame, path: str | Path) -> None:
if comparison.empty:
return
parameters = list(comparison["parameter"].drop_duplicates())
fig, axes = plt.subplots(
1, len(parameters), figsize=(4 * len(parameters), 4), squeeze=False
)
for ax, parameter in zip(axes[0], parameters, strict=True):
work = comparison[comparison["parameter"] == parameter]
ax.scatter(work["map_value"], work["population_value"], s=12, alpha=0.5)
values = (
pd.concat([work["map_value"], work["population_value"]])
.replace([np.inf, -np.inf], np.nan)
.dropna()
)
if not values.empty:
lo, hi = float(values.min()), float(values.max())
ax.plot([lo, hi], [lo, hi], color="black", lw=1)
ax.set_title(parameter)
ax.set_xlabel("independent MAP")
ax.set_ylabel("population MAP")
ax.grid(alpha=0.2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_map_population_chi2(comparison: pd.DataFrame, path: str | Path) -> None:
if comparison.empty or not {"reduced_chi2_map", "reduced_chi2_population"}.issubset(
comparison.columns
):
return
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].scatter(
comparison["reduced_chi2_map"],
comparison["reduced_chi2_population"],
s=12,
alpha=0.5,
)
values = (
pd.concat(
[comparison["reduced_chi2_map"], comparison["reduced_chi2_population"]]
)
.replace([np.inf, -np.inf], np.nan)
.dropna()
)
if not values.empty:
lo, hi = float(values.min()), float(values.max())
axes[0].plot([lo, hi], [lo, hi], color="black", lw=1)
axes[0].set_xlabel("independent MAP reduced chi2")
axes[0].set_ylabel("population MAP reduced chi2")
axes[0].grid(alpha=0.2)
if "delta_reduced_chi2_population_minus_map" in comparison:
axes[1].hist(
comparison["delta_reduced_chi2_population_minus_map"].dropna(),
bins=50,
alpha=0.75,
)
axes[1].axvline(0, color="black", lw=1)
axes[1].set_xlabel("population - independent reduced chi2")
axes[1].set_ylabel("galaxies")
axes[1].grid(alpha=0.2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_hmc_map_population(comparison: pd.DataFrame, path: str | Path) -> None:
if comparison.empty:
return
parameters = list(comparison["parameter"].drop_duplicates())
fig, axes = plt.subplots(
len(parameters), 1, figsize=(9, max(2.6 * len(parameters), 3)), sharex=False
)
axes = np.atleast_1d(axes)
for ax, parameter in zip(axes, parameters, strict=True):
work = (
comparison[comparison["parameter"] == parameter]
.sort_values("row_index")
.reset_index(drop=True)
)
x = np.arange(len(work))
median = work["hmc_median"].to_numpy(dtype=float)
q16 = work["hmc_q16"].to_numpy(dtype=float)
q84 = work["hmc_q84"].to_numpy(dtype=float)
ax.errorbar(
x,
median,
yerr=[median - q16, q84 - median],
fmt="o",
ms=4,
capsize=2,
label="HMC",
)
ax.scatter(x, work["map_value"], marker="x", s=35, label="MAP")
if "population_value" in work:
ax.scatter(
x,
work["population_value"],
marker="s",
s=20,
facecolors="none",
edgecolors="tab:orange",
label="population",
)
ax.set_xticks(x)
ax.set_xticklabels(work["row_index"].astype(str), rotation=30)
ax.set_ylabel(parameter)
ax.grid(alpha=0.2)
ax.legend(fontsize=8, ncol=3)
axes[-1].set_xlabel("catalog row index")
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_batch_dashboard(
valid: pd.DataFrame, by_row: pd.DataFrame, path: str | Path
) -> None:
fig, axes = plt.subplots(2, 2, figsize=(11, 8))
plot_residual_boxplot(valid, axes[0, 0])
plot_observed_model_scatter(valid, axes[0, 1])
reduced = by_row["reduced_chi2"].replace([np.inf, -np.inf], np.nan).dropna()
if not reduced.empty:
axes[1, 0].hist(np.log10(reduced + 1.0e-12), bins=50, alpha=0.75)
axes[1, 0].set_xlabel("log10 reduced chi2")
axes[1, 0].set_ylabel("galaxies")
axes[1, 0].grid(alpha=0.2)
if {"z_obs", "redshift_truth"}.issubset(by_row.columns):
plot_redshift_scatter(by_row, axes[1, 1])
elif "z_obs" in by_row.columns:
axes[1, 1].scatter(
by_row["z_obs"], by_row["mean_residual_mag"], s=8, alpha=0.35
)
axes[1, 1].axhline(0, color="black", lw=1)
axes[1, 1].set_xlabel("z used by DSPS")
axes[1, 1].set_ylabel("mean residual [mag]")
axes[1, 1].grid(alpha=0.2)
else:
axes[1, 1].axis("off")
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_batch_residuals_by_band(valid: pd.DataFrame, path: str | Path) -> None:
fig, ax = plt.subplots(figsize=(9, 5))
plot_residual_boxplot(valid, ax)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_batch_observed_vs_model(valid: pd.DataFrame, path: str | Path) -> None:
fig, ax = plt.subplots(figsize=(6, 6))
plot_observed_model_scatter(valid, ax)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_batch_redshift_truth(by_row: pd.DataFrame, path: str | Path) -> None:
if not {"z_obs", "redshift_truth"}.issubset(by_row.columns):
return
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
plot_redshift_scatter(by_row, axes[0])
dz = by_row["delta_z_obs_minus_truth"].replace([np.inf, -np.inf], np.nan).dropna()
axes[1].hist(dz, bins=60, alpha=0.75)
axes[1].axvline(0, color="black", lw=1)
fit_label, truth_label = _redshift_axis_labels(by_row)
label = f"{fit_label} - {truth_label}"
axes[1].set_xlabel(label)
axes[1].set_ylabel("galaxies")
axes[1].grid(alpha=0.2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_redshift_attractors(
by_row: pd.DataFrame, attractors: pd.DataFrame, path: str | Path
) -> None:
if by_row.empty or attractors.empty or "z_obs" not in by_row:
return
fig, axes = plt.subplots(1, 2, figsize=(10.5, 4))
work = by_row.copy()
if {"redshift_truth", "z_obs"}.issubset(work.columns):
sample = work[["redshift_truth", "z_obs"]].replace(
[np.inf, -np.inf], np.nan
).dropna()
if len(sample) > 5000:
sample = sample.sample(5000, random_state=4)
axes[0].scatter(sample["redshift_truth"], sample["z_obs"], s=8, alpha=0.3)
if not sample.empty:
lo = float(min(sample["redshift_truth"].min(), sample["z_obs"].min()))
hi = float(max(sample["redshift_truth"].max(), sample["z_obs"].max()))
axes[0].plot([lo, hi], [lo, hi], color="black", lw=1)
axes[0].set_xlabel("Catalog truth redshift")
axes[0].set_ylabel("Inferred redshift")
else:
axes[0].scatter(np.arange(len(work)), work["z_obs"], s=8, alpha=0.3)
axes[0].set_xlabel("galaxy index")
axes[0].set_ylabel("Inferred redshift")
top = attractors.head(10)
for _, row in top.iterrows():
axes[0].axhline(
float(row["z_fit_bin"]),
color="#B85C38",
lw=0.8,
alpha=0.35,
)
axes[0].grid(alpha=0.2)
axes[1].barh(
top["z_fit_bin"].astype(str),
top["n_galaxies"],
color="#2F5D8C",
alpha=0.8,
)
axes[1].invert_yaxis()
axes[1].set_xlabel("galaxies in fitted-z bin")
axes[1].set_ylabel("fitted-z bin")
axes[1].grid(alpha=0.2, axis="x")
fig.suptitle("MAP redshift attractors", y=0.995)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_batch_parameter_truth(
by_row: pd.DataFrame,
path: str | Path,
config: dict[str, Any] | None = None,
) -> None:
# Find all parameters that have both fit and truth
params = []
for col in by_row.columns:
if col.startswith("truth_"):
param_name = col[6:]
if (
f"fit_{param_name}" in by_row.columns
and is_comparable_fit_parameter(config, param_name)
):
params.append(param_name)
if not params:
return
fig, axes = plt.subplots(len(params), 2, figsize=(10, 4 * len(params)))
axes = np.atleast_2d(axes)
for i, param in enumerate(params):
truth_col = f"truth_{param}"
fit_col = f"fit_{param}"
delta_col = f"delta_fit_{param}_minus_truth"
work = (
by_row[[truth_col, fit_col, delta_col]]
.replace([np.inf, -np.inf], np.nan)
.dropna()
)
if work.empty:
continue
ax_scatter = axes[i, 0]
ax_hist = axes[i, 1]
ax_scatter.scatter(work[truth_col], work[fit_col], s=8, alpha=0.35)
lo, hi = (
min(work[truth_col].min(), work[fit_col].min()),
max(work[truth_col].max(), work[fit_col].max()),
)
ax_scatter.plot([lo, hi], [lo, hi], color="black", lw=1)
ax_scatter.set_xlabel(
f"{_truth_legend_label(param, by_row)} {_parameter_display_label(param)}"
)
ax_scatter.set_ylabel(f"inferred {_parameter_display_label(param)}")
ax_scatter.grid(alpha=0.2)
ax_hist.hist(work[delta_col], bins=60, alpha=0.75)
ax_hist.axvline(0, color="black", lw=1)
ax_hist.set_xlabel(
f"inferred {_parameter_display_label(param)} - "
f"{_truth_legend_label(param, by_row)}"
)
ax_hist.set_ylabel("galaxies")
ax_hist.grid(alpha=0.2)
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
[docs]
def plot_residuals_by_property(by_row: pd.DataFrame, path: str | Path) -> None:
specs = [
("redshift_truth", "truth redshift"),
("z_obs", "DSPS fitted redshift"),
("catalog_log_stellar_mass", "catalog log stellar mass"),
("fit_log10_formed_mass_msun", "fit log formed mass"),
]
available = [(column, label) for column, label in specs if column in by_row]
has_color = "catalog_color_kind" in by_row
if not available and not has_color:
return
n_panels = len(available) + int(has_color)
fig, axes = plt.subplots(
n_panels, 1, figsize=(8, max(2.8 * n_panels, 3)), squeeze=False
)
axes_flat = axes[:, 0]
panel = 0
for column, label in available:
work = by_row[[column, "mean_residual_mag"]].replace([np.inf, -np.inf], np.nan)
work = work.dropna()
if work.empty:
axes_flat[panel].axis("off")
panel += 1
continue
axes_flat[panel].scatter(
work[column], work["mean_residual_mag"], s=10, alpha=0.35
)
_plot_running_median(
work[column].to_numpy(dtype=float),
work["mean_residual_mag"].to_numpy(dtype=float),
axes_flat[panel],
)
axes_flat[panel].axhline(0, color="black", lw=1)
axes_flat[panel].set_xlabel(label)
axes_flat[panel].set_ylabel("mean model - obs [mag]")
axes_flat[panel].grid(alpha=0.2)
panel += 1
if has_color:
work = by_row[["catalog_color_kind", "mean_residual_mag"]].dropna()
if work.empty:
axes_flat[panel].axis("off")
else:
groups = [
group["mean_residual_mag"].to_numpy(dtype=float)
for _, group in work.groupby("catalog_color_kind")
]
labels = [str(key) for key in work.groupby("catalog_color_kind").groups]
axes_flat[panel].boxplot(groups, labels=labels, showfliers=False)
axes_flat[panel].axhline(0, color="black", lw=1)
axes_flat[panel].set_xlabel("catalog color_kind")
axes_flat[panel].set_ylabel("mean model - obs [mag]")
axes_flat[panel].grid(alpha=0.2, axis="y")
fig.tight_layout()
fig.savefig(path, dpi=170)
plt.close(fig)
def _plot_running_median(x: np.ndarray, y: np.ndarray, ax: plt.Axes) -> None:
mask = np.isfinite(x) & np.isfinite(y)
if mask.sum() < 8:
return
x = x[mask]
y = y[mask]
order = np.argsort(x)
x = x[order]
y = y[order]
bins = np.array_split(np.arange(len(x)), min(8, len(x)))
centers = []
medians = []
for item in bins:
if len(item) == 0:
continue
centers.append(float(np.median(x[item])))
medians.append(float(np.median(y[item])))
ax.plot(centers, medians, color="black", lw=1.4)
[docs]
def plot_residual_boxplot(valid: pd.DataFrame, ax: plt.Axes) -> None:
bands = ordered_bands(valid)
data = [
valid.loc[valid["band"] == band, "residual_mag_model_minus_observed"].dropna()
for band in bands
]
ax.boxplot(data, labels=bands, showfliers=False)
ax.axhline(0, color="black", lw=1)
ax.set_ylabel("model - simulated [mag]")
ax.tick_params(axis="x", rotation=30)
ax.grid(alpha=0.2, axis="y")
[docs]
def plot_observed_model_scatter(valid: pd.DataFrame, ax: plt.Axes) -> None:
sample = (
valid.sample(min(len(valid), 5000), random_state=1)
if len(valid) > 5000
else valid
)
for band in ordered_bands(sample):
work = sample[sample["band"] == band]
ax.scatter(
work["observed_mag_ab"], work["model_mag_ab"], s=8, alpha=0.35, label=band
)
values = (
pd.concat([sample["observed_mag_ab"], sample["model_mag_ab"]])
.replace([np.inf, -np.inf], np.nan)
.dropna()
)
if not values.empty:
lo, hi = float(values.min()), float(values.max())
ax.plot([lo, hi], [lo, hi], color="black", lw=1)
ax.set_xlabel("simulated AB mag")
ax.set_ylabel("DSPS AB mag")
ax.legend(fontsize=7, ncol=2)
ax.grid(alpha=0.2)
[docs]
def plot_redshift_scatter(by_row: pd.DataFrame, ax: plt.Axes) -> None:
work = (
by_row[["redshift_truth", "z_obs"]].replace([np.inf, -np.inf], np.nan).dropna()
)
if len(work) > 5000:
work = work.sample(5000, random_state=2)
ax.scatter(work["redshift_truth"], work["z_obs"], s=8, alpha=0.35)
if not work.empty:
lo = float(min(work["redshift_truth"].min(), work["z_obs"].min()))
hi = float(max(work["redshift_truth"].max(), work["z_obs"].max()))
ax.plot([lo, hi], [lo, hi], color="black", lw=1)
fit_label, truth_label = _redshift_axis_labels(by_row)
ax.set_xlabel(truth_label)
ax.set_ylabel(fit_label)
ax.grid(alpha=0.2)
def _redshift_axis_labels(by_row: pd.DataFrame) -> tuple[str, str]:
truth_source = _first_string(by_row, "redshift_truth_source") or "z_true"
fit_source = _first_string(by_row, "z_obs_source") or "z_phz"
fit_label = (
"Inferred Redshift (Fit)"
if "fit_z_obs" in by_row.columns or fit_source == "DSPS fit"
else f"Input {fit_source} to DSPS"
)
return fit_label, f"Ground Truth ({truth_source})"
def _first_string(frame: pd.DataFrame, column: str) -> str | None:
if column not in frame:
return None
values = frame[column].dropna()
if values.empty:
return None
return str(values.iloc[0])
def _truth_kind_from_frame(frame: pd.DataFrame, parameter: str) -> str:
column = f"truth_kind_{parameter}"
value = _first_string(frame, column)
if value:
return value
return "proxy" if parameter in {"dust_av", "log10_metallicity"} else "direct"
def _truth_legend_label(parameter: str, frame: pd.DataFrame | None = None) -> str:
kind = _truth_kind_from_frame(frame, parameter) if frame is not None else None
if kind is None:
kind = "proxy" if parameter in {"dust_av", "log10_metallicity"} else "direct"
if kind == "proxy":
return "Catalog proxy"
return "Catalog truth"
def _parameter_display_label(parameter: str) -> str:
labels = {
"z_obs": "Redshift",
"log10_sfr_at_obs": "log10 SFR [Msun/yr]",
"log10_sfr": "log10 SFH Amplitude",
"log10_formed_mass_msun": "log10 formed mass [Msun]",
"dust_av": "Dust Av [mag]",
"log10_metallicity": "log10 Metallicity [absolute log10(Z)]",
}
return labels.get(parameter, parameter)
[docs]
def ordered_bands(df: pd.DataFrame) -> list[str]:
if "effective_wavelength_angstrom" not in df:
return [str(band) for band in df["band"].drop_duplicates().tolist()]
order = (
df.groupby("band")["effective_wavelength_angstrom"]
.median()
.sort_values()
.index.tolist()
)
return [str(band) for band in order]