package com.zhengmeng.ocrplatform.extract;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.zhengmeng.ocrplatform.extraction.OcrKeyValueEntity;
import com.zhengmeng.ocrplatform.extraction.OcrKeyValueRepository;
import com.zhengmeng.ocrplatform.recognition.OcrTextBlockEntity;
import com.zhengmeng.ocrplatform.recognition.OcrTextBlockRepository;
import com.zhengmeng.ocrplatform.task.OcrTaskService;
import com.zhengmeng.ocrplatform.task.OcrTaskSummary;
import com.zhengmeng.ocrplatform.task.TaskStatus;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import org.springframework.web.multipart.MultipartFile;

import java.math.BigDecimal;
import java.time.Duration;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;

@Service
public class OcrExtractService {
    private static final int DEFAULT_SYNC_TIMEOUT_SECONDS = 120;

    private final OcrTaskService taskService;
    private final OcrTextBlockRepository textBlockRepository;
    private final OcrKeyValueRepository keyValueRepository;
    private final OcrBusinessProfileService profileService;
    private final ObjectMapper objectMapper;

    public OcrExtractService(OcrTaskService taskService,
                             OcrTextBlockRepository textBlockRepository,
                             OcrKeyValueRepository keyValueRepository,
                             OcrBusinessProfileService profileService,
                             ObjectMapper objectMapper) {
        this.taskService = taskService;
        this.textBlockRepository = textBlockRepository;
        this.keyValueRepository = keyValueRepository;
        this.profileService = profileService;
        this.objectMapper = objectMapper;
    }

    public ExtractResponse createExtractTask(List<MultipartFile> files,
                                             String keysJson,
                                             String businessProfileId,
                                             String sourceSystem,
                                             String businessType,
                                             boolean sync,
                                             int timeoutSeconds,
                                             String clientCode) {
        List<MultipartFile> safeFiles = files == null ? List.of() : files.stream()
                .filter(file -> file != null && !file.isEmpty())
                .toList();
        if (safeFiles.isEmpty()) {
            throw new IllegalArgumentException("At least one image or PDF file is required.");
        }
        List<ExtractKey> keys = parseKeys(keysJson);
        if (keys.isEmpty()) {
            throw new IllegalArgumentException("keys is required. Pass a JSON array of strings or objects.");
        }
        Optional<OcrBusinessProfileService.OcrBusinessProfile> profile = profileService.findProfile(businessProfileId);

        List<OcrTaskSummary> tasks = new ArrayList<>();
        for (MultipartFile file : safeFiles) {
            tasks.add(taskService.createTask(
                    defaultIfBlank(sourceSystem, "open-ocr-extract-api"),
                    defaultIfBlank(businessType, "OCR_EXTRACT"),
                    null,
                    5,
                    file,
                    clientCode
            ));
        }

        if (!sync) {
            return new ExtractResponse(
                    tasks.size() == 1 ? tasks.getFirst().taskId() : null,
                    tasks,
                    "ACCEPTED",
                    profile.map(OcrBusinessProfileService.OcrBusinessProfile::profileId).orElse(null),
                    profile.map(OcrBusinessProfileService.OcrBusinessProfile::profileName).orElse(null),
                    keys,
                    List.of(),
                    List.of(new ExtractWarning(null, "ASYNC_TASK_CREATED", "任务已创建，请调用 /api/v1/ocr/extract-tasks/{taskId}?keys=... 查询结果。"))
            );
        }

        int safeTimeout = timeoutSeconds <= 0 ? DEFAULT_SYNC_TIMEOUT_SECONDS : Math.min(timeoutSeconds, 300);
        List<OcrTaskSummary> completedTasks = waitForTasks(tasks, Duration.ofSeconds(safeTimeout), clientCode);
        List<ExtractResultItem> results = new ArrayList<>();
        List<ExtractWarning> warnings = new ArrayList<>();
        for (OcrTaskSummary task : completedTasks) {
            if (task.status() != TaskStatus.COMPLETED) {
                warnings.add(new ExtractWarning(null, "TASK_NOT_COMPLETED", "任务 " + task.taskId() + " 当前状态：" + task.status()));
                continue;
            }
            ExtractedTaskResult extracted = extractFromTask(task, keys, clientCode, profile);
            results.addAll(extracted.results());
            warnings.addAll(extracted.warnings());
        }
        return new ExtractResponse(
                completedTasks.size() == 1 ? completedTasks.getFirst().taskId() : null,
                completedTasks,
                completedTasks.stream().allMatch(task -> task.status() == TaskStatus.COMPLETED) ? "COMPLETED" : "PROCESSING",
                profile.map(OcrBusinessProfileService.OcrBusinessProfile::profileId).orElse(null),
                profile.map(OcrBusinessProfileService.OcrBusinessProfile::profileName).orElse(null),
                keys,
                results,
                warnings
        );
    }

    public ExtractResponse getExtractTask(String taskId, String keysJson, String businessProfileId, String clientCode) {
        List<ExtractKey> keys = parseKeys(keysJson);
        if (keys.isEmpty()) {
            throw new IllegalArgumentException("keys is required. Pass the same JSON key list used when submitting.");
        }
        Optional<OcrBusinessProfileService.OcrBusinessProfile> profile = profileService.findProfile(businessProfileId);
        OcrTaskSummary task = taskService.findTask(taskId, clientCode)
                .orElseThrow(() -> new IllegalArgumentException("Task not found: " + taskId));
        if (task.status() != TaskStatus.COMPLETED) {
            return new ExtractResponse(
                    task.taskId(),
                    List.of(task),
                    task.status().name(),
                    profile.map(OcrBusinessProfileService.OcrBusinessProfile::profileId).orElse(null),
                    profile.map(OcrBusinessProfileService.OcrBusinessProfile::profileName).orElse(null),
                    keys,
                    List.of(),
                    List.of(new ExtractWarning(null, "TASK_NOT_COMPLETED", "任务尚未完成，当前状态：" + task.status()))
            );
        }
        ExtractedTaskResult extracted = extractFromTask(task, keys, clientCode, profile);
        return new ExtractResponse(
                task.taskId(),
                List.of(task),
                task.status().name(),
                profile.map(OcrBusinessProfileService.OcrBusinessProfile::profileId).orElse(null),
                profile.map(OcrBusinessProfileService.OcrBusinessProfile::profileName).orElse(null),
                keys,
                extracted.results(),
                extracted.warnings()
        );
    }

    private List<OcrTaskSummary> waitForTasks(List<OcrTaskSummary> tasks, Duration timeout, String clientCode) {
        long deadline = System.currentTimeMillis() + timeout.toMillis();
        List<OcrTaskSummary> current = tasks;
        while (System.currentTimeMillis() <= deadline) {
            current = current.stream()
                    .map(task -> taskService.findTask(task.taskId(), clientCode).orElse(task))
                    .toList();
            boolean allDone = current.stream().allMatch(task ->
                    task.status() == TaskStatus.COMPLETED
                            || task.status() == TaskStatus.FAILED
                            || task.status() == TaskStatus.CANCELLED);
            if (allDone) {
                return current;
            }
            try {
                Thread.sleep(1500);
            } catch (InterruptedException ex) {
                Thread.currentThread().interrupt();
                return current;
            }
        }
        return current.stream()
                .map(task -> taskService.findTask(task.taskId(), clientCode).orElse(task))
                .toList();
    }

    private ExtractedTaskResult extractFromTask(OcrTaskSummary task,
                                                List<ExtractKey> keys,
                                                String clientCode,
                                                Optional<OcrBusinessProfileService.OcrBusinessProfile> profile) {
        taskService.ensureResultReadable(task.taskId(), clientCode);
        Integer latestRunNo = textBlockRepository.findMaxRunNoByTaskId(task.taskId());
        List<OcrTextBlockEntity> textBlocks = latestRunNo == null
                ? List.of()
                : textBlockRepository.findByTaskIdAndRunNoOrderByPageNoAscReadingOrderAscIdAsc(task.taskId(), latestRunNo);
        Integer latestKeyValueRunNo = keyValueRepository.findMaxRunNoByTaskId(task.taskId());
        List<OcrKeyValueEntity> keyValues = latestKeyValueRunNo == null
                ? List.of()
                : keyValueRepository.findByTaskIdAndRunNoOrderByIdAsc(task.taskId(), latestKeyValueRunNo);

        List<ExtractResultItem> results = new ArrayList<>();
        List<ExtractWarning> warnings = new ArrayList<>();
        for (ExtractKey key : keys) {
            Optional<ExtractResultItem> fromKeyValue = matchKeyValue(task, key, keyValues);
            if (fromKeyValue.isPresent()) {
                results.add(fromKeyValue.get());
                continue;
            }
            Optional<ExtractResultItem> fromText = matchTextBlock(task, key, textBlocks, profile);
            if (fromText.isPresent()) {
                results.add(fromText.get());
                continue;
            }
            results.add(new ExtractResultItem(
                    task.taskId(),
                    key.key(),
                    key.label(),
                    "",
                    null,
                    "",
                    null,
                    null,
                    "NOT_FOUND"
            ));
            warnings.add(new ExtractWarning(key.key(), "KEY_NOT_FOUND", "未从 OCR 结果中命中字段：" + key.displayName()));
        }
        return new ExtractedTaskResult(results, warnings);
    }

    private Optional<ExtractResultItem> matchKeyValue(OcrTaskSummary task, ExtractKey key, List<OcrKeyValueEntity> keyValues) {
        for (OcrKeyValueEntity entity : keyValues) {
            if (!matches(key, entity.getFieldKey(), entity.getFieldName())) {
                continue;
            }
            String value = StringUtils.hasText(entity.getNormalizedValue()) ? entity.getNormalizedValue() : entity.getValueText();
            return Optional.of(new ExtractResultItem(
                    task.taskId(),
                    key.key(),
                    key.label(),
                    nullToBlank(value),
                    entity.getConfidence(),
                    nullToBlank(entity.getValueText()),
                    entity.getPageNo(),
                    null,
                    "ENGINE_KEY_VALUE"
            ));
        }
        return Optional.empty();
    }

    private Optional<ExtractResultItem> matchTextBlock(OcrTaskSummary task,
                                                       ExtractKey key,
                                                       List<OcrTextBlockEntity> textBlocks,
                                                       Optional<OcrBusinessProfileService.OcrBusinessProfile> profile) {
        Optional<OcrBusinessProfileService.FieldRule> fieldRule = profile
                .flatMap(found -> profileService.findFieldRule(found, key.key(), key.label()));
        List<String> candidates = matchTerms(key, fieldRule);
        for (int index = 0; index < textBlocks.size(); index++) {
            OcrTextBlockEntity block = textBlocks.get(index);
            String text = nullToBlank(block.getTextContent());
            String matchedTerm = candidates.stream()
                    .filter(term -> containsRelaxed(text, term))
                    .findFirst()
                    .orElse(null);
            if (matchedTerm == null) {
                continue;
            }
            String value = extractValueAfterTerm(text, matchedTerm);
            if (!StringUtils.hasText(value)) {
                OcrBusinessProfileService.ValueStrategy strategy = fieldRule.map(OcrBusinessProfileService.FieldRule::valueStrategy).orElse(null);
                value = findSpatialValue(textBlocks, index, candidates, strategy);
                if (!StringUtils.hasText(value)) {
                    value = findAdjacentValue(textBlocks, index, candidates, strategy);
                }
            }
            return Optional.of(new ExtractResultItem(
                    task.taskId(),
                    key.key(),
                    key.label(),
                    cleanupValue(value),
                    block.getConfidence(),
                    text,
                    block.getPageNo(),
                    block.getBboxJson(),
                    "TEXT_BLOCK_MATCH"
            ));
        }
        return Optional.empty();
    }

    private boolean matches(ExtractKey key, String... values) {
        Set<String> terms = new LinkedHashSet<>(matchTerms(key, Optional.empty()).stream().map(this::normalizeForMatch).toList());
        for (String value : values) {
            String normalized = normalizeForMatch(value);
            if (terms.contains(normalized)) {
                return true;
            }
        }
        return false;
    }

    private boolean containsRelaxed(String text, String term) {
        String normalizedText = normalizeForMatch(text);
        String normalizedTerm = normalizeForMatch(term);
        if (!StringUtils.hasText(normalizedTerm) || !StringUtils.hasText(normalizedText)) {
            return false;
        }
        if (normalizedText.contains(normalizedTerm)) {
            return true;
        }
        return normalizedTerm.length() >= 8
                && normalizedText.length() >= 5
                && normalizedTerm.startsWith(normalizedText);
    }

    private String extractValueAfterTerm(String text, String term) {
        String trimmed = nullToBlank(text).trim();
        int index = trimmed.indexOf(term);
        if (index >= 0) {
            String after = trimmed.substring(index + term.length()).trim();
            return cleanupValue(after);
        }
        return "";
    }

    private String findAdjacentValue(List<OcrTextBlockEntity> textBlocks,
                                     int index,
                                     List<String> candidates,
                                     OcrBusinessProfileService.ValueStrategy strategy) {
        int maxDistance = strategy == null || strategy.maxDistance() == null ? 2 : Math.max(1, strategy.maxDistance());
        String prefer = strategy == null ? "next" : defaultIfBlank(strategy.prefer(), "next");
        List<Integer> offsets = offsets(prefer, maxDistance);
        for (Integer offset : offsets) {
            int candidateIndex = index + offset;
            if (candidateIndex < 0 || candidateIndex >= textBlocks.size()) {
                continue;
            }
            String value = cleanupValue(textBlocks.get(candidateIndex).getTextContent());
            if (looksLikeValue(value, candidates, strategy)) {
                return value;
            }
        }
        return "";
    }

    private String findSpatialValue(List<OcrTextBlockEntity> textBlocks,
                                    int index,
                                    List<String> candidates,
                                    OcrBusinessProfileService.ValueStrategy strategy) {
        OcrTextBlockEntity labelBlock = textBlocks.get(index);
        Optional<Bbox> labelBox = parseBbox(labelBlock.getBboxJson());
        if (labelBox.isEmpty()) {
            return "";
        }
        double rowTolerance = Math.max(28, labelBox.get().height() * 1.6);
        SpatialCandidate best = null;
        for (int candidateIndex = 0; candidateIndex < textBlocks.size(); candidateIndex++) {
            if (candidateIndex == index) {
                continue;
            }
            OcrTextBlockEntity candidateBlock = textBlocks.get(candidateIndex);
            if (labelBlock.getPageNo() != candidateBlock.getPageNo()) {
                continue;
            }
            Optional<Bbox> candidateBox = parseBbox(candidateBlock.getBboxJson());
            if (candidateBox.isEmpty()) {
                continue;
            }
            Bbox valueBox = candidateBox.get();
            if (Math.abs(valueBox.centerY() - labelBox.get().centerY()) > rowTolerance) {
                continue;
            }
            if (valueBox.left() < labelBox.get().right() - 6) {
                continue;
            }
            String value = cleanupValue(candidateBlock.getTextContent());
            if (!looksLikeValue(value, candidates, strategy)) {
                continue;
            }
            double horizontalGap = Math.max(0, valueBox.left() - labelBox.get().right());
            double verticalGap = Math.abs(valueBox.centerY() - labelBox.get().centerY());
            double readingPenalty = Math.abs(candidateIndex - index) * 0.25;
            double score = horizontalGap + verticalGap * 4 + readingPenalty;
            if (best == null || score < best.score()) {
                best = new SpatialCandidate(value, score);
            }
        }
        return best == null ? "" : best.value();
    }

    private List<Integer> offsets(String prefer, int maxDistance) {
        List<Integer> result = new ArrayList<>();
        if ("previous_or_next".equalsIgnoreCase(prefer)) {
            for (int distance = 1; distance <= maxDistance; distance++) {
                result.add(-distance);
                result.add(distance);
            }
            return result;
        }
        for (int distance = 1; distance <= maxDistance; distance++) {
            result.add(distance);
        }
        for (int distance = 1; distance <= maxDistance; distance++) {
            result.add(-distance);
        }
        return result;
    }

    private boolean looksLikeValue(String value, List<String> candidates, OcrBusinessProfileService.ValueStrategy strategy) {
        String cleaned = cleanupValue(value);
        if (!StringUtils.hasText(cleaned)) {
            return false;
        }
        String normalized = normalizeForMatch(cleaned);
        Set<String> allowedShortValues = new LinkedHashSet<>(
                strategy == null || strategy.allowShortValues() == null
                        ? List.of()
                        : strategy.allowShortValues()
        );
        if (allowedShortValues.stream().map(this::normalizeForMatch).anyMatch(normalized::equals)) {
            return true;
        }
        if (Set.of("是", "否", "无", "有", "男", "女", "1", "2", "3").stream()
                .map(this::normalizeForMatch)
                .anyMatch(normalized::equals)) {
            return true;
        }
        if (normalized.length() <= 1) {
            return false;
        }
        if (candidates.stream().map(this::normalizeForMatch).anyMatch(normalized::contains)) {
            return false;
        }
        return !normalized.matches(".*(地址|名称|姓名|类型|号码|编号|日期|有效期|情况|期限|范围|文本|标志|方式|对象|时间|类别|账号|户名)$");
    }

    private Optional<Bbox> parseBbox(String bboxJson) {
        if (!StringUtils.hasText(bboxJson)) {
            return Optional.empty();
        }
        try {
            JsonNode node = objectMapper.readTree(bboxJson);
            if (node.isArray() && node.size() >= 4) {
                return Optional.of(new Bbox(
                        node.get(0).asDouble(),
                        node.get(1).asDouble(),
                        node.get(2).asDouble(),
                        node.get(3).asDouble()
                ));
            }
            if (node.isObject()) {
                return Optional.of(new Bbox(
                        doubleValue(node, "x1", doubleValue(node, "left", 0)),
                        doubleValue(node, "y1", doubleValue(node, "top", 0)),
                        doubleValue(node, "x2", doubleValue(node, "right", 0)),
                        doubleValue(node, "y2", doubleValue(node, "bottom", 0))
                ));
            }
        } catch (Exception ignored) {
            return Optional.empty();
        }
        return Optional.empty();
    }

    private double doubleValue(JsonNode node, String field, double defaultValue) {
        JsonNode value = node.path(field);
        return value.isNumber() ? value.asDouble() : defaultValue;
    }

    private String cleanupValue(String value) {
        return nullToBlank(value)
                .replaceFirst("^[：:;；,，\\s]+", "")
                .replaceFirst("^[□■☑✓√]+", "")
                .trim();
    }

    private String normalizeForMatch(String value) {
        return nullToBlank(value)
                .toLowerCase(Locale.ROOT)
                .replaceAll("[\\s*★☆·•:：;；,，_\\-（）()\\[\\]【】]+", "");
    }

    private List<String> matchTerms(ExtractKey key, Optional<OcrBusinessProfileService.FieldRule> fieldRule) {
        List<String> terms = new ArrayList<>(key.matchTerms());
        fieldRule.map(OcrBusinessProfileService.FieldRule::matchTerms).ifPresent(terms::addAll);
        return terms.stream().filter(StringUtils::hasText).distinct().toList();
    }

    private List<ExtractKey> parseKeys(String keysJson) {
        if (!StringUtils.hasText(keysJson)) {
            return List.of();
        }
        try {
            JsonNode root = objectMapper.readTree(keysJson);
            if (!root.isArray()) {
                throw new IllegalArgumentException("keys must be a JSON array.");
            }
            List<ExtractKey> keys = new ArrayList<>();
            for (JsonNode item : root) {
                if (item.isTextual()) {
                    String key = item.asText();
                    keys.add(new ExtractKey(key, key, null, null));
                    continue;
                }
                String key = text(item, "key", text(item, "fieldKey", ""));
                String label = text(item, "label", text(item, "fieldName", key));
                String description = text(item, "description", "");
                List<String> aliases = new ArrayList<>();
                JsonNode aliasNode = item.path("aliases");
                if (aliasNode.isArray()) {
                    aliasNode.forEach(alias -> aliases.add(alias.asText()));
                }
                if (StringUtils.hasText(key) || StringUtils.hasText(label)) {
                    keys.add(new ExtractKey(defaultIfBlank(key, label), defaultIfBlank(label, key), description, aliases));
                }
            }
            return keys;
        } catch (Exception ex) {
            throw new IllegalArgumentException("Invalid keys JSON: " + ex.getMessage(), ex);
        }
    }

    private String text(JsonNode node, String field, String defaultValue) {
        JsonNode value = node.path(field);
        return value.isMissingNode() || value.isNull() ? defaultValue : value.asText(defaultValue);
    }

    private String defaultIfBlank(String value, String defaultValue) {
        return StringUtils.hasText(value) ? value.trim() : defaultValue;
    }

    private String nullToBlank(String value) {
        return value == null ? "" : value;
    }

    public record ExtractKey(
            String key,
            String label,
            String description,
            List<String> aliases
    ) {
        List<String> matchTerms() {
            List<String> terms = new ArrayList<>();
            if (StringUtils.hasText(key)) {
                terms.add(key);
            }
            if (StringUtils.hasText(label)) {
                terms.add(label);
            }
            if (aliases != null) {
                aliases.stream().filter(StringUtils::hasText).forEach(terms::add);
            }
            return terms.stream().distinct().toList();
        }

        String displayName() {
            return StringUtils.hasText(label) ? label : key;
        }
    }

    public record ExtractResponse(
            String taskId,
            List<OcrTaskSummary> tasks,
            String status,
            String businessProfileId,
            String businessProfileName,
            List<ExtractKey> keys,
            List<ExtractResultItem> results,
            List<ExtractWarning> warnings
    ) {
    }

    public record ExtractResultItem(
            String taskId,
            String key,
            String label,
            String value,
            BigDecimal confidence,
            String sourceText,
            Integer pageNo,
            String bboxJson,
            String source
    ) {
    }

    public record ExtractWarning(
            String key,
            String code,
            String message
    ) {
    }

    private record ExtractedTaskResult(
            List<ExtractResultItem> results,
            List<ExtractWarning> warnings
    ) {
    }

    private record Bbox(double left, double top, double right, double bottom) {
        double height() {
            return Math.max(1, bottom - top);
        }

        double centerY() {
            return (top + bottom) / 2;
        }
    }

    private record SpatialCandidate(String value, double score) {
    }
}
