#!/usr/bin/env python3
"""ka-promote — resolve fleet/<host>.yaml + emit cumulative.patch + manifest.lock.

First of the three writing verbs (ka-promote → ka-build → ka-install).
Read-mostly: only writes to ${KA_BUILD_DIR:-./build}/<host>/<baseline_ref>/.

Usage:
    ka-promote <host>
    ka-promote <host> --output-dir <path>
    ka-promote <host> --validate-against <linux-checkout>
    ka-promote --list-hosts
    ka-promote --version

Exit codes:
    0  success
    2  missing input (manifest, patch file, series-dir)
    3  --validate-against failed (ref mismatch or apply-check failure)
    4  manifest parse / schema error

Language note: pure python3 (not bash like ka-status). The data shape
here — YAML in, YAML out, dict construction, per-file hashing, glob
resolution — fits python naturally; bash + python -c heredocs would be
quoting hell for no readability gain. See issue #22 comment 1132.
"""

import argparse
import glob
import hashlib
import os
import re
import subprocess
import sys
from datetime import datetime, timezone

import yaml

VERSION = 1
SCHEMA_VERSION = 1
COVER_LETTER = "0000-cover-letter.patch"

# git format-patch trailer: "-- \n<MAJOR>.<MINOR>(.<PATCH>)?\n" at EOF,
# possibly with trailing blank line(s). Strip from each source patch so
# that the cumulative is always well-formed regardless of include order.
# See issue #31.
_TRAILER_RE = re.compile(rb'\n-- \n\d+\.\d+(?:\.\d+)?\n+\Z')

# Canonical separator emitted between concatenated patches in the
# cumulative. Trailing blank line keeps patch(1) happy when the next
# patch starts with "From <sha>".
_CANONICAL_TRAILER = b'-- \n2.54.0\n\n'


def die(msg, code=1):
    print(f"ka-promote: error: {msg}", file=sys.stderr)
    sys.exit(code)


def find_repo_root():
    here = os.path.dirname(os.path.abspath(__file__))
    root = os.path.dirname(here)
    if not os.path.isdir(os.path.join(root, "fleet")):
        die(f"fleet/ not found relative to {here}", 4)
    return root


def list_hosts(fleet_dir):
    for path in sorted(glob.glob(os.path.join(fleet_dir, "*.yaml"))):
        print(os.path.basename(path)[:-5])


def load_manifest(path):
    try:
        raw = open(path, "rb").read()
    except FileNotFoundError:
        die(f"manifest not found: {path}", 2)
    sha = hashlib.sha256(raw).hexdigest()
    try:
        m = yaml.safe_load(raw)
    except yaml.YAMLError as e:
        die(f"manifest parse error: {e}", 4)
    if not isinstance(m, dict):
        die(f"manifest root must be a mapping: {path}", 4)
    for key in ("host", "baseline", "includes"):
        if key not in m:
            die(f"manifest missing required key '{key}': {path}", 4)
    if not isinstance(m["includes"], list) or not m["includes"]:
        die(f"manifest.includes must be a non-empty list: {path}", 4)
    return m, sha


def resolve_includes(includes, patches_root):
    """Walk manifest.includes, expand series-dirs, dedupe-check, hash."""
    seen = set()
    resolved = []
    order = 0
    for entry in includes:
        if not isinstance(entry, str):
            die(f"includes entry must be a string, got {type(entry).__name__}: {entry!r}", 4)
        if entry in seen:
            die(f"duplicate include: {entry}", 4)
        seen.add(entry)
        src_path = os.path.join(patches_root, entry)
        if entry.endswith(".patch"):
            if not os.path.isfile(src_path):
                die(f"missing patch: {src_path}", 2)
            order += 1
            resolved.append({
                "apply_order": order,
                "include": entry,
                "src": src_path,
                "from_series": False,
            })
        elif entry.endswith("/"):
            dir_path = src_path.rstrip("/")
            if not os.path.isdir(dir_path):
                die(f"missing series-dir: {dir_path}", 2)
            files = sorted(glob.glob(os.path.join(dir_path, "*.patch")))
            files = [f for f in files if os.path.basename(f) != COVER_LETTER]
            if not files:
                die(f"series-dir has no applied patches (only cover-letter or empty): {dir_path}", 2)
            for f in files:
                order += 1
                resolved.append({
                    "apply_order": order,
                    "include": entry + os.path.basename(f),
                    "src": f,
                    "from_series": True,
                })
        else:
            die(f"include must end in '.patch' or '/': {entry}", 4)
    for r in resolved:
        with open(r["src"], "rb") as f:
            data = f.read()
        r["sha256"] = hashlib.sha256(data).hexdigest()
        r["size"] = len(data)
    return resolved


def strip_trailer(data):
    """Strip any trailing git format-patch sentinel from a patch.

    Accepts patches in either canonical shape:
      - WITH trailer: "...\n-- \n2.54.0\n\n"
      - WITHOUT trailer: "...\n" (already stripped)

    Returns data ending in a single newline so the caller can either
    append a canonical trailer (mid-cumulative) or leave it bare (last).
    """
    stripped = _TRAILER_RE.sub(b'\n', data)
    if not stripped.endswith(b'\n'):
        stripped += b'\n'
    return stripped


def write_cumulative(resolved, out_path):
    with open(out_path, "wb") as out:
        n = len(resolved)
        for i, r in enumerate(resolved):
            with open(r["src"], "rb") as src:
                data = src.read()
            data = strip_trailer(data)
            out.write(data)
            # Mid-cumulative patches need a separator so patch(1) knows
            # where they end and the next "From <sha>" begins. Last
            # patch stays bare — a trailing orphan sentinel reads as
            # the start of a malformed new patch at EOF (issue #31).
            if i != n - 1:
                out.write(_CANONICAL_TRAILER)
    with open(out_path, "rb") as f:
        b2 = hashlib.blake2b(f.read()).hexdigest()
    size = os.path.getsize(out_path)
    return size, b2


