import cv2
from ultralytics import YOLO
from collections import deque
import paho.mqtt.client as mqtt
from influxdb import InfluxDBClient
from influxdb_client import InfluxDBClient, Point, WriteOptions
import time
from datetime import datetime

# InfluxDB Configuration
INFLUX_URL = "http://localhost:8086"
INFLUX_TOKEN = "--k98NX5UQ2qBCGAO80lLc_-teD-AUtKNj4uQfz0M8WyjHt04AT9d0dr6w8pup93ukw6YcJxWURmo2v6CAP_2g=="
INFLUX_ORG = "GAAIM"
INFLUX_BUCKET = "AGVIGNETTE"

# Connect to InfluxDB
client = InfluxDBClient(url=INFLUX_URL, token=INFLUX_TOKEN, org=INFLUX_ORG)
write_api = client.write_api(write_options=WriteOptions(batch_size=1))

# MQTT Setup
MQTT_BROKER = "192.168.10.51"
MQTT_TOPIC = "fruit/classification"

mqtt_client = mqtt.Client()
mqtt_client.connect(MQTT_BROKER, 1883, 6000)

# Camera index (default camera is 0)
camera_index = 0
i = 0

# Load the YOLO model
model = YOLO(r"/Users/vel/Desktop/CvModel/CV_AG/runs/detect/train5/weights/best.pt")  # Load custom model

# Initialize the camera
cap = cv2.VideoCapture(camera_index)
if not cap.isOpened():
    print("Unable to open the camera. Please check the device.")
    exit()

fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"Camera resolution: {width}x{height}, FPS: {fps} FPS")

# Define class labels
class_labels = {
    0: "Bruised",
    1: "DefectiveLemon",
    2: "GoodLemon",
    3: "NotRipeLemon",
    4: "Rotten"
}

# Classes that require ID assignment
id_tracked_classes = ["DefectiveLemon", "GoodLemon", "NotRipeLemon"]

# Parameters
HISTORY_LENGTH = 7  # Number of frames to consider for majority voting
CONFIRMATION_FRAMES = 7  # Frames needed to confirm a new label
lemon_time = 0

# Dictionary to track detection history and confirmed states
lemon_history = {}  # Format: {ID: deque(maxlen=HISTORY_LENGTH)}
lemon_states = {}  # Format: {ID: "State"}
lemon_duration = {}  # Format: {ID: {"current_label": str, "duration": int}}

# Set the window to be resizable
cv2.namedWindow("Live Detection", cv2.WINDOW_NORMAL)

# Function to update lemon label based on history (majority voting)
def update_lemon_label_with_majority(obj_id, current_label):
    if obj_id not in lemon_history:
        lemon_history[obj_id] = deque(maxlen=HISTORY_LENGTH)
    lemon_history[obj_id].append(current_label)

    # Perform majority voting
    most_common_label = max(set(lemon_history[obj_id]), key=lemon_history[obj_id].count)
    return most_common_label

# Function to update lemon state based on duration logic
def update_lemon_state_with_duration(obj_id, current_label):
    if obj_id not in lemon_duration:
        lemon_duration[obj_id] = {"current_label": current_label, "duration": 0}

    if lemon_duration[obj_id]["current_label"] == current_label:
        lemon_duration[obj_id]["duration"] += 1
    else:
        lemon_duration[obj_id] = {"current_label": current_label, "duration": 1}

    # Update state only if the new label persists for CONFIRMATION_FRAMES
    if lemon_duration[obj_id]["duration"] >= CONFIRMATION_FRAMES:
        return current_label
    return lemon_states.get(obj_id, current_label)

# Process video stream in real-time
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        print("Unable to read camera input. Terminating program.")
        break

    # Perform object tracking using BoT-SORT
    results = model.track(source=frame, conf=0.5, tracker='botsort.yaml', show=False, device = 'mps')

    for result in results:
        frame = result.orig_img  # Current frame
        detections = result.boxes  # Detection box information

        for box in detections:
            x1, y1, x2, y2 = map(int, box.xyxy[0])  # Detection box coordinates
            obj_id = int(box.id) if box.id is not None else -1  # Tracking object ID
            class_id = int(box.cls)  # Class ID
            score = box.conf  # Confidence score
            label = class_labels.get(class_id, "Unknown")  # Get class label

            # Assign ID only to tracked classes
            if label in id_tracked_classes and obj_id != -1:
                # Update label with majority voting
                majority_label = update_lemon_label_with_majority(obj_id, label)

                # Update final state based on duration logic
                final_label = update_lemon_state_with_duration(obj_id, majority_label)

                # Store the confirmed state
                lemon_states[obj_id] = final_label

                display_text = f"ID {obj_id} | {final_label}"
            else:
                # For untracked labels, just display the label
                display_text = label

            # Draw detection boxes and labels
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(frame, display_text, (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
            
            # Create Decision Point at x = 600
            if x1 > 100:
                cv2.line(frame, (600, 0), (600, height), (255, 0, 0), 2)
            # Create Decision Point at x = 670
            if x1 > 100:
                cv2.line(frame, (760, 0), (760, height), (255, 0, 0), 2)
                cv2.putText(frame, "Decision Point", (630, height // 2),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
            
            # Lock in the label once it crosses the decision point
            if x1 > 700 and obj_id in lemon_states:
                cv2.putText(frame, f"Locked: {lemon_states[obj_id]}", (x1, y1 - 40),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
            else:
                cv2.putText(frame, "Waiting to Lock", (x1, y1 - 40),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2)
            
            if x1 > 600 and x1 < 780:
                if final_label == "DefectiveLemon":
                    mqtt_message = f"lemon_classification classification=\"{final_label}\" {int(time.time()*1e9)}"
                    if time.time() - lemon_time > .3:
                        mqtt_client.publish(MQTT_TOPIC, mqtt_message)
                        lemon_time = time.time()
                        i = i + 1
                elif final_label == "NotRipeLemon":
                    mqtt_message = f"lemon_classification classification=\"{final_label}\" {int(time.time()*1e9)}"
                    if time.time() - lemon_time > .3:
                        mqtt_client.publish(MQTT_TOPIC, mqtt_message)
                        lemon_time = time.time()
                        i = i + 1
                elif final_label == "GoodLemon":
                    mqtt_message = f"lemon_classification classification=\"{final_label}\" {int(time.time()*1e9)}"
                    if time.time() - lemon_time > .3:
                        mqtt_client.publish(MQTT_TOPIC, mqtt_message)
                        lemon_time = time.time()
                        i = i + 1

    # Display the processed video stream
    cv2.imshow("Live Detection", frame)

    # Exit the loop when ESC key is pressed
    if cv2.waitKey(1) & 0xFF == 27:  # 27 is the ASCII value for ESC key
        print("ESC key detected. Exiting the program.")
        break

# Release resources
cap.release()
cv2.destroyAllWindows()
print("Camera video processing complete. Program terminated.")