import { gql, useMutation } from '@apollo/client';
import RocketLaunchIcon from '@mui/icons-material/RocketLaunch';
import { LoadingButton } from '@mui/lab';
import { Box, Typography } from '@mui/material';
import { showMessageAction } from '@theora360/shared/src/redux-store/snackbar-store';
import React, { useCallback, useMemo, useState } from 'react';
import { useDispatch } from 'react-redux';
import StartTrainingModal from './StartTrainingModal';
import TrainLogsTerm from './TrainLogsTerm';
import TrainingLogsPlot from './TrainingLogsPlot';
import { NeuralNetTrainingStatuses } from './constants';

const cancelTrainNeuralNetMutation = gql`
  mutation cancelTrainNeuralNet($_id: ID!) {
    cancelTrainNeuralNet(_id: $_id)
  }
`;

function NetTrainTab({ neuralNet, initTraining, status }) {
  const [starting, setStarting] = useState(false);

  const onTrain = () => setStarting({ neuralNet });
  const onCancelTrain = () => setStarting(undefined);
  const onFinishTrain = () => {
    initTraining();
    setStarting(undefined);
  };

  const [loading, setLoading] = useState(false);
  const [cancelTrainNeuralNet] = useMutation(cancelTrainNeuralNetMutation);
  const dispatch = useDispatch();
  const onStop = useCallback(async () => {
    setLoading(true);
    try {
      await cancelTrainNeuralNet({
        variables: {
          _id: neuralNet._id,
        },
      });

      dispatch(
        showMessageAction({
          _id: 'stop-training',
          severity: 'success',
          message: 'Neural network training stopped!',
        }),
      );
    } catch (err) {
      console.error(err);
      dispatch(
        showMessageAction({
          _id: 'stop-training',
          severity: 'error',
          message: 'There was an error stopping training',
        }),
      );
    }
    setLoading(false);
  }, [cancelTrainNeuralNet, dispatch, neuralNet]);

  const showStart = useMemo(() => {
    return ['STOPPED', undefined].includes(status.status);
  }, [status]);

  return (
    <>
      <Box sx={{ mt: 2, mb: 2, display: 'flex', alignItems: 'center' }}>
        <LoadingButton
          sx={{ mr: 2 }}
          variant="outlined"
          startIcon={<RocketLaunchIcon />}
          onClick={showStart ? onTrain : onStop}
          disabled={![undefined, 'STOPPED', 'TRAINING'].includes(status.status)}
          loading={loading}
        >
          {showStart ? 'Start Training' : 'Stop Training'}
        </LoadingButton>
        {status?.status && (
          <Box sx={{ mr: 2 }}>
            <Typography variant="h6">
              {NeuralNetTrainingStatuses[status.status]?.label}
            </Typography>
          </Box>
        )}
        {status?.step !== undefined && (
          <Box sx={{ mr: 2 }}>
            <Typography>{`Step ${status.step}`}</Typography>
          </Box>
        )}
      </Box>
      {status?.savedModelId && (
        <Box sx={{ mr: 2 }}>
          <Typography variant="h6">Saved Model</Typography>
          <Typography>{status.savedModelName}</Typography>
        </Box>
      )}
      <TrainingLogsPlot
        title="Neural Network Training"
        trainTraces={status.traces || []}
      />
      <Box sx={{ mb: 2 }} />
      <TrainLogsTerm lines={status.logs || []} />
      <StartTrainingModal
        visible={starting}
        onCancel={onCancelTrain}
        onFinish={onFinishTrain}
      />
    </>
  );
}

export default NetTrainTab;
