#!/home/joaquin.m.belart/software/miniconda3/bin/python3.13

import argparse
import os
import sys

import numpy as np
import xdem
from rasterio.warp import calculate_default_transform, transform_bounds

# ---------------------------------------------------------------------------
# Supported single-step methods
# ---------------------------------------------------------------------------
SINGLE_METHODS = {
    "NuthKaab": xdem.coreg.NuthKaab,
    "DhMinimize": xdem.coreg.DhMinimize,
    "VerticalShift": xdem.coreg.VerticalShift,
    "LZD": xdem.coreg.LZD,
    "ICP": xdem.coreg.ICP,
    "CPD": xdem.coreg.CPD,
    "Deramp": xdem.coreg.Deramp,
    "TerrainBias": xdem.coreg.TerrainBias,
    "DirectionalBias": xdem.coreg.DirectionalBias,
}

VALID_METHODS_STR = ", ".join(SINGLE_METHODS)


def build_coreg(method_str: str):
    """
    Parse a method string (single or '+'-separated pipeline) and return
    the corresponding Coreg object.

    Examples
    --------
    "NuthKaab"          -> NuthKaab()
    "ICP+NuthKaab"      -> ICP() + NuthKaab()
    "VerticalShift+ICP+NuthKaab" -> VerticalShift() + ICP() + NuthKaab()
    """
    parts = [p.strip() for p in method_str.split("+")]
    unknown = [p for p in parts if p not in SINGLE_METHODS]
    if unknown:
        sys.exit(
            f"Unknown method(s): {', '.join(unknown)}.\n"
            f"Valid methods: {VALID_METHODS_STR}\n"
            f"Pipelines are specified as 'Method1+Method2[+...]'."
        )
    steps = [SINGLE_METHODS[p]() for p in parts]
    result = steps[0]
    for step in steps[1:]:
        result = result + step
    return result


# ---------------------------------------------------------------------------
# Grid checks and preparation
# ---------------------------------------------------------------------------

def _bounds_in_crs(
    dem: xdem.DEM, target_crs
) -> tuple[float, float, float, float]:
    """Return (left, bottom, right, top) of *dem* expressed in *target_crs*."""
    b = dem.bounds
    if dem.crs == target_crs:
        return (b.left, b.bottom, b.right, b.top)
    return transform_bounds(dem.crs, target_crs, b.left, b.bottom, b.right, b.top)


def _native_res_in_crs(dem: xdem.DEM, target_crs) -> float:
    """Approximate native pixel size of *dem* expressed in *target_crs* units."""
    if dem.crs == target_crs:
        return max(abs(dem.res[0]), abs(dem.res[1]))
    b = dem.bounds
    tf, _, _ = calculate_default_transform(
        dem.crs, target_crs,
        dem.width, dem.height,
        left=b.left, bottom=b.bottom, right=b.right, top=b.top,
    )
    return max(abs(tf.a), abs(tf.e))


