import { gql, useMutation } from '@apollo/client';
import { LoadingButton } from '@mui/lab';
import { Box, Modal, Typography } from '@mui/material';
import { showMessageAction } from '@theora360/shared/src/redux-store/snackbar-store';
import { generateUuid } from '@theora360/shared/src/shared/utils';
import React, { useCallback, useEffect, useState } from 'react';
import { useForm } from 'react-hook-form';
import { FormContainer, TextFieldElement } from 'react-hook-form-mui';
import { useDispatch } from 'react-redux';

const trainNeuralNetMutation = gql`
  mutation TrainNeuralNet(
    $_id: ID!
    $savedModelId: ID!
    $savedModelName: String!
  ) {
    trainNeuralNet(
      _id: $_id
      savedModelId: $savedModelId
      savedModelName: $savedModelName
    )
  }
`;

function StartTrainingModal({ visible, onCancel, onFinish }) {
  const { neuralNet } = visible || {};
  const [loading, setLoading] = useState(false);
  const [trainNeuralNet] = useMutation(trainNeuralNetMutation);
  const form = useForm();

  useEffect(() => {
    form.reset();
  }, [visible, form]);

  const dispatch = useDispatch();
  const onSubmit = useCallback(
    async (_values) => {
      const { ...values } = _values;
      setLoading(true);
      try {
        const savedModelId = generateUuid();
        await trainNeuralNet({
          variables: {
            _id: neuralNet._id,
            savedModelId,
            ...values,
          },
        });

        dispatch(
          showMessageAction({
            _id: 'start-training',
            severity: 'success',
            message: 'Neural network training started!',
          }),
        );
        onFinish();
      } catch (err) {
        console.error(err);
        dispatch(
          showMessageAction({
            _id: 'start-training',
            severity: 'error',
            message: 'There was an error starting training',
          }),
        );
      }
      setLoading(false);
    },
    [trainNeuralNet, dispatch, onFinish, neuralNet],
  );

  return (
    <Modal
      open={!!visible}
      onClose={onCancel}
      sx={{
        overflow: 'scroll',
        p: 4,
      }}
    >
      <Box
        sx={{
          backgroundColor: 'white',
          textAlign: 'center',
          borderRadius: 1,
          padding: 4,
          marginTop: 4,
          marginBottom: 4,
          width: 500,
          marginLeft: 'auto',
          marginRight: 'auto',
        }}
      >
        <Typography variant="h4" sx={{ mb: 3 }}>
          Start Training
        </Typography>
        <Typography variant="h6">Neural Network</Typography>
        <Typography sx={{ mb: 2 }}>{neuralNet?.name}</Typography>
        <FormContainer formContext={form} onSuccess={onSubmit}>
          <TextFieldElement
            style={{ marginBottom: 32, width: '100%' }}
            variant="standard"
            name="savedModelName"
            label="Saved Model Name"
            helperText="Choose a name for the result saved model from this training session."
            type="text"
            disabled={loading}
            required
          />
          <div style={{ height: 16 }} />
          <Box>
            <LoadingButton type="submit" variant="contained" loading={loading}>
              Start
            </LoadingButton>
          </Box>
        </FormContainer>
      </Box>
    </Modal>
  );
}

export default StartTrainingModal;
