import argparse
import asyncio
import logging
import os
import pathlib
from typing import Optional

import cv2
import numpy as np
from cv2.typing import MatLike

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)


async def read_image_to_mem(image_path) -> Optional[MatLike]:
    try:
        image_in8_vec = cv2.imread(image_path)
        return image_in8_vec
    except Exception as error:
        logging.error(f"Error: {error}")
    return None


async def convert_image_to_mask(image: MatLike) -> Optional[MatLike]:
    try:
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        _, mask_vec = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
        kernel_size = np.ones((5, 5), np.uint8)
        mask_vec = cv2.morphologyEx(mask_vec, cv2.MORPH_OPEN, kernel_size)
        return mask_vec

    except Exception as error:
        logging.error(f"Error: {error}")
    return None


async def save_image_to_persist_storage(mask_image: MatLike, path: str) -> None:
    try:
        logging.info(f"Saving image - : {str(path)}")
        cv2.imwrite(path, mask_image)
    except Exception as error:
        logging.error(f"Error: {error}")


async def do_shape_comparison(mask_image_1: MatLike, mask_image_2: MatLike) -> MatLike:
    shape_change = cv2.bitwise_xor(mask_image_1, mask_image_2)
    return shape_change


async def do_color_comparsion(image_1: MatLike, image_2: MatLike) -> MatLike:
    lab1 = cv2.cvtColor(image_1, cv2.COLOR_BGR2Lab)
    lab2 = cv2.cvtColor(image_2, cv2.COLOR_BGR2Lab)

    l1, a1, b1 = cv2.split(lab1)
    l2, a2, b2 = cv2.split(lab2)

    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))

    _ = clahe.apply(l1)  # Normalizing l1
    _ = clahe.apply(l2)  # Normalizing l2

    diff_a = cv2.absdiff(a1, a2)
    diff_b = cv2.absdiff(b1, b2)

    color_change = cv2.addWeighted(diff_a, 0.5, diff_b, 0.5, 0)

    _, color_change = cv2.threshold(color_change, 25, 255, cv2.THRESH_BINARY)

    return color_change


async def draw_changes_area_on_image(
    cordination: MatLike, image: MatLike, area_size=200
) -> MatLike:
    coord_text = "building change"
    FONT_SCALE = 1.2
    THICKNESS = 2

    # 1. Find the areas of change
    contours, _ = cv2.findContours(
        cordination, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE
    )

    final_convas: MatLike = image.copy()

    for cnt in contours:
        if cv2.contourArea(cnt) < area_size:
            continue

        # 2. Get the rectangle coordinates
        x, y, w, h = cv2.boundingRect(cnt)

        # 3. Calculate text size to make the background box fit
        (text_w, text_h), _ = cv2.getTextSize(
            coord_text, cv2.FONT_HERSHEY_SIMPLEX, FONT_SCALE, THICKNESS
        )

        # 4. Draw the Bounding Box (The neat rectangle)
        cv2.rectangle(final_convas, (x, y), (x + w, y + h), (0, 0, 255), 2)

        # 5. Draw text background (Scaled to text size)
        # Positioned just above the change area
        cv2.rectangle(
            final_convas, (x, y - text_h - 15), (x + text_w + 5, y), (0, 0, 255), -1
        )

        # 6. Draw the Label (White text for better contrast on the red background)
        cv2.putText(
            final_convas,
            text=coord_text,
            org=(x + 2, y - 10),
            fontFace=cv2.FONT_HERSHEY_SIMPLEX,
            fontScale=FONT_SCALE,
            color=(255, 255, 255),  # Pure white is easier to read than green on red
            thickness=THICKNESS,
        )

    return final_convas


async def do_combine_diff_mask(shape_diff: MatLike, color_diff: MatLike) -> MatLike:
    combined_mask = cv2.bitwise_or(shape_diff, color_diff)
    kernel = np.ones((5, 5), np.uint8)
    combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_OPEN, kernel)
    return combined_mask


async def do_comparsion(image_1: str, image_2: str) -> Optional[MatLike]:
    try:
        image_1_vec = await read_image_to_mem(image_1)
        image_2_vec = await read_image_to_mem(image_2)
        if image_1_vec is not None and image_2_vec is not None:
            image_1_mask = await convert_image_to_mask(image_1_vec)
            image_2_mask = await convert_image_to_mask(image_2_vec)
            if image_1_mask is not None and image_2_mask is not None:
                image_1_path = pathlib.Path(image_1)
                image_2_path = pathlib.Path(image_2)
                await save_image_to_persist_storage(
                    image_1_mask,
                    str(image_1_path.with_stem(f"{image_1_path.stem}_mask")),
                )
                await save_image_to_persist_storage(
                    image_2_mask,
                    str(image_2_path.with_stem(f"{image_2_path.stem}_mask")),
                )
                color_diff = await do_color_comparsion(image_1_vec, image_2_vec)
                shape_diff = await do_shape_comparison(image_1_mask, image_2_mask)
                combined_mask = await do_combine_diff_mask(shape_diff, color_diff)

                final_convas = await draw_changes_area_on_image(
                    combined_mask, image_2_vec
                )
                await save_image_to_persist_storage(
                    final_convas,
                    str(image_2_path.with_stem(f"{image_2_path.stem}_final")),
                )

    except Exception as error:
        logging.error(f"Error: {error}")
    return None


async def main(args: argparse.Namespace):
    image_1, image_2 = args.images
    logging.info(f"Comparing images: {image_1} and {image_2}")
    if pathlib.Path(image_1).is_file() and pathlib.Path(image_2).is_file():
        logging.info(f"Comparing images: {image_1} and {image_2}")
        final_result = await do_comparsion(image_1, image_2)
        if final_result is not None:
            logging.info(f"Comparision result: {final_result}, images are illegal")
    else:
        logging.error(f"Images not found: {image_1} and {image_2}, process quit")
        exit(1)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Hyper for the image comparasion")
    parser.add_argument(
        "-i",
        "--images",
        nargs=2,
        required=True,
        metavar=("image1", "image2"),
        help="The image paths: image1 and image2",
    )
    args = parser.parse_args()
    asyncio.run(main(args))

# 使用方法
# python3 main.py -i images/2/T1.jpg images/2/T2.jpg 