def check_and_prepare_grids(
    ref: xdem.DEM,
    slave: xdem.DEM,
) -> tuple[xdem.DEM, xdem.DEM]:
    """
    Print CRS/extent/resolution diagnostics, then return (ref_work, slv_work)
    on a common pixel grid defined by:

    * CRS: projected is preferred over geographic; if both share the same
      type the reference CRS is kept.
    * Extent: spatial overlap of the two DEMs only.
    * Resolution: coarser of the two native resolutions in the target CRS.

    Exits with an error if the two DEMs do not overlap.
    """
    print("Checking DEM compatibility …")

    def _info(dem: xdem.DEM, label: str) -> None:
        geo = dem.crs.is_geographic
        b = dem.bounds
        try:
            unit = dem.crs.axis_info[0].unit_name
        except Exception:
            unit = "deg" if geo else "m"
        epsg = dem.crs.to_epsg()
        crs_id = f"EPSG:{epsg}" if epsg else dem.crs.to_string()
        print(
            f"  {label}: CRS={crs_id} ({'geographic' if geo else 'projected'}), "
            f"res={abs(dem.res[0]):.4g}×{abs(dem.res[1]):.4g} {unit}, "
            f"bounds=({b.left:.6g} {b.bottom:.6g} {b.right:.6g} {b.top:.6g}), "
            f"shape={dem.shape[-2]}×{dem.shape[-1]}"
        )

    _info(ref,   "Reference")
    _info(slave, "Slave    ")

    # --- Choose target CRS ---
    if ref.crs.is_geographic and not slave.crs.is_geographic:
        target_crs = slave.crs
        crs_source = "slave (projected preferred over geographic)"
    else:
        target_crs = ref.crs
        crs_source = "reference"
    epsg_t = target_crs.to_epsg()
    crs_str = f"EPSG:{epsg_t}" if epsg_t else target_crs.to_string()
    print(f"  Target CRS  : {crs_str}  ({crs_source})")

    # --- Overlap in target CRS ---
    ref_b = _bounds_in_crs(ref, target_crs)
    slv_b = _bounds_in_crs(slave, target_crs)
    left   = max(ref_b[0], slv_b[0])
    bottom = max(ref_b[1], slv_b[1])
    right  = min(ref_b[2], slv_b[2])
    top    = min(ref_b[3], slv_b[3])
    if left >= right or bottom >= top:
        sys.exit(
            "Error: the reference and slave DEMs do not spatially overlap — "
            "nothing to co-register."
        )
    ref_area = (ref_b[2] - ref_b[0]) * (ref_b[3] - ref_b[1])
    slv_area = (slv_b[2] - slv_b[0]) * (slv_b[3] - slv_b[1])
    ov_area  = (right - left) * (top - bottom)
    print(
        f"  Overlap     : ({left:.6g} {bottom:.6g} {right:.6g} {top:.6g})  "
        f"[{100*ov_area/ref_area:.1f}% of ref, {100*ov_area/slv_area:.1f}% of slave]"
    )

    # --- Target resolution: coarser of the two ---
    ref_res_t = _native_res_in_crs(ref, target_crs)
    slv_res_t = _native_res_in_crs(slave, target_crs)
    target_res = max(ref_res_t, slv_res_t)
    coarser_src = "reference" if ref_res_t >= slv_res_t else "slave"
    print(
        f"  Target res  : {target_res:.4g}  "
        f"(ref≈{ref_res_t:.4g}, slave≈{slv_res_t:.4g}; coarser from {coarser_src})"
    )

    # --- Crop each DEM to the overlap in its native CRS before reprojecting ---
    def _crop(dem: xdem.DEM) -> xdem.DEM:
        if dem.crs != target_crs:
            nb = transform_bounds(target_crs, dem.crs, left, bottom, right, top)
        else:
            nb = (left, bottom, right, top)
        buf = 2 * max(abs(dem.res[0]), abs(dem.res[1]))
        return dem.crop((nb[0] - buf, nb[1] - buf, nb[2] + buf, nb[3] + buf))

    ref_cropped = _crop(ref)
    slv_cropped = _crop(slave)

    # --- Reproject to a shared pixel grid ---
    # The first DEM sets the exact grid; the second is warped to match it.
    if target_crs == ref.crs:
        ref_work = ref_cropped.reproject(crs=target_crs, res=target_res)
        slv_work = slv_cropped.reproject(ref=ref_work)
    else:
        # ref is geographic, slave is projected → target = slave CRS
        slv_work = slv_cropped.reproject(crs=target_crs, res=target_res)
        ref_work = ref_cropped.reproject(ref=slv_work)

    print(
        f"  Working grid: {ref_work.shape[-2]}×{ref_work.shape[-1]} px  "
        f"at {target_res:.4g} in {crs_str}"
    )
    return ref_work, slv_work


