// Copyright (C) 2024 Reveal AI
//
// SPDX-License-Identifier: MIT

import { Dispatch, SetStateAction } from 'react';

import { GenerateRequest, InvokeResponse, SessionMessage } from '@/types';

type GenerateHook = (
    webSocket: WebSocket,
    generateRequest: GenerateRequest,
    lastSessionMessage: SessionMessage,
    setSessionMessages: Dispatch<SetStateAction<SessionMessage[] | []>>
) => Promise<void>;

const useGenerate = (): GenerateHook => {
    const generate: GenerateHook = (
        webSocket,
        generateRequest,
        lastSessionMessage,
        setSessionMessages
    ): Promise<void> => new Promise((resolve, reject) => {
        let fullText = '';

        // Send the request
        webSocket.send(JSON.stringify(generateRequest));

        // Handle incoming chunks
        webSocket.onmessage = (event) => {
            const response = JSON.parse(event.data as string) as InvokeResponse;
            if (response.type === 'error') {
                reject(new Error(response.data as string));
                webSocket.close();
            }
            if (response.type === 'chunk') {
                fullText += response.data as string;
                setSessionMessages((prev: SessionMessage[]) => prev.map((sessionMessage) => {
                    if (sessionMessage.id === lastSessionMessage.id) {
                        const updatedChatMessage: SessionMessage = {
                            ...sessionMessage,
                            answer: fullText,
                        }
                        return updatedChatMessage
                    }
                    return sessionMessage
                }));
            }
            if (response.type === 'message') {
                setSessionMessages((prev: SessionMessage[]) => prev.map((sessionMessage) => {
                    if (sessionMessage.id === lastSessionMessage.id) {
                        const updatedChatMessage: SessionMessage = response.data as SessionMessage;
                        return updatedChatMessage
                    }
                    return sessionMessage
                }));
                webSocket.close();
                resolve();
            }
        };

        webSocket.onerror = (error) => {
            reject(error);
            webSocket.close();
        };
    });

    return generate;
};

export default useGenerate;
