import cv2
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from skimage import segmentation, measure
import os


class BuildingChangeDetector:
    def __init__(self):
        """
        Illegal Building Detection System
        Traditional computer vision approach for detecting building changes in drone images
        """
        self.kernel_size = 5
        self.morph_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))

    def preprocess_image(self, image):
        """
        Image preprocessing: noise reduction and contrast enhancement
        """
        # Gaussian filtering for noise reduction
        denoised = cv2.GaussianBlur(image, (5, 5), 0)

        # Histogram equalization for contrast enhancement
        if len(image.shape) == 3:
            # Convert to LAB color space
            lab = cv2.cvtColor(denoised, cv2.COLOR_BGR2LAB)
            lab[:, :, 0] = cv2.equalizeHist(lab[:, :, 0])
            enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
        else:
            enhanced = cv2.equalizeHist(denoised)

        return enhanced

    def remove_vegetation(self, image):
        """
        Remove vegetation based on HSV color space green mask
        """
        hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

        # Define HSV range for green colors (vegetation)
        lower_green1 = np.array([35, 25, 25])
        upper_green1 = np.array([85, 255, 255])

        # Create green mask
        green_mask = cv2.inRange(hsv, lower_green1, upper_green1)

        # Morphological operations to remove noise
        green_mask = cv2.morphologyEx(green_mask, cv2.MORPH_CLOSE, self.morph_kernel)
        green_mask = cv2.morphologyEx(green_mask, cv2.MORPH_OPEN, self.morph_kernel)

        # Set vegetation areas to neutral gray
        result = image.copy()
        result[green_mask > 0] = [128, 128, 128]

        return result, green_mask

    def remove_vehicles_and_people(self, image):
        """
        Remove vehicles and people based on size and shape features
        """
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Edge detection
        edges = cv2.Canny(gray, 50, 150)

        # Find contours
        contours, _ = cv2.findContours(
            edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        # Create mask to remove small objects (vehicles and people)
        mask = np.zeros(gray.shape, dtype=np.uint8)

        for contour in contours:
            area = cv2.contourArea(contour)
            # Filter small objects based on area (vehicles and people are usually smaller)
            # This threshold needs to be adjusted based on actual image size
            if 100 < area < 5000:  # Adjustable parameter
                x, y, w, h = cv2.boundingRect(contour)
                aspect_ratio = w / float(h)

                # Vehicles usually have specific aspect ratios
                if 0.5 < aspect_ratio < 3.0:
                    cv2.drawContours(mask, [contour], -1, 255, -1)

        # Set detected vehicle/people areas to background color
        result = image.copy()
        result[mask > 0] = [128, 128, 128]

        return result, mask

    def extract_buildings(self, image):
        """
        Building extraction based on texture and structural features
        """
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Use Gabor filter to detect building textures
        def apply_gabor_filter(img):
            filters = []
            for theta in range(0, 180, 45):
                kernel = cv2.getGaborKernel(
                    (21, 21),
                    5,
                    np.radians(theta),
                    2 * np.pi * 0.5,
                    0.5,
                    0,
                    ktype=cv2.CV_32F,
                )
                filtered = cv2.filter2D(img, cv2.CV_8UC3, kernel)
                filters.append(filtered)
            return np.maximum.reduce(filters)

        # Apply Gabor filter
        gabor_result = apply_gabor_filter(gray)

        # Edge detection
        edges = cv2.Canny(gray, 30, 100)

        # Combine Gabor and edge information
        building_features = cv2.addWeighted(gabor_result, 0.6, edges, 0.4, 0)

        # Thresholding
        _, building_mask = cv2.threshold(
            building_features, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
        )

        # Morphological operations to fill building interiors
        building_mask = cv2.morphologyEx(
            building_mask,
            cv2.MORPH_CLOSE,
            cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15)),
        )

        return building_mask

    def register_images(self, img1, img2):
        """
        Image registration using feature point matching for image alignment
        """
        # Convert to grayscale
        gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
        gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)

        # Use ORB feature detector
        orb = cv2.ORB_create(nfeatures=1000)

        # Detect keypoints and descriptors
        kp1, des1 = orb.detectAndCompute(gray1, None)
        kp2, des2 = orb.detectAndCompute(gray2, None)

        if des1 is None or des2 is None:
            print(
                "Warning: Unable to extract sufficient feature points, returning original image"
            )
            return img2

        # Feature matching
        bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
        matches = bf.match(des1, des2)
        matches = sorted(matches, key=lambda x: x.distance)

        if len(matches) < 10:
            print("Warning: Too few matching points, returning original image")
            return img2

        # Extract matching points
        src_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
        dst_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)

        # Calculate homography matrix
        try:
            M, mask = cv2.findHomography(
                dst_pts, src_pts, cv2.RANSAC, 5.0, maxIters=5000
            )

            if M is not None:
                # Apply transformation
                height, width = img1.shape[:2]
                aligned_img2 = cv2.warpPerspective(img2, M, (width, height))
                return aligned_img2
            else:
                print(
                    "Warning: Unable to calculate homography matrix, returning original image"
                )
                return img2

        except Exception as e:
            print(f"Error during registration: {e}")
            return img2

    def detect_changes(self, img1, img2):
        """
        Change detection between two images
        """
        # Image registration
        print("Performing image registration...")
        aligned_img2 = self.register_images(img1, img2)

        # Preprocessing
        print("Preprocessing images...")
        processed_img1 = self.preprocess_image(img1)
        processed_img2 = self.preprocess_image(aligned_img2)

        # Remove vegetation
        print("Removing vegetation...")
        img1_no_veg, veg_mask1 = self.remove_vegetation(processed_img1)
        img2_no_veg, veg_mask2 = self.remove_vegetation(processed_img2)

        # Remove vehicles and people
        print("Removing vehicles and people...")
        img1_clean, vehicle_mask1 = self.remove_vehicles_and_people(img1_no_veg)
        img2_clean, vehicle_mask2 = self.remove_vehicles_and_people(img2_no_veg)

        # Extract buildings
        print("Extracting buildings...")
        building_mask1 = self.extract_buildings(img1_clean)
        building_mask2 = self.extract_buildings(img2_clean)

        # Compute differences
        print("Detecting changes...")

        # Method 1: Building mask-based differences
        building_diff = cv2.absdiff(building_mask1, building_mask2)

        # Method 2: Grayscale-based differences (within building areas)
        gray1 = cv2.cvtColor(img1_clean, cv2.COLOR_BGR2GRAY)
        gray2 = cv2.cvtColor(img2_clean, cv2.COLOR_BGR2GRAY)

        # Calculate differences only in building areas
        combined_building_mask = cv2.bitwise_or(building_mask1, building_mask2)
        gray_diff = cv2.absdiff(gray1, gray2)
        gray_diff = cv2.bitwise_and(gray_diff, combined_building_mask)

        # Combine two types of differences
        combined_diff = cv2.addWeighted(building_diff, 0.6, gray_diff, 0.4, 0)

        # Thresholding to get change areas
        _, change_mask = cv2.threshold(
            combined_diff, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
        )

        # Morphological operations to remove noise
        change_mask = cv2.morphologyEx(
            change_mask,
            cv2.MORPH_OPEN,
            cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)),
        )
        change_mask = cv2.morphologyEx(
            change_mask,
            cv2.MORPH_CLOSE,
            cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)),
        )

        # Filter small change areas - only keep significant large area changes
        change_mask = self.filter_small_changes(
            change_mask, min_area=2000
        )  # Much higher threshold for large changes only

        return change_mask, aligned_img2

    def filter_small_changes(self, change_mask, min_area=2000):
        """
        Filter out small change areas, keep only large significant building changes
        """
        # Find connected components
        num_labels, labels = cv2.connectedComponents(change_mask)

        # Create filtered mask
        filtered_mask = np.zeros_like(change_mask)

        for i in range(1, num_labels):
            # Calculate area of each connected component
            area = np.sum(labels == i)
            if area >= min_area:
                filtered_mask[labels == i] = 255

        return filtered_mask

    def create_result_image(self, img2, change_mask):
        """
        Create result image: annotate change areas on the second image
        """
        # Create result image based on the second image
        result_image = img2.copy()

        # Find contours of change areas
        contours, _ = cv2.findContours(
            change_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        # Annotate change areas on result image
        for i, contour in enumerate(contours):
            # Draw red contour
            cv2.drawContours(result_image, [contour], -1, (0, 0, 255), 4)

            # Calculate bounding rectangle of contour
            x, y, w, h = cv2.boundingRect(contour)

            # Add label to change area
            label = f"Change Area {i+1}"
            label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1.0, 3)[0]

            # Draw label background
            cv2.rectangle(
                result_image,
                (x, y - label_size[1] - 15),
                (x + label_size[0] + 15, y),
                (0, 0, 255),
                -1,
            )

            # Draw label text
            cv2.putText(
                result_image,
                label,
                (x + 7, y - 8),
                cv2.FONT_HERSHEY_SIMPLEX,
                1.0,
                (255, 255, 255),
                3,
            )

        return result_image

    def visualize_results(self, img2, change_mask, save_path=None):
        """
        Simplified result visualization: only show the final result image
        """
        # Create result image
        result_image = self.create_result_image(img2, change_mask)

        # Display single result image
        plt.figure(figsize=(12, 8))
        plt.imshow(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
        plt.title("Detected Illegal Building Changes", fontsize=16, fontweight="bold")
        plt.axis("off")

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches="tight")

        plt.show()

        return result_image

    def process_images(self, image1_path, image2_path, output_path=None):
        """
        Process two images to detect building changes
        """
        # Read images
        print(f"Reading images: {image1_path}, {image2_path}")
        img1 = cv2.imread(image1_path)
        img2 = cv2.imread(image2_path)

        if img1 is None or img2 is None:
            raise ValueError("Unable to read image files")

        # Detect changes
        change_mask, aligned_img2 = self.detect_changes(img1, img2)

        # Visualize results
        result_image = self.visualize_results(aligned_img2, change_mask, output_path)

        # Statistical information about changes
        total_pixels = change_mask.shape[0] * change_mask.shape[1]
        changed_pixels = np.sum(change_mask > 0)
        change_percentage = (changed_pixels / total_pixels) * 100

        print(f"Detection completed!")
        print(f"Changed pixels: {changed_pixels}")
        print(f"Change percentage: {change_percentage:.2f}%")

        # Count number of change areas
        num_labels, _ = cv2.connectedComponents(change_mask)
        print(f"Detected {num_labels-1} potential illegal building areas")

        return change_mask, result_image


# Usage example
def main():
    """
    Main function example
    """
    detector = BuildingChangeDetector()

    # Process images (replace with actual image paths)
    image1_path = "./images/image_1.jpg"  # First image path
    image2_path = "./images/image_2.jpg"  # Second image path
    output_path = "./images/illegal_building_detection_result.png"  # Result save path

    try:
        change_mask, result_image = detector.process_images(
            image1_path, image2_path, output_path
        )

        # Save results
        if output_path:
            cv2.imwrite("change_mask.png", change_mask)
            cv2.imwrite("final_result_with_annotations.png", result_image)
            print(f"Results saved to: {output_path}")
            print("Change mask saved as: change_mask.png")
            print("Final annotated result saved as: final_result_with_annotations.png")

    except Exception as e:
        print(f"Error during processing: {e}")


if __name__ == "__main__":
    main()