def build_inlier_mask(
    ref: xdem.DEM,
    slave_repr: xdem.DEM,
    stable_mask_path: str | None,
    invert_mask: bool,
    min_slope: float,
    max_slope: float,
    min_dh: float | None,
    max_dh: float,
) -> tuple[np.ndarray, dict]:
    """
    Build a boolean inlier mask (True = valid stable pixel) by combining:
      1. An optional external stable-terrain mask (raster or vector).
      2. Slope thresholds computed from the reference DEM.
      3. Elevation-difference thresholds.

    Parameters
    ----------
    ref : xdem.DEM
        Reference DEM (already loaded, on its native grid).
    slave_repr : xdem.DEM
        Slave DEM reprojected onto the reference grid.
    stable_mask_path : str or None
        Path to a raster stable-terrain mask (non-zero / True = stable).
    invert_mask : bool
        If True, invert the loaded stable mask (treat unstable areas as stable).
    min_slope, max_slope : float
        Slope thresholds in degrees for the reference DEM.
    min_dh : float or None
        Minimum absolute elevation difference to keep (None = no lower bound).
    max_dh : float
        Maximum absolute elevation difference to keep.

    Returns
    -------
    inlier_mask : np.ndarray (bool, 2-D)
        Combined inlier mask aligned to the reference DEM grid.
    component_masks : dict
        Individual boolean masks (True = pixel passes this criterion):
          "polygon"  – stable-terrain polygon mask (only if stable_mask_path given)
          "slope"    – slope within [min_slope, max_slope]
          "dh"       – |dh| within bounds
          "nodata"   – both DEMs have finite data
    """
    shape = ref.data.shape[-2:]
    mask = np.ones(shape, dtype=bool)  # start with all pixels included
    component_masks: dict[str, np.ndarray] = {}

    # --- 1. External stable-terrain mask ---
    if stable_mask_path is not None:
        import geoutils as gu

        _vector_exts = {".gpkg", ".shp", ".geojson", ".json", ".fgb", ".kml", ".gml"}
        ext = os.path.splitext(stable_mask_path)[1].lower()
        if ext in _vector_exts:
            stable_vec = gu.Vector(stable_mask_path)
            stable_arr = stable_vec.create_mask(ref, as_array=True).squeeze()
        else:
            stable_raster = gu.Raster(stable_mask_path)
            stable_raster = stable_raster.reproject(ref, resampling="nearest")
            stable_arr = stable_raster.data.filled(0).squeeze().astype(bool)
        if invert_mask:
            stable_arr = ~stable_arr
        component_masks["polygon"] = stable_arr
        mask &= stable_arr

    # --- 2. Slope filter on the reference DEM ---
    slope = ref.slope()
    slope_arr = slope.data.filled(np.nan).squeeze()
    slope_mask = (slope_arr >= min_slope) & (slope_arr <= max_slope)
    component_masks["slope"] = slope_mask
    mask &= slope_mask

    # --- 3. Elevation-difference filter ---
    ref_arr = ref.data.filled(np.nan).squeeze()
    slv_arr = slave_repr.data.filled(np.nan).squeeze()
    dh = np.abs(ref_arr - slv_arr)
    dh_mask = np.ones(shape, dtype=bool)
    if min_dh is not None:
        dh_mask &= dh >= min_dh
    dh_mask &= dh <= max_dh
    component_masks["dh"] = dh_mask
    mask &= dh_mask

    # Exclude pixels with NaN in either DEM (xdem will also handle this, but
    # doing it here makes the count printed to stdout accurate)
    nodata_mask = np.isfinite(ref_arr) & np.isfinite(slv_arr)
    component_masks["nodata"] = nodata_mask
    mask &= nodata_mask

    n_valid = mask.sum()
    n_total = mask.size
    print(
        f"  Inlier mask: {n_valid:,} / {n_total:,} pixels "
        f"({100 * n_valid / n_total:.1f} %) selected as stable terrain."
    )
    if n_valid == 0:
        sys.exit(
            "Error: inlier mask contains no valid pixels. "
            "Relax the filtering thresholds or check the stable mask."
        )
    return mask, component_masks


# ---------------------------------------------------------------------------
# Offset extraction helpers
# ---------------------------------------------------------------------------

