import random
import pandas as pd
from geopy.distance import geodesic
import h3
import psutil
import mysql.connector  # Import mysql.connector for MySQL connection
from mysql.connector import Error


def process_dataframe(df: pd.DataFrame, lat: float, long: float) -> pd.DataFrame:
    # -------------------- Maps Integration & H3 Aggregation -------------------- #

    # -------------------- Retrieve Nearby Listings -------------------- #
    def find_nearest_entry(dataframe, lat, long):
        dataframe["distance"] = dataframe.apply(
            lambda row: geodesic((lat, long), (row["lat"], row["long"])).meters, axis=1
        )
        min_distance = dataframe["distance"].min()
        nearest_entries = dataframe[dataframe["distance"] == min_distance]
        # If multiple entries have the same distance, sort by color_rank descending and take the first
        nearest_entry = nearest_entries.sort_values("color_rank", ascending=False).iloc[0]
        return nearest_entry, min_distance

    def get_hexagon_entries(dataframe, h3_index):
        return dataframe[dataframe["h3_index"] == h3_index].copy()

    def get_nearby_entries(dataframe, h3_index, target_color_rank, current_count, max_needed):
        collected_entries = pd.DataFrame()
        k = 1  # start with direct neighbors
        max_search_k = 10  # limit to 10 rings

        while len(collected_entries) + current_count < max_needed and k <= max_search_k:
            nearby_hexagons = list(h3.k_ring(h3_index, k))
            nearby_df = dataframe[dataframe["h3_index"].isin(nearby_hexagons)].copy()

            if not nearby_df.empty:
                nearby_df["color_rank_diff"] = abs(nearby_df["color_rank"] - target_color_rank)
                nearby_df = nearby_df.sort_values(["color_rank", "color_rank_diff"], ascending=[False, True])
                additional_needed = max_needed - (len(collected_entries) + current_count)
                collected_entries = pd.concat([collected_entries, nearby_df.head(additional_needed)])

            k += 1  # expand search range
        return collected_entries.head(max_needed - current_count)

    def retrieve_entries(dataframe, lat, long):
        nearest_entry, distance = find_nearest_entry(dataframe, lat, long)
        h3_index = nearest_entry["h3_index"]
        color_rank = nearest_entry["color_rank"]

        # Get all entries within the main hexagon
        hexagon_entries = get_hexagon_entries(dataframe, h3_index)

        # Dynamically fetch additional entries until we have the desired count (set here as 200)
        while len(hexagon_entries) < 200:
            required_entries = 200 - len(hexagon_entries)
            additional_entries = get_nearby_entries(dataframe, h3_index, color_rank, len(hexagon_entries), 200)
            if additional_entries.empty:
                break
            hexagon_entries = pd.concat([hexagon_entries, additional_entries]).head(200)

        # Print some information about the nearest entry and hexagon
        hex_centroid = h3.h3_to_geo(h3_index)
        print(f"Distance to nearest entry: {distance:.2f} meters")
        print(f"Hexagon centroid: {hex_centroid}")
        print(f"Final dataset size: {len(hexagon_entries)}")

        return hexagon_entries

    # Retrieve entries based on the given lat-long
    final_df = retrieve_entries(df, lat, long)
    print(final_df.shape)

    # # Optionally save the final dataframe to CSV:
    # final_df.to_csv("doesit.csv", index=False)
    # print("general_processed.csv has been saved successfully!")

    return final_df



# try:
#     connection = mysql.connector.connect(
#         host='127.0.0.1',
#         user='root',
#         password='',  # Using XAMPP default: no password
#         database='pricing'
#     )
#     if connection.is_connected():
#         print("Connected to the MySQL database for rental_properties data")
#         # Read data from the 'general' table
#         general = pd.read_sql_query("SELECT * FROM rental_properties", connection)
#
#     else:
#         raise Exception("Failed to connect to the MySQL database")
# except Error as e:
#     print(f"Error connecting to MySQL: {e}")
#     raise e
# finally:
#     if connection.is_connected():
#         connection.close()
#         print("MySQL connection closed")
#
#
# final_df = process_dataframe(general, lat=42.1713, long=-73.9698)
# print(final_df.dtypes)
# print(general.dtypes)