#!/usr/bin/env python3
# Author:  Michael Heflin (original)
# Modified to include loading corrections (NTAL, NTOL, CDEC, SLEL, GRACE) 03/16/2026 MJS
import os
import argparse
import numpy as np
import time
import calendar

prolog="""
**PROGRAM**
    staBreakG.py
      
**PURPOSE**
    | Search for breaks using an Ftest
    | All possible break locations are tested
    | Ftest is applied to break which minimizes CHI^2/DOF
    | If Ftest is passed that break is kept
    | Search continues until no remaining break location passes Ftest
    | Optionally removes loading deformation (NTAL, NTOL, CDEC, SLEL, GRACE)
"""
epilog="""
**EXAMPLE**
    staBreakG.py -i ALGO.series -o ALGO.break --ftest 150
    staBreakG.py -i ALGO.series -o ALGO.break --ftest 150 
        --ntal ntal.cf/algo --ntol ntol.cf/algo 
        --cdec cdec.cf/algo --slel slel.cf/algo 
        --cmcf cmcf/algo --grace mascon.cf/algo
"""

def _getParser():
    parser = argparse.ArgumentParser(description=prolog, epilog=epilog,
                            formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('-i', action='store', dest='series', required=True,
                        help='input .series file')
    parser.add_argument('-o', action='store', dest='output', required=True,
                        help='output break file')
    parser.add_argument('--ftest', action='store', dest='ftest', default='150',
                        help='ftest threshold value (default: 150)')
    # Optional loading correction files
    parser.add_argument('--ntal', action='store', dest='ntal', default=None,
                        help='NTAL loading displacement file')
    parser.add_argument('--ntol', action='store', dest='ntol', default=None,
                        help='NTOL loading displacement file')
    parser.add_argument('--cdec', action='store', dest='cdec', default=None,
                        help='CDEC (N. America lakes) loading displacement file')
    parser.add_argument('--slel', action='store', dest='slel', default=None,
                        help='SLEL (barystatic ocean mass) loading displacement file')
    parser.add_argument('--cmcf', action='store', dest='cmcf', default=None,
                        help='CMCF (seasonal geocenter) loading displacement file')
    parser.add_argument('--grace', action='store', dest='grace', default=None,
                        help='GRACE loading file')
    return parser


def read_disp_file(filepath):
    """
    Read NGL-format loading displacement file.
    Format: site  YYMMMDD  decyr  east(m)  north(m)  up(m)
    First line is a header and is skipped.
    Date key is integer YYYYMMDD derived from YYMMMDD field.
    """
    dates, east, north, up = [], [], [], []
    with open(filepath, 'r') as f:
        f.readline()  # skip header line
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            cols = line.split()
            if len(cols) < 6:
                continue
            try:
                t = time.strptime(cols[1], "%y%b%d")
                date_int = int(time.strftime("%Y%m%d", t))
                dates.append(date_int)
                east.append(float(cols[3]))
                north.append(float(cols[4]))
                up.append(float(cols[5]))
            except (ValueError, IndexError):
                continue
    return np.array(dates), np.array(east), np.array(north), np.array(up)


def apply_loading_corrections(obs_dates, obs_n, obs_e, obs_u, loading_files):
    """
    Subtract loading displacement corrections from observed time series.
    Only epochs common to ALL loaded datasets are kept.

    Parameters
    ----------
    obs_dates : list of int  (YYYYMMDD)
    obs_n, obs_e, obs_u : list of float  (metres)
    loading_files : list of filepaths for NTAL/NTOL/CDEC/SLEL/CMCF/GRACE (None entries skipped)

    Returns
    -------
    Filtered and corrected arrays: dates, north, east, up
    """
    obs_dates = np.array(obs_dates, dtype=int)
    obs_n = np.array(obs_n, dtype=float)
    obs_e = np.array(obs_e, dtype=float)
    obs_u = np.array(obs_u, dtype=float)

    # Start with the full set of observed dates
    common_set = set(obs_dates.tolist())

    # Load all correction datasets and intersect dates
    loaded = []
    for fpath in loading_files:
        if fpath is not None:
            d, e, n, u = read_disp_file(fpath)
            loaded.append((d, e, n, u))
            common_set &= set(d.tolist())

    if not common_set:
        raise ValueError("No common dates found across all loading files and the series file. "
                         "Check that date formats match (YYYYMMDD integer).")

    # Filter obs to common dates
    obs_mask = np.array([d in common_set for d in obs_dates])
    obs_dates = obs_dates[obs_mask]
    obs_n     = obs_n[obs_mask]
    obs_e     = obs_e[obs_mask]
    obs_u     = obs_u[obs_mask]

    # Build date->index map for fast lookup
    obs_idx = {d: i for i, d in enumerate(obs_dates)}

    # Accumulate corrections
    corr_n = np.zeros(len(obs_dates))
    corr_e = np.zeros(len(obs_dates))
    corr_u = np.zeros(len(obs_dates))

    for (d, e, n, u) in loaded:
        mask = np.array([di in common_set for di in d])
        d_f, e_f, n_f, u_f = d[mask], e[mask], n[mask], u[mask]
        for k, dk in enumerate(d_f):
            if dk in obs_idx:
                idx = obs_idx[dk]
                corr_n[idx] += n_f[k]
                corr_e[idx] += e_f[k]
                corr_u[idx] += u_f[k]

    return obs_dates, obs_n - corr_n, obs_e - corr_e, obs_u - corr_u


def main(args=None):
    parser = _getParser()
    results = parser.parse_args(args)

    # ------------------------------------------------------------------ #
    # Read the .series time series (NGL tenv3 format, header line skipped)
    # ------------------------------------------------------------------ #
    T     = []   # decimal year
    D     = []   # seconds since J2000 (for break time output)
    N     = []   # north  (m)
    E     = []   # east   (m)
    V     = []   # up     (m)
    DATES = []   # integer YYYYMMDD for loading date matching

    inFile = open(results.series, 'r')
    inFile.readline()  # skip header
    line = inFile.readline()
    while line:
        test = line.split()
        if len(test) < 13:
            line = inFile.readline()
            continue
        T.append(test[2])
        N.append(test[10])
        E.append(test[8])
        V.append(test[12])
        D.append((float(test[3]) - 51544) * 86400)

        raw_date = test[1]
        t = time.strptime(raw_date, "%y%b%d")
        DATES.append(int(time.strftime("%Y%m%d", t)))

        line = inFile.readline()
    inFile.close()

    # ------------------------------------------------------------------ #
    # Apply loading corrections if any files were provided
    # ------------------------------------------------------------------ #
    loading_files = [results.ntal, results.ntol, results.cdec,
                     results.slel, results.cmcf, results.grace]
    any_loading = any(f is not None for f in loading_files)

    if any_loading:
#       print("Applying loading corrections...")
        DATES_arr, N_corr, E_corr, V_corr = apply_loading_corrections(
            DATES,
            [float(x) for x in N],
            [float(x) for x in E],
            [float(x) for x in V],
            loading_files
        )
        date_set = set(DATES_arr.tolist())
        keep = [i for i, d in enumerate(DATES) if d in date_set]
        T = [T[i] for i in keep]
        D = [D[i] for i in keep]
        N = N_corr.tolist()
        E = E_corr.tolist()
        V = V_corr.tolist()
#       print(f"  Epochs after date intersection: {len(T)}")
    else:
        N = [float(x) for x in N]
        E = [float(x) for x in E]
        V = [float(x) for x in V]

    # ------------------------------------------------------------------ #
    # Break search
    # ------------------------------------------------------------------ #
    I = []
    L = []
    F = []   # fmax value for each accepted break so we know how likely an identified offset is
    S = []

    ndat = len(T)
    I.append(ndat - 1)

    A = np.zeros((ndat, 2))
    A[0:ndat, 0] = np.ones(ndat)
    A[0:ndat, 1] = np.array(T, dtype=float) - 2015.0

    B = np.zeros((ndat, 3))
    B[0:ndat, 0] = np.array(N, dtype=float)
    B[0:ndat, 1] = np.array(E, dtype=float)
    B[0:ndat, 2] = np.array(V, dtype=float)

    a, b, c, d = np.linalg.lstsq(A, B, rcond=-1)
    c0 = b[0] + b[1] + b[2] / 4
    p0 = 2

    search = 1
    p1 = p0 + 1
    while search:
        A = np.hstack((A, np.zeros((ndat, 1))))
        fmax = 0
        for i in range(0, ndat):
            A[i, p1 - 1] = 1
            fit = 1
            for j in range(0, len(I)):
                if i == I[j]:
                    fit = 0
            if fit == 1:
                a, b, c, d = np.linalg.lstsq(A, B, rcond=-1)
                c1 = b[0] + b[1] + b[2] / 4
                f = ((c0 - c1) / c1) * ((ndat - p1) / (p1 - p0))
                if f > fmax:
                    imax = i
                    fmax = f
                    cmax = c1
                    Z = np.copy(A)
        if fmax > float(results.ftest):
#           print('F=', fmax)
            p0 = p1
            p1 = p1 + 1
            c0 = cmax
            I.append(imax)
            A = np.copy(Z)
            L.append(D[imax + 1])
            F.append(fmax)        # store fmax for this identified offset
        else:
            search = 0

    # Sort breaks by time, keeping fmax values with their associated date (IMPORTANT)
    if L:
        pairs = sorted(zip(L, F), key=lambda x: int(x[0]))
        L, F = zip(*pairs)
    else:
        L, F = [], []

    # Format break times
    for i in range(0, len(L)):
        list_str = time.strftime("%Y %m %d %H %M %S",
                    time.gmtime(calendar.timegm(
                        time.strptime("2000JAN01 12:00:00", "%Y%b%d %H:%M:%S")) + L[i]))
        S.append(str(list_str))

    # Write output
    if os.path.exists(results.output):
        os.remove(results.output)
    if len(S) > 0:
        outFile = open(results.output, 'w')
        site = results.series[0:results.series.find('.')]
        for i in range(0, len(S)):
            print("{:s} {:s} {:.2f}".format(site, S[i], F[i]), file=outFile)
        outFile.close()


if __name__ == '__main__':
    main()
