"""音频与时间戳辅助工具。

职责概览：
- 将音频文件按秒切分为临时 wav 切片。
- 把标点注入 token 级时间戳序列。
- 将 token 级时间戳按句末标点合并为句子级时间戳。
"""

import tempfile
from typing import List
from pydub import AudioSegment


def split_audio_to_chunks(file_path: str, chunk_seconds: int) -> List[str]:
    audio = AudioSegment.from_file(file_path)
    chunk_ms = max(1, chunk_seconds) * 1000
    chunks: List[str] = []
    for start in range(0, len(audio), chunk_ms):
        segment = audio[start : start + chunk_ms]
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
            segment.export(tmp.name, format="wav")
            chunks.append(tmp.name)
    return chunks


def insert_punctuations_into_segments(segments: List[dict], full_text: str) -> List[dict]:
    if not full_text:
        return segments
    sentence_end_chars = set(list("。.!?！？.,，、"))
    augmented: List[dict] = []
    token_idx = 0
    i = 0
    last_end = 0.0
    n = len(full_text)

    while i < n:
        ch = full_text[i]
        if ch.isspace():
            i += 1
            continue

        if ch in sentence_end_chars:
            augmented.append({"start": round(last_end, 3), "end": round(last_end, 3), "text": ch})
            i += 1
            continue

        if token_idx < len(segments):
            seg = segments[token_idx]
            token_text = seg.get("text", "")
            if token_text:
                j = full_text.find(token_text, i)
            else:
                j = i

            if j == -1:
                augmented.append(seg)
                last_end = seg.get("end", last_end)
                token_idx += 1
                i += 1
            else:
                while i < j:
                    ch2 = full_text[i]
                    if not ch2.isspace():
                        augmented.append({"start": round(last_end, 3), "end": round(last_end, 3), "text": ch2})
                    i += 1
                augmented.append(seg)
                last_end = seg.get("end", last_end)
                token_idx += 1
                i = j + max(1, len(token_text))
        else:
            if not ch.isspace():
                augmented.append({"start": round(last_end, 3), "end": round(last_end, 3), "text": ch})
            i += 1

    for k in range(token_idx, len(segments)):
        augmented.append(segments[k])

    return augmented


def merge_sentences_from_tokens(tokens: List[dict]) -> List[dict]:
    sentence_end_chars = set(list("。.!?！？"))
    merged: List[dict] = []
    cur_text_parts: List[str] = []
    cur_start = None
    cur_end = None

    for t in tokens:
        txt = t.get("text", "")
        if not txt:
            continue
        if len(txt) == 1 and txt in sentence_end_chars:
            cur_text_parts.append(txt)
            if cur_start is not None:
                merged.append({"start": round(cur_start, 3), "end": round(t.get("end", cur_end), 3), "text": "".join(cur_text_parts).strip()})
            cur_text_parts = []
            cur_start = None
            cur_end = None
            continue

        if cur_start is None:
            cur_start = t.get("start")
        cur_end = t.get("end")
        cur_text_parts.append(txt)

    if cur_start is not None and cur_text_parts:
        merged.append({"start": round(cur_start, 3), "end": round(cur_end, 3), "text": "".join(cur_text_parts).strip()})

    return merged
