Skip to content

API Reference

Quick reference

Function Description
compute_scaledRI Compute Scaled RI (RI₀₋₃ by default) from six isoGDGT abundances
predict_proxy_from_T Forward: temperature → proxy percentiles (pure Python)
predict_T_from_proxyObs Inverse: proxy → temperature with full uncertainty (runs Stan)
download_posteriors Download forward posteriors from Zenodo
download_training_data Download training CSVs + CMEMS NO₃ field
list_posteriors Print and return .nc stems in the local cache
build_fwd_data Build validated Stan data dict for forward calibration
get_posterior Run forward calibration Stan sampling
save_posterior Persist forward posterior as compressed NetCDF
load_posterior Load a forward or invT posterior from the cache
summarize_sampler_diagnostics Divergences, R-hat, ESS, E-BFMI

Prediction

Compute Scaled Ring Index

Compute Scaled Ring Index from six isoGDGT abundances.

Accepts raw LC/MS peak areas or fractional abundances — both give identical results because the formula divides by the total sum of all six GDGTs, so any common scale factor drops out. Default cren_rings=3 produces scaledRI_cren3 (RI₀₋₃), the canonical proxy used in TEXAS calibration posteriors.

Parameters

gdgt0, gdgt1, gdgt2, gdgt3, cren, cren_prime : float or array-like isoGDGT abundances — GDGT-0, GDGT-1, GDGT-2, GDGT-3, crenarchaeol, crenarchaeol regioisomer (cren'). Raw LC/MS peak areas and fractional abundances give the same result (see above). cren_rings : int Ring count assigned to both crenarchaeol and its regioisomer. 3 → scaledRI_cren3 / RI₀₋₃ (default, recommended). 4 → scaledRI / RI₀₋₄ (Zhang et al. 2016 convention).

Returns

numpy.ndarray or float Scaled Ring Index, dimensionless, nominally in [0, 1].

Notes

The formula is::

RI      = (1·GDGT1 + 2·GDGT2 + 3·GDGT3 + cren_rings·cren + cren_rings·cren')
          / (GDGT0 + GDGT1 + GDGT2 + GDGT3 + cren + cren')
scaledRI = RI / cren_rings

Examples

compute_scaledRI(0.45, 0.10, 0.08, 0.05, 0.30, 0.02) array(0.547...)

import pandas as pd df = pd.read_csv("my_gdgt_data.csv") df["scaledRI_cren3"] = compute_scaledRI( ... df["GDGT-0"], df["GDGT-1"], df["GDGT-2"], df["GDGT-3"], ... df["cren"], df["cren_prime"], ... )

Source code in src/TEXAS/predict.py
def compute_scaledRI(
    gdgt0,
    gdgt1,
    gdgt2,
    gdgt3,
    cren,
    cren_prime,
    *,
    cren_rings: int = 3,
) -> np.ndarray:
    """
    Compute Scaled Ring Index from six isoGDGT abundances.

    Accepts raw LC/MS peak areas or fractional abundances — both give identical
    results because the formula divides by the total sum of all six GDGTs, so
    any common scale factor drops out.
    Default ``cren_rings=3`` produces **scaledRI_cren3** (RI₀₋₃), the canonical
    proxy used in TEXAS calibration posteriors.

    Parameters
    ----------
    gdgt0, gdgt1, gdgt2, gdgt3, cren, cren_prime : float or array-like
        isoGDGT abundances — GDGT-0, GDGT-1, GDGT-2, GDGT-3, crenarchaeol,
        crenarchaeol regioisomer (cren').  Raw LC/MS peak areas and fractional
        abundances give the same result (see above).
    cren_rings : int
        Ring count assigned to both crenarchaeol and its regioisomer.
        ``3`` → scaledRI_cren3 / RI₀₋₃ (default, recommended).
        ``4`` → scaledRI / RI₀₋₄ (Zhang et al. 2016 convention).

    Returns
    -------
    numpy.ndarray or float
        Scaled Ring Index, dimensionless, nominally in [0, 1].

    Notes
    -----
    The formula is::

        RI      = (1·GDGT1 + 2·GDGT2 + 3·GDGT3 + cren_rings·cren + cren_rings·cren')
                  / (GDGT0 + GDGT1 + GDGT2 + GDGT3 + cren + cren')
        scaledRI = RI / cren_rings

    Examples
    --------
    >>> compute_scaledRI(0.45, 0.10, 0.08, 0.05, 0.30, 0.02)
    array(0.547...)

    >>> import pandas as pd
    >>> df = pd.read_csv("my_gdgt_data.csv")
    >>> df["scaledRI_cren3"] = compute_scaledRI(
    ...     df["GDGT-0"], df["GDGT-1"], df["GDGT-2"], df["GDGT-3"],
    ...     df["cren"],   df["cren_prime"],
    ... )
    """
    g0 = np.asarray(gdgt0, dtype=float)
    g1 = np.asarray(gdgt1, dtype=float)
    g2 = np.asarray(gdgt2, dtype=float)
    g3 = np.asarray(gdgt3, dtype=float)
    cr = np.asarray(cren, dtype=float)
    cp = np.asarray(cren_prime, dtype=float)

    numerator = g1 + 2 * g2 + 3 * g3 + cren_rings * cr + cren_rings * cp
    denominator = (g0 + g1 + g2 + g3 + cr + cp) * cren_rings
    return numerator / denominator

Predict proxy from T

Forward prediction: temperature → proxy percentiles (Scaled RI, TEX86, or any fitted proxy).

Samples n_draws self-consistent parameter sets from the forward calibration posterior (all parameters drawn from the same posterior index, preserving correlations) and evaluates the calibration curve at each requested temperature. Corresponds to the forward model described in Eq. 1 / Eq. 6–7 of the manuscript.

Parameters

temperatures : array-like Temperatures (°C) at which to evaluate the calibration curve. posterior : xr.Dataset or str Forward calibration posterior — either a loaded xr.Dataset or a saved-file name string (looked up in the posterior cache). n_draws : int Number of posterior draws to sample. Default 500. percentiles : list of float Percentiles to return, e.g. [5, 50, 95]. return_full : bool If True, also return the full (n_draws × len(temperatures)) ensemble array and run metadata under keys "ensemble" and "metadata". seed : int Random seed for reproducible draw sampling. gdgt23ratio : array-like, optional GDGT-2/GDGT-3 ratio values (one per temperature point). Required only when the posterior was fitted with the multivariate model (β_{G₂/₃} correction). no3 : array-like, optional Nitrate concentration values (one per temperature point). Required only when the posterior was fitted with NO₃ correction. no3_cutoff : float, optional Nitrate threshold (μmol/L) below which the NO₃ correction applies. Defaults to the value stored in the posterior attributes. suffix : str, optional Force a specific parameter suffix (e.g. "crtp"). Auto-detected by priority order when omitted.

Returns

dict with keys: "x_vals" — temperature array (°C) "pN" — one key per requested percentile, e.g. "p5", "p50", "p95" "ensemble" — full array, shape (n_draws, len(temperatures)), if return_full=True "metadata" — run metadata dict, if return_full=True

Source code in src/TEXAS/predict.py
def predict_proxy_from_T(
    temperatures: Union[np.ndarray, List[float]],
    posterior: Union[xr.Dataset, str],
    *,
    n_draws: int = 500,
    percentiles: List[float] = [5, 50, 95],
    return_full: bool = False,
    seed: int = 42,
    gdgt23ratio: Optional[np.ndarray] = None,
    no3: Optional[np.ndarray] = None,
    no3_cutoff: Optional[float] = None,
    suffix: Optional[str] = None,
) -> Dict[str, np.ndarray]:
    """
    Forward prediction: temperature → proxy percentiles (Scaled RI, TEX86, or any fitted proxy).

    Samples `n_draws` self-consistent parameter sets from the forward
    calibration posterior (all parameters drawn from the same posterior
    index, preserving correlations) and evaluates the calibration curve
    at each requested temperature.  Corresponds to the forward model
    described in Eq. 1 / Eq. 6–7 of the manuscript.

    Parameters
    ----------
    temperatures : array-like
        Temperatures (°C) at which to evaluate the calibration curve.
    posterior : xr.Dataset or str
        Forward calibration posterior — either a loaded xr.Dataset or
        a saved-file name string (looked up in the posterior cache).
    n_draws : int
        Number of posterior draws to sample.  Default 500.
    percentiles : list of float
        Percentiles to return, e.g. [5, 50, 95].
    return_full : bool
        If True, also return the full (n_draws × len(temperatures))
        ensemble array and run metadata under keys ``"ensemble"`` and
        ``"metadata"``.
    seed : int
        Random seed for reproducible draw sampling.
    gdgt23ratio : array-like, optional
        GDGT-2/GDGT-3 ratio values (one per temperature point).
        Required only when the posterior was fitted with the multivariate
        model (β_{G₂/₃} correction).
    no3 : array-like, optional
        Nitrate concentration values (one per temperature point).
        Required only when the posterior was fitted with NO₃ correction.
    no3_cutoff : float, optional
        Nitrate threshold (μmol/L) below which the NO₃ correction applies.
        Defaults to the value stored in the posterior attributes.
    suffix : str, optional
        Force a specific parameter suffix (e.g. ``"crtp"``).  Auto-detected
        by priority order when omitted.

    Returns
    -------
    dict with keys:
        ``"x_vals"``  — temperature array (°C)
        ``"pN"``      — one key per requested percentile, e.g. ``"p5"``, ``"p50"``, ``"p95"``
        ``"ensemble"``  — full array, shape (n_draws, len(temperatures)), if return_full=True
        ``"metadata"``  — run metadata dict, if return_full=True
    """
    if isinstance(posterior, str):
        posterior = load_posterior(posterior)

    return generate_ensemble_auto(
        post_ds=posterior,
        x_vals=np.asarray(temperatures, dtype=float),
        model_type="forward",
        gdgt23ratio=gdgt23ratio,
        no3=no3,
        no3_cutoff=no3_cutoff,
        return_full_ensemble=return_full,
        suffix=suffix,
        # passed through **kwargs to generate_ensemble:
        n_draws=n_draws,
        percentiles=percentiles,
        seed=seed,
    )

Predict T from proxy observations

Inverse reconstruction: scaled RI → temperature percentiles.

Runs the TEXAS-Bay inverse Stan model to infer paleotemperature from observed scaled Ring Index values. Marginalises over M draws from the forward calibration posterior to propagate calibration uncertainty into the temperature reconstruction. Corresponds to Section 8 (Applications to Paleothermometry) of the manuscript.

Parameters

proxyObs : array-like, shape (N,) Observed proxy values from downcore or coretop samples (e.g. scaledRI, TEX86). prior_mu_t : float or array-like, shape (N,) Prior mean temperature (°C). Scalar applies the same prior to all N observations; array sets a site-specific prior per sample. prior_sigma_t : float Prior temperature uncertainty (°C). Use a diffuse value (e.g. 10) when little prior information is available. fwd_posterior : str or xr.Dataset, optional The forward calibration posterior. Accepts either:

- **str** — name of the saved posterior (without ``.nc`` extension)
  in the posterior cache directory.  The file is loaded automatically.
- **xr.Dataset** — a pre-loaded posterior Dataset.  No file I/O or
  Zenodo download is attempted; pass this when the cache is unavailable
  (e.g. Google Colab with a Drive-mounted ``.nc``)::

      ds = xr.open_dataset("my_drive/posterior.nc")
      result = predict_T_from_proxyObs(..., fwd_posterior=ds)
str, optional

Temperature type: "SST" or "thermoT". Used for metadata and output file naming.

site_name : str, optional Label attached to result metadata and output filenames. predictors : dict, optional Non-thermal predictor arrays for the N observations, e.g. {"gdgt23ratio": array, "no3": array}. Must be provided when the forward posterior was fitted with the multivariate model. Overridden by no3 / gdgt23ratio shorthands when both are given. no3 : float or array-like, optional Nitrate concentration (µmol/L) for the N observations.

- **Array** (length N): per-observation values — use modern WOA23
  values extracted at each sample's location (``ocean_prop_ds``
  column ``"no3_sf2tc_avg"``).
- **Scalar**: broadcast to all N observations.  Pass a value above
  ``no3_cutoff`` (e.g. ``no3=10.0`` when ``no3_cutoff=1.0``) to
  effectively disable the NO₃ correction — all observations fall
  outside the correction window.

Overrides any ``"no3"`` key in *predictors*.  Ignored when
*site_lat* / *site_lon* / *no3_dataset* are also provided (the
lookup result takes priority).

gdgt23ratio : float or array-like, optional GDGT-2/GDGT-3 ratio for the N observations. Scalar or array, same broadcast rules as no3. Overrides any "gdgt23ratio" key in predictors. site_lat : float or array-like, optional Decimal latitude(s) of the study site(s). Scalar for a single drill core; array of length N to assign a distinct location to each observation. Requires site_lon and no3_dataset. site_lon : float or array-like, optional Decimal longitude(s) of the study site(s). Same shape rules as site_lat. no3_dataset : xr.Dataset, optional WOA23-derived dataset with a (lat, lon) grid, typically the ocean_prop_ds generated in the preprocessing notebook (SI_code1). Must contain no3_dataset_var. When provided together with site_lat / site_lon, the NO₃ value at those coordinates is looked up via bilinear interpolation and used as the predictor. The result is a scalar (one drill site) or array (per-obs sites), and is broadcast to all N observations when scalar. no3_dataset_var : str Variable name to extract from no3_dataset. Default "no3_sf2tc_avg". config : InvTConfig, optional Controls number of forward-posterior draws (M), seed, etc. Defaults to InvTConfig() (M=100). chains : int Number of MCMC chains. Default 4. iter_warmup : int Warmup iterations per chain. Default 500. iter_sampling : int Sampling iterations per chain. Default 1000. seed : int Random seed. Default 42. constraint_type : str Temperature constraint applied in the Stan model:

- ``"unconstrained"`` (default): no lower bound; P5 can be unrealistically cold
  near the calibration curve's lower asymptote.
- ``"hard_constraint"``: hard lower bound via ``<lower=min_temp>``; prevents
  sub-freezing samples but the Jacobian biases P50 warm for polar sites.
- ``"truncated_prior"`` (recommended when ``min_temp`` is set): proper
  truncated Normal prior via inverse-CDF reparameterization — P50 is
  data-driven and P5 is bounded at ``min_temp`` without warm bias.
- ``"reparameterized"``, ``"soft"``: experimental variants.

min_temp : float, optional Lower temperature bound (°C). Required for "hard_constraint" and "truncated_prior". Typically −1.8 (seawater freezing point). When provided without an explicit constraint_type, automatically selects "truncated_prior". threads_per_chain : int, optional Enable within-chain parallelism via Stan's reduce_sum. save_results : bool If True, save the quantile posterior .nc and results .npz to the invT cache directory. save_draws : bool If True, also save the raw posterior draws (pre-quantile) as a separate {base}_draws.nc file in the invT cache directory. The file contains t_est with dims (chain, draw, obs_idx) and is suitable for kernel-density plots or custom quantile calculation. Default False. filename_tag : str or list of str, optional Extra tag(s) appended to the output filename. cache_dir : Path or str, optional Directory where .nc and .npz files are written when save_results or save_draws is True. Defaults to the standard invT cache (~/.texas/cache/TEXAS_invT_posterior_cache/ for pip installs, or data/cache/TEXAS_invT_posterior_cache/ in the repo).

Returns

dict with keys: "proxyObs" — input proxy array "proxy_name" — proxy type label (e.g. "scaledRI", "TEX86") "p5" — 5th percentile temperature (°C), shape (N,) "p50" — median temperature (°C), shape (N,) "p95" — 95th percentile temperature (°C), shape (N,) "metadata" — run metadata dict (model name, attrs, etc.)

Source code in src/TEXAS/predict.py
def predict_T_from_proxyObs(
    proxyObs: Union[np.ndarray, List[float]],
    prior_mu_t: Union[np.ndarray, float],
    prior_sigma_t: float,
    fwd_posterior: Optional[Union[str, xr.Dataset]] = None,
    *,
    proxy_name: Optional[str] = None,
    temptype: Optional[str] = None,
    site_name: Optional[str] = None,
    predictors: Optional[Dict[str, np.ndarray]] = None,
    # ── Predictor shorthands (override anything in predictors dict) ───────
    no3: Optional[Union[float, np.ndarray]] = None,
    gdgt23ratio: Optional[Union[float, np.ndarray]] = None,
    # ── Modern-ocean NO₃ lookup from WOA23-derived dataset ───────────────
    site_lat: Optional[Union[float, np.ndarray]] = None,
    site_lon: Optional[Union[float, np.ndarray]] = None,
    no3_dataset: Optional[xr.Dataset] = None,
    no3_dataset_var: str = "no3_sf2tc_avg",
    # ─────────────────────────────────────────────────────────────────────
    config: Optional[InvTConfig] = None,
    chains: int = 4,
    iter_warmup: int = 500,
    iter_sampling: int = 1000,
    seed: int = 42,
    constraint_type: Literal[
        "unconstrained", "hard_constraint", "truncated_prior", "reparameterized", "soft"
    ] = "unconstrained",
    min_temp: Optional[float] = None,
    threads_per_chain: Optional[int] = None,
    save_results: bool = False,
    save_draws: bool = False,
    filename_tag: Optional[Union[str, Sequence[str]]] = None,
    cache_dir: Optional[Union[str, Path]] = None,
) -> Dict[str, Any]:
    """
    Inverse reconstruction: scaled RI → temperature percentiles.

    Runs the TEXAS-Bay inverse Stan model to infer paleotemperature from
    observed scaled Ring Index values.  Marginalises over M draws from
    the forward calibration posterior to propagate calibration uncertainty
    into the temperature reconstruction.  Corresponds to Section 8
    (Applications to Paleothermometry) of the manuscript.

    Parameters
    ----------
    proxyObs : array-like, shape (N,)
        Observed proxy values from downcore or coretop samples (e.g. scaledRI, TEX86).
    prior_mu_t : float or array-like, shape (N,)
        Prior mean temperature (°C).  Scalar applies the same prior to all
        N observations; array sets a site-specific prior per sample.
    prior_sigma_t : float
        Prior temperature uncertainty (°C).  Use a diffuse value (e.g. 10)
        when little prior information is available.
    fwd_posterior : str or xr.Dataset, optional
        The forward calibration posterior.  Accepts either:

        - **str** — name of the saved posterior (without ``.nc`` extension)
          in the posterior cache directory.  The file is loaded automatically.
        - **xr.Dataset** — a pre-loaded posterior Dataset.  No file I/O or
          Zenodo download is attempted; pass this when the cache is unavailable
          (e.g. Google Colab with a Drive-mounted ``.nc``)::

              ds = xr.open_dataset("my_drive/posterior.nc")
              result = predict_T_from_proxyObs(..., fwd_posterior=ds)

    temptype : str, optional
        Temperature type: ``"SST"`` or ``"thermoT"``.  Used for metadata
        and output file naming.
    site_name : str, optional
        Label attached to result metadata and output filenames.
    predictors : dict, optional
        Non-thermal predictor arrays for the N observations, e.g.
        ``{"gdgt23ratio": array, "no3": array}``.  Must be provided when
        the forward posterior was fitted with the multivariate model.
        Overridden by *no3* / *gdgt23ratio* shorthands when both are given.
    no3 : float or array-like, optional
        Nitrate concentration (µmol/L) for the N observations.

        - **Array** (length N): per-observation values — use modern WOA23
          values extracted at each sample's location (``ocean_prop_ds``
          column ``"no3_sf2tc_avg"``).
        - **Scalar**: broadcast to all N observations.  Pass a value above
          ``no3_cutoff`` (e.g. ``no3=10.0`` when ``no3_cutoff=1.0``) to
          effectively disable the NO₃ correction — all observations fall
          outside the correction window.

        Overrides any ``"no3"`` key in *predictors*.  Ignored when
        *site_lat* / *site_lon* / *no3_dataset* are also provided (the
        lookup result takes priority).
    gdgt23ratio : float or array-like, optional
        GDGT-2/GDGT-3 ratio for the N observations.  Scalar or array,
        same broadcast rules as *no3*.  Overrides any ``"gdgt23ratio"``
        key in *predictors*.
    site_lat : float or array-like, optional
        Decimal latitude(s) of the study site(s).  Scalar for a single
        drill core; array of length N to assign a distinct location to
        each observation.  Requires *site_lon* and *no3_dataset*.
    site_lon : float or array-like, optional
        Decimal longitude(s) of the study site(s).  Same shape rules as
        *site_lat*.
    no3_dataset : xr.Dataset, optional
        WOA23-derived dataset with a ``(lat, lon)`` grid, typically the
        ``ocean_prop_ds`` generated in the preprocessing notebook
        (SI_code1).  Must contain *no3_dataset_var*.  When provided
        together with *site_lat* / *site_lon*, the NO₃ value at those
        coordinates is looked up via bilinear interpolation and used as
        the predictor.  The result is a scalar (one drill site) or array
        (per-obs sites), and is broadcast to all N observations when scalar.
    no3_dataset_var : str
        Variable name to extract from *no3_dataset*.
        Default ``"no3_sf2tc_avg"``.
    config : InvTConfig, optional
        Controls number of forward-posterior draws (M), seed, etc.
        Defaults to ``InvTConfig()`` (M=100).
    chains : int
        Number of MCMC chains.  Default 4.
    iter_warmup : int
        Warmup iterations per chain.  Default 500.
    iter_sampling : int
        Sampling iterations per chain.  Default 1000.
    seed : int
        Random seed.  Default 42.
    constraint_type : str
        Temperature constraint applied in the Stan model:

        - ``"unconstrained"`` (default): no lower bound; P5 can be unrealistically cold
          near the calibration curve's lower asymptote.
        - ``"hard_constraint"``: hard lower bound via ``<lower=min_temp>``; prevents
          sub-freezing samples but the Jacobian biases P50 warm for polar sites.
        - ``"truncated_prior"`` (recommended when ``min_temp`` is set): proper
          truncated Normal prior via inverse-CDF reparameterization — P50 is
          data-driven and P5 is bounded at ``min_temp`` without warm bias.
        - ``"reparameterized"``, ``"soft"``: experimental variants.
    min_temp : float, optional
        Lower temperature bound (°C). Required for ``"hard_constraint"`` and
        ``"truncated_prior"``. Typically −1.8 (seawater freezing point).
        When provided without an explicit ``constraint_type``, automatically
        selects ``"truncated_prior"``.
    threads_per_chain : int, optional
        Enable within-chain parallelism via Stan's ``reduce_sum``.
    save_results : bool
        If True, save the quantile posterior ``.nc`` and results ``.npz`` to the
        invT cache directory.
    save_draws : bool
        If True, also save the raw posterior draws (pre-quantile) as a separate
        ``{base}_draws.nc`` file in the invT cache directory.  The file contains
        ``t_est`` with dims ``(chain, draw, obs_idx)`` and is suitable for
        kernel-density plots or custom quantile calculation.  Default False.
    filename_tag : str or list of str, optional
        Extra tag(s) appended to the output filename.
    cache_dir : Path or str, optional
        Directory where ``.nc`` and ``.npz`` files are written when
        *save_results* or *save_draws* is True.  Defaults to the standard
        invT cache (``~/.texas/cache/TEXAS_invT_posterior_cache/`` for pip
        installs, or ``data/cache/TEXAS_invT_posterior_cache/`` in the repo).

    Returns
    -------
    dict with keys:
        ``"proxyObs"``   — input proxy array
        ``"proxy_name"`` — proxy type label (e.g. ``"scaledRI"``, ``"TEX86"``)
        ``"p5"``         — 5th percentile temperature (°C), shape (N,)
        ``"p50"``        — median temperature (°C), shape (N,)
        ``"p95"``        — 95th percentile temperature (°C), shape (N,)
        ``"metadata"``   — run metadata dict (model name, attrs, etc.)
    """
    # ── Normalise fwd_posterior: split str vs pre-loaded Dataset ─────────────
    if isinstance(fwd_posterior, xr.Dataset):
        _fwd_ds: Optional[xr.Dataset] = fwd_posterior
        _fwd_name: Optional[str] = None
    else:
        _fwd_ds = None
        _fwd_name = fwd_posterior  # str or None

    # ── Resolve NO₃ predictor ─────────────────────────────────────────────────
    # Priority: site_lat/lon lookup > no3= explicit > predictors["no3"] > zeros
    predictors = dict(predictors or {})

    if site_lat is not None or site_lon is not None:
        if site_lat is None or site_lon is None:
            raise ValueError(
                "site_lat and site_lon must both be provided for a WOA23 lookup."
            )
        if no3_dataset is None:
            raise ValueError(
                "no3_dataset must be provided when using site_lat/site_lon. "
                "Pass the WOA23-derived ocean_prop_ds from your preprocessing notebook."
            )
        no3 = lookup_no3_from_woa(
            lat=site_lat,
            lon=site_lon,
            woa_dataset=no3_dataset,
            variable=no3_dataset_var,
        )
        _lat_repr = f"{site_lat}" if np.isscalar(site_lat) else f"array[{np.asarray(site_lat).size}]"
        _lon_repr = f"{site_lon}" if np.isscalar(site_lon) else f"array[{np.asarray(site_lon).size}]"
        _no3_repr = f"{float(no3):.3g}" if np.asarray(no3).ndim == 0 else f"array[{np.asarray(no3).size}], mean={float(np.nanmean(no3)):.3g}"
        print(f"🌊 WOA23 NO₃ lookup: lat={_lat_repr}, lon={_lon_repr}{_no3_repr} µmol/L")

    if no3 is not None:
        predictors["no3"] = no3
    if gdgt23ratio is not None:
        predictors["gdgt23ratio"] = gdgt23ratio

    # Warn if predictors are passed but the forward posterior doesn't use them
    _ds_for_check = _fwd_ds
    if _ds_for_check is None and _fwd_name:
        try:
            _ds_for_check = load_posterior(_fwd_name)
        except Exception:
            pass
    if _ds_for_check is not None:
        _attrs = _ds_for_check.attrs
        if predictors.get("gdgt23ratio") is not None and not _attrs.get("use_gdgt23ratio", False):
            warnings.warn(
                "gdgt23ratio was passed but the forward posterior has no GDGT-2/3 ratio "
                "parameters (use_gdgt23ratio=False) — the predictor will be silently ignored. "
                "To apply the GDGT-2/3 correction, use a multivariate posterior "
                "(e.g. gen_logi_fixed_hier_crtp_multiv_priorApprox_*).",
                UserWarning, stacklevel=2,
            )
        if predictors.get("no3") is not None and not _attrs.get("use_no3", False):
            warnings.warn(
                "no3 was passed but the forward posterior has no NO₃ parameters "
                "(use_no3=False) — the predictor will be silently ignored. "
                "To apply the NO₃ correction, use a multivariate posterior "
                "(e.g. gen_logi_fixed_hier_crtp_multiv_priorApprox_*).",
                UserWarning, stacklevel=2,
            )

    return _predict_temperature_from_proxyObs(
        proxyObs=proxyObs,
        prior_mu_t=prior_mu_t,
        prior_sigma_t=prior_sigma_t,
        fwd_posterior_name=_fwd_name,
        fwd_posterior=_fwd_ds,
        site_name=site_name,
        temptype=temptype,
        predictors=predictors,
        config=config,
        chains=chains,
        iter_warmup=iter_warmup,
        iter_sampling=iter_sampling,
        seed=seed,
        save_results=save_results,
        save_draws=save_draws,
        filename_tag=filename_tag,
        cache_dir=cache_dir,
        threads_per_chain=threads_per_chain,
        model_type="direct",
        constraint_type=constraint_type,
        min_temp=min_temp,
        proxy_name=proxy_name,
    )

Download and cache

Download posteriors

Download forward calibration posteriors from Zenodo.

Parameters

names : list of str, optional Subset of POSTERIOR_REGISTRY keys to download. Downloads all five posteriors when omitted (~158 MB total; the two EIV multivariate posteriors are ~78 MB each — pass names= to download only the univariate ones if you don't need the EIV model). cache_dir : Path or str, optional Destination directory. Defaults to the standard posterior cache. force : bool Re-download files that already exist locally.

Returns

list of Path Local paths of the downloaded .nc files.

Examples

Download only the univariate SST posterior (~0.3 MB):

download_posteriors(["gen_logi_fixed_hier_crtp_univ_priorApprox_SST_scaledRI_cren3"])

Source code in src/TEXAS/utils/download.py
def download_posteriors(
    names: Optional[List[str]] = None,
    cache_dir: Optional[Path | str] = None,
    force: bool = False,
) -> List[Path]:
    """Download forward calibration posteriors from Zenodo.

    Parameters
    ----------
    names : list of str, optional
        Subset of ``POSTERIOR_REGISTRY`` keys to download.  Downloads all
        five posteriors when omitted (~158 MB total; the two EIV multivariate
        posteriors are ~78 MB each — pass ``names=`` to download only the
        univariate ones if you don't need the EIV model).
    cache_dir : Path or str, optional
        Destination directory.  Defaults to the standard posterior cache.
    force : bool
        Re-download files that already exist locally.

    Returns
    -------
    list of Path
        Local paths of the downloaded ``.nc`` files.

    Examples
    --------
    Download only the univariate SST posterior (~0.3 MB):

    >>> download_posteriors(["gen_logi_fixed_hier_crtp_univ_priorApprox_SST_scaledRI_cren3"])
    """
    targets = names if names is not None else list(POSTERIOR_REGISTRY)
    dest_dir = Path(cache_dir) if cache_dir else POSTERIOR_CACHE_DIR
    dest_dir.mkdir(parents=True, exist_ok=True)

    for name in targets:
        if name not in POSTERIOR_REGISTRY:
            available = "\n  ".join(POSTERIOR_REGISTRY)
            raise KeyError(
                f"'{name}' is not in POSTERIOR_REGISTRY.\n"
                f"Available posteriors:\n  {available}"
            )

    missing = [n for n in targets if not (dest_dir / f"{n}.nc").exists() or force]
    if not missing:
        print("All requested posteriors already cached.")
        return [dest_dir / f"{n}.nc" for n in targets]

    total_mb = sum(POSTERIOR_REGISTRY[n]["size_mb"] for n in missing)
    if total_mb >= 5:
        print(f"Downloading {len(missing)} posterior(s) — total ~{total_mb:.0f} MB")

    paths = []
    for name in targets:
        entry = POSTERIOR_REGISTRY[name]
        dest = dest_dir / f"{name}.nc"
        _download_file(_file_url(entry["filename"]), dest, entry["size_mb"], force=force)
        paths.append(dest)

    return paths

Download all

Download everything from Zenodo: forward posteriors + training data.

Files are downloaded individually; already-cached files are skipped unless force=True. Total download is ~158 MB (dominated by the two EIV multivariate posteriors at ~78 MB each).

Parameters

cache_dir : Path or str, optional Destination for .nc posteriors. Defaults to the standard posterior cache directory. data_dir : Path or str, optional Destination for training data files. Defaults to data/spreadsheets/. force : bool Re-download files that already exist locally.

Source code in src/TEXAS/utils/download.py
def download_all(
    cache_dir: Optional[Path | str] = None,
    data_dir: Optional[Path | str] = None,
    force: bool = False,
) -> None:
    """Download everything from Zenodo: forward posteriors + training data.

    Files are downloaded individually; already-cached files are skipped unless
    *force=True*.  Total download is ~158 MB (dominated by the two EIV
    multivariate posteriors at ~78 MB each).

    Parameters
    ----------
    cache_dir : Path or str, optional
        Destination for ``.nc`` posteriors.  Defaults to the standard
        posterior cache directory.
    data_dir : Path or str, optional
        Destination for training data files.  Defaults to ``data/spreadsheets/``.
    force : bool
        Re-download files that already exist locally.
    """
    download_posteriors(cache_dir=cache_dir, force=force)
    download_training_data(dest_dir=data_dir, force=force)

Download training data

Download GDGT training data files from Zenodo.

Downloads the coretop/culture/mesocosm training CSVs and the CMEMS NO₃ uncertainty field used in the EIV calibration. These are needed only to re-run the SI preprocessing and calibration notebooks from scratch; they are NOT required for inverse temperature reconstructions — use :func:download_posteriors for that.

Parameters

dest_dir : Path or str, optional Destination directory. Defaults to data/spreadsheets/ in the repo (or ~/.texas/data/spreadsheets/ when pip-installed). force : bool Re-download files that already exist locally.

Returns

list of Path Local paths of the downloaded files.

Source code in src/TEXAS/utils/download.py
def download_training_data(
    dest_dir: Optional[Path | str] = None,
    force: bool = False,
) -> List[Path]:
    """Download GDGT training data files from Zenodo.

    Downloads the coretop/culture/mesocosm training CSVs and the CMEMS
    NO₃ uncertainty field used in the EIV calibration.  These are needed
    only to re-run the SI preprocessing and calibration notebooks from
    scratch; they are NOT required for inverse temperature reconstructions —
    use :func:`download_posteriors` for that.

    Parameters
    ----------
    dest_dir : Path or str, optional
        Destination directory.  Defaults to ``data/spreadsheets/`` in the
        repo (or ``~/.texas/data/spreadsheets/`` when pip-installed).
    force : bool
        Re-download files that already exist locally.

    Returns
    -------
    list of Path
        Local paths of the downloaded files.
    """
    dest = Path(dest_dir) if dest_dir else SPREADSHEETS_DIR
    dest.mkdir(parents=True, exist_ok=True)

    missing = [
        name for name, entry in TRAINING_DATA_REGISTRY.items()
        if not (dest / entry["filename"]).exists() or force
    ]
    if not missing:
        print("All training data files already present.")
        return [dest / entry["filename"] for entry in TRAINING_DATA_REGISTRY.values()]

    total_mb = sum(TRAINING_DATA_REGISTRY[n]["size_mb"] for n in missing)
    if total_mb >= 5:
        print(f"Downloading {len(missing)} training data file(s) — total ~{total_mb:.0f} MB")

    paths = []
    for name, entry in TRAINING_DATA_REGISTRY.items():
        out = dest / entry["filename"]
        _download_file(_file_url(entry["filename"]), out, entry["size_mb"], force=force)
        paths.append(out)

    return paths

List posteriors

List available posterior files in the cache directory.

Prints a summary and returns a dict of stem names that can be passed directly to predict_T_from_proxyObs(fwd_posterior=...).

Parameters

model_type : "forward", "invT", or "both" Which cache to inspect. Default "both". cache_dir : Path or str, optional Override the default cache root. When given, both forward and invT subdirectories are looked for under this path.

Returns

dict {"forward": [...], "invT": [...]} — lists of stem names (no .nc).

Source code in src/TEXAS/stan/io.py
def list_posteriors(
    model_type: Literal["forward", "invT", "both"] = "both",
    cache_dir: Optional[Union[str, Path]] = None,
) -> Dict[str, list]:
    """
    List available posterior files in the cache directory.

    Prints a summary and returns a dict of stem names that can be passed
    directly to ``predict_T_from_proxyObs(fwd_posterior=...)``.

    Parameters
    ----------
    model_type : "forward", "invT", or "both"
        Which cache to inspect.  Default ``"both"``.
    cache_dir : Path or str, optional
        Override the default cache root.  When given, both forward and invT
        subdirectories are looked for under this path.

    Returns
    -------
    dict
        ``{"forward": [...], "invT": [...]}`` — lists of stem names (no ``.nc``).
    """
    if cache_dir:
        root = Path(cache_dir)
        fwd_dir = root / "TEXAS_posterior_cache"
        invt_dir = root / "TEXAS_invT_posterior_cache"
    else:
        fwd_dir = DEFAULT_FORWARD_DIR
        invt_dir = DEFAULT_INVT_DIR

    result: Dict[str, list] = {"forward": [], "invT": []}

    def _list(directory: Path, label: str) -> list:
        files = sorted(directory.glob("*.nc")) if directory.exists() else []
        stems = [f.stem for f in files]
        print(f"{label} posteriors  [{directory}]")
        if stems:
            for name in stems:
                print(f"  {name}")
        else:
            print("  (none)")
        return stems

    if model_type in ("forward", "both"):
        result["forward"] = _list(fwd_dir, "Forward calibration")
    if model_type in ("invT", "both"):
        if model_type == "both":
            print()
        result["invT"] = _list(invt_dir, "Inverse temperature (invT)")

    return result

Set cache directory

Override TEXAS cache directories at runtime.

Call this before any posterior I/O. For a persistent override, set the TEXAS_CACHE_DIR environment variable instead.

Parameters:

Name Type Description Default
path 'str | Path'

Root directory for all TEXAS caches. Two subdirectories will be used inside it: TEXAS_posterior_cache/ and TEXAS_invT_posterior_cache/.

required
Source code in src/TEXAS/utils/paths.py
def set_cache_dir(path: "str | Path") -> None:
    """Override TEXAS cache directories at runtime.

    Call this before any posterior I/O.  For a persistent override, set the
    ``TEXAS_CACHE_DIR`` environment variable instead.

    Args:
        path: Root directory for all TEXAS caches.  Two subdirectories will be
              used inside it: ``TEXAS_posterior_cache/`` and
              ``TEXAS_invT_posterior_cache/``.
    """
    import TEXAS.utils.paths as _paths
    root = Path(path)
    _paths.CACHE_ROOT           = root
    _paths.CACHE_DIR            = root
    _paths.POSTERIOR_CACHE_DIR  = root / "TEXAS_posterior_cache"
    _paths.INVT_CACHE_DIR       = root / "TEXAS_invT_posterior_cache"
    # Propagate into io.py module-level defaults (bound at import time)
    try:
        import TEXAS.stan.io as _io
        _io.DEFAULT_FORWARD_DIR = _paths.POSTERIOR_CACHE_DIR
        _io.DEFAULT_INVT_DIR    = _paths.INVT_CACHE_DIR
    except ImportError:
        pass

Data builders

Build forward data

Build the Stan data dictionary for forward calibration models.

Handles all forward model variants
  • culmeso / Q1_culmeso / v1_culmeso : pass t_cul, proxy_cul, t_meso, proxy_meso
  • culmesocore : add t_crtp, proxy_crtp
  • hier_crtp_multiv : add gdgt23ratio_crtp, no3_crtp
  • hier_crtp_multiv_priorApprox : add culmeso_posterior (extracts hyperpriors)
  • hier_crtp_univ_priorApprox : add culmeso_posterior (no predictors needed)

Parameters:

Name Type Description Default
t_cul, proxy_cul

Culture temperature and proxy arrays.

required
t_meso, proxy_meso

Mesocosm temperature and proxy arrays.

required
t_crtp, proxy_crtp

Coretop temperature and proxy arrays.

required
gdgt23ratio_crtp

GDGT-2/GDGT-3 ratio for coretop samples. Sets use_gdgt23ratio=1 if non-zero/non-NaN.

None
sd_gdgt23ratio_crtp

Per-site measurement SE of gdgt23ratio (same units, linear). Required for the _eiv model; always included in the data dict (defaults to zeros when not provided, which disables the G₂/₃ EIV measurement model).

None
no3_crtp

Nitrate concentration for coretop samples. Sets use_no3=1 if non-zero/non-NaN.

None
sd_no3_crtp

Per-site measurement SE of NO₃ (μmol/L, linear space). Required for the _eiv model. Always included (defaults to zeros; sites with sd=0 receive only the lognormal prior and skip the normal measurement model).

None
no3_cutoff Optional[float]

NO3 threshold for the nonthermal correction. Priority: (1) this arg, (2) culmeso_posterior attrs, (3) auto-calculated via Spearman method.

None
proxy_residuals_crtp Optional[ndarray]

Pre-computed proxy residuals for NO3 threshold calculation. If omitted, residuals are computed internally by fitting a generalized logistic curve. Warning: Stan models use generalized logistic — residuals from other functional forms may shift the threshold.

None
culmeso_posterior Optional[Dataset]

xr.Dataset from a completed culmeso forward run. Auto-extracts prior_mean_/prior_sd_ hyperpriors and no3_cutoff (if saved in attrs).

None
prior_mean_*/prior_sd_*

Manual hyperprior values. Override auto-extracted values from culmeso_posterior for individual params.

required

Returns:

Name Type Description
dict Dict[str, Any]

Stan-ready data dict with proxyObs_ keys, N_ counts, use_* flags, and hyperpriors — ready for get_posterior().

Source code in src/TEXAS/data/builder.py
def build_fwd_data(
    *,
    # Culture observations
    t_cul=None,
    proxy_cul=None,
    # Mesocosm observations
    t_meso=None,
    proxy_meso=None,
    # Coretop observations
    t_crtp=None,
    proxy_crtp=None,
    # Optional coretop predictors
    gdgt23ratio_crtp=None,
    sd_gdgt23ratio_crtp=None,   # per-site SE of G₂/₃ — required for _eiv model; defaults to zeros
    no3_crtp=None,
    sd_no3_crtp=None,            # per-site SE of NO₃ (μmol/L, linear) — required for _eiv model; defaults to zeros
    no3_cutoff: Optional[float] = None,
    proxy_residuals_crtp: Optional[np.ndarray] = None,
    # Analytical measurement SE of RI — required for _eiv model
    sd_proxyObs=None,            # per-site SE of scaled RI; defaults to 0.03 (Schouten et al. 2013)
    # Thermal-only R² — required for _eiv sigma prior scaling
    R2_thermal: Optional[float] = None,   # compute from a thermal-only coretop run first
    # Stage-1 culmeso posterior — auto-extracts hyperpriors and no3_cutoff
    culmeso_posterior: Optional[xr.Dataset] = None,
    # Manual hyperprior overrides (alternative or supplement to culmeso_posterior)
    prior_mean_t0=None, prior_sd_t0=None,
    prior_mean_k=None,  prior_sd_k=None,
    prior_mean_b=None,  prior_sd_b=None,
    prior_mean_v=None,  prior_sd_v=None,
) -> Dict[str, Any]:
    """
    Build the Stan data dictionary for forward calibration models.

    Handles all forward model variants:
      - culmeso / Q1_culmeso / v1_culmeso  : pass t_cul, proxy_cul, t_meso, proxy_meso
      - culmesocore                         : add t_crtp, proxy_crtp
      - hier_crtp_multiv                    : add gdgt23ratio_crtp, no3_crtp
      - hier_crtp_multiv_priorApprox        : add culmeso_posterior (extracts hyperpriors)
      - hier_crtp_univ_priorApprox          : add culmeso_posterior (no predictors needed)

    Args:
        t_cul, proxy_cul:           Culture temperature and proxy arrays.
        t_meso, proxy_meso:         Mesocosm temperature and proxy arrays.
        t_crtp, proxy_crtp:         Coretop temperature and proxy arrays.
        gdgt23ratio_crtp:           GDGT-2/GDGT-3 ratio for coretop samples.
                                    Sets use_gdgt23ratio=1 if non-zero/non-NaN.
        sd_gdgt23ratio_crtp:        Per-site measurement SE of gdgt23ratio (same units,
                                    linear). Required for the _eiv model; always included
                                    in the data dict (defaults to zeros when not provided,
                                    which disables the G₂/₃ EIV measurement model).
        no3_crtp:                   Nitrate concentration for coretop samples.
                                    Sets use_no3=1 if non-zero/non-NaN.
        sd_no3_crtp:                Per-site measurement SE of NO₃ (μmol/L, linear space).
                                    Required for the _eiv model. Always included (defaults
                                    to zeros; sites with sd=0 receive only the lognormal
                                    prior and skip the normal measurement model).
        no3_cutoff:                 NO3 threshold for the nonthermal correction.
                                    Priority: (1) this arg, (2) culmeso_posterior attrs,
                                    (3) auto-calculated via Spearman method.
        proxy_residuals_crtp:       Pre-computed proxy residuals for NO3 threshold
                                    calculation. If omitted, residuals are computed
                                    internally by fitting a generalized logistic curve.
                                    Warning: Stan models use generalized logistic — residuals
                                    from other functional forms may shift the threshold.
        culmeso_posterior:          xr.Dataset from a completed culmeso forward run.
                                    Auto-extracts prior_mean_*/prior_sd_* hyperpriors and
                                    no3_cutoff (if saved in attrs).
        prior_mean_*/prior_sd_*:    Manual hyperprior values. Override auto-extracted
                                    values from culmeso_posterior for individual params.

    Returns:
        dict: Stan-ready data dict with proxyObs_* keys, N_* counts, use_* flags,
              and hyperpriors — ready for get_posterior().
    """
    import warnings

    # ── Validate dataset pairs and sizes ──────────────────────────────────────
    _pairs = [
        ("culture",  "cul",  t_cul,  proxy_cul),
        ("mesocosm", "meso", t_meso, proxy_meso),
        ("coretop",  "crtp", t_crtp, proxy_crtp),
    ]
    for label, sfx, t, p in _pairs:
        if (t is None) != (p is None):
            raise ValueError(
                f"{label}: t_{sfx} and proxy_{sfx} must be provided together."
            )
        if t is not None and len(np.asarray(t)) != len(np.asarray(p)):
            raise ValueError(
                f"{label}: t_{sfx} (len={len(np.asarray(t))}) and "
                f"proxy_{sfx} (len={len(np.asarray(p))}) must have the same length."
            )

    if all(t is None for _, _, t, _ in _pairs):
        raise ValueError(
            "build_fwd_data() requires at least one dataset "
            "(culture, mesocosm, or coretop)."
        )
    if (gdgt23ratio_crtp is not None or no3_crtp is not None) and t_crtp is None:
        raise ValueError(
            "gdgt23ratio_crtp / no3_crtp require coretop data (t_crtp, proxy_crtp)."
        )
    if sd_gdgt23ratio_crtp is not None and gdgt23ratio_crtp is None:
        raise ValueError(
            "sd_gdgt23ratio_crtp requires gdgt23ratio_crtp to be provided."
        )
    if sd_no3_crtp is not None and no3_crtp is None:
        raise ValueError(
            "sd_no3_crtp requires no3_crtp to be provided."
        )

    # ── Build data dict ────────────────────────────────────────────────────────
    data: Dict[str, Any] = {}
    summary_lines: List[str] = []

    def _range_str(t_arr, p_arr):
        return (
            f"N={len(t_arr)}"
            f"  (t: {t_arr.min():.1f}{t_arr.max():.1f}°C,"
            f" proxy: {p_arr.min():.3f}{p_arr.max():.3f})"
        )

    # Culture
    if t_cul is not None:
        t_arr = np.asarray(t_cul, dtype=float)
        p_arr = np.asarray(proxy_cul, dtype=float)
        data.update({"N_cul": len(t_arr), "t_cul": t_arr, "proxyObs_cul": p_arr})
        summary_lines.append(f"   culture:   {_range_str(t_arr, p_arr)}")
    else:
        summary_lines.append("   culture:   —")

    # Mesocosm
    if t_meso is not None:
        t_arr = np.asarray(t_meso, dtype=float)
        p_arr = np.asarray(proxy_meso, dtype=float)
        data.update({"N_meso": len(t_arr), "t_meso": t_arr, "proxyObs_meso": p_arr})
        summary_lines.append(f"   mesocosm:  {_range_str(t_arr, p_arr)}")
    else:
        summary_lines.append("   mesocosm:  —")

    # Coretop + predictors
    N_crtp = 0
    if t_crtp is not None:
        t_arr  = np.asarray(t_crtp,    dtype=float)
        p_arr  = np.asarray(proxy_crtp, dtype=float)
        N_crtp = len(t_arr)
        data.update({"N_crtp": N_crtp, "t_crtp": t_arr, "proxyObs_crtp": p_arr})
        summary_lines.append(f"   coretop:   {_range_str(t_arr, p_arr)}")

        pred_parts: List[str] = []

        # GDGT-2/3 ratio
        if gdgt23ratio_crtp is not None:
            g_arr = np.asarray(gdgt23ratio_crtp, dtype=float)
            if len(g_arr) != N_crtp:
                raise ValueError(
                    f"gdgt23ratio_crtp length ({len(g_arr)}) != N_crtp ({N_crtp})."
                )
            use_g = int(not (np.all(np.isnan(g_arr)) or np.all(g_arr == 0)))
            data.update({"gdgt23ratio_crtp": g_arr, "use_gdgt23ratio": use_g})
            pred_parts.append(f"gdgt23ratio {'✓' if use_g else '✗'}")

            if sd_gdgt23ratio_crtp is not None:
                sd_g = np.asarray(sd_gdgt23ratio_crtp, dtype=float)
                if len(sd_g) != N_crtp:
                    raise ValueError(
                        f"sd_gdgt23ratio_crtp length ({len(sd_g)}) != N_crtp ({N_crtp})."
                    )
                if np.any(sd_g < 0):
                    raise ValueError("sd_gdgt23ratio_crtp must be non-negative.")
                data["sd_gdgt23ratio_crtp"] = sd_g
                pred_parts.append(
                    f"sd_gdgt23ratio: {sd_g.min():.3g}{sd_g.max():.3g} (EIV)"
                )
            else:
                # Always include sd_gdgt23ratio_crtp so _eiv model always has its required
                # data variable. Zeros disable the EIV measurement model for G₂/₃.
                data["sd_gdgt23ratio_crtp"] = np.zeros(N_crtp, dtype=float)
        else:
            data.update({"gdgt23ratio_crtp": np.zeros(N_crtp), "use_gdgt23ratio": 0,
                         "sd_gdgt23ratio_crtp": np.zeros(N_crtp, dtype=float)})

        # NO3
        if no3_crtp is not None:
            n_arr = np.asarray(no3_crtp, dtype=float)
            if len(n_arr) != N_crtp:
                raise ValueError(
                    f"no3_crtp length ({len(n_arr)}) != N_crtp ({N_crtp})."
                )
            use_n = int(not (np.all(np.isnan(n_arr)) or np.all(n_arr == 0)))
            data.update({"no3_crtp": n_arr, "use_no3": use_n})

            if sd_no3_crtp is not None:
                sd_n = np.asarray(sd_no3_crtp, dtype=float)
                if len(sd_n) != N_crtp:
                    raise ValueError(
                        f"sd_no3_crtp length ({len(sd_n)}) != N_crtp ({N_crtp})."
                    )
                if np.any(sd_n < 0):
                    raise ValueError("sd_no3_crtp must be non-negative.")
                data["sd_no3_crtp"] = sd_n
                n_with_se = int((sd_n > 0).sum())
                pred_parts.append(
                    f"sd_no3: {sd_n[sd_n>0].min():.3g}{sd_n.max():.3g} µmol/L "
                    f"({n_with_se}/{N_crtp} sites have SE; EIV)"
                )
            else:
                # sd_no3_crtp not provided: treat all NO₃ sites as exact (no latent variable).
                # The Stan data block always requires sd_no3_crtp; default to zeros so
                # build_fwd_data() produces a complete data dict for EIV models.
                data["sd_no3_crtp"] = np.zeros(N_crtp, dtype=float)
                pred_parts.append("sd_no3: not provided (all sites treated as exact)")

            if use_n:
                # Determine no3_cutoff — priority: explicit > posterior attrs > auto-calc
                if no3_cutoff is not None:
                    final_cutoff  = float(no3_cutoff)
                    cutoff_source = "user-provided"
                elif (culmeso_posterior is not None
                      and culmeso_posterior.attrs.get("no3_cutoff") is not None):
                    final_cutoff  = float(culmeso_posterior.attrs["no3_cutoff"])
                    cutoff_source = "from culmeso_posterior attrs"
                else:
                    print("   🔍 Auto-calculating no3_cutoff (Spearman method)...")
                    if proxy_residuals_crtp is not None:
                        warnings.warn(
                            "Pre-computed proxy_residuals_crtp provided. "
                            "Note: Stan models use generalized logistic functional form — "
                            "residuals from other models may produce a suboptimal threshold.",
                            UserWarning, stacklevel=2,
                        )
                    final_cutoff = _auto_no3_cutoff(
                        n_arr, t_arr, p_arr,
                        proxy_residuals=proxy_residuals_crtp,
                    )
                    cutoff_source = "auto-calculated (Spearman)"
                    print(f"      → no3_cutoff = {final_cutoff:.3g} µmol/L")

                # Validate cutoff
                if final_cutoff <= 0:
                    warnings.warn(
                        f"no3_cutoff={final_cutoff} ≤ 0. _no3ratio Stan models require "
                        "cutoff > 0 (used as log10(no3/cutoff) denominator). "
                        "Specify a positive no3_cutoff= explicitly.",
                        UserWarning, stacklevel=2,
                    )
                no3_min = float(np.nanmin(n_arr))
                no3_max = float(np.nanmax(n_arr))
                if not (no3_min <= final_cutoff <= no3_max):
                    warnings.warn(
                        f"no3_cutoff={final_cutoff:.3g} is outside the range of no3_crtp "
                        f"({no3_min:.3g}{no3_max:.3g} µmol/L). The NO3 correction may "
                        "not apply to any samples. Check your data or threshold value.",
                        UserWarning, stacklevel=2,
                    )

                data["no3_cutoff"] = final_cutoff
                pred_parts.append(
                    f"no3 ✓  (no3_cutoff={final_cutoff:.3g}, {cutoff_source})"
                )
            else:
                data["no3_cutoff"] = 1.0
                pred_parts.append("no3 ✗  (all zeros/NaN — correction disabled)")
        else:
            data.update({"no3_crtp": np.zeros(N_crtp), "use_no3": 0, "no3_cutoff": 1.0,
                         "sd_no3_crtp": np.zeros(N_crtp, dtype=float)})

        summary_lines.append(
            f"   predictors: {', '.join(pred_parts) if pred_parts else 'none'}"
        )
    else:
        summary_lines.append("   coretop:   —")
        summary_lines.append("   predictors: none")

    # ── sd_proxyObs — per-site RI analytical SE (_eiv model) ─────────────────
    if sd_proxyObs is not None:
        sd_p = np.asarray(sd_proxyObs, dtype=float)
        if N_crtp and len(sd_p) != N_crtp:
            raise ValueError(
                f"sd_proxyObs length ({len(sd_p)}) != N_crtp ({N_crtp})."
            )
        if np.any(sd_p < 0):
            raise ValueError("sd_proxyObs must be non-negative.")
        data["sd_proxyObs"] = sd_p
    elif N_crtp:
        data["sd_proxyObs"] = np.full(N_crtp, 0.03, dtype=float)  # default Rs (Schouten et al. 2013)

    # ── R2_thermal — thermal-only R² for sigma prior scaling (_eiv model) ──────
    if R2_thermal is not None:
        r2 = float(R2_thermal)
        if not (0.0 <= r2 <= 1.0):
            raise ValueError(f"R2_thermal must be in [0, 1], got {r2}.")
        data["R2_thermal"] = r2

    # ── Hyperpriors ────────────────────────────────────────────────────────────
    manual_hp = {
        "prior_mean_t0": prior_mean_t0, "prior_sd_t0": prior_sd_t0,
        "prior_mean_k":  prior_mean_k,  "prior_sd_k":  prior_sd_k,
        "prior_mean_b":  prior_mean_b,  "prior_sd_b":  prior_sd_b,
        "prior_mean_v":  prior_mean_v,  "prior_sd_v":  prior_sd_v,
    }

    if culmeso_posterior is not None:
        hp      = _extract_culmeso_hyperpriors(culmeso_posterior)
        suffix  = hp.pop("_suffix")
        n_draws = hp.pop("_n_draws")
        data.update(hp)
        for k, v in manual_hp.items():   # manual overrides win
            if v is not None:
                data[k] = float(v)
        param_lines = [
            f"      {p}: mean={data[f'prior_mean_{p}']:.3g}  sd={data[f'prior_sd_{p}']:.3g}"
            for p in _FWD_HYPERPARAMS if f"prior_mean_{p}" in data
        ]
        summary_lines.append(
            f"   hyperpriors from culmeso posterior  (suffix={suffix}, {n_draws} draws):"
        )
        summary_lines.extend(param_lines)
    else:
        provided = {k: float(v) for k, v in manual_hp.items() if v is not None}
        data.update(provided)
        if provided:
            summary_lines.append("   hyperpriors: manual")
            for k, v in provided.items():
                summary_lines.append(f"      {k}: {v:.3g}")
        else:
            summary_lines.append("   hyperpriors: none")

    # ── Print summary ──────────────────────────────────────────────────────────
    print("📦 build_fwd_data summary:")
    for line in summary_lines:
        print(line)

    return data

Build invT input data

Build the data dictionary for Stan's inverse model and sampler configuration.

WORKFLOW: ───────── 1. Load forward calibration posterior from .nc file (or accept a pre-loaded Dataset) 2. Randomly sample M parameter sets from that posterior 3. Extract calibration curve parameters (t0, k, b, v, sigma) 4. Package optional environmental predictors (GDGT-2/3, NO3) if used 5. Return data dict (for Stan) + sampler_kwargs (for CmdStanPy)

Parameters:

Name Type Description Default
proxyObs Union[ndarray, List[float]]

Observed proxy values to predict temperature from (length N). Any proxy is accepted: scaledRI, TEX86, ringIndex, etc.

None
prior_mu_t Union[ndarray, float]

Prior mean temperature (scalar or array of length N)

None
prior_sigma_t float

Prior temperature uncertainty (e.g., 10°C)

None
fwd_posterior_name Optional[str]

Name of saved forward calibration (without .nc extension). Not required when fwd_posterior is supplied directly.

None
predictors Optional[Dict[str, ndarray]]

Optional environmental covariates {'gdgt23ratio': array, 'no3': array}

None
config Optional[InvTConfig]

Configuration object controlling M, seed, etc.

None
fwd_posterior Optional[Dataset]

Pre-loaded forward posterior xr.Dataset. When provided, fwd_posterior_name is ignored and no file I/O is performed. Useful when running from Google Colab or any pip-install context where the posterior cache is not available.

None

Returns:

Name Type Description
data Dict[str, Any]

Dictionary for Stan's data block

sampler_kwargs Dict[str, Any]

Dictionary for CmdStanPy sampling configuration

Source code in src/TEXAS/data/builder.py
def build_invT_inputData(
    proxyObs: Union[np.ndarray, List[float]] = None,
    prior_mu_t: Union[np.ndarray, float] = None,
    prior_sigma_t: float = None,
    *,
    scaledRI: Union[np.ndarray, List[float]] = None,  # deprecated alias
    fwd_posterior_name: Optional[str] = None,
    predictors: Optional[Dict[str, np.ndarray]] = None,
    config: Optional[InvTConfig] = None,
    fwd_posterior: Optional[xr.Dataset] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Build the `data` dictionary for Stan's inverse model and sampler configuration.

    WORKFLOW:
    ─────────
    1. Load forward calibration posterior from .nc file (or accept a pre-loaded Dataset)
    2. Randomly sample M parameter sets from that posterior
    3. Extract calibration curve parameters (t0, k, b, v, sigma)
    4. Package optional environmental predictors (GDGT-2/3, NO3) if used
    5. Return data dict (for Stan) + sampler_kwargs (for CmdStanPy)

    Args:
        proxyObs: Observed proxy values to predict temperature from (length N).
            Any proxy is accepted: scaledRI, TEX86, ringIndex, etc.
        prior_mu_t: Prior mean temperature (scalar or array of length N)
        prior_sigma_t: Prior temperature uncertainty (e.g., 10°C)
        fwd_posterior_name: Name of saved forward calibration (without .nc extension).
            Not required when fwd_posterior is supplied directly.
        predictors: Optional environmental covariates {'gdgt23ratio': array, 'no3': array}
        config: Configuration object controlling M, seed, etc.
        fwd_posterior: Pre-loaded forward posterior xr.Dataset. When provided,
            fwd_posterior_name is ignored and no file I/O is performed. Useful
            when running from Google Colab or any pip-install context where the
            posterior cache is not available.

    Returns:
        data: Dictionary for Stan's data block
        sampler_kwargs: Dictionary for CmdStanPy sampling configuration
    """
    # Backward-compat: accept deprecated scaledRI kwarg
    if scaledRI is not None and proxyObs is None:
        import warnings
        warnings.warn(
            "The 'scaledRI' parameter is deprecated; use 'proxyObs' instead.",
            DeprecationWarning, stacklevel=2,
        )
        proxyObs = scaledRI
    if proxyObs is None:
        raise TypeError("build_invT_inputData() missing required argument: 'proxyObs'")

    if fwd_posterior is None and fwd_posterior_name is None:
        raise ValueError(
            "Provide either fwd_posterior_name (cache lookup) "
            "or fwd_posterior (pre-loaded xr.Dataset)."
        )

    config = config or InvTConfig()
    np.random.seed(config.seed)  # Ensure reproducible M-sample selection
    predictors = predictors or {}

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 1: LOAD FORWARD CALIBRATION POSTERIOR
    # ═══════════════════════════════════════════════════════════════════════════
    # Accept a pre-loaded Dataset (fwd_posterior) or load from the cache by name.
    # Expected structure: xr.Dataset with dims (chain, draw) and data_vars like:
    #   - t0_crtp, k_crtp, b_crtp, v_crtp, sigma_proxyObs_crtp
    #   - beta_G23_crtp, beta_NO3_crtp (if multivariate)
    # ───────────────────────────────────────────────────────────────────────────
    if fwd_posterior is not None:
        post: xr.Dataset = fwd_posterior
    else:
        post: xr.Dataset = load_posterior(fwd_posterior_name)
    vars_ = list(post.data_vars)

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 2: FLATTEN MCMC DIMENSIONS (chain, draw) → (sample)
    # ═══════════════════════════════════════════════════════════════════════════
    # Stan posteriors have shape (chain, draw, ...). We flatten to a single
    # dimension 'sample' so we can randomly select M draws across all chains.
    # ───────────────────────────────────────────────────────────────────────────
    if "chain" in post.dims:
        # Standard case: (chain, draw) → (sample)
        post = post.stack(sample=("chain", "draw")).reset_index("sample")
        draw_dim_name = "sample"
    else:
        # Edge case: already flattened or single-chain posterior
        if "sample" in post.dims:
            draw_dim_name = "sample"
        elif "draw" in post.dims:
            draw_dim_name = "draw"
        else:
            raise ValueError("Could not find a 'draw' or 'sample' dimension in the posterior file.")

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 3: PREPARE OBSERVATION DATA
    # ═══════════════════════════════════════════════════════════════════════════
    y = np.asarray(proxyObs, dtype=float)  # Observed proxy values
    N = y.size  # Number of observations (e.g., coretop samples or PETM section)

    # Prior on temperature: Can be scalar (same prior for all N) or array (site-specific)
    mu_t = (np.full(N, float(prior_mu_t))
            if np.isscalar(prior_mu_t)
            else np.asarray(prior_mu_t, dtype=float))
    if mu_t.shape[0] != N:
        raise ValueError(f"prior_mu_t length ({mu_t.shape[0]}) must match proxyObs length ({N})")

    if config.mode != "ensemble":
        raise NotImplementedError(f"Only 'ensemble' mode is supported, not '{config.mode}'")

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 4: IDENTIFY PARAMETER SUFFIX
    # ═══════════════════════════════════════════════════════════════════════════
    # Forward models may produce parameters like:
    #   - t0_crtp, k_crtp (coretop-only calibration)
    #   - t0_culmesocore, k_culmesocore (culture+mesocosm+coretop)
    # We prioritize higher-data combinations (crtp > culmesocore > culmes > ...)
    # ───────────────────────────────────────────────────────────────────────────
    PRIORITY_SUFFIXES = ["crtp", "culmesocore", "culmes", "meso", "cul"]
    used_suffix = config.suffix
    if not used_suffix:
        # Auto-detect the highest-priority suffix present in the posterior
        for sfx in PRIORITY_SUFFIXES:
            if all(f"{p}_{sfx}" in vars_ for p in _FORWARD_PARAMS):
                used_suffix = sfx
                break
        if not used_suffix:
            raise ValueError("Could not find a valid parameter suffix in the forward posterior.")
    else:
        # User specified a suffix - verify it exists
        missing = [v for v in [f"{p}_{used_suffix}" for p in _FORWARD_PARAMS] if v not in vars_]
        if missing:
            raise ValueError(f"User suffix '{used_suffix}' is missing required parameters: {missing}")

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 5: RANDOMLY SAMPLE M PARAMETER SETS FROM THE FORWARD POSTERIOR
    # ═══════════════════════════════════════════════════════════════════════════
    # This is the KEY STEP for Bayesian uncertainty propagation!
    # 
    # Instead of using a single "best-fit" calibration, we randomly draw M
    # plausible calibration curves from the forward posterior. Each draw m
    # represents one self-consistent set of parameters {t0, k, b, v, sigma}.
    # 
    # The inverse model averages predictions across all M draws, naturally
    # accounting for calibration uncertainty.
    # ───────────────────────────────────────────────────────────────────────────
    total_available = post.dims[draw_dim_name]

    if config.n_draws is None:
        # Auto-select M: use up to 25% of available draws, bounded [100, 500].
        # For typical runs (4 chains × 1000 = 4000 draws): 4000//4=1000, capped
        # at 500 → M=500 (~4.5% Monte Carlo error). Use InvTConfig(n_draws=300)
        # to match the publication default (M=300, ~5.8% error).
        M = min(500, max(100, total_available // 4))
        print(
            f"⚙️  Auto M={M} ({total_available} posterior draws available; "
            f"approx. error ≈ {100/M**0.5:.1f}%). "
            f"Override with InvTConfig(n_draws=...)."
        )
    else:
        M = config.n_draws
        if M > total_available:
            print(
                f"⚠️  n_draws={M} exceeds available draws ({total_available}); "
                f"sampling with replacement (draws will repeat)."
            )

    draw_indices = np.random.choice(total_available, M, replace=True)
    P = post.isel({draw_dim_name: draw_indices})  # Subset to M samples
    # Now P has shape (sample=M, ...) instead of (sample=total_available, ...)

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 6: INITIALIZE DATA DICTIONARY FOR STAN
    # ═══════════════════════════════════════════════════════════════════════════
    data: Dict[str, Any] = {
        "N": N,                     # Number of observations
        "proxyObs": y,              # Observed proxy values
        "prior_mu_t": mu_t,         # Prior temperature mean
        "prior_sigma_t": float(prior_sigma_t),  # Prior temperature uncertainty
        "M": M,                     # Number of forward posterior samples
    }

    used_posts: List[str] = []  # Track which parameters were extracted

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 7: EXTRACT CORE CALIBRATION PARAMETERS
    # ═══════════════════════════════════════════════════════════════════════════
    # Extract the M samples of each core parameter:
    #   t0[m] = inflection temperature for draw m
    #   k[m]  = growth rate for draw m
    #   b[m]  = lower asymptote for draw m
    # These will be passed to Stan as FIXED data (not sampled again).
    # ───────────────────────────────────────────────────────────────────────────
    for p in _FORWARD_PARAMS:
        key = f"{p}_{used_suffix}"  # e.g., "t0_crtp"
        data[p] = np.asarray(P[key].values, dtype=float)  # Shape: (M,)
        used_posts.append(key)

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 8: EXTRACT EXTRA PARAMETERS (v) IF PRESENT
    # ═══════════════════════════════════════════════════════════════════════════
    # Generalized logistic models include v (shape) parameter. Q is fixed to 1
    # in all active models and is no longer sampled or passed to invT Stan models.
    # Standard logistic models also fix v=1 and don't include it in the posterior.
    # ───────────────────────────────────────────────────────────────────────────
    for p in _EXTRA_PARAMS:
        key = f"{p}_{used_suffix}"
        if key in P:
            data[p] = np.asarray(P[key].values, dtype=float)  # Shape: (M,)
            used_posts.append(key)

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 9: EXTRACT RESIDUAL ERROR PARAMETER
    # ═══════════════════════════════════════════════════════════════════════════
    # sigma_proxyObs[m] = measurement/model error for calibration sample m
    # If not available in posterior, use a default value.
    # ───────────────────────────────────────────────────────────────────────────
    sigma_key = f"sigma_proxyObs_{used_suffix}"
    if sigma_key not in P:
        sigma_key = f"sigma_scaledRI_{used_suffix}"  # backward compat with old posteriors
    if sigma_key in P:
        data["sigma_proxyObs"] = np.asarray(P[sigma_key].values, dtype=float)  # Shape: (M,)
        used_posts.append(sigma_key)
    else:
        # Fallback to constant error if not estimated in forward model
        data["sigma_proxyObs"] = np.full(M, 0.1, dtype=float)

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 10: PROCESS OPTIONAL ENVIRONMENTAL PREDICTORS
    # ═══════════════════════════════════════════════════════════════════════════
    # If the forward model included environmental covariates (GDGT-2/3, NO3),
    # we need to:
    #   1. Extract their coefficients (beta0) from the forward posterior
    #   2. Pass the observed predictor values for our N observations
    #   3. Set use_* flags so Stan knows to apply nonthermal corrections
    # ───────────────────────────────────────────────────────────────────────────
    predictor_usage = {}
    for pred in OPTIONAL_PREDICTORS:  # ['gdgt23ratio', 'no3']
        # Check if this predictor was used in the forward calibration
        use_flag = bool(post.attrs.get(f"use_{pred}", False))
        predictor_usage[pred] = use_flag

        # Get observed predictor values for our N observations (or zeros if not provided).
        # Scalars (e.g. no3=10.0) are broadcast to all N observations.
        arr = ensure_numpy(predictors.get(pred, np.zeros(N, dtype=float)))
        if arr.ndim == 0:  # scalar → broadcast
            arr = np.full(N, float(arr), dtype=float)
        if arr.shape[0] != N:
            raise ValueError(f"Predictor '{pred}' length ({arr.shape[0]}) must equal N ({N})")

        # Add to Stan data
        data[pred] = arr  # Observed predictor values: shape (N,)
        data[f"use_{pred}"] = 1 if use_flag else 0  # Flag for Stan conditional logic

        # Stan always expects beta_G23 and beta_NO3 as vector[M], even when the
        # predictor is disabled (use_* = 0). Provide zeros when not used so the
        # data block dimensions match regardless of which predictors are active.
        beta_name = PREDICTOR_BETA_NAMES[pred]          # e.g., "beta_G23"
        if use_flag:
            # Extract the coefficient from forward posterior
            beta_key = f"{beta_name}_{used_suffix}"          # e.g., "beta_G23_crtp"
            if beta_key not in P:
                raise ValueError(f"Expected '{beta_key}' in forward posterior but not found.")
            data[beta_name] = np.asarray(P[beta_key].values, dtype=float)  # Shape: (M,)
            used_posts.append(beta_key)
        else:
            # Predictor unused → pass zeros so Stan's vector[M] declaration is satisfied
            data[beta_name] = np.zeros(M, dtype=float)  # Shape: (M,)

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 11: HANDLE NITRATE CUTOFF (Special case for NO3 predictor)
    # ═══════════════════════════════════════════════════════════════════════════
    # NO3 uses a threshold model: correction only applies when NO3 > cutoff.
    # Priority: (1) forward posterior attrs, (2) InvTConfig, (3) default 0.0
    # ───────────────────────────────────────────────────────────────────────────
    if data.get("use_no3"):
        if np.allclose(data["no3"], 0):
            # All NO3 values are zero → no correction needed
            data["no3_cutoff"] = 0.0
        else:
            # 1. Try to get cutoff from forward posterior attributes
            cutoff_from_attrs = post.attrs.get("no3_cutoff")

            if cutoff_from_attrs is not None:
                final_cutoff = float(cutoff_from_attrs)
                print(f"💡 Using no3_cutoff from forward posterior attributes: {final_cutoff}")
            # 2. Fallback to InvTConfig if not in attributes
            elif config.no3_cutoff is not None:
                final_cutoff = config.no3_cutoff
                print(f"💡 Using no3_cutoff from InvTConfig: {final_cutoff}")
            # 3. Default to 0.0 if neither source available
            else:
                final_cutoff = 0.0
                print(f"⚠️ no3_cutoff not specified. Using default value: {final_cutoff}")

            data["no3_cutoff"] = final_cutoff
    else:
        # NO3 not used → cutoff irrelevant
        data["no3_cutoff"] = 0.0

    # ═══════════════════════════════════════════════════════════════════════════
    # STEP 12: PACKAGE SAMPLER CONFIGURATION AND METADATA
    # ═══════════════════════════════════════════════════════════════════════════
    # Return both the Stan data and configuration for CmdStanPy.
    # Metadata helps with provenance tracking and debugging.
    # ───────────────────────────────────────────────────────────────────────────
    fwd_model_name = post.attrs.get("stan_model_name", "")
    sampler_kwargs: Dict[str, Any] = {
        "chains": 4,           # Number of MCMC chains
        "iter_warmup": 500,    # Warmup iterations per chain
        "iter_sampling": 1000, # Sampling iterations per chain
        "seed": int(config.seed),
        "_metadata": {
            "posteriors_used": used_posts,  # Which parameters were extracted
            "calibration_model_name": fwd_model_name,
            "used_suffix": used_suffix,  # e.g., "crtp"
            "predictor_usage": predictor_usage,  # Which covariates are active
            "no3ratio": "_no3ratio" in fwd_model_name,  # Forward model used centred NO₃ form
        },
    }

    return data, sampler_kwargs

WOA23 NO₃ lookup

Look up modern NO₃ at one or more lat/lon coordinates from a WOA23-derived xarray Dataset.

The dataset is typically the preprocessed ocean_prop_ds generated in SI_code1, which contains thermocline-depth-integrated WOA23 climatology on a regular (lat, lon) grid. The returned value(s) are time-invariant (climatological mean) and intended as a modern-ocean proxy for the NO₃ correction in paleo reconstructions.

Parameters

lat : float or array-like Latitude(s) in decimal degrees (−90 to 90). Pass a scalar for a single drill site; pass an array of length N to match N observations. lon : float or array-like Longitude(s) in decimal degrees. Both −180–180 and 0–360 conventions are accepted — the function normalises to match the dataset's convention automatically. woa_dataset : xr.Dataset WOA23-derived dataset with a (lat, lon) grid containing variable. Dimensions must be named "lat" and "lon". variable : str Name of the NO₃ variable to extract. Default "no3_sf2tc_avg" (thermocline depth-integrated annual average from SI_code1). method : {"linear", "nearest"} Interpolation method. "linear" (default) performs bilinear interpolation and is preferred for smooth fields. "nearest" snaps to the closest grid cell and is useful when the dataset is sparse or has NaN-masked shelves.

Returns

np.ndarray NO₃ value(s) in µmol/L. Shape matches the scalar/array input: a 0-d array for scalar inputs, 1-d array of length N for array inputs. NaN is returned for locations outside the dataset's valid range (e.g. continental shelves masked in WOA23).

Raises

KeyError If variable is not found in woa_dataset. ValueError If woa_dataset does not have "lat" and "lon" dimensions.

Examples

Single drill site:

no3_val = lookup_no3_from_woa(15.3, -23.7, ocean_prop_ds)

returns scalar-equivalent float; broadcasts to all N obs automatically

result = predict_T_from_proxyObs(..., no3=no3_val)

Multi-site stack (per-obs lookup):

no3_arr = lookup_no3_from_woa(core_df["lat"].values, ... core_df["lon"].values, ... ocean_prop_ds) result = predict_T_from_proxyObs(..., no3=no3_arr)

Source code in src/TEXAS/data/ocean_lookup.py
def lookup_no3_from_woa(
    lat: Union[float, np.ndarray],
    lon: Union[float, np.ndarray],
    woa_dataset: xr.Dataset,
    variable: str = "no3_sf2tc_avg",
    method: Literal["linear", "nearest"] = "linear",
) -> np.ndarray:
    """
    Look up modern NO₃ at one or more lat/lon coordinates from a WOA23-derived
    xarray Dataset.

    The dataset is typically the preprocessed ``ocean_prop_ds`` generated in
    SI_code1, which contains thermocline-depth-integrated WOA23 climatology on
    a regular ``(lat, lon)`` grid.  The returned value(s) are time-invariant
    (climatological mean) and intended as a modern-ocean proxy for the NO₃
    correction in paleo reconstructions.

    Parameters
    ----------
    lat : float or array-like
        Latitude(s) in decimal degrees (−90 to 90).  Pass a scalar for a
        single drill site; pass an array of length N to match N observations.
    lon : float or array-like
        Longitude(s) in decimal degrees.  Both −180–180 and 0–360 conventions
        are accepted — the function normalises to match the dataset's convention
        automatically.
    woa_dataset : xr.Dataset
        WOA23-derived dataset with a ``(lat, lon)`` grid containing *variable*.
        Dimensions must be named ``"lat"`` and ``"lon"``.
    variable : str
        Name of the NO₃ variable to extract.  Default ``"no3_sf2tc_avg"``
        (thermocline depth-integrated annual average from SI_code1).
    method : {"linear", "nearest"}
        Interpolation method.  ``"linear"`` (default) performs bilinear
        interpolation and is preferred for smooth fields.  ``"nearest"`` snaps
        to the closest grid cell and is useful when the dataset is sparse or
        has NaN-masked shelves.

    Returns
    -------
    np.ndarray
        NO₃ value(s) in µmol/L.  Shape matches the scalar/array input:
        a 0-d array for scalar inputs, 1-d array of length N for array inputs.
        NaN is returned for locations outside the dataset's valid range (e.g.
        continental shelves masked in WOA23).

    Raises
    ------
    KeyError
        If *variable* is not found in *woa_dataset*.
    ValueError
        If *woa_dataset* does not have ``"lat"`` and ``"lon"`` dimensions.

    Examples
    --------
    Single drill site:

    >>> no3_val = lookup_no3_from_woa(15.3, -23.7, ocean_prop_ds)
    >>> # returns scalar-equivalent float; broadcasts to all N obs automatically
    >>> result = predict_T_from_proxyObs(..., no3=no3_val)

    Multi-site stack (per-obs lookup):

    >>> no3_arr = lookup_no3_from_woa(core_df["lat"].values,
    ...                                core_df["lon"].values,
    ...                                ocean_prop_ds)
    >>> result = predict_T_from_proxyObs(..., no3=no3_arr)
    """
    # ── Validate dataset ──────────────────────────────────────────────────────
    if "lat" not in woa_dataset.dims or "lon" not in woa_dataset.dims:
        raise ValueError(
            "woa_dataset must have 'lat' and 'lon' dimensions. "
            f"Found: {list(woa_dataset.dims)}"
        )
    if variable not in woa_dataset:
        raise KeyError(
            f"Variable '{variable}' not found in woa_dataset. "
            f"Available: {list(woa_dataset.data_vars)}"
        )

    da: xr.DataArray = woa_dataset[variable]

    # ── Normalise longitude convention ────────────────────────────────────────
    # Dataset may use 0–360; input may use −180–180 (or vice-versa).
    # Detect the dataset's convention from its lon coordinate range.
    ds_lon = da["lon"].values
    ds_uses_0_360 = float(ds_lon.max()) > 180.0

    lon_arr = np.asarray(lon, dtype=float)
    if ds_uses_0_360:
        # Normalise input to 0–360
        lon_arr = lon_arr % 360.0
    else:
        # Normalise input to −180–180
        lon_arr = ((lon_arr + 180.0) % 360.0) - 180.0

    lat_arr = np.asarray(lat, dtype=float)
    scalar_input = lat_arr.ndim == 0

    # ── Interpolate ───────────────────────────────────────────────────────────
    if scalar_input:
        result = da.interp(
            lat=float(lat_arr),
            lon=float(lon_arr),
            method=method,
        )
        out = np.asarray(result.values, dtype=float)
    else:
        # Vectorised interpolation: build coordinate DataArrays so xarray
        # performs point-wise (not grid) interpolation.
        lat_da = xr.DataArray(lat_arr, dims="obs")
        lon_da = xr.DataArray(lon_arr, dims="obs")
        result = da.interp(lat=lat_da, lon=lon_da, method=method)
        out = np.asarray(result.values, dtype=float)

    # ── Warn on NaN (masked shelf / land) ────────────────────────────────────
    n_nan = int(np.isnan(out).sum()) if out.ndim > 0 else int(np.isnan(out))
    if n_nan > 0:
        import warnings
        warnings.warn(
            f"lookup_no3_from_woa: {n_nan} location(s) returned NaN — likely "
            "on a continental shelf or land mask in the WOA23 dataset. "
            "Consider using method='nearest' or check your lat/lon values.",
            UserWarning,
            stacklevel=2,
        )

    return out

Forward calibration

Get posterior

Run forward calibration Stan sampling and return the posterior.

Wraps StanSampler with automatic predictor detection, CPU configuration, and metadata attachment. The returned dataset can be passed directly to predict_proxy_from_T or saved with save_posterior.

Parameters

data : dict Stan data dict built by build_fwd_data(). Predictor flags (use_gdgt23ratio, use_no3) are auto-detected from the arrays present; you do not need to set them manually. stan_file : str Stan model name (without .stan), e.g. "gen_logi_fixed_hier_crtp_multiv_priorApprox_eiv". temptype : str Temperature variable type, e.g. "SST" or "thermoT". Stored in the posterior metadata. proxy_name : str Proxy type, e.g. "scaledRI_cren3". Required — stored in the .nc attrs and validated downstream when the posterior is used for inverse reconstruction. iter_warmup : int, optional HMC warmup iterations per chain (default: CmdStan default, 1000). iter_sampling : int, optional Post-warmup sampling iterations per chain (default: 1000). chains : int, optional Number of independent chains (default: 4). parallel_chains : int, optional Chains to run simultaneously (auto-detected from CPU count). threads_per_chain : int, optional Threads per chain for reduce_sum models (auto-enabled for models whose filename contains reduce_sum). adapt_delta : float, optional Target acceptance rate (default: 0.8). Increase toward 0.99 to reduce divergences at the cost of more leapfrog steps. max_treedepth : int, optional Maximum tree depth for HMC (default: 10). **kwargs Additional keyword arguments forwarded to CmdStanModel.sample.

Returns

posterior : xr.Dataset Forward calibration posterior with parameter draws and metadata attrs (model name, temptype, proxy_name, priors, diagnostics). diagnostics : str Human-readable sampler diagnostic summary (divergences, R-hat, ESS, E-BFMI).

Raises

ValueError If active predictors are present but a univariate model is requested, or if an EIV model is requested without R2_thermal.

Examples

data = build_fwd_data(t_crtp=..., proxy_crtp=..., ...) posterior, diag = get_posterior( ... data, ... stan_file="gen_logi_fixed_hier_crtp_univ_priorApprox", ... temptype="SST", ... proxy_name="scaledRI_cren3", ... ) save_posterior(posterior)

Source code in src/TEXAS/stan/sampler.py
def get_posterior(
    data: dict,
    stan_file: str,
    temptype: str,
    proxy_name: str,
    *,
    iter_warmup: Optional[int] = None,
    iter_sampling: Optional[int] = None,
    threads_per_chain: Optional[int] = None,
    chains: Optional[int] = None,
    parallel_chains: Optional[int] = None,
    adapt_delta: Optional[float] = None,
    max_treedepth: Optional[int] = None,
    **kwargs
) -> Tuple[xr.Dataset, str]:
    """
    Run forward calibration Stan sampling and return the posterior.

    Wraps ``StanSampler`` with automatic predictor detection, CPU
    configuration, and metadata attachment.  The returned dataset can be
    passed directly to ``predict_proxy_from_T`` or saved with
    ``save_posterior``.

    Parameters
    ----------
    data : dict
        Stan data dict built by ``build_fwd_data()``.  Predictor flags
        (``use_gdgt23ratio``, ``use_no3``) are auto-detected from the
        arrays present; you do not need to set them manually.
    stan_file : str
        Stan model name (without ``.stan``), e.g.
        ``"gen_logi_fixed_hier_crtp_multiv_priorApprox_eiv"``.
    temptype : str
        Temperature variable type, e.g. ``"SST"`` or ``"thermoT"``.
        Stored in the posterior metadata.
    proxy_name : str
        Proxy type, e.g. ``"scaledRI_cren3"``.  Required — stored in
        the ``.nc`` attrs and validated downstream when the posterior is
        used for inverse reconstruction.
    iter_warmup : int, optional
        HMC warmup iterations per chain (default: CmdStan default, 1000).
    iter_sampling : int, optional
        Post-warmup sampling iterations per chain (default: 1000).
    chains : int, optional
        Number of independent chains (default: 4).
    parallel_chains : int, optional
        Chains to run simultaneously (auto-detected from CPU count).
    threads_per_chain : int, optional
        Threads per chain for ``reduce_sum`` models (auto-enabled for
        models whose filename contains ``reduce_sum``).
    adapt_delta : float, optional
        Target acceptance rate (default: 0.8).  Increase toward 0.99 to
        reduce divergences at the cost of more leapfrog steps.
    max_treedepth : int, optional
        Maximum tree depth for HMC (default: 10).
    **kwargs
        Additional keyword arguments forwarded to ``CmdStanModel.sample``.

    Returns
    -------
    posterior : xr.Dataset
        Forward calibration posterior with parameter draws and metadata
        attrs (model name, temptype, proxy_name, priors, diagnostics).
    diagnostics : str
        Human-readable sampler diagnostic summary (divergences, R-hat,
        ESS, E-BFMI).

    Raises
    ------
    ValueError
        If active predictors are present but a univariate model is
        requested, or if an EIV model is requested without ``R2_thermal``.

    Examples
    --------
    >>> data = build_fwd_data(t_crtp=..., proxy_crtp=..., ...)
    >>> posterior, diag = get_posterior(
    ...     data,
    ...     stan_file="gen_logi_fixed_hier_crtp_univ_priorApprox",
    ...     temptype="SST",
    ...     proxy_name="scaledRI_cren3",
    ... )
    >>> save_posterior(posterior)
    """
    rng_seed = kwargs.setdefault("seed", 42)
    np.random.seed(rng_seed)

    # Normalize/auto-complete predictors & flags
    data = auto_detect_predictors(data)

    # Guard: reject univ model when active predictors are present
    _use_g23 = data.get("use_gdgt23ratio", 0)
    _use_no3 = data.get("use_no3", 0)
    _is_univ = "univ" in str(stan_file) and "multiv" not in str(stan_file)
    if (_use_g23 or _use_no3) and _is_univ:
        _active = []
        if _use_g23:
            _active.append("gdgt23ratio")
        if _use_no3:
            _active.append("no3")
        raise ValueError(
            f"Active predictor(s) {_active} detected in data but stan_file='{stan_file}' "
            f"is a univariate model. Use a multivariate model (e.g. replace 'univ' with "
            f"'multiv', such as 'gen_logi_fixed_hier_crtp_multiv_priorApprox') or "
            f"omit the predictor arrays from the data dict."
        )

    # Guard: _eiv model requires R2_thermal (no sensible default exists)
    _is_eiv = "_eiv" in str(stan_file)
    if _is_eiv and "R2_thermal" not in data:
        raise ValueError(
            f"EIV model '{stan_file}' requires R2_thermal (R² from a thermal-only coretop fit) "
            f"but it is missing from the data dict.\n"
            f"Compute it from a non-EIV run first, then pass: "
            f"build_fwd_data(..., R2_thermal=<value>)."
        )

    # Auto-detect optimal CPU settings for this machine.
    # parallel_chains: always apply (CmdStanPy already does min(chains, cpu_count)
    #   but being explicit avoids surprises).
    # threads_per_chain: only beneficial for reduce_sum model variants; the
    #   standard forward models are single-threaded per chain.
    _auto = suggest_stan_sampling_kwargs()
    _uses_reduce_sum = "reduce_sum" in str(stan_file)

    # Push explicit sampling args to CmdStanPy via kwargs
    if iter_warmup is not None:
        kwargs["iter_warmup"] = iter_warmup
    if iter_sampling is not None:
        kwargs["iter_sampling"] = iter_sampling
    if chains is not None:
        kwargs["chains"] = chains
    if parallel_chains is not None:
        kwargs["parallel_chains"] = parallel_chains
    else:
        kwargs.setdefault("parallel_chains", _auto["parallel_chains"])

    # threads_per_chain requires STAN_THREADS compilation; only auto-enable for
    # reduce_sum variants where it actually helps.
    cpp_options: Dict = {}
    if threads_per_chain is not None:
        kwargs["threads_per_chain"] = threads_per_chain
        cpp_options["STAN_THREADS"] = True
    elif _uses_reduce_sum and "threads_per_chain" in _auto:
        threads_per_chain = _auto["threads_per_chain"]
        kwargs["threads_per_chain"] = threads_per_chain
        cpp_options["STAN_THREADS"] = True
        print(
            f"⚙️  Auto CPU config: {_auto['parallel_chains']} parallel chains, "
            f"{threads_per_chain} threads/chain (reduce_sum model detected)"
        )
    else:
        print(
            f"⚙️  Auto CPU config: {_auto['parallel_chains']} parallel chains "
            f"(single-threaded per chain)"
        )

    if adapt_delta is not None:
        kwargs["adapt_delta"] = adapt_delta
    if max_treedepth is not None:
        kwargs["max_treedepth"] = max_treedepth

    _has_tqdm = importlib.util.find_spec("tqdm") is not None
    kwargs.setdefault("show_progress", _has_tqdm)
    kwargs.setdefault("show_console", not _has_tqdm)

    print(f"   proxy_name: {proxy_name}")

    compiler = StanCompiler()
    sampler = StanSampler(compiler)

    ds, diag = sampler.sample(
        data=data,
        stan_file=stan_file,
        temptype=temptype,
        proxy_name=proxy_name,
        cpp_options=cpp_options or None,
        **kwargs
    )

    return ds, diag

Save posterior

Save a forward-model posterior to disk as compressed NetCDF.

The filename is auto-generated from the posterior's metadata attrs: {model}_{temptype}[_gdgt23ratio][_no3_{cutoff}][_{proxy_name}]{suffix}.nc

Parameters

posterior : xr.Dataset Forward calibration posterior returned by get_posterior(). Must have stan_model_name, temptype, and proxy_name attrs set (proxy_name is required — a warning is raised if missing). cache_dir : str or Path, optional Directory to write the file. Defaults to the standard forward posterior cache (data/cache/TEXAS_posterior_cache/ for source installs, ~/.texas/cache/TEXAS_posterior_cache/ for pip installs). overwrite : bool If False, raise FileExistsError when the output path already exists. Default True. filename_suffix : str, optional Extra tag appended before .nc, e.g. "032326" for a date-stamped run. Leading/trailing underscores are stripped.

Returns

Path Absolute path of the saved .nc file.

Source code in src/TEXAS/stan/io.py
def save_posterior(
    posterior: xr.Dataset,
    cache_dir: Optional[Union[str, Path]] = None,
    overwrite: bool = True,
    filename_suffix: str = "",
) -> Path:
    """
    Save a forward-model posterior to disk as compressed NetCDF.

    The filename is auto-generated from the posterior's metadata attrs:
    ``{model}_{temptype}[_gdgt23ratio][_no3_{cutoff}][_{proxy_name}]{suffix}.nc``

    Parameters
    ----------
    posterior : xr.Dataset
        Forward calibration posterior returned by ``get_posterior()``.
        Must have ``stan_model_name``, ``temptype``, and ``proxy_name``
        attrs set (``proxy_name`` is required — a warning is raised if
        missing).
    cache_dir : str or Path, optional
        Directory to write the file.  Defaults to the standard forward
        posterior cache (``data/cache/TEXAS_posterior_cache/`` for
        source installs, ``~/.texas/cache/TEXAS_posterior_cache/`` for
        pip installs).
    overwrite : bool
        If ``False``, raise ``FileExistsError`` when the output path
        already exists.  Default ``True``.
    filename_suffix : str, optional
        Extra tag appended before ``.nc``, e.g. ``"032326"`` for a
        date-stamped run.  Leading/trailing underscores are stripped.

    Returns
    -------
    Path
        Absolute path of the saved ``.nc`` file.
    """
    if not isinstance(posterior, xr.Dataset):
        raise TypeError("posterior must be an xarray.Dataset")

    outdir = Path(cache_dir) if cache_dir else DEFAULT_FORWARD_DIR
    outdir.mkdir(exist_ok=True, parents=True)

    name  = posterior.attrs.get("stan_model_name", "unknown_model")
    ttype = posterior.attrs.get("temptype", "unknown")
    if posterior.attrs.get("use_gdgt23ratio", 0):
        ttype += "_gdgt23ratio"
    if posterior.attrs.get("use_no3", 0):
        cutoff = posterior.attrs.get("no3_cutoff")
        if cutoff is None:
            raise ValueError("no3_cutoff must be set when use_no3=1")
        ttype += f"_no3_{cutoff}"

    proxy_name = posterior.attrs.get("proxy_name", "") or ""
    proxy_tag = f"_{proxy_name}" if proxy_name and proxy_name != "unknown" else ""

    # sanitize suffix
    if filename_suffix:
        filename_suffix = f"_{filename_suffix.strip('_')}"

    outpath = outdir / f"{name}_{ttype}{proxy_tag}{filename_suffix}.nc"
    if outpath.exists() and not overwrite:
        raise FileExistsError(f"{outpath} exists and overwrite=False")

    posterior.attrs["filename"] = outpath.name
    if not posterior.attrs.get("proxy_name"):
        import warnings
        warnings.warn(
            "proxy_name is not set on this posterior. "
            "Pass proxy_name= to get_posterior() (e.g. proxy_name='scaledRI'). "
            "It is stored in the .nc file and used downstream to validate that "
            "the correct proxy type is passed to predict_T_from_proxyObs().",
            UserWarning, stacklevel=2,
        )
        posterior.attrs["proxy_name"] = "unknown"
    encoding = {var: {"zlib": True} for var in posterior.data_vars}
    sanitized = _sanitize_attrs_for_netcdf(posterior)
    sanitized.to_netcdf(outpath, encoding=encoding)
    print(f"Saved forward posterior to {outpath}  [proxy_name='{posterior.attrs['proxy_name']}']")
    return outpath

Load posterior

Load a posterior from disk: {model_name}.nc in the appropriate cache directory.

Parameters:

Name Type Description Default
model_name str

Name of the model file (without .nc extension)

required
model_type Literal['forward', 'invT']

Type of posterior ("forward" or "invT")

'forward'
cache_dir Optional[Union[str, Path]]

Custom cache directory (overrides default locations)

None

Returns:

Type Description
Dataset

xarray.Dataset containing the posterior

Raises:

Type Description
FileNotFoundError

If the posterior file doesn't exist

Source code in src/TEXAS/stan/io.py
def load_posterior(
    model_name: str,
    model_type: Literal["forward", "invT"] = "forward",
    cache_dir: Optional[Union[str, Path]] = None,
) -> xr.Dataset:
    """
    Load a posterior from disk: `{model_name}.nc` in the appropriate cache directory.

    Args:
        model_name: Name of the model file (without .nc extension)
        model_type: Type of posterior ("forward" or "invT")
        cache_dir: Custom cache directory (overrides default locations)

    Returns:
        xarray.Dataset containing the posterior

    Raises:
        FileNotFoundError: If the posterior file doesn't exist
    """
    # Determine cache directory
    if cache_dir:
        indir = Path(cache_dir)
    elif model_type == "forward":
        indir = DEFAULT_FORWARD_DIR
    elif model_type == "invT":
        indir = DEFAULT_INVT_DIR
    else:
        # This shouldn't happen due to type hints, but just in case
        raise ValueError(f"Invalid model_type: {model_type}. Must be 'forward' or 'invT'")

    # Ensure directory exists
    indir.mkdir(exist_ok=True, parents=True)

    # Construct file path
    fpath = indir / f"{model_name}.nc"

    if not fpath.exists():
        available = sorted(indir.glob("*.nc"))
        available_str = "\n    ".join(f.stem for f in available) if available else "(none)"
        raise FileNotFoundError(
            f"Posterior file not found: '{model_name}.nc'\n"
            f"Searched in: {indir}\n"
            f"Files present in that directory:\n    {available_str}\n\n"
            f"Options:\n"
            f"  1. The file is in a different directory — load it yourself and pass the Dataset:\n"
            f"       import xarray as xr\n"
            f"       ds = xr.open_dataset('/your/path/{model_name}.nc')\n"
            f"       predict_T_from_proxyObs(..., fwd_posterior=ds)\n\n"
            f"  2. Search a different cache directory:\n"
            f"       load_posterior('{model_name}', cache_dir='/your/path/here')\n\n"
            f"  3. Download from Zenodo:\n"
            f"       from TEXAS.utils.download import download_posteriors\n"
            f"       download_posteriors(['{model_name}'])"
        )

    return xr.load_dataset(fpath)

Ensemble

Generate ensemble (auto)

Sample draws from a forward posterior and compute calibration-curve percentiles.

Inspects the posterior's stan_model_name attr and data_vars to determine the model function, parameter names, and optional-predictor flags automatically, then delegates to generate_ensemble.

Parameters

post_ds : xr.Dataset Forward calibration posterior returned by get_posterior() or loaded with load_posterior(). x_vals : np.ndarray Temperature values (°C) at which to evaluate the calibration curve. model_type : {"auto", "forward", "inverse"} Force forward or inverse dispatch; "auto" (default) infers from the posterior. InvT posteriors are not supported — use predict_T_from_proxyObs() instead. gdgt23ratio : np.ndarray, optional GDGT-2/3 ratio values; required when the posterior was fitted with a multivariate (GDGT-2/3) model. no3 : np.ndarray, optional NO₃ concentrations (µmol/L); required when the posterior uses the NO₃ correction. no3_cutoff : float, optional Override the NO₃ cutoff from the posterior attrs. return_full_ensemble : bool If True, return the full M × N draw matrix in addition to percentiles. Default False. suffix : str, optional Force a specific parameter suffix (e.g. "crtp"); overrides auto-detection. **kwargs Forwarded to generate_ensemble.

Returns

dict Keys "p1""p99" (and optionally "ensemble") — each a numpy array of length len(x_vals).

Raises

NotImplementedError If called with an invT posterior.

Source code in src/TEXAS/ensemble/generator.py
def generate_ensemble_auto(
    post_ds: xr.Dataset,
    x_vals: np.ndarray,
    model_type: Literal["auto","forward","inverse"] = "auto",
    gdgt23ratio: Optional[np.ndarray] = None,
    no3: Optional[np.ndarray] = None,
    no3_cutoff: Optional[float] = None,
    return_full_ensemble: bool = False,
    suffix: Optional[str] = None,
    **kwargs
) -> Dict[str, np.ndarray]:
    """
    Sample draws from a forward posterior and compute calibration-curve percentiles.

    Inspects the posterior's ``stan_model_name`` attr and ``data_vars`` to
    determine the model function, parameter names, and optional-predictor
    flags automatically, then delegates to ``generate_ensemble``.

    Parameters
    ----------
    post_ds : xr.Dataset
        Forward calibration posterior returned by ``get_posterior()`` or
        loaded with ``load_posterior()``.
    x_vals : np.ndarray
        Temperature values (°C) at which to evaluate the calibration curve.
    model_type : {"auto", "forward", "inverse"}
        Force forward or inverse dispatch; ``"auto"`` (default) infers from
        the posterior.  InvT posteriors are not supported — use
        ``predict_T_from_proxyObs()`` instead.
    gdgt23ratio : np.ndarray, optional
        GDGT-2/3 ratio values; required when the posterior was fitted with
        a multivariate (GDGT-2/3) model.
    no3 : np.ndarray, optional
        NO₃ concentrations (µmol/L); required when the posterior uses the
        NO₃ correction.
    no3_cutoff : float, optional
        Override the NO₃ cutoff from the posterior attrs.
    return_full_ensemble : bool
        If ``True``, return the full M × N draw matrix in addition to
        percentiles.  Default ``False``.
    suffix : str, optional
        Force a specific parameter suffix (e.g. ``"crtp"``); overrides
        auto-detection.
    **kwargs
        Forwarded to ``generate_ensemble``.

    Returns
    -------
    dict
        Keys ``"p1"`` … ``"p99"`` (and optionally ``"ensemble"``) — each
        a numpy array of length ``len(x_vals)``.

    Raises
    ------
    NotImplementedError
        If called with an invT posterior.
    """
    model_name = post_ds.attrs.get("stan_model_name", "")
    is_inv = ("invT_" in model_name) or ("t_est" in post_ds.data_vars)

    if model_type in ("auto", "inverse") and is_inv:
        raise NotImplementedError(
            "generate_ensemble_auto() does not support invT posteriors. "
            "Use predict_T_from_proxyObs() for inverse temperature reconstruction."
        )

    elif model_type in ("auto", "forward") and not is_inv:
        det = detect_model_and_params(post_ds, suffix=suffix)  # <-- use explicit suffix
        # precedence: arg > detector > None (resolved inside generate_ensemble)
        resolved_no3_cutoff = no3_cutoff if no3_cutoff is not None else det.get("no3_cutoff")

        return generate_ensemble(
            post_ds=post_ds,
            model_function=det["model_function"],
            x_vals=x_vals,
            param_names=det["param_names"],
            suffix=det.get("suffix"),
            gdgt23ratio=gdgt23ratio,
            no3=no3,
            no3_cutoff=resolved_no3_cutoff,
            is_multivariate=det.get("is_multivariate"),
            use_gdgt23ratio_flag=det.get("use_gdgt23ratio"),
            use_no3_flag=det.get("use_no3"),
            return_full_ensemble=return_full_ensemble,
            **kwargs
        )
    else:
        raise ValueError(f"Cannot dispatch ensemble_auto(model_type={model_type}, is_inv={is_inv})")

Detect model and params

Auto-detect which logistic model and parameters to use. Uses a shared suffix priority via choose_suffix().

Source code in src/TEXAS/ensemble/detection.py
def detect_model_and_params(posterior_ds: xr.Dataset, suffix: str = None):
    """
    Auto-detect which logistic model and parameters to use.
    Uses a shared suffix priority via choose_suffix().
    """
    vars_ = set(posterior_ds.data_vars)
    attrs_ = set(posterior_ds.attrs)

    # Detect presence of optional parameter groups without assuming attrs
    has_v_any   = ("v" in vars_) or any(v.startswith("v_") for v in vars_)
    has_gdz_any = ("use_gdgt23ratio" in attrs_)
    has_no3_any = ("use_no3" in attrs_)

    # Build candidate basenames up-front for suffix selection
    basenames = ["t0", "b", "k"]
    if has_v_any:
        basenames.append("v")
    if has_gdz_any:
        basenames += ["beta_G23"]
    if has_no3_any:
        basenames += ["beta_NO3"]

    # Choose suffix by priority (or validate preferred)
    suffix = choose_suffix(posterior_ds, basenames, preferred=suffix)

    # Now, re-evaluate presence WITH the chosen suffix
    has_v   = (f"v_{suffix}" in vars_) or (suffix == "" and "v" in vars_)
    has_gdz = ("use_gdgt23ratio" in attrs_)
    has_no3 = ("use_no3" in attrs_) 
    detected_no3_cutoff = 0.0
    if has_no3:
        detected_no3_cutoff = posterior_ds.attrs.get("no3_cutoff", 0.0)

    # Build param list and select model
    is_generalized = has_v

    if is_generalized:
        params = ["t0", "b", "k"]
        if has_v:
            params.append("v")
        if has_gdz:
            params.append("beta_G23")
        if has_no3:
            params.append("beta_NO3")

        model_fn = (
            generalized_logistic_fixed_upper_multivariate
            if has_gdz or has_no3 else
            generalized_logistic_fixed_upper
        )
    else:
        params = ["t0", "b", "k"]
        if has_gdz:
            params.append("beta_G23")
        if has_no3:
            params.append("beta_NO3")

        model_fn = (
            simple_logistic_fixed_upper_multivariate
            if has_gdz or has_no3 else
            logistic_fixed_upper
        )

    return {
        "model_function": model_fn,
        "param_names": params,
        "suffix": suffix,
        "is_multivariate": bool(has_gdz or has_no3),
        "use_gdgt23ratio": bool(has_gdz),
        "use_no3": bool(has_no3),
        "no3_cutoff": detected_no3_cutoff,
    }

Diagnostics

Sampler diagnostics

Extract divergent__, treedepth__, E-BFMI, R_hat, and ESS_bulk from a CmdStanPy fit and return them as stan_diag_* attrs.

Source code in src/TEXAS/diagnostics.py
def summarize_sampler_diagnostics(fit) -> dict:
    """
    Extract divergent__, treedepth__, E-BFMI, R_hat, and ESS_bulk
    from a CmdStanPy fit and return them as stan_diag_* attrs.
    """
    diag = {}
    # 1) method variables
    mv = fit.method_variables()
    total_draws = mv["divergent__"].size

    # divergent transitions
    n_div = int(np.sum(mv["divergent__"]))
    pct_div = 100 * n_div / total_draws
    diag["stan_diag_n_divergent"] = n_div
    diag["stan_diag_pct_divergent"] = pct_div
    diag["stan_diag_divergent_status"] = "PASS" if pct_div < 1.0 else "FAIL"

    # treedepth
    td = mv["treedepth__"]
    max_td = 10
    n_td = int(np.sum(td >= max_td))
    pct_td = 100 * n_td / total_draws
    diag["stan_diag_n_max_treedepth"] = n_td
    diag["stan_diag_pct_max_treedepth"] = pct_td
    diag["stan_diag_treedepth_status"] = "PASS" if pct_td < 5.0 else "FAIL"

    # E-BFMI
    try:
        bfmi_vals = fit.bfmi if hasattr(fit, "bfmi") else fit.bfmi_
    except Exception:
        bfmi_vals = None
    min_ebfmi = float(np.min(bfmi_vals)) if bfmi_vals is not None else -1.0
    if min_ebfmi == -1.0:
        ebfmi_status = "UNKNOWN"
    else:
        ebfmi_status = "PASS" if min_ebfmi > 0.2 else "FAIL"
    diag["stan_diag_min_ebfmi"] = min_ebfmi
    diag["stan_diag_ebfmi_status"] = ebfmi_status

    # R-hat & ESS
    summary_df = fit.summary()
    max_rhat = float(summary_df["R_hat"].max())
    n_high_rhat = int((summary_df["R_hat"] > 1.01).sum())
    min_ess = float(summary_df["ESS_bulk"].min())
    diag["stan_diag_max_rhat"] = max_rhat
    diag["stan_diag_n_high_rhat"] = n_high_rhat
    diag["stan_diag_rhat_status"] = "PASS" if max_rhat < 1.01 else "FAIL"
    diag["stan_diag_min_ess_bulk"] = min_ess
    diag["stan_diag_ess_status"] = "PASS" if min_ess > 100 else "FAIL"

    # overall
    checks = [
        diag[k]
        for k in [
            "stan_diag_divergent_status",
            "stan_diag_treedepth_status",
            "stan_diag_rhat_status",
            "stan_diag_ess_status",
        ]
        if diag[k] != "UNKNOWN"
    ]
    diag["stan_diag_overall_status"] = "PASS" if all(c == "PASS" for c in checks) else "FAIL"

    return diag

Summary table

Build a DataFrame summarizing the stan_diag_* attrs from each xarray.Dataset.

Source code in src/TEXAS/diagnostics.py
def create_summary_table(datasets: list) -> pd.DataFrame:
    """
    Build a DataFrame summarizing the stan_diag_* attrs from each xarray.Dataset.
    """
    rows = []
    for ds in datasets:
        row = {"model": ds.attrs.get("filename", "unknown")}
        for k, v in ds.attrs.items():
            if k.startswith("stan_diag_"):
                row[k.replace("stan_diag_", "")] = v
        rows.append(row)
    return pd.DataFrame(rows)

Plotting

Plot prior distributions

Plot priors + any number of posterior distributions in a grid, split by parameter group (t0, k, b, etc.).

Parameters:

Name Type Description Default
param_source_map Optional[Dict[str, int]]

Optional dict mapping a param group name to the index of the dataset in posterior_datasets that should be used as the sole source for that group. All other datasets are skipped for that group.

Use this when different parameters come from different posteriors — e.g. logistic params (t0, k, b…) from a culmeso run and beta coefficients from a multivariate crtp run::

plot_prior_distributions(
    posterior_datasets=[culmeso_ds, crtp_multiv_ds],
    param_source_map={"beta_G23": 1, "beta_NO3": 1},
)

When a group is not in param_source_map, all datasets are searched as usual.

None
Source code in src/TEXAS/plotting/prior_plot.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
def plot_prior_distributions(
    priors_list: Optional[Union[List[str], Dict[str,str]]] = None,
    posterior_datasets: Optional[List["xr.Dataset"]] = None,
    posterior_labels_list: Optional[List[str]] = None,
    show_suptitle: bool = True,
    kde_bw: float = 0.3,
    focus_on_posterior: bool = True,
    include_groups: Sequence[str] = ("t0","k","b","v","a","beta_G23","beta_NO3"),
    suffix_include: Optional[List[str]] = None,
    zoomin_suffix: Optional[Union[str,List[str]]] = None,
    zoomin_dataset_idx: Optional[int] = None,
    use_linestyle_by_param: bool = False,
    show_histogram: bool = True,
    show_annotation: bool = False,
    set_linewidth: float = 1.5,
    set_fig_width_factor: float = 3,
    set_fig_height_factor: float = 3.5,
    set_leg_max_ncol: int = 3,
    color_list: Optional[Sequence[str]] = None,
    param_source_map: Optional[Dict[str, int]] = None,
    annotation_style: Literal["ci95", "ci68", "sigma"] = "ci95",
    show_subplot_legend: bool = True,
    show_figure_legend: bool = True,
    show_prior_expression: bool = True,
):
    """
    Plot priors + any number of posterior distributions in a grid,
    split by parameter group (t0, k, b, etc.).

    Args:
        param_source_map: Optional dict mapping a param group name to the index of
            the dataset in ``posterior_datasets`` that should be used as the sole
            source for that group.  All other datasets are skipped for that group.

            Use this when different parameters come from different posteriors — e.g.
            logistic params (t0, k, b…) from a ``culmeso`` run and beta coefficients
            from a multivariate ``crtp`` run::

                plot_prior_distributions(
                    posterior_datasets=[culmeso_ds, crtp_multiv_ds],
                    param_source_map={"beta_G23": 1, "beta_NO3": 1},
                )

            When a group is not in ``param_source_map``, all datasets are searched
            as usual.
    """
    # ── Resolve any string / Path entries in posterior_datasets ──────────────
    if posterior_datasets:
        from pathlib import Path as _Path
        from ..stan.io import load_posterior as _load_posterior
        resolved = []
        for item in posterior_datasets:
            if isinstance(item, (str, _Path)):
                resolved.append(_load_posterior(str(item)))
            else:
                resolved.append(item)
        posterior_datasets = resolved

    fig, axes = None, None
    parsed_priors = {}

    # If priors_list not supplied, pull prior strings from the posterior datasets'
    # attrs["priors"] (set automatically during sampling).  Merge across all
    # datasets and deduplicate so each prior name appears only once.
    if priors_list is None:
        seen = {}
        for ds in (posterior_datasets or []):
            for entry in ds.attrs.get("priors", []):
                name = entry.split(":")[0].strip()
                seen.setdefault(name, entry)   # first dataset wins on conflict
        priors_list = list(seen.values())

    for prior in priors_list:
        name, dist_expr = prior.split(":", 1)
        name = name.strip()
        dist_expr = dist_expr.strip()

        if any(sym in dist_expr for sym in ["mu_", "sigma_", "logit"]):
            continue

        match = re.match(r"(\w+)\(([^,]+),\s*([^)]+)\)(?:\s*T\[(.*)\])?", dist_expr)
        if not match:
            continue

        dist_name, a_str, b_str, trunc = match.groups()
        try:
            a = float(a_str)
            b = float(b_str)
        except ValueError:
            continue

        parsed_priors[name] = {
            "dist": dist_name,
            "a": a,
            "b": b,
            "trunc": trunc,
        }

    all_param_names = set(parsed_priors.keys())
    _use_gdgt23ratio_detection = 0
    _use_no3_detection = 0
    if posterior_datasets:
        for ds in posterior_datasets:
            all_param_names.update(ds.data_vars)

            # collect "use_gdgt23ratio" and "use_no3" from all datasets
            if ds.attrs.get("use_gdgt23ratio", 0) == 1:
                _use_gdgt23ratio_detection += 1

            if ds.attrs.get("use_no3", 0) == 1:
                _use_no3_detection += 1

    grouped = {key: [] for key in include_groups}
    for name in all_param_names:
        for prefix in include_groups:
            if name.startswith(prefix + "_"):
                grouped[prefix].append(name)
    param_groups = [g for g in include_groups if grouped[g]]

    ### pop param_groups
    if ('beta_G23' in param_groups) and (_use_gdgt23ratio_detection == 0):
        param_groups.pop(param_groups.index('beta_G23'))

    if ('beta_NO3' in param_groups) and (_use_no3_detection == 0):
        param_groups.pop(param_groups.index('beta_NO3'))

    # ncols/nrows computed AFTER popping so the grid is sized correctly.
    # Max layout is 2 rows × 3 cols (6 params: t0, k, b, v, beta_G23, beta_NO3).
    ncols = min(3, len(param_groups))
    nrows = int(np.ceil(len(param_groups) / ncols)) if ncols > 0 else 1

    # squeeze=False guarantees axes is always 2-D (nrows × ncols) so indexing
    # axes[row, col] is safe regardless of grid size.
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                             figsize=(set_fig_width_factor * ncols, set_fig_height_factor * nrows),
                             squeeze=False, clear=True,
                             sharex=False, sharey=False)

    # Keyed by idx_ds → (line_handle, label); populated during plotting so the
    # figure-level legend always has exactly one entry per dataset in list order.
    _fig_legend_entries: Dict[int, tuple] = {}
    _prior_handle = None

    for idx, base in enumerate(param_groups):
        row_idx, col_idx = divmod(idx, ncols)
        ax = axes[row_idx, col_idx]
        ax.clear()


        param_names = sorted(grouped[base])   # sort for deterministic order
        if suffix_include:
            param_names = [p for p in param_names if any(p.endswith(suf) for suf in suffix_include)]
        all_samples = []
        x_min, x_max = None, None

        prior_key = next((k for k in parsed_priors if k.startswith(base)), None)
        x = None  # Initialize x variable
        trunc_bounds = None  # Store truncation bounds for later

        if prior_key:
            prior_info = parsed_priors[prior_key]
            dist, a, b, trunc = prior_info["dist"], prior_info["a"], prior_info["b"], prior_info["trunc"]

            if dist == "normal":
                std_range = 4 if trunc is None else 3
                x_min = a - std_range * b
                x_max = a + std_range * b
                # Parse truncation bounds but don't apply them to x yet
                if trunc:
                    trunc_bounds = [float(v) if v else None for v in re.split(r",\s*", trunc)]
                x = np.linspace(x_min, x_max, 5000)
                y = stats.norm.pdf(x, a, b)
                if trunc_bounds:
                    if trunc_bounds[0] is not None:
                        y[x < trunc_bounds[0]] = 0
                    if len(trunc_bounds) > 1 and trunc_bounds[1] is not None:
                        y[x > trunc_bounds[1]] = 0

            elif dist == "beta":
                x_min, x_max = 0.0, 1.0
                x = np.linspace(x_min, x_max, 5000)
                y = stats.beta.pdf(x, a, b)

            elif dist == "cauchy":
                x_min = a - 10 * b
                x_max = a + 10 * b
                x = np.linspace(x_min, x_max, 5000)
                y = stats.cauchy.pdf(x, a, b)

            elif dist == "lognormal":
                x_min = max(1e-6, np.exp(a - 4 * b))
                x_max = np.exp(a + 4 * b)
                x = np.linspace(x_min, x_max, 5000)
                y = stats.lognorm.pdf(x, s=b, scale=np.exp(a))

            else:
                continue

            (prior_line,) = ax.plot(x, y, color='black', lw=set_linewidth, label="Prior")
            if _prior_handle is None:
                _prior_handle = prior_line
            if show_prior_expression:
                expr = _format_prior_expr(dist, a, b, trunc)
                ax.text(0.98, 0.98, expr, transform=ax.transAxes,
                        fontsize=9, va='top', ha='right', color='black')

        # Collect all samples first to determine x range if no prior available.
        # If param_source_map specifies a source dataset for this group, only
        # pull samples from that dataset; otherwise search all datasets.
        _source_idx = param_source_map.get(base) if param_source_map else None
        for name in param_names:
            if posterior_datasets:
                for idx_ds, ds in enumerate(posterior_datasets):
                    if _source_idx is not None and idx_ds != _source_idx:
                        continue
                    if name not in ds.data_vars:
                        continue
                    samples = ds[name].values.flatten()
                    if posterior_labels_list is not None:
                        stan_model_labels = posterior_labels_list[idx_ds]
                    else:
                        # Fallback to dataset filename if labels not provided
                        stan_model_labels = ds.attrs.get('filename', 'Unknown Model')
                    use_gdgt23ratio_check = ds.attrs.get('use_gdgt23ratio', 0)
                    use_no3_check = ds.attrs.get('use_no3', 0)
                    all_samples.append((samples, idx_ds, name, stan_model_labels,
                                        use_gdgt23ratio_check, use_no3_check))

        # Sort by (idx_ds, param_name) so all lines for dataset 0 are plotted
        # before dataset 1, giving a consistent order in every subplot.
        all_samples.sort(key=lambda t: (t[1], t[2]))

        # ── NEW: expand x to cover posterior tails if they exceed prior bounds ──
        if all_samples and x is not None:
            all_param_samples = np.concatenate([s for s, _, _, _, _, _ in all_samples])
            p_lo = np.percentile(all_param_samples, 1)
            p_hi = np.percentile(all_param_samples, 99)
            # Add 5% padding on each side to ensure KDE tails are visible
            data_range = p_hi - p_lo
            if data_range > 0:
                pad = 0.05 * data_range
            else:
                pad = max(abs(p_lo), abs(p_hi)) * 0.1 if (p_lo != p_hi) else 0.1
            p_lo_padded = p_lo - pad
            p_hi_padded = p_hi + pad

            needs_expansion = (p_lo_padded < x_min) or (p_hi_padded > x_max)
            if needs_expansion:
                x_min = min(x_min, p_lo_padded)
                x_max = max(x_max, p_hi_padded)
                x = np.linspace(x_min, x_max, 5000)

        # If x was not defined by prior, create it from posterior data range
        if x is None and all_samples:
            # Get the combined range of all samples for this parameter group
            all_param_samples = np.concatenate([s for s, _, _, _, _, _ in all_samples])
            data_min, data_max = all_param_samples.min(), all_param_samples.max()

            # Robust padding calculation that works for any value range
            data_range = data_max - data_min
            if data_range > 0:
                # Use 15% padding relative to data range
                padding = 0.15 * data_range
            else:
                # If all values are identical, use absolute padding based on magnitude
                abs_magnitude = abs(data_min) if data_min != 0 else 1.0
                padding = 0.1 * abs_magnitude

            x_min = data_min - padding
            x_max = data_max + padding
            x = np.linspace(x_min, x_max, 5000)
        elif x is None:
            # Fallback: create a default range if no samples either
            x = np.linspace(-1, 1, 5000)

        # Determine number of distinct models (used to color by model)
        if posterior_datasets:
            num_models = len(posterior_datasets)
        else:
            num_models = 0

        # Validate and assign plotting colors
        if color_list is not None:
            if len(color_list) != num_models:
                raise ValueError(f"color_list must have exactly {num_models} colors to match the number of posterior datasets.")
            default_colors = color_list
        else:
            # Use default tab10 colors; repeat if not enough
            default_colors = plt.cm.tab10.colors
            if num_models > len(default_colors):
                from itertools import cycle, islice
                default_colors = list(islice(cycle(default_colors), num_models))

        linestyles = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]

        unique_param_names = sorted(set(pname for _, _, pname, _, _, _ in all_samples))
        _n_annotated = [0]  # counts lines actually drawn; drives annotation y-position

        for iiii, (samples, idx_ds, param_label, stan_model_label, use_gdgt23ratio_check, use_no3_check) in enumerate(all_samples):
            color = default_colors[idx_ds % len(default_colors)]

            if use_linestyle_by_param:
                ls_idx = unique_param_names.index(param_label)
                linestyle = linestyles[ls_idx % len(linestyles)]
            else:
                linestyle = '-'

            kde = stats.gaussian_kde(samples, bw_method=kde_bw)
            kde_y = kde(x)

            def _plot_line():
                (line,) = ax.plot(x, kde_y, color=color, lw=set_linewidth, linestyle=linestyle, label=param_label)
                if idx_ds not in _fig_legend_entries:
                    _fig_legend_entries[idx_ds] = (line, stan_model_label)
                if show_histogram:
                    ax.hist(samples, bins=100, density=True, alpha=0.2, color=color)
                if show_annotation:
                    med = np.median(samples)
                    if annotation_style == "ci95":
                        text = _format_ci(med, np.percentile(samples, 5),
                                               np.percentile(samples, 95))
                    elif annotation_style == "ci68":
                        text = _format_ci(med, np.percentile(samples, 16),
                                               np.percentile(samples, 84))
                    else:
                        text = _format_stat(med, np.std(samples, ddof=1))
                    ypos = 0.98 - _n_annotated[0] * 0.065
                    ax.text(0.02, ypos, text, transform=ax.transAxes,
                            fontsize=8, va='top', ha='left', color=color)
                    _n_annotated[0] += 1

            if param_label.startswith("beta_G23"):
                if use_gdgt23ratio_check == 1:
                    _plot_line()
            elif param_label.startswith("beta_NO3"):
                if use_no3_check == 1:
                    _plot_line()
            else:
                _plot_line()


        if all_samples:
            combined = np.concatenate([s for s, _, _, _, _, _ in all_samples])

            # Handle dataset-specific zooming (priority over suffix-based zooming)
            if focus_on_posterior and zoomin_dataset_idx is not None:
                zoom_min, zoom_max = compute_dataset_specific_range(all_samples, zoomin_dataset_idx)
                if zoom_min is not None and zoom_max is not None:
                    ax.set_xlim([zoom_min, zoom_max])
                else:
                    # Fallback to standard range if dataset-specific fails
                    zoom_min, zoom_max = compute_sample_range(combined)
                    if zoom_min is not None:
                        ax.set_xlim([zoom_min, zoom_max])
            else:
                # Handle zoomin_suffix as either string or list (legacy behavior)
                should_zoom = False
                if focus_on_posterior and zoomin_suffix:
                    if isinstance(zoomin_suffix, str):
                        should_zoom = zoomin_suffix in base
                    else:
                        should_zoom = any(suffix in base for suffix in zoomin_suffix)

                if should_zoom:
                    # For zoomin_suffix matches, use suffix-specific P5-P95 range for aggressive zooming
                    if isinstance(zoomin_suffix, str):
                        zoom_min, zoom_max = compute_suffix_specific_range(all_samples, zoomin_suffix)
                    else:
                        # For list of suffixes, try each one
                        zoom_min, zoom_max = None, None
                        for suffix in zoomin_suffix:
                            if suffix in base:
                                zoom_min, zoom_max = compute_suffix_specific_range(all_samples, suffix)
                                break

                    if zoom_min is not None and zoom_max is not None:
                        ax.set_xlim([zoom_min, zoom_max])
                    else:
                        # Fallback to standard range if suffix-specific fails
                        zoom_min, zoom_max = compute_sample_range(combined)
                        if zoom_min is not None:
                            ax.set_xlim([zoom_min, zoom_max])
                elif focus_on_posterior:
                    # For other cases, use the standard percentile-based range
                    zoom_min, zoom_max = compute_sample_range(combined)
                    if zoom_min is not None:
                        if x_min is not None and x_max is not None:
                            ax.set_xlim([max(x_min, zoom_min), min(x_max, zoom_max)])
                        else:
                            ax.set_xlim([zoom_min, zoom_max])
                elif x_min is not None and x_max is not None:
                    ax.set_xlim([x_min, x_max])
        else:
            if x_min is not None and x_max is not None:
                ax.set_xlim([x_min, x_max])


        ### Modified labels in ax
        ax_legends_labels_dict = {
            "t0_crtp": r"T$_0$",
            "k_crtp": "k",
            "b_crtp": "b",
            "v_crtp": r"$\nu$",
            "a_crtp": "a",
            "beta_G23_crtp": r"$\beta_{G_{2/3}}$",
            "beta_NO3_crtp": r"$\beta_{NO_3}$",

            "t0_culmeso": r"T$_{0, culmeso}$",
            "k_culmeso": r"k$_{culmeso}$",
            "b_culmeso": r"b$_{culmeso}$",
            "v_culmeso": r"$\nu_{culmeso}$",
            "a_culmeso": r"a$_{culmeso}$",
            "beta_G23_culmeso": r"$\beta_{G_{2/3},culmeso}$",
            "beta_NO3_culmeso": r"$\beta_{NO_3,culmeso}$",
        }

        if all_samples:
            handles, labels_in_ax = ax.get_legend_handles_labels()
            revised_labels_in_ax = []
            for lbl in labels_in_ax:
                revised_lbl = lbl
                for key, val in ax_legends_labels_dict.items():
                    if key in lbl:
                        ### strip with "_" before replacing both prefix and suffix
                        revised_lbl = val + revised_lbl.replace(key, "")
                        break
                revised_labels_in_ax.append(revised_lbl)
            if handles and show_subplot_legend:
                ax.legend(handles, revised_labels_in_ax, loc='upper right', fontsize=8, ncol=1, frameon=False)

        revised_base_dict = {
            "t0": r"T$_0$",
            "k": "k",
            "b": "b",
            "v": r"$\nu$",
            "a": "a",
            "beta_G23": r"$\beta_{G_{2/3}}$",
            "beta_NO3": r"$\beta_{NO_3}$",
        }
        revised_base = revised_base_dict.get(base, base)

        ax.set_xlabel(f"{revised_base}")
        ax.grid(True)

    if posterior_datasets:
        # Build figure legend: Prior first, then one entry per dataset in
        # posterior_datasets order (guaranteed by idx_ds key).
        fig_handles, fig_labels = [], []
        if _prior_handle is not None:
            fig_handles.append(_prior_handle)
            fig_labels.append("Prior")
        for i in range(len(posterior_datasets)):
            if i in _fig_legend_entries:
                h_ds, lbl_ds = _fig_legend_entries[i]
                fig_handles.append(h_ds)
                fig_labels.append(lbl_ds)

        legend_ncol = min(len(fig_labels), set_leg_max_ncol)
        legend_nrow = int(np.ceil(len(fig_labels) / legend_ncol))

        top_margin = 0.95 if show_suptitle else 1.0

        # 1. Let tight_layout organize the subplots in the original figure size
        fig.tight_layout(rect=[0, 0, 1, top_margin])

        # 2. Add extra physical height to the figure specifically for the legend
        #    (~0.35 inches per row + 0.25 inches of buffer)
        legend_extra_inches = 0.25 + 0.25 * legend_nrow
        w, h = fig.get_size_inches()
        new_h = h + legend_extra_inches
        fig.set_size_inches(w, new_h)

        # 3. Calculate what fraction of the NEW height belongs to the legend space
        bottom_frac = legend_extra_inches / new_h

        # Push subplots up so they sit exactly above the new legend space
        fig.subplots_adjust(bottom=bottom_frac + 0.1)

        # 4. Anchor the legend to the top of the reserved space, drawing downward
        if show_figure_legend:
            fig.legend(
                handles=fig_handles,
                labels=fig_labels,
                loc='upper center',         
                # Anchor right below the subplots (inside the figure boundaries)
                bbox_to_anchor=(0.5, bottom_frac),  
                ncol=legend_ncol,
                fontsize=10,
                frameon=True,
                borderaxespad=0.0,
                handletextpad=0.4,
                labelspacing=0.3,
            )

    if show_suptitle:
        fig.suptitle("Prior and Posterior Distributions", fontsize=14)

    # Hide any unused subplots in the grid (e.g. last cell when param count is odd).
    for ax in axes.flat:
        if not ax.has_data():
            ax.set_visible(False)

    return fig, axes