import cv2
import numpy as np
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from skimage import segmentation, measure
import os


class BuildingChangeDetector:
    def __init__(self):
        """
        Building-Focused Change Detection System
        Specialized for detecting building area changes and color changes while ignoring roads and trees
        """
        self.kernel_size = 7
        self.morph_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))

    def preprocess_image(self, image):
        """
        Image preprocessing: noise reduction and contrast enhancement
        """
        # Gaussian filtering for noise reduction
        denoised = cv2.GaussianBlur(image, (5, 5), 0)

        # Histogram equalization for contrast enhancement
        if len(image.shape) == 3:
            # Convert to LAB color space
            lab = cv2.cvtColor(denoised, cv2.COLOR_BGR2LAB)
            lab[:, :, 0] = cv2.equalizeHist(lab[:, :, 0])
            enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
        else:
            enhanced = cv2.equalizeHist(denoised)

        return enhanced

    def remove_roads_comprehensive(self, image):
        """
        Comprehensive road removal: gray, gray+black, gray+yellow roads
        """
        hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)

        # Road color detection in multiple color spaces
        road_mask = np.zeros(image.shape[:2], dtype=np.uint8)

        # 1. Gray roads (various shades)
        gray_ranges_bgr = [
            ([60, 60, 60], [140, 140, 140]),  # Medium gray
            ([40, 40, 40], [100, 100, 100]),  # Dark gray
            ([100, 100, 100], [180, 180, 180]),  # Light gray
            ([20, 20, 20], [70, 70, 70]),  # Very dark gray/black
        ]

        for lower, upper in gray_ranges_bgr:
            mask = cv2.inRange(image, np.array(lower), np.array(upper))
            road_mask = cv2.bitwise_or(road_mask, mask)

        # 2. Low saturation areas (roads typically have low saturation)
        low_saturation = hsv[:, :, 1] < 40
        low_sat_mask = np.uint8(low_saturation * 255)

        # 3. Yellow line detection for road markings
        yellow_lower = np.array([20, 100, 100])
        yellow_upper = np.array([30, 255, 255])
        yellow_mask = cv2.inRange(hsv, yellow_lower, yellow_upper)

        # Dilate yellow lines to include surrounding road area
        yellow_dilated = cv2.dilate(
            yellow_mask,
            cv2.getStructuringElement(cv2.MORPH_RECT, (20, 20)),
            iterations=2,
        )

        # Combine all road detection methods
        combined_road_mask = cv2.bitwise_or(road_mask, low_sat_mask)
        combined_road_mask = cv2.bitwise_or(combined_road_mask, yellow_dilated)

        # Detect linear structures (roads are linear)
        # Create various directional kernels for road detection
        kernels = [
            cv2.getStructuringElement(cv2.MORPH_RECT, (40, 5)),  # Horizontal roads
            cv2.getStructuringElement(cv2.MORPH_RECT, (5, 40)),  # Vertical roads
            np.array(
                [
                    [1, 0, 0, 0, 0, 0, 1],
                    [0, 1, 0, 0, 0, 1, 0],
                    [0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 0, 0, 0],
                    [0, 0, 1, 0, 1, 0, 0],
                    [0, 1, 0, 0, 0, 1, 0],
                    [1, 0, 0, 0, 0, 0, 1],
                ],
                dtype=np.uint8,
            ),  # Diagonal roads
        ]

        linear_structures = np.zeros_like(combined_road_mask)
        for kernel in kernels:
            linear_result = cv2.morphologyEx(combined_road_mask, cv2.MORPH_OPEN, kernel)
            linear_structures = cv2.bitwise_or(linear_structures, linear_result)

        # Final road mask
        final_road_mask = cv2.bitwise_or(combined_road_mask, linear_structures)

        # Clean up road mask
        kernel_close = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 15))
        final_road_mask = cv2.morphologyEx(
            final_road_mask, cv2.MORPH_CLOSE, kernel_close
        )

        # Apply road mask to image
        result = image.copy()
        result[final_road_mask > 0] = [128, 128, 128]

        return result, final_road_mask

    def remove_vegetation_advanced(self, image):
        """
        Advanced tree and vegetation removal using multiple techniques
        """
        hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)

        # Multiple green ranges for comprehensive vegetation detection
        vegetation_mask = np.zeros(hsv.shape[:2], dtype=np.uint8)

        # Green detection in HSV
        green_ranges = [
            ([30, 40, 20], [90, 255, 255]),  # All green shades
            ([25, 30, 15], [95, 255, 240]),  # Extended green range
            ([35, 20, 10], [85, 255, 200]),  # Dark green (tree shadows)
            ([40, 60, 40], [80, 255, 255]),  # Bright green (sunlit leaves)
        ]

        for lower, upper in green_ranges:
            mask = cv2.inRange(hsv, np.array(lower), np.array(upper))
            vegetation_mask = cv2.bitwise_or(vegetation_mask, mask)

        # Use LAB color space for better green detection
        # 'a' channel in LAB is green-red axis
        a_channel = lab[:, :, 1]
        green_lab_mask = a_channel < 120  # Green areas have low 'a' values
        green_lab_mask = np.uint8(green_lab_mask * 255)

        # Combine HSV and LAB vegetation detection
        vegetation_mask = cv2.bitwise_or(vegetation_mask, green_lab_mask)

        # Texture-based vegetation detection (trees have organic, non-geometric patterns)
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Use bilateral filter to detect organic textures
        bilateral = cv2.bilateralFilter(gray, 15, 40, 40)
        texture_diff = cv2.absdiff(gray, bilateral)

        # Trees have high texture variation
        _, organic_texture = cv2.threshold(texture_diff, 25, 255, cv2.THRESH_BINARY)

        # Combine color and texture for final vegetation mask
        final_vegetation_mask = cv2.bitwise_and(vegetation_mask, organic_texture)

        # Morphological operations to clean up vegetation mask
        vegetation_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (20, 20))
        final_vegetation_mask = cv2.morphologyEx(
            final_vegetation_mask, cv2.MORPH_CLOSE, vegetation_kernel
        )
        final_vegetation_mask = cv2.morphologyEx(
            final_vegetation_mask,
            cv2.MORPH_OPEN,
            cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (10, 10)),
        )

        # Apply vegetation mask
        result = image.copy()
        result[final_vegetation_mask > 0] = [128, 128, 128]

        return result, final_vegetation_mask

    def extract_building_regions(self, image):
        """
        Extract building regions using architectural features and geometric patterns
        """
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

        # Building detection using multiple techniques

        # 1. Detect straight lines (building edges)
        edges = cv2.Canny(gray, 30, 150, apertureSize=3)
        lines = cv2.HoughLinesP(
            edges, 1, np.pi / 180, threshold=80, minLineLength=50, maxLineGap=20
        )

        # Create line mask for building edges
        line_mask = np.zeros_like(gray)
        if lines is not None:
            for line in lines:
                x1, y1, x2, y2 = line[0]
                cv2.line(line_mask, (x1, y1), (x2, y2), 255, 3)

        # 2. Corner detection for building corners
        corners = cv2.cornerHarris(gray, 2, 3, 0.04)
        corner_mask = np.zeros_like(gray)
        corner_mask[corners > 0.01 * corners.max()] = 255

        # 3. Rectangular structure detection using morphological operations
        rect_kernel_h = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 5))
        rect_kernel_v = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 25))

        rect_h = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, rect_kernel_h)
        rect_v = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, rect_kernel_v)
        rectangular_features = cv2.bitwise_or(rect_h, rect_v)

        # 4. Combine all building features
        building_features = cv2.bitwise_or(line_mask, corner_mask)
        building_features = cv2.bitwise_or(building_features, rectangular_features)

        # 5. Create building regions by filling enclosed areas
        building_mask = cv2.morphologyEx(
            building_features,
            cv2.MORPH_CLOSE,
            cv2.getStructuringElement(cv2.MORPH_RECT, (30, 30)),
        )

        # Fill holes in building regions
        contours, _ = cv2.findContours(
            building_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )
        filled_mask = np.zeros_like(building_mask)

        for contour in contours:
            area = cv2.contourArea(contour)
            if area > 1000:  # Only keep large building-like areas
                cv2.fillPoly(filled_mask, [contour], 255)

        return filled_mask

    def detect_building_changes(self, img1, img2, building_mask1, building_mask2):
        """
        Detect both area changes and color changes in buildings
        """
        # 1. AREA CHANGES: Detect new buildings or building extensions
        area_diff = cv2.absdiff(building_mask1, building_mask2)

        # 2. COLOR CHANGES: Detect building renovations, repainting, etc.
        # Convert to multiple color spaces for comprehensive color change detection

        # HSV color change detection
        hsv1 = cv2.cvtColor(img1, cv2.COLOR_BGR2HSV)
        hsv2 = cv2.cvtColor(img2, cv2.COLOR_BGR2HSV)

        hsv_diff = cv2.absdiff(hsv1, hsv2)
        hsv_change = cv2.cvtColor(hsv_diff, cv2.COLOR_HSV2BGR)
        hsv_change_gray = cv2.cvtColor(hsv_change, cv2.COLOR_BGR2GRAY)

        # LAB color change detection (better for perceptual color differences)
        lab1 = cv2.cvtColor(img1, cv2.COLOR_BGR2LAB)
        lab2 = cv2.cvtColor(img2, cv2.COLOR_BGR2LAB)

        lab_diff = cv2.absdiff(lab1, lab2)
        lab_change = cv2.cvtColor(lab_diff, cv2.COLOR_LAB2BGR)
        lab_change_gray = cv2.cvtColor(lab_change, cv2.COLOR_BGR2GRAY)

        # BGR direct color change
        bgr_diff = cv2.absdiff(img1, img2)
        bgr_change_gray = cv2.cvtColor(bgr_diff, cv2.COLOR_BGR2GRAY)

        # Combine color changes from different color spaces
        color_change = cv2.addWeighted(hsv_change_gray, 0.4, lab_change_gray, 0.4, 0)
        color_change = cv2.addWeighted(color_change, 0.8, bgr_change_gray, 0.2, 0)

        # Only consider color changes within building areas
        combined_building_mask = cv2.bitwise_or(building_mask1, building_mask2)
        building_color_change = cv2.bitwise_and(color_change, combined_building_mask)

        # 3. COMBINE AREA AND COLOR CHANGES
        # Area changes are more important than color changes
        final_change = cv2.addWeighted(area_diff, 0.7, building_color_change, 0.3, 0)

        # Threshold for significant changes only
        _, change_mask = cv2.threshold(final_change, 50, 255, cv2.THRESH_BINARY)

        return change_mask

    def register_images(self, img1, img2):
        """
        Image registration using feature point matching for image alignment
        """
        # Convert to grayscale
        gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
        gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)

        # Use ORB feature detector with more features
        orb = cv2.ORB_create(nfeatures=2000)

        # Detect keypoints and descriptors
        kp1, des1 = orb.detectAndCompute(gray1, None)
        kp2, des2 = orb.detectAndCompute(gray2, None)

        if des1 is None or des2 is None:
            print(
                "Warning: Unable to extract sufficient feature points, returning original image"
            )
            return img2

        # Feature matching
        bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
        matches = bf.match(des1, des2)
        matches = sorted(matches, key=lambda x: x.distance)

        if len(matches) < 20:
            print("Warning: Too few matching points, returning original image")
            return img2

        # Extract matching points (use more matches for better registration)
        good_matches = matches[: min(len(matches), 100)]
        src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(
            -1, 1, 2
        )
        dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(
            -1, 1, 2
        )

        # Calculate homography matrix with better parameters
        try:
            M, mask = cv2.findHomography(
                dst_pts, src_pts, cv2.RANSAC, 3.0, maxIters=10000, confidence=0.99
            )

            if M is not None:
                # Apply transformation
                height, width = img1.shape[:2]
                aligned_img2 = cv2.warpPerspective(img2, M, (width, height))
                return aligned_img2
            else:
                print(
                    "Warning: Unable to calculate homography matrix, returning original image"
                )
                return img2

        except Exception as e:
            print(f"Error during registration: {e}")
            return img2

    def filter_building_changes_only(self, change_mask, min_area=5000):
        """
        Filter to keep only significant building changes, ignore tiny changes
        """
        # Find connected components
        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
            change_mask
        )

        # Create filtered mask
        filtered_mask = np.zeros_like(change_mask)

        valid_changes = 0

        for i in range(1, num_labels):
            area = stats[i, cv2.CC_STAT_AREA]

            if area >= min_area:
                # Additional geometric filtering for building-like shapes
                component_mask = (labels == i).astype(np.uint8) * 255
                contours, _ = cv2.findContours(
                    component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
                )

                for contour in contours:
                    # Shape analysis for building characteristics
                    area_contour = cv2.contourArea(contour)

                    if area_contour >= min_area:
                        # Calculate shape properties
                        perimeter = cv2.arcLength(contour, True)

                        if perimeter > 0:
                            # Rectangularity check (buildings are more rectangular)
                            x, y, w, h = cv2.boundingRect(contour)
                            rect_area = w * h
                            extent = area_contour / rect_area if rect_area > 0 else 0

                            # Aspect ratio (buildings have reasonable proportions)
                            aspect_ratio = max(w, h) / min(w, h) if min(w, h) > 0 else 0

                            # Accept if it looks like a building change
                            if (
                                extent > 0.3
                                and aspect_ratio < 8.0  # Reasonably rectangular
                                and aspect_ratio > 0.3  # Not too elongated
                                and area_contour > min_area  # Not too thin
                            ):  # Large enough

                                filtered_mask[labels == i] = 255
                                valid_changes += 1
                                break

        print(f"Found {valid_changes} significant building changes after filtering")
        return filtered_mask

    def process_images(self, image1_path, image2_path, output_path=None):
        """
        Main processing pipeline focused on building changes only
        """
        # Read images
        print(f"Reading images: {image1_path}, {image2_path}")
        img1 = cv2.imread(image1_path)
        img2 = cv2.imread(image2_path)

        if img1 is None or img2 is None:
            raise ValueError("Unable to read image files")

        # Image registration
        print("Performing precise image registration...")
        aligned_img2 = self.register_images(img1, img2)

        # Preprocessing
        print("Preprocessing images...")
        processed_img1 = self.preprocess_image(img1)
        processed_img2 = self.preprocess_image(aligned_img2)

        # Remove roads comprehensively
        print("Removing roads (gray, gray+black, gray+yellow)...")
        img1_no_road, road_mask1 = self.remove_roads_comprehensive(processed_img1)
        img2_no_road, road_mask2 = self.remove_roads_comprehensive(processed_img2)

        # Remove vegetation/trees
        print("Removing trees and vegetation...")
        img1_clean, veg_mask1 = self.remove_vegetation_advanced(img1_no_road)
        img2_clean, veg_mask2 = self.remove_vegetation_advanced(img2_no_road)

        # Extract building regions
        print("Extracting building regions...")
        building_mask1 = self.extract_building_regions(img1_clean)
        building_mask2 = self.extract_building_regions(img2_clean)

        # Detect building changes (both area and color changes)
        print("Detecting building area and color changes...")
        change_mask = self.detect_building_changes(
            img1_clean, img2_clean, building_mask1, building_mask2
        )

        # Apply morphological operations for building-like shapes
        print("Applying building-specific filtering...")
        building_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (10, 10))
        change_mask = cv2.morphologyEx(
            change_mask,
            cv2.MORPH_OPEN,
            cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)),
        )
        change_mask = cv2.morphologyEx(change_mask, cv2.MORPH_CLOSE, building_kernel)

        # Filter for significant building changes only (ignore tiny changes)
        print("Filtering for significant building changes only...")
        final_change_mask = self.filter_building_changes_only(
            change_mask, min_area=5000
        )

        # Create result visualization
        result_image = self.create_result_image(aligned_img2, final_change_mask)

        # Display results
        self.visualize_results(result_image, output_path)

        # Statistics
        total_pixels = final_change_mask.shape[0] * final_change_mask.shape[1]
        changed_pixels = np.sum(final_change_mask > 0)
        change_percentage = (changed_pixels / total_pixels) * 100

        print(f"\n=== BUILDING CHANGE DETECTION RESULTS ===")
        print(f"Total building change pixels: {changed_pixels}")
        print(f"Building change percentage: {change_percentage:.3f}%")

        # Count number of building change areas
        num_labels, _ = cv2.connectedComponents(final_change_mask)
        print(f"Number of significant building changes detected: {num_labels-1}")

        return final_change_mask, result_image

    def create_result_image(self, img2, change_mask):
        """
        Create result image with building change annotations
        """
        result_image = img2.copy()

        # Find contours of change areas
        contours, _ = cv2.findContours(
            change_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
        )

        # Annotate each building change
        for i, contour in enumerate(contours):
            # Draw thick red outline
            cv2.drawContours(result_image, [contour], -1, (0, 0, 255), 5)

            # Calculate bounding rectangle
            x, y, w, h = cv2.boundingRect(contour)

            # Add building change label
            label = f"Building Change {i+1}"
            font_scale = 1.2
            thickness = 3
            label_size = cv2.getTextSize(
                label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
            )[0]

            # Draw label background
            cv2.rectangle(
                result_image,
                (x, y - label_size[1] - 20),
                (x + label_size[0] + 20, y),
                (0, 0, 255),
                -1,
            )

            # Draw label text
            cv2.putText(
                result_image,
                label,
                (x + 10, y - 10),
                cv2.FONT_HERSHEY_SIMPLEX,
                font_scale,
                (255, 255, 255),
                thickness,
            )

        return result_image

    def visualize_results(self, result_image, save_path=None):
        """
        Display the final result with building changes highlighted
        """
        plt.figure(figsize=(15, 10))
        plt.imshow(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
        plt.title(
            "Building Area and Color Changes Detection", fontsize=18, fontweight="bold"
        )
        plt.axis("off")

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches="tight")

        plt.show()

        return result_image


# Usage example
def main():
    """
    Main function for building-focused change detection
    """
    detector = BuildingChangeDetector()

    # Process images (replace with actual image paths)
    image1_path = "./images/image_1.jpg"  # Reference image
    image2_path = "./images/image_2.jpg"  # Current image
    output_path = "./images/building_changes_result.png"  # Result save path

    try:
        change_mask, result_image = detector.process_images(
            image1_path, image2_path, output_path
        )

        # Save results
        if output_path:
            cv2.imwrite("building_change_mask.png", change_mask)
            cv2.imwrite("building_changes_annotated.png", result_image)
            print(f"\nResults saved:")
            print(f"- Main result: {output_path}")
            print(f"- Change mask: building_change_mask.png")
            print(f"- Annotated result: building_changes_annotated.png")

    except Exception as e:
        print(f"Error during processing: {e}")


if __name__ == "__main__":
    main()
