import { useEffect, useState, useCallback } from "react";
import useWebSocket, { ReadyState } from "react-use-websocket";
import { useAccessToken } from "../_helpers/use-access-token";
import { getNewSessionId } from '../_services/websocket-session';

import config from "../config";

export type Message = {
  message_type: string;
  payload: Record<string, any>;
  error?: Record<string, any>;
} & Record<string, any>;

export type OnMessageCallback = (message: Message) => void;

export const useWebsocket = () => {
  const { check, accessToken, logout } = useAccessToken();
  const [socketUrl, setSocketUrl] = useState<string | null>(null);
  const [callbacks, setCallbacks] = useState<Record<string, OnMessageCallback>>({});
  const [defaultMessageCallback, setDefaultMessageCallback] = useState<OnMessageCallback | null>(null);

  const HEARTBEAT_TIMEOUT_MS = 60 * 1000; // 1 minute in milliseconds
  const HEARTBEAT_INTERVAL_MS = 14 * 60 * 1000; // 14 minutes in milliseconds

  const on = (messageTypes: string[], callback: OnMessageCallback) => {
    if (messageTypes.length === 0 && defaultMessageCallback === null) {
      setDefaultMessageCallback(callback)
    }
    for (const messageType of messageTypes) {
      if (callbacks[messageType]) {
        continue
      }
      setCallbacks((prev) => ({
        ...prev, [messageType]: (message) => {
          callback(message)
        }
      }));
    }
  };

  const off = (messageTypes: string[]) => {
    for (const messageType of messageTypes) {
      setCallbacks((prev) => {
        const newCallbacks = { ...prev };
        delete newCallbacks[messageType];
        return newCallbacks;
      });
    }
  }

  const unbindAll = () => {
    setCallbacks({});
    setDefaultMessageCallback(null);
  }

  const onMessage = (message: Message) => {
    let messageType: string = message.message_type ?? message.response_type;

    if (!messageType) {
      defaultMessageCallback && defaultMessageCallback(message);
      return;
    }

    const callback = callbacks[messageType];

    if (!callback) {
      return;
    }
    callback(message);
  }

  const { sendJsonMessage, readyState } = useWebSocket(
    socketUrl,
    {
      shouldReconnect: (_) => true,
      reconnectAttempts: 10,
      reconnectInterval: 2000,
      share: true,
      onMessage: (event) => {
        const message: Message = JSON.parse(event.data);
        onMessage(message);
      },

      heartbeat: {
        // send 'request_type' too for backwards compatibility with backend
        message: '{"message_type": "keep_alive", "request_type": "keep_alive"}',
        timeout: HEARTBEAT_TIMEOUT_MS,
        interval: HEARTBEAT_INTERVAL_MS,
      },
    },
    socketUrl != null,
  );


  const joinSession = useCallback(async (tries = 0) => {
    if (!socketUrl) {
      console.log('Attempting to join session, try:', tries);

      const token = await check();
      console.log('Token after check:', token);

      if (!token && tries < 5) {
        setTimeout(() => joinSession(tries + 1), 2000 + 500 * tries);
        return;
      }
      if (token) {
        try {
          const { sessionId } = await getNewSessionId();
          const url = `wss://${config.backendHost}/ws/email_agent/${sessionId}/?token=${token}`
          console.log('Setting WebSocket URL:', url);
          setSocketUrl(url);
        } catch (error) {
          console.error("Failed to get session ID:", error);
          if (tries < 5) {
            setTimeout(() => joinSession(tries + 1), 2000 + 500 * tries);
          } else {
            console.error("Max retries reached for getting session ID");
            logout();
          }
        }
      } else {
        console.error("Failed to get token, logging out");
        logout();
      }
    }
  }, [check, socketUrl, logout]);

  const sendMessage = async (message: Message, retries = 5) => {
    await check();
    if (readyState !== ReadyState.OPEN) {
      if (retries > 0) {
        setTimeout(() => sendMessage(message, retries - 1), 1000 + (5 - retries) * 500);
      }
      return;
    }
    if (!message.request && message.payload) {
      message.request = message.payload
    }
    sendJsonMessage({
      ...message,
      // backwards compatibility with backend
      // send old 'request_type' along with new 'message_type'
      request_type: message.message_type,
    });
  }

  useEffect(() => {
    joinSession();
  }, [joinSession, accessToken]);

  return {
    isConnected: readyState === ReadyState.OPEN,
    isLoading: readyState === ReadyState.CONNECTING,
    joinSession,
    on,
    off,
    unbindAll,
    sendMessage,
  }
}
