// Copyright (C) 2024 Reveal AI
//
// SPDX-License-Identifier: MIT

import React, { FC, useContext } from 'react';
import Grid from '@mui/material/Grid2';
import AddIcon from '@mui/icons-material/Add';
import Skeleton from '@mui/material/Skeleton';
import { useTheme } from '@mui/material/styles';
import { useWatch } from 'react-hook-form'
import {
    Form, SelectInput, SaveButton, CreateBase,
    required, AutocompleteInput,
} from 'react-admin';

import { Session, PromptStatus, AssistantStatus } from '@/types';
import { GlobalContext } from '@/context';
import { ModelSelectInput } from '@/provider/model';
import { PromptSelectInput } from '@/prompt/form';
import { AssistantSelectInput } from '@/assistant/form';

const SelectRelated: FC = () => {
    const { type } = useWatch();

    if (!type) {
        return (
            <Skeleton animation='wave'>
                <AutocompleteInput source='related_model.model_name' margin='none' />
            </Skeleton>
        );
    }

    if (type === 'llm') {
        return (
            <ModelSelectInput
                disabled={!type}
                label={false}
                variant='outlined'
                validate={required()}
                filter={{ model_type: 'llm', valid_only: true }}
            />
        );
    }

    if (type === 'prompt') {
        return (
            <PromptSelectInput
                source='related_prompt'
                label={false}
                disabled={!type}
                validate={required()}
                filter={{ status: PromptStatus.PUBLISHED }}
                sort={{ field: 'group__name', order: 'ASC' }}
            />
        );
    }

    return (
        <AssistantSelectInput
            source='related_assistant'
            label={false}
            disabled={!type}
            validate={required()}
            filter={{ status: AssistantStatus.PUBLISHED }}
            sort={{ field: 'group__name', order: 'ASC' }}
        />
    );
}

type SessionCreateProps = object

const SessionCreate: FC<SessionCreateProps> = () => {
    const { sessionTypes } = useContext(GlobalContext);
    const theme = useTheme();

    const transform = ({
        name,
        related_model,
        related_prompt,
        related_assistant,
        ...data
    }: Session): Session => ({
        name: 'New Chat',
        related_model: data.type === 'llm' ? related_model : undefined,
        related_prompt: data.type === 'prompt' ? related_prompt : undefined,
        related_assistant: data.type === 'assistant' ? related_assistant : undefined,
        ...data,
    });
    return (
        <CreateBase
            redirect='show'
            resource='sessions'
            transform={transform}
        >
            <Form>
                <Grid container rowSpacing={{ xs: 0 }} columnSpacing={4}>
                    <Grid
                        size={{
                            xs: 12, sm: 12, md: 2
                        }}
                    >
                        {
                            sessionTypes.length > 0 ? (
                                <SelectInput
                                    source='type'
                                    label={false}
                                    choices={sessionTypes}
                                    defaultValue={sessionTypes[0]?.id}
                                    validate={required()}
                                    variant='outlined'
                                    margin='none'
                                />
                            ) : (
                                <Skeleton animation='wave'><SelectInput source='type' /></Skeleton>
                            )
                        }
                    </Grid>
                    <Grid
                        size={{
                            xs: 12, sm: 12, md: 5
                        }}
                    >
                        <SelectRelated />
                    </Grid>
                    <Grid
                        size={{
                            xs: 12, sm: 12, md: 4
                        }}
                        sx={{
                            [theme.breakpoints.down('sm')]: {
                                textAlign: 'center',
                            }
                        }}
                    >
                        <SaveButton
                            label='label.session.new'
                            icon={<AddIcon />}
                        />
                    </Grid>
                </Grid>
            </Form>
        </CreateBase>
    );
};

export default SessionCreate;
