from __future__ import annotations

import re

from lxml import html as lxml_html

BLOCK_TAGS = {
    "p", "div", "br", "li", "tr",
    "h1", "h2", "h3", "h4", "h5", "h6",
    "section", "article", "blockquote",
}
REMOVE_TAGS = {"script", "style", "noscript", "iframe"}


def _cell_text(cell_el) -> str:
    """Plain text of a <td>/<th> cell with whitespace collapsed.

    Pipe characters inside cell text would break the markdown row format,
    so we escape them. Newlines inside a cell collapse to a space — Markdown
    table rows have to be single-line."""
    txt = cell_el.text_content() or ""
    # Drop NBSP and full-width spaces that are common in TRS-CMS markup.
    txt = txt.replace(" ", " ").replace("　", " ")
    txt = re.sub(r"\s+", " ", txt).strip()
    return txt.replace("|", "\\|")


def _render_table_markdown(table_el) -> str:
    """Render a <table> as a Markdown-style text block:

        | col1 | col2 | col3 |
        | --- | --- | --- |
        | a | b | c |

    Empty tables collapse to "" so the placeholder doesn't add noise.
    Header row is detected as either the first <tr> with any <th>, or
    falls back to the first row with the most cells (which matches the
    `tableTab` header pattern used by gd_wjk's row-headers)."""
    rows = table_el.findall(".//tr")
    if not rows:
        return ""

    matrix: list[list[str]] = []
    for tr in rows:
        cells = tr.findall("./td") + tr.findall("./th")
        if not cells:
            continue
        matrix.append([_cell_text(c) for c in cells])

    if not matrix:
        return ""

    width = max(len(row) for row in matrix)
    # Pad short rows so the markdown grid stays rectangular.
    matrix = [row + [""] * (width - len(row)) for row in matrix]

    out_lines = []
    for i, row in enumerate(matrix):
        out_lines.append("| " + " | ".join(row) + " |")
        if i == 0:
            out_lines.append("| " + " | ".join(["---"] * width) + " |")
    return "\n".join(out_lines)


def html_to_text(html_str: str) -> str:
    """Convert HTML fragment to paragraph-preserving plain text.

    - Remove script/style/noscript/iframe
    - Render every <table> as a markdown table block (cells were
      previously concatenated with no delimiter, making relief lists /
      org charts / fee schedules unreadable)
    - Insert '\\n' after every block-level tag's tail
    - Collapse runs of blank lines, strip per-line whitespace
    """
    if not html_str or not html_str.strip():
        return ""
    try:
        doc = lxml_html.fragment_fromstring(html_str, create_parent="div")
    except Exception:
        # Fallback — wrap in minimal html; lxml is tolerant
        doc = lxml_html.fromstring(f"<div>{html_str}</div>")
    # Drop script/style subtrees
    for el in list(doc.iter()):
        if el.tag in REMOVE_TAGS:
            parent = el.getparent()
            if parent is not None:
                parent.remove(el)

    # Replace each <table> with a <pre>-like placeholder containing the
    # markdown rendering. Walk in DOC ORDER but build the substitution
    # bottom-up — the lxml `iter("table")` pass below catches every table
    # including nested ones; we render them outside-in so a nested table's
    # text gets folded into the outer cell's _cell_text() naturally.
    # tostring() reading order is depth-first; processing reverse() means
    # innermost nested tables get markdown'd first.
    tables = list(doc.iter("table"))
    for tbl in reversed(tables):
        md = _render_table_markdown(tbl)
        parent = tbl.getparent()
        if parent is None:
            continue
        placeholder = lxml_html.Element("p")
        # Surround with newlines so the table block is visually separated
        # from the surrounding prose after text_content() flattens it.
        placeholder.text = ("\n" + md + "\n") if md else ""
        placeholder.tail = tbl.tail
        parent.replace(tbl, placeholder)

    # Tail-inject newlines on block tags
    for el in doc.iter():
        if el.tag in BLOCK_TAGS:
            el.tail = ("\n" + el.tail) if el.tail else "\n"
    text = doc.text_content()
    # Per-line strip but preserve markdown table lines verbatim.
    out_lines: list[str] = []
    for ln in text.splitlines():
        s = ln.rstrip()
        if s.lstrip().startswith("|") and s.lstrip().endswith("|"):
            out_lines.append(s.lstrip())  # markdown row
        else:
            out_lines.append(s.strip())
    return "\n".join(ln for ln in out_lines if ln)