def write_lock(lock_path, *, host, manifest_rel, manifest_sha, baseline,
               resolved, cumulative_size, cumulative_b2sum):
    epoch = os.environ.get("SOURCE_DATE_EPOCH")
    if epoch:
        generated_at = datetime.fromtimestamp(int(epoch), tz=timezone.utc).isoformat()
    else:
        generated_at = datetime.now(tz=timezone.utc).isoformat()
    lock = {
        "ka_promote_version": VERSION,
        "schema_version": SCHEMA_VERSION,
        "generated_at": generated_at,
        "host": host,
        "manifest": {"path": manifest_rel, "sha256": manifest_sha},
        "baseline": baseline,
        "resolved_patches": [
            {
                "apply_order": r["apply_order"],
                "include": r["include"],
                "sha256": r["sha256"],
                "size": r["size"],
                "from_series": r["from_series"],
            }
            for r in resolved
        ],
        "cumulative": {
            "path": "cumulative.patch",
            "size": cumulative_size,
            "b2sum": cumulative_b2sum,
        },
    }
    with open(lock_path, "w") as f:
        yaml.dump(lock, f, sort_keys=True, default_flow_style=False)


def validate_against(checkout, baseline_ref, cumulative_path):
    # `.git` is a directory in a plain checkout, a file (gitdir pointer)
    # in a worktree. `os.path.exists` covers both.
    if not os.path.exists(os.path.join(checkout, ".git")):
        die(f"--validate-against: not a git checkout: {checkout}", 3)
    def git(*args):
        return subprocess.run(
            ["git", *args], cwd=checkout, capture_output=True, text=True
        )
    r = git("rev-parse", f"{baseline_ref}^{{tree}}")
    if r.returncode != 0:
        die(f"baseline ref '{baseline_ref}' not found in checkout {checkout}", 3)
    baseline_tree = r.stdout.strip()
    head_tree = git("rev-parse", "HEAD^{tree}").stdout.strip()
    if head_tree != baseline_tree:
        die(f"checkout HEAD tree {head_tree} != baseline.ref {baseline_ref} tree {baseline_tree}. "
            "Refusing apply-check on diverged tree.", 3)
    # Working tree must match HEAD too — `git apply --check` runs against
    # the working tree, not HEAD, so a dirty tree gives false negatives.
    r = git("status", "--porcelain")
    if r.stdout.strip():
        die(f"checkout {checkout} has uncommitted changes. "
            "`git reset --hard {0} && git clean -fdx` first.".format(baseline_ref), 3)
    r = git("apply", "--check", cumulative_path)
    if r.returncode != 0:
        die(f"git apply --check failed:\n{r.stderr}", 3)


def main():
    p = argparse.ArgumentParser(prog="ka-promote", add_help=True)
    p.add_argument("host", nargs="?", help="fleet host name (omit with --list-hosts/--version)")
    p.add_argument("--output-dir", help="override ${KA_BUILD_DIR:-<repo>/build}")
    p.add_argument("--validate-against", metavar="CHECKOUT",
                   help="run git apply --check against a clean baseline.ref checkout")
    p.add_argument("--list-hosts", action="store_true", help="list available fleet/<host>.yaml manifests")
    p.add_argument("--version", action="store_true", help="print ka-promote schema version + exit")
    args = p.parse_args()

    repo_root = find_repo_root()
    fleet_dir = os.path.join(repo_root, "fleet")
    patches_root = os.path.join(repo_root, "patches")

    if args.version:
        print(f"ka-promote version {VERSION} (schema {SCHEMA_VERSION})")
        return 0
    if args.list_hosts:
        list_hosts(fleet_dir)
        return 0
    if not args.host:
        p.error("host is required (or use --list-hosts / --version)")

    manifest_path = os.path.join(fleet_dir, f"{args.host}.yaml")
    manifest, manifest_sha = load_manifest(manifest_path)

    if manifest.get("host") != args.host:
        die(f"manifest.host {manifest.get('host')!r} does not match filename {args.host!r}", 4)

    baseline = manifest["baseline"]
    if "ref" not in baseline:
        die("manifest.baseline.ref is required", 4)
    baseline_ref = baseline["ref"]

    resolved = resolve_includes(manifest["includes"], patches_root)

    out_root = args.output_dir or os.environ.get("KA_BUILD_DIR") or os.path.join(repo_root, "build")
    out_dir = os.path.join(out_root, args.host, baseline_ref)
    os.makedirs(out_dir, exist_ok=True)
    cumulative_path = os.path.join(out_dir, "cumulative.patch")

    size, b2sum = write_cumulative(resolved, cumulative_path)
    write_lock(
        os.path.join(out_dir, "manifest.lock"),
        host=args.host,
        manifest_rel=os.path.relpath(manifest_path, repo_root),
        manifest_sha=manifest_sha,
        baseline=baseline,
        resolved=resolved,
        cumulative_size=size,
        cumulative_b2sum=b2sum,
    )

    if args.validate_against:
        validate_against(args.validate_against, baseline_ref, cumulative_path)

    print(f"ka-promote: {args.host} -> {out_dir}")
    print(f"  cumulative: cumulative.patch ({size} bytes)")
    print(f"  b2sum:      {b2sum}")
    print(f"  patches:    {len(resolved)} resolved ({sum(1 for r in resolved if r['from_series'])} from series-dirs)")
    return 0


if __name__ == "__main__":
    sys.exit(main())
