// Copyright (C) 2024 Reveal AI
//
// SPDX-License-Identifier: MIT

import React, { FC, useContext } from 'react';
import Grid from '@mui/material/Grid2';
import Stack from '@mui/material/Stack';
import AddIcon from '@mui/icons-material/Add';
import Skeleton from '@mui/material/Skeleton';
import { useTheme } from '@mui/material/styles';
import { useWatch } from 'react-hook-form'
import {
    Form, SelectInput, ReferenceInput, SaveButton, CreateBase,
    required, AutocompleteInput,
} from 'react-admin';

import { Session } from '@/types';
import { GlobalContext } from '@/context';
import { ModelSelectInput } from '@/provider/model';

const SelectRelated: FC = () => {
    const { type } = useWatch();

    if (!type) {
        return (
            <Skeleton animation='wave'>
                <AutocompleteInput source='related_model.model_name' margin='none' />
            </Skeleton>
        );
    }

    if (type === 'llm') {
        return (
            <ModelSelectInput
                disabled={!type}
                label={false}
                variant='outlined'
                validate={required()}
                filter={{ model_type: 'llm', valid_only: true }}
            />
        );
    }

    return (
        <ReferenceInput
            source={`related_${type}`}
            reference={`${type}s`}
            filter={{ mode_type: type }}
            disabled={!type}
        >
            <AutocompleteInput
                label={false}
                variant='outlined'
                validate={required()}
                debounce={500}
                optionText='name'
                margin='none'
            />
        </ReferenceInput>
    );
}

type SessionCreateProps = object

const SessionCreate: FC<SessionCreateProps> = () => {
    const { sessionTypes } = useContext(GlobalContext);
    const theme = useTheme();

    const transform = ({
        name,
        related_model,
        related_prompt,
        related_assistant,
        ...data
    }: Session): Session => ({
        name: 'New Chat',
        related_model: data.type === 'llm' ? related_model : undefined,
        related_prompt: data.type === 'prompt' ? related_prompt : undefined,
        related_assistant: data.type === 'assistant' ? related_assistant : undefined,
        ...data,
    });
    return (
        <CreateBase
            redirect='show'
            resource='sessions'
            transform={transform}
        >
            <Form>
                <Grid container rowSpacing={{ xs: 0 }} columnSpacing={4}>
                    <Grid
                        size={{
                            xs: 12, sm: 8, md: 8, lg: 6, xl: 4
                        }}
                    >
                        <Stack direction='row' spacing={2}>
                            {
                                sessionTypes.length > 0 ? (
                                    <SelectInput
                                        source='type'
                                        label={false}
                                        choices={sessionTypes}
                                        defaultValue={sessionTypes[0]?.id}
                                        validate={required()}
                                        variant='outlined'
                                        margin='none'
                                    />
                                ) : (
                                    <Skeleton animation='wave'><SelectInput source='type' /></Skeleton>
                                )
                            }
                            <SelectRelated />
                        </Stack>
                    </Grid>
                    <Grid
                        size={{
                            xs: 12, sm: 4, md: 3, lg: 2
                        }}
                        sx={{
                            [theme.breakpoints.down('sm')]: {
                                textAlign: 'center',
                            }
                        }}
                    >
                        <SaveButton
                            label='label.session.new'
                            icon={<AddIcon />}
                        />
                    </Grid>
                </Grid>
            </Form>
        </CreateBase>
    );
};

export default SessionCreate;