def get_translation_offsets(
    coreg,
) -> tuple[float | None, float | None, float | None]:
    """
    Extract the total (dx, dy, dz) translation in metres from a fitted coreg
    object, iterating over pipeline steps if necessary.

    Any component returns None when the method does not report that axis.
    """
    steps = list(getattr(coreg, "pipeline", [coreg]))
    dx_total, dy_total, dz_total = 0.0, 0.0, 0.0
    has_xy, has_z = False, False

    for step in steps:
        meta: dict = {}
        for attr in ("meta", "_meta"):
            candidate = getattr(step, attr, None)
            if isinstance(candidate, dict):
                meta = candidate
                break

        # Horizontal translation keys used by NuthKaab / ICP
        for xk, yk in (("shift_x", "shift_y"), ("tx", "ty")):
            if xk in meta and yk in meta:
                dx_total += float(meta[xk])
                dy_total += float(meta[yk])
                has_xy = True
                break

        # Vertical shift keys used by NuthKaab / VerticalShift / ICP
        for zk in ("shift_z", "bias", "tz"):
            if zk in meta:
                dz_total += float(meta[zk])
                has_z = True
                break

    return (
        dx_total if has_xy else None,
        dy_total if has_xy else None,
        dz_total if has_z else None,
    )


def _offset_suffix(
    dx: float | None, dy: float | None, dz: float | None
) -> str:
    """Return a filename suffix like '_x+5.0m_y-2.0m_z+3.0m', or '' if no offsets."""
    parts = []
    if dx is not None:
        parts.append(f"x{dx:+.1f}m")
    if dy is not None:
        parts.append(f"y{dy:+.1f}m")
    if dz is not None:
        parts.append(f"z{dz:+.1f}m")
    return ("_" + "_".join(parts)) if parts else ""


# ---------------------------------------------------------------------------
# Elevation-difference raster output
# ---------------------------------------------------------------------------

def save_dh_rasters(
    ref: xdem.DEM,
    slave_coreg: xdem.DEM,
    inlier_mask: np.ndarray,
    out_stem: str,
) -> None:
    """
    Save two dh = (ref − slave_coreg) GeoTIFFs on the reference DEM grid:

    ``<out_stem>_dh_align.tif``
        Full-overlap elevation difference after co-registration.

    ``<out_stem>_dh_align_masked.tif``
        Same, restricted to inlier pixels (combined stable-terrain mask).
    """
    import geoutils as gu

    slave_coreg_repr = slave_coreg.reproject(ref)
    ref_arr = ref.data.filled(np.nan).squeeze()
    slv_arr = slave_coreg_repr.data.filled(np.nan).squeeze()
    dh_arr = (ref_arr - slv_arr).astype(np.float32)

    nodata_val = -9999.0
    valid = np.isfinite(dh_arr)

    def _save(data: np.ndarray, path: str) -> None:
        raster = gu.Raster.from_array(
            data=data,
            transform=ref.transform,
            crs=ref.crs,
            nodata=nodata_val,
        )
        raster.save(path)
        print(f"  Saved: {path}")

    # Full-overlap dh
    dh_full = np.where(valid, dh_arr, nodata_val).astype(np.float32)
    _save(dh_full, out_stem + "_dh_align.tif")

    # Inlier-masked dh
    dh_masked = np.where(inlier_mask & valid, dh_arr, nodata_val).astype(np.float32)
    _save(dh_masked, out_stem + "_dh_align_masked.tif")


# ---------------------------------------------------------------------------
# Plotting
# ---------------------------------------------------------------------------

def _mask_display(ax, data_mask, title, nodata_mask=None):
    """
    Render a binary mask on *ax*.

    Parameters
    ----------
    data_mask : 2-D bool array
        True = pixel passes this criterion (inlier / kept).
    nodata_mask : 2-D bool array or None
        True = pixel has valid data in both DEMs.  Pixels that are False
        here are drawn in a neutral grey regardless of data_mask.
    title : str
    """
    rgb = np.zeros((*data_mask.shape, 3), dtype=np.uint8)
    # Default: show as excluded (red-ish)
    rgb[..., 0] = 210
    rgb[..., 1] = 80
    rgb[..., 2] = 70
    # Kept pixels: green
    rgb[data_mask, 0] = 80
    rgb[data_mask, 1] = 170
    rgb[data_mask, 2] = 100
    # No-data pixels: neutral grey (overrides both)
    if nodata_mask is not None:
        grey_pixels = ~nodata_mask
        rgb[grey_pixels, 0] = 200
        rgb[grey_pixels, 1] = 200
        rgb[grey_pixels, 2] = 200

    ax.imshow(rgb, interpolation="nearest", aspect="equal")
    ax.set_title(title, fontsize=10)
    ax.axis("off")


