"""Import 清远市 / 县区 部门清单 from Excel into `local_department`.

Source files live in `清远市及县区政府组成部门和乡镇/*.xlsx` at the repo root.
Each sheet has header row `DEPTID, DEPTNAME, SHORTNAME` (连山 has an extra
broken SHORTNAME column we ignore).

This is an OA snapshot import — matches the shape site_department.mapped
expects (a pre-existing `local_department.dept_id` to FK into). Upsert-by-
`dept_id`, idempotent: re-running just refreshes names.

File → (dept_level, admin_area_note) mapping:
  01清远市    → level 1 (市本级)
  02清城      → level 2 (清城区)
  03清新      → level 2 (清新区)
  04英德      → level 2 (县级市 · 英德)
  05佛冈      → level 2
  06连州      → level 2
  07连山      → level 2
  08连南      → level 2
  09阳山      → level 2

Run:
  uv run python scripts/import_local_departments.py
  uv run python scripts/import_local_departments.py --dry-run
"""
from __future__ import annotations

import argparse
import os
import sys
from datetime import datetime
from pathlib import Path

import openpyxl
from sqlalchemy import select
from sqlalchemy.orm import Session

# allow running from repo root
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))

from govcrawler.db import get_sessionmaker  # noqa: E402
from govcrawler.models import LocalDepartment  # noqa: E402

SRC_DIR = ROOT / "清远市及县区政府组成部门和乡镇"

# File → dept_level. 01清远市 is market-level (1); everything else is
# district/county-level (2). order_id preserves file ordering so lists
# come out 市→清城→清新→…→阳山.
FILE_MAP: list[tuple[str, int]] = [
    ("01清远市.xlsx", 1),
    ("02清城.xlsx",   2),
    ("03清新.xlsx",   2),
    ("04英德.xlsx",   2),
    ("05佛冈.xlsx",   2),
    ("06连州.xlsx",   2),
    ("07连山.xlsx",   2),
    ("08连南.xlsx",   2),
    ("09阳山.xlsx",   2),
]


def _iter_rows(xlsx_path: Path):
    """Yield (dept_id, dept_name, short_name) tuples, skipping the header row.

    Only reads the first 3 columns so 连山's stray 4th SHORTNAME column is
    ignored. Rows with a non-int DEPTID are treated as header/garbage and
    skipped.
    """
    wb = openpyxl.load_workbook(xlsx_path, data_only=True, read_only=True)
    ws = wb.worksheets[0]  # every file has a single sheet
    for row in ws.iter_rows(values_only=True):
        if not row:
            continue
        dept_id, dept_name, short_name = row[0], row[1], row[2] if len(row) > 2 else None
        if not isinstance(dept_id, int):
            continue  # header or blank
        if not dept_name:
            continue
        yield dept_id, str(dept_name).strip(), (str(short_name).strip() if short_name else None)


def _upsert(session: Session, *, dept_id: int, dept_name: str, short_name: str | None,
            dept_level: int, order_id: int) -> str:
    """Insert or update one row. Returns 'created' | 'updated' | 'unchanged'."""
    row = session.get(LocalDepartment, dept_id)
    payload = {
        "dept_name": dept_name,
        "short_name": short_name,
        "full_name": dept_name,   # no independent full_name in source — reuse
        "dept_level": dept_level,
        "order_id": order_id,
        "state": 1,
        "updated_at": datetime.utcnow(),
    }
    if row is None:
        session.add(LocalDepartment(dept_id=dept_id, **payload))
        return "created"
    changed = any(getattr(row, k) != v for k, v in payload.items() if k != "updated_at")
    for k, v in payload.items():
        setattr(row, k, v)
    return "updated" if changed else "unchanged"


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--dry-run", action="store_true", help="parse + summarise, no DB writes")
    args = ap.parse_args()

    if not SRC_DIR.is_dir():
        print(f"ERR: source dir not found: {SRC_DIR}", file=sys.stderr)
        return 2

    # Gather all rows first so we can de-dup (dept_id wins last-seen if collision)
    records: dict[int, dict] = {}
    per_file_counts: list[tuple[str, int]] = []
    order_seq = 0

    for fname, level in FILE_MAP:
        path = SRC_DIR / fname
        if not path.is_file():
            print(f"WARN: missing {path.name}, skipped")
            per_file_counts.append((fname, 0))
            continue
        n = 0
        for dept_id, dept_name, short_name in _iter_rows(path):
            order_seq += 1
            prev = records.get(dept_id)
            if prev and prev["dept_name"] != dept_name:
                print(
                    f"WARN: dept_id {dept_id} seen twice with different names: "
                    f"{prev['dept_name']!r} vs {dept_name!r} (from {fname}); "
                    f"keeping the latter"
                )
            records[dept_id] = {
                "dept_id": dept_id,
                "dept_name": dept_name,
                "short_name": short_name,
                "dept_level": level,
                "order_id": order_seq,
            }
            n += 1
        per_file_counts.append((fname, n))

    print("== parse summary ==")
    for fname, n in per_file_counts:
        print(f"  {fname:20s}  {n:4d} rows")
    print(f"  -> total distinct dept_id: {len(records)}")

    if args.dry_run:
        print("(dry-run, not writing)")
        return 0

    SM = get_sessionmaker()
    stats = {"created": 0, "updated": 0, "unchanged": 0}
    with SM() as session:
        for rec in records.values():
            kind = _upsert(session, **rec)
            stats[kind] += 1
        session.commit()

    print("== import result ==")
    for k, v in stats.items():
        print(f"  {k:10s} {v}")

    # Quick sanity readback
    with SM() as s:
        total = s.scalar(select(__import__("sqlalchemy").func.count(LocalDepartment.dept_id)))
        print(f"  local_department total rows now: {total}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
