"""Focused tests for ingest_trace_recorder opensearch-py parameter passing.

验证 _sync_seq / _atomic_inc_seq 通过 params dict 传递
_source 和 retry_on_conflict，而非关键字参数（opensearch-py 3.0.0 不支持）。
"""

from __future__ import annotations

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from app.core.ingest_trace_recorder import IngestTraceRecorder


def _make_recorder() -> IngestTraceRecorder:
    """Create a recorder with a mocked ESClient."""
    es = MagicMock()
    es.raw = MagicMock()
    recorder = IngestTraceRecorder(es, trace_id="test-trace", doc_id="test-doc")
    return recorder


class TestSyncExistingSeqParams:
    """_sync_seq 应通过 params 传递 _source。"""

    def test_get_called_with_params(self):
        recorder = _make_recorder()
        recorder._es.raw.get = AsyncMock(
            return_value={
                "_source": {"latest_seq": 5, "started_at": "2026-01-01T00:00:00Z"},
            }
        )

        asyncio.get_event_loop().run_until_complete(recorder._sync_seq())

        recorder._es.raw.get.assert_called_once()
        call_kwargs = recorder._es.raw.get.call_args.kwargs
        # _source 必须在 params 中，不能作为顶层关键字参数
        assert "params" in call_kwargs, "get() must use params dict"
        assert "_source" in call_kwargs["params"], "params must contain _source"
        assert "_source" not in call_kwargs, "_source must not be a top-level kwarg"
        # 验证 seq 被正确同步
        assert recorder._seq == 5

    def test_get_params_source_is_comma_separated(self):
        """_source 值应为逗号分隔字符串（非列表），符合 opensearch-py params 格式。"""
        recorder = _make_recorder()
        recorder._es.raw.get = AsyncMock(
            return_value={"_source": {"latest_seq": 3}}
        )

        asyncio.get_event_loop().run_until_complete(recorder._sync_seq())

        params = recorder._es.raw.get.call_args.kwargs["params"]
        assert isinstance(params["_source"], str), "_source param should be a string"


class TestAtomicIncSeqParams:
    """_atomic_inc_seq 应通过 params 传递 _source 和 retry_on_conflict。"""

    def test_update_called_with_params(self):
        recorder = _make_recorder()
        recorder._seq_synced = True
        recorder._es.raw.update = AsyncMock(
            return_value={
                "get": {"_source": {"latest_seq": 2}},
            }
        )

        result = asyncio.get_event_loop().run_until_complete(recorder._atomic_inc_seq())

        recorder._es.raw.update.assert_called_once()
        call_kwargs = recorder._es.raw.update.call_args.kwargs
        # params 必须包含 _source 和 retry_on_conflict
        assert "params" in call_kwargs, "update() must use params dict"
        params = call_kwargs["params"]
        assert "_source" in params, "params must contain _source"
        assert "retry_on_conflict" in params, "params must contain retry_on_conflict"
        assert params["retry_on_conflict"] == 3
        # 不能作为顶层关键字参数
        assert "_source" not in call_kwargs, "_source must not be a top-level kwarg"
        assert "retry_on_conflict" not in call_kwargs, "retry_on_conflict must not be a top-level kwarg"
        # 返回正确的 seq 值
        assert result == 2

    def test_fallback_get_when_update_lacks_get_source(self):
        """update 响应无 get._source 时，应 fallback 到 get 读取。"""
        recorder = _make_recorder()
        recorder._seq_synced = True
        # update 响应不含 get._source
        recorder._es.raw.update = AsyncMock(return_value={"result": "updated"})
        # fallback get
        recorder._es.raw.get = AsyncMock(
            return_value={"_source": {"latest_seq": 7}}
        )

        result = asyncio.get_event_loop().run_until_complete(recorder._atomic_inc_seq())

        # 验证 fallback get 也通过 params 传递 _source
        recorder._es.raw.get.assert_called_once()
        get_kwargs = recorder._es.raw.get.call_args.kwargs
        assert "params" in get_kwargs
        assert "_source" in get_kwargs["params"]
        assert "_source" not in get_kwargs
        assert result == 7

    def test_local_fallback_on_exception(self):
        """update 抛异常时应回退到本地递增，不阻塞流程。"""
        recorder = _make_recorder()
        recorder._seq_synced = True
        recorder._seq = 10
        recorder._es.raw.update = AsyncMock(side_effect=TypeError("bad kwarg"))

        result = asyncio.get_event_loop().run_until_complete(recorder._atomic_inc_seq())

        assert result == 11  # 本地递增
        assert recorder._seq == 11