def make_plots(
    ref: xdem.DEM,
    slave_repr: xdem.DEM,
    slave_coreg: xdem.DEM,
    inlier_mask: np.ndarray,
    component_masks: dict,
    output_stem: str,
) -> None:
    """
    Produce and save three diagnostic PNG files:

    ``<output_stem>_plot_overlap.png``
        DEM coverage map (ref-only / slave-only / overlap) and the signed
        elevation-difference (dh = ref − slave) over the overlapping area.

    ``<output_stem>_plot_masks.png``
        Individual masking components (polygon, slope, dh threshold) plus
        the combined inlier mask.

    ``<output_stem>_plot_stats.png``
        Histograms of dh on stable terrain before and after co-registration,
        annotated with median and NMAD.
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    import matplotlib.colors as mcolors
    from xdem.spatialstats import nmad as xdem_nmad

    ref_arr = ref.data.filled(np.nan).squeeze()
    slv_arr = slave_repr.data.filled(np.nan).squeeze()

    # Reproject co-registered slave to the reference grid for dh-after
    slave_coreg_repr = slave_coreg.reproject(ref)
    slv_coreg_arr = slave_coreg_repr.data.filled(np.nan).squeeze()

    nodata_mask = component_masks["nodata"]  # True where both DEMs have data

    dh_before = ref_arr - slv_arr          # signed
    dh_after  = ref_arr - slv_coreg_arr

    # Stats on inlier pixels only
    dh_b = dh_before[inlier_mask & np.isfinite(dh_before)]
    dh_a = dh_after[inlier_mask & np.isfinite(dh_after)]

    med_b  = float(np.median(dh_b))
    nmad_b = float(xdem_nmad(dh_b))
    med_a  = float(np.median(dh_a))
    nmad_a = float(xdem_nmad(dh_a))

    print(f"  Before coreg  –  median dh = {med_b:+.3f} m,  NMAD = {nmad_b:.3f} m")
    print(f"  After  coreg  –  median dh = {med_a:+.3f} m,  NMAD = {nmad_a:.3f} m")

    # ------------------------------------------------------------------ #
    # Figure 1 – Overlap                                                   #
    # ------------------------------------------------------------------ #
    ref_valid = np.isfinite(ref_arr)
    slv_valid = np.isfinite(slv_arr)

    # 4-category coverage map: 0=nodata, 1=ref only, 2=slave only, 3=overlap
    cov = np.zeros(ref_arr.shape, dtype=np.uint8)
    cov[ref_valid & ~slv_valid] = 1
    cov[~ref_valid & slv_valid] = 2
    cov[ref_valid &  slv_valid] = 3

    cmap_cov = mcolors.ListedColormap(
        ["#d0d0d0", "#4878CF", "#D65F5F", "#6ACC65"]
    )
    norm_cov = mcolors.BoundaryNorm([-0.5, 0.5, 1.5, 2.5, 3.5], 4)

    # dh colourmap clipped to ±3 NMAD for display
    clip = max(3 * nmad_b, 1.0)
    dh_disp = np.where(nodata_mask, dh_before, np.nan)

    fig1, axes1 = plt.subplots(1, 2, figsize=(13, 5))

    im0 = axes1[0].imshow(cov, cmap=cmap_cov, norm=norm_cov,
                           interpolation="nearest", aspect="equal")
    axes1[0].set_title("DEM coverage", fontsize=11)
    axes1[0].axis("off")
    legend_patches = [
        mpatches.Patch(color="#4878CF", label="Reference only"),
        mpatches.Patch(color="#D65F5F", label="Slave only"),
        mpatches.Patch(color="#6ACC65", label="Overlap"),
        mpatches.Patch(color="#d0d0d0", label="No data"),
    ]
    axes1[0].legend(handles=legend_patches, loc="lower right",
                    fontsize=8, framealpha=0.8)

    im1 = axes1[1].imshow(dh_disp, cmap="RdBu_r",
                           vmin=-clip, vmax=clip,
                           interpolation="nearest", aspect="equal")
    axes1[1].set_title(f"dh = ref − slave  (clipped ±{clip:.1f} m)", fontsize=11)
    axes1[1].axis("off")
    fig1.colorbar(im1, ax=axes1[1], label="dh (m)", shrink=0.8)

    fig1.suptitle("DEM overlap", fontsize=13, fontweight="bold")
    fig1.tight_layout()
    path1 = output_stem + "_plot_overlap.png"
    fig1.savefig(path1, dpi=150, bbox_inches="tight")
    plt.close(fig1)
    print(f"  Saved: {path1}")

    # ------------------------------------------------------------------ #
    # Figure 2 – Masks                                                     #
    # ------------------------------------------------------------------ #
    # Determine which panels to draw
    panels = []
    if "polygon" in component_masks:
        panels.append(("polygon", "Polygon / stable-terrain mask"))
    panels.append(("slope", "Slope filter"))
    panels.append(("dh",    "dh threshold filter"))
    panels.append(("combined", "Combined inlier mask"))

    ncols = 2
    nrows = int(np.ceil(len(panels) / ncols))
    fig2, axes2 = plt.subplots(nrows, ncols,
                                figsize=(6 * ncols, 5 * nrows),
                                squeeze=False)

    for idx, (key, title) in enumerate(panels):
        ax = axes2[idx // ncols][idx % ncols]
        arr = inlier_mask if key == "combined" else component_masks[key]
        _mask_display(ax, arr, title, nodata_mask=nodata_mask)

    # Legend (drawn once outside the subplots)
    kept_patch   = mpatches.Patch(color=(80/255, 170/255, 100/255), label="Kept (inlier)")
    excl_patch   = mpatches.Patch(color=(210/255, 80/255, 70/255),  label="Excluded")
    nodat_patch  = mpatches.Patch(color=(200/255, 200/255, 200/255), label="No data")
    fig2.legend(handles=[kept_patch, excl_patch, nodat_patch],
                loc="lower center", ncol=3, fontsize=10,
                bbox_to_anchor=(0.5, 0.0), framealpha=0.9)

    # Hide any unused axes
    for idx in range(len(panels), nrows * ncols):
        axes2[idx // ncols][idx % ncols].axis("off")

    fig2.suptitle("Masking components", fontsize=13, fontweight="bold")
    fig2.tight_layout(rect=[0, 0.05, 1, 1])
    path2 = output_stem + "_plot_masks.png"
    fig2.savefig(path2, dpi=150, bbox_inches="tight")
    plt.close(fig2)
    print(f"  Saved: {path2}")

    # ------------------------------------------------------------------ #
    # Figure 3 – Statistics (dh histograms + bar chart)                   #
    # ------------------------------------------------------------------ #
    fig3, axes3 = plt.subplots(1, 3, figsize=(15, 5))

    def _hist(ax, vals, med, nmad, label, color):
        clip_h = max(abs(med) + 4 * nmad, 1.0)
        bins = np.linspace(-clip_h, clip_h, 80)
        ax.hist(vals, bins=bins, color=color, alpha=0.75, edgecolor="none",
                density=True, label=label)
        ax.axvline(med, color="k", lw=1.8, ls="-",  label=f"Median = {med:+.3f} m")
        ax.axvspan(med - nmad, med + nmad, alpha=0.18, color="k",
                   label=f"±NMAD  = {nmad:.3f} m")
        ax.set_xlabel("dh (m)", fontsize=10)
        ax.set_ylabel("Density", fontsize=10)
        ax.legend(fontsize=8)
        ax.set_title(label, fontsize=11)

    _hist(axes3[0], dh_b, med_b, nmad_b,
          "Before co-registration", "#4878CF")
    _hist(axes3[1], dh_a, med_a, nmad_a,
          "After co-registration",  "#6ACC65")

    # Bar chart: median and NMAD before vs after
    ax_bar = axes3[2]
    x = np.array([0.0, 1.0])
    width = 0.32
    bars_med  = ax_bar.bar(x - width / 2,
                           [abs(med_b),  abs(med_a)],
                           width, label="|Median|", color=["#4878CF", "#6ACC65"],
                           alpha=0.85, edgecolor="k", linewidth=0.7)
    bars_nmad = ax_bar.bar(x + width / 2,
                           [nmad_b, nmad_a],
                           width, label="NMAD", color=["#4878CF", "#6ACC65"],
                           alpha=0.50, edgecolor="k", linewidth=0.7, hatch="//")

    for bar, val in zip(list(bars_med) + list(bars_nmad),
                        [abs(med_b), abs(med_a), nmad_b, nmad_a]):
        ax_bar.text(bar.get_x() + bar.get_width() / 2,
                    bar.get_height() + 0.005 * max(nmad_b, nmad_a, 0.01),
                    f"{val:.3f}", ha="center", va="bottom", fontsize=8)

    ax_bar.set_xticks(x)
    ax_bar.set_xticklabels(["Before", "After"], fontsize=11)
    ax_bar.set_ylabel("Elevation difference (m)", fontsize=10)
    ax_bar.set_title("Median & NMAD comparison", fontsize=11)

    # Custom legend for the bar chart
    solid_patch  = mpatches.Patch(facecolor="grey", edgecolor="k",
                                   alpha=0.85, label="|Median|")
    hatch_patch  = mpatches.Patch(facecolor="grey", edgecolor="k",
                                   alpha=0.50, hatch="//", label="NMAD")
    ax_bar.legend(handles=[solid_patch, hatch_patch], fontsize=9)

    fig3.suptitle(
        f"Co-registration statistics (on {inlier_mask.sum():,} stable-terrain pixels)",
        fontsize=12, fontweight="bold",
    )
    fig3.tight_layout()
    path3 = output_stem + "_plot_stats.png"
    fig3.savefig(path3, dpi=150, bbox_inches="tight")
    plt.close(fig3)
    print(f"  Saved: {path3}")


def main():
    parser = argparse.ArgumentParser(
        prog="xdemcoreg.py",
        description=(
            "Co-register a slave DEM to a reference DEM using xDEM.\n\n"
            "Supported single methods (use '+' to chain into a pipeline):\n"
            f"  {VALID_METHODS_STR}\n\n"
            "Pipeline example:  --method ICP+NuthKaab"
        ),
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    # --- Positional arguments ---
    parser.add_argument("reference", help="Path to the reference DEM (GeoTIFF).")
    parser.add_argument("slave", help="Path to the slave (to-be-aligned) DEM (GeoTIFF).")

    # --- Output ---
    parser.add_argument(
        "-o", "--output",
        default=None,
        help=(
            "Name of the output folder (created if necessary). "
            "All output files are written inside it, prefixed with the folder's "
            "basename.  Example: -o xdemcoreg_myr writes "
            "xdemcoreg_myr/xdemcoreg_myr_DEM_align_x+…_y+…_z+….tif, "
            "xdemcoreg_myr/xdemcoreg_myr_dh_align.tif, etc. "
            "Default: <slave_stem>_coreg/ next to the slave file."
        ),
    )

    # --- Coregistration method ---
    parser.add_argument(
        "--method",
        default="NuthKaab",
        metavar="METHOD[+METHOD...]",
        help=(
            "Coregistration method or '+'-chained pipeline. "
            f"Valid methods: {VALID_METHODS_STR}. "
            "Default: NuthKaab."
        ),
    )

    # --- Stable mask ---
    parser.add_argument(
        "--stable-mask",
        default=None,
        metavar="FILE",
        help=(
            "Path to a stable-terrain mask: raster (GeoTIFF, etc.) or vector "
            "(.gpkg, .shp, .geojson, .fgb, .kml, .gml). "
            "For rasters, non-zero pixels are treated as stable ground (inliers). "
            "For vectors, polygon interiors are treated as stable ground. "
            "If omitted, only slope and dh filters are applied."
        ),
    )
    parser.add_argument(
        "--invert-mask",
        action="store_true",
        help="Invert the stable mask (treat non-stable pixels as inliers).",
    )

    # --- Slope thresholds ---
    parser.add_argument(
        "--min-slope",
        type=float,
        default=0.1,
        metavar="DEG",
        help="Minimum slope threshold in degrees. Default: 0.1.",
    )
    parser.add_argument(
        "--max-slope",
        type=float,
        default=40.0,
        metavar="DEG",
        help="Maximum slope threshold in degrees. Default: 40.",
    )

    # --- Elevation-difference thresholds ---
    parser.add_argument(
        "--min-dh",
        type=float,
        default=None,
        metavar="M",
        help=(
            "Minimum absolute elevation difference (m) to include a pixel. "
            "Default: no lower bound."
        ),
    )
    parser.add_argument(
        "--max-dh",
        type=float,
        default=100.0,
        metavar="M",
        help="Maximum absolute elevation difference (m) to include a pixel. Default: 100.",
    )

    # --- Plots ---
    parser.add_argument(
        "--plots",
        action="store_true",
        help=(
            "Produce diagnostic plots saved inside the output folder:\n"
            "  <stem>_plot_overlap.png  – DEM coverage and dh map\n"
            "  <stem>_plot_masks.png    – masking components\n"
            "  <stem>_plot_stats.png    – median/NMAD before & after"
        ),
    )

    args = parser.parse_args()

    # --- Validate threshold logic ---
    if args.min_slope >= args.max_slope:
        parser.error(f"--min-slope ({args.min_slope}) must be less than --max-slope ({args.max_slope}).")
    if args.min_dh is not None and args.min_dh >= args.max_dh:
        parser.error(f"--min-dh ({args.min_dh}) must be less than --max-dh ({args.max_dh}).")

    # --- Output folder and file stem ---
    if args.output is None:
        slave_base = os.path.splitext(os.path.basename(args.slave))[0]
        out_folder = slave_base + "_coreg"
    else:
        out_folder = args.output.rstrip("/\\")

    out_name = os.path.basename(out_folder)   # used as the filename prefix
    os.makedirs(out_folder, exist_ok=True)
    out_stem = os.path.join(out_folder, out_name)

    # --- Load DEMs ---
    print(f"Loading reference DEM: {args.reference}")
    ref = xdem.DEM(args.reference)

    print(f"Loading slave DEM:     {args.slave}")
    slave = xdem.DEM(args.slave)

    # --- Check grids and prepare a common overlap grid ---
    ref_work, slv_work = check_and_prepare_grids(ref, slave)

    # --- Build inlier mask ---
    print("Building inlier mask …")
    inlier_mask, component_masks = build_inlier_mask(
        ref=ref_work,
        slave_repr=slv_work,
        stable_mask_path=args.stable_mask,
        invert_mask=args.invert_mask,
        min_slope=args.min_slope,
        max_slope=args.max_slope,
        min_dh=args.min_dh,
        max_dh=args.max_dh,
    )

    # --- Build coregistration object ---
    print(f"Building coregistration pipeline: {args.method}")
    coreg = build_coreg(args.method)

    # --- Fit ---
    print("Fitting coregistration …")
    coreg.fit(
        reference_elev=ref_work,
        to_be_aligned_elev=slv_work,
        inlier_mask=inlier_mask,
    )

    # --- Apply ---
    print("Applying coregistration to slave DEM …")
    slave_coreg = coreg.apply(slv_work)

    # --- Extract translation offsets for the filename ---
    dx, dy, dz = get_translation_offsets(coreg)
    off_sfx = _offset_suffix(dx, dy, dz)
    if off_sfx:
        labels = []
        if dx is not None:
            labels.append(f"dx={dx:+.1f} m")
        if dy is not None:
            labels.append(f"dy={dy:+.1f} m")
        if dz is not None:
            labels.append(f"dz={dz:+.1f} m")
        print(f"  Detected offsets: {', '.join(labels)}")

    # --- Save aligned DEM ---
    dem_path = out_stem + f"_DEM_align{off_sfx}.tif"
    print(f"Saving co-registered DEM to: {dem_path}")
    slave_coreg.save(dem_path)

    # --- Save elevation-difference rasters ---
    print("Saving elevation-difference rasters …")
    save_dh_rasters(ref_work, slave_coreg, inlier_mask, out_stem)

    # --- Plots (optional) ---
    if args.plots:
        print("Generating diagnostic plots …")
        make_plots(
            ref=ref_work,
            slave_repr=slv_work,
            slave_coreg=slave_coreg,
            inlier_mask=inlier_mask,
            component_masks=component_masks,
            output_stem=out_stem,
        )

    print(f"\nAll outputs written to: {out_folder}/")


if __name__ == "__main__":
    main()
