#!/usr/bin/env python3
import argparse
import json
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from docx import Document
from docx.shared import Inches

plt.rcParams["font.sans-serif"] = [
    "Noto Sans CJK SC",
    "WenQuanYi Micro Hei",
    "WenQuanYi Zen Hei",
    "DejaVu Sans",
]
plt.rcParams["axes.unicode_minus"] = False

TABLE_EXTENSIONS = (".xlsx", ".xlsm", ".xls", ".csv", ".tsv")


def resolve_input_path(value):
    if value and value.lower() not in ("auto", "manifest"):
        p = Path(value)
        if p.is_file():
            return str(p)
        if p.is_dir():
            return select_from_manifest(p / "manifest.json")
        raise SystemExit(f"input path not found: {p}")
    return select_from_manifest(Path("/inputs/manifest.json"))


def select_from_manifest(path):
    if not path.exists():
        raise SystemExit(f"input manifest not found: {path}")
    items = json.loads(path.read_text(encoding="utf-8-sig"))
    if not isinstance(items, list):
        raise SystemExit(f"system input manifest must be a list: {path}")
    candidates = [
        item for item in items
        if str(item.get("path", "")).lower().endswith(TABLE_EXTENSIONS)
    ]
    if not candidates:
        raise SystemExit("no xlsx/xls/csv/tsv input found in /inputs/manifest.json")
    return candidates[0]["path"]


def read_table(path):
    suffix = Path(path).suffix.lower()
    if suffix in [".xlsx", ".xlsm", ".xls"]:
        return pd.read_excel(path)
    sep = "\t" if suffix == ".tsv" else ","
    return pd.read_csv(path, encoding="utf-8-sig", sep=sep)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", help="Table path, /inputs directory, or 'auto'. Defaults to first table in /inputs/manifest.json.")
    ap.add_argument("--category", required=True)
    ap.add_argument("--value", required=True)
    ap.add_argument("--out-docx", required=True)
    ap.add_argument("--out-chart", required=True)
    ap.add_argument("--title", default="统计分析简报")
    args = ap.parse_args()

    df = read_table(resolve_input_path(args.input))
    data = df[[args.category, args.value]].dropna()
    data[args.value] = pd.to_numeric(data[args.value], errors="coerce")
    data = data.dropna(subset=[args.value])
    summary = data.groupby(args.category, as_index=False)[args.value].sum()
    summary = summary.sort_values(args.value, ascending=False)

    plt.figure(figsize=(9, 5), dpi=160)
    sns.barplot(data=summary.head(12), x=args.value, y=args.category, color="#4F81BD")
    plt.title(args.title)
    plt.xlabel(args.value)
    plt.ylabel(args.category)
    plt.tight_layout()
    plt.savefig(args.out_chart)

    doc = Document()
    doc.add_heading(args.title, 0)
    doc.add_paragraph(f"本次分析共纳入 {len(df)} 条记录，按“{args.category}”汇总“{args.value}”。")
    if not summary.empty:
        top = summary.iloc[0]
        total = summary[args.value].sum()
        share = top[args.value] / total if total else 0
        doc.add_paragraph(
            f"汇总结果显示，“{top[args.category]}”居首，数值为 {top[args.value]:,.2f}，"
            f"占汇总总量的 {share:.1%}。"
        )
    doc.add_heading("图表", level=1)
    doc.add_picture(args.out_chart, width=Inches(6.3))
    doc.add_heading("数据表", level=1)
    table = doc.add_table(rows=1, cols=2)
    table.rows[0].cells[0].text = args.category
    table.rows[0].cells[1].text = args.value
    for _, row in summary.head(20).iterrows():
        cells = table.add_row().cells
        cells[0].text = str(row[args.category])
        cells[1].text = f"{row[args.value]:,.2f}"
    doc.add_paragraph("注：本简报为自动生成初稿，正式报送前应核对数据口径、时间范围和来源。")
    doc.save(args.out_docx)


if __name__ == "__main__":
    main()
