package com.zhengmeng.ocrplatform.worker;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zhengmeng.ocrplatform.engine.EngineProperties;
import com.zhengmeng.ocrplatform.engine.OcrEngineAdapter;
import com.zhengmeng.ocrplatform.engine.OcrEngineRequest;
import com.zhengmeng.ocrplatform.engine.OcrEngineResult;
import com.zhengmeng.ocrplatform.engine.OcrEngineRegistry;
import com.zhengmeng.ocrplatform.task.OcrTaskEntity;
import com.zhengmeng.ocrplatform.task.OcrTaskRepository;
import com.zhengmeng.ocrplatform.task.TaskStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;

import java.util.concurrent.atomic.AtomicBoolean;
import java.util.LinkedHashMap;
import java.util.Map;

@Component
@ConditionalOnProperty(prefix = "ocr-platform.worker", name = "enabled", havingValue = "true", matchIfMissing = true)
public class OcrTaskWorker {
    private static final Logger log = LoggerFactory.getLogger(OcrTaskWorker.class);

    private final OcrTaskRepository taskRepository;
    private final OcrTaskExecutionService executionService;
    private final OcrEngineRegistry engineRegistry;
    private final EngineProperties engineProperties;
    private final WorkerProperties properties;
    private final ObjectMapper objectMapper = new ObjectMapper();
    private final AtomicBoolean running = new AtomicBoolean(false);

    public OcrTaskWorker(OcrTaskRepository taskRepository,
                         OcrTaskExecutionService executionService,
                         OcrEngineRegistry engineRegistry,
                         EngineProperties engineProperties,
                         WorkerProperties properties) {
        this.taskRepository = taskRepository;
        this.executionService = executionService;
        this.engineRegistry = engineRegistry;
        this.engineProperties = engineProperties;
        this.properties = properties;
    }

    @Scheduled(fixedDelayString = "${ocr-platform.worker.poll-interval-ms:5000}")
    public void pollQueuedTasks() {
        if (!running.compareAndSet(false, true)) {
            return;
        }

        try {
            executionService.recoverTimedOutRecognizingTasks();
            taskRepository.findTop10ByStatusOrderByCreatedAtAsc(TaskStatus.QUEUED).stream()
                    .limit(properties.batchSize())
                    .forEach(this::processTask);
        } finally {
            running.set(false);
        }
    }

    private void processTask(OcrTaskEntity task) {
        String taskId = task.getTaskId();
        try {
            executionService.claim(taskId).ifPresent(context -> {
                OcrEngineAdapter adapter = engineRegistry.defaultAdapter();
                OcrEngineRequest request = new OcrEngineRequest(
                        context.taskId(),
                        context.document().getOriginalFilename(),
                        context.document().getContentType(),
                        context.document().getFileSize(),
                        context.document().getSha256(),
                        context.filePath()
                );
                String requestPayloadJson = buildRequestPayloadJson(context.runNo(), adapter, request);
                long startNanos = System.nanoTime();
                try {
                    OcrEngineResult result = adapter.recognize(request);
                    long elapsedMs = Math.max(0L, (System.nanoTime() - startNanos) / 1_000_000);
                    executionService.complete(taskId, context.runNo(), result, requestPayloadJson, elapsedMs);
                    log.info("OCR task {} completed by engine {} on run #{}", taskId, result.engineCode(), context.runNo());
                } catch (Exception ex) {
                    long elapsedMs = Math.max(0L, (System.nanoTime() - startNanos) / 1_000_000);
                    executionService.fail(taskId, context.runNo(), adapter, ex, requestPayloadJson, elapsedMs);
                    log.error("OCR task {} failed on run #{}", taskId, context.runNo(), ex);
                }
            });
        } catch (Exception ex) {
            log.error("OCR task {} failed", taskId, ex);
            int fallbackRunNo = task.getAttemptCount() + 1;
            executionService.fail(taskId,
                    fallbackRunNo,
                    engineRegistry.defaultAdapter(),
                    ex,
                    buildRequestPayloadJson(fallbackRunNo, engineRegistry.defaultAdapter(), taskId),
                    null);
        }
    }

    private String buildRequestPayloadJson(int runNo, OcrEngineAdapter adapter, OcrEngineRequest request) {
        Map<String, Object> payload = new LinkedHashMap<>();
        payload.put("runNo", runNo);
        payload.put("engineCode", adapter == null ? null : adapter.engineCode());
        payload.put("engineVersion", adapter == null ? null : adapter.engineVersion());
        payload.put("taskId", request.taskId());
        payload.put("originalFilename", request.originalFilename());
        payload.put("contentType", request.contentType());
        payload.put("fileSize", request.fileSize());
        payload.put("sha256", request.sha256());
        payload.put("filePath", request.filePath().toString());
        payload.put("engine", engineSnapshot());

        try {
            return objectMapper.writeValueAsString(payload);
        } catch (JsonProcessingException ex) {
            log.warn("Failed to serialize engine request payload for task {}", request.taskId(), ex);
            return null;
        }
    }

    private String buildRequestPayloadJson(int runNo, OcrEngineAdapter adapter, String taskId) {
        Map<String, Object> payload = new LinkedHashMap<>();
        payload.put("runNo", runNo);
        payload.put("engineCode", adapter == null ? null : adapter.engineCode());
        payload.put("engineVersion", adapter == null ? null : adapter.engineVersion());
        payload.put("taskId", taskId);
        payload.put("error", "Request snapshot unavailable");
        payload.put("engine", engineSnapshot());
        try {
            return objectMapper.writeValueAsString(payload);
        } catch (JsonProcessingException ex) {
            return null;
        }
    }

    private Map<String, Object> engineSnapshot() {
        Map<String, Object> engine = new LinkedHashMap<>();
        engine.put("defaultEngineCode", engineProperties.defaultCode());
        engine.put("paddleLayoutEnabled", engineProperties.paddleLayout().enabled());
        engine.put("paddleLayoutEndpoint", engineProperties.paddleLayout().endpoint());
        engine.put("paddleLayoutTimeoutMs", engineProperties.paddleLayout().timeout().toMillis());
        engine.put("ppChatOcrEnabled", engineProperties.ppChatOcr().enabled());
        engine.put("ppChatOcrVisualEndpoint", engineProperties.ppChatOcr().visualEndpoint());
        engine.put("ppChatOcrChatEndpoint", engineProperties.ppChatOcr().chatEndpoint());
        engine.put("ppChatOcrTimeoutMs", engineProperties.ppChatOcr().timeout().toMillis());
        engine.put("ppChatOcrExtractionKeys", engineProperties.ppChatOcr().extractionKeys());
        return engine;
    }
}
