import React, { useState, useEffect, useRef } from 'react';
import {
  Autocomplete,
  Box,
  Button,
  Card,
  CardContent,
  Dialog,
  DialogTitle,
  FormControlLabel,
  Grid,
  IconButton,
  Radio,
  Slider,
  TextField,
  Typography
} from '@mui/material';
import Tooltip from '@mui/material/Tooltip';

import { dispatch, useAppSelector } from '../../store/hooks';
import { modalsMiddleware } from '../../store/slices/modals';
import { ModalName } from '../../types/modals';
import CloseIcon from '@mui/icons-material/Close';
import {
  activationFunctionOptions,
  classificationMetricsOptions,
  layerOptions,
  penaltyOptions,
  regressionMetricsOptions,
  solverOptions
} from '../../helpers/managers';
import { IconInfoCircle, IconQuestionMark } from '@tabler/icons';
import { useTheme } from '@mui/material/styles';
import { modelSettingsMiddleware } from '../../store/slices/modelSettings';
import { modelSettingsSelector } from '../../store/slices/modelSettings';
import { isNaN } from 'formik';

export interface ModelInfoDialogProps {
  paramsLR: any;
  setParamsLR: any;
  paramsBoost: any;
  paramsCatboost: any;
  paramsRandom: any;
  setParamsBoost: any;
  setParamsCatboost: any;
  setParamsRandom: any;
  modelInfo: string;
  problemType: string;
  params: any;
  setParams: any;
  nnSaveObject: any;
  setNNSaveObject: any;
  nnReset: any;
  logisticSaveObject: any;
  setLogisticSaveObject: any;
  logisticReset: any;
  boostSaveObject: any;
  setBoostSaveObject: any;
  boostReset: any;
  catboostSaveObject: any;
  setCatboostSaveObject: any;
  catboostReset: any;
  randomSaveObject: any;
  setRandomSaveObject: any;
  randomReset: any;
}

const NNModelDialog = ({
  modelInfo,
  problemType,
  params,
  setParams,
  nnSaveObject,
  setNNSaveObject,
  nnReset
}: ModelInfoDialogProps) => {
  const theme = useTheme();

  const initialValues = {
    nnToggle: true,
    metric: problemType === 'classification' ? classificationMetricsOptions[0] : regressionMetricsOptions[0],
    time: 1,
    epoch: 10,
    optimizer: 'Adam',
    batchNorm: true,
    learningRate: 0.001,
    layers: layerOptions[2],
    nodes: [256, 256, 256],
    activation: ['ReLU', 'ReLU', 'ReLU'],
    dropout_rates: [0.1, 0.1, 0.1]
  };

  const [nnToggle, setNNtoggle] = useState(nnSaveObject.nnToggle ?? initialValues.nnToggle);
  const [metrics, setMetrics] = useState(nnSaveObject.metric ?? initialValues.metric);
  const [time, setTime] = useState(nnSaveObject.time ?? initialValues.time);
  const [epoch, setEpoch] = useState(nnSaveObject.epoch ?? initialValues.epoch);
  const [optimizer, setOptimizer] = useState(nnSaveObject.optimizer ?? initialValues.optimizer);
  const [batchNorm, setBatchNorm] = useState(nnSaveObject.batchNorm ?? initialValues.batchNorm);
  const [learningRate, setLearningRate] = useState(nnSaveObject.learningRate ?? initialValues.learningRate);
  const [layers, setLayers] = useState(nnSaveObject.layers ?? initialValues.layers);
  const [nodes, setNodes] = useState(nnSaveObject.nodes ?? initialValues.nodes);
  const [activation, setActivation] = useState(nnSaveObject.activation ?? initialValues.activation);
  const [dropout_rates, setDropout_rates] = useState(nnSaveObject.dropout_rates ?? initialValues.dropout_rates);
  const [disableSave, setDisableSave] = useState(false);

  const { modelSettingsData, modelSettingsStatus } = useAppSelector(modelSettingsSelector.modelSettingsData);

  const getModelSettingsURL = '/model_settings/';
  const modelSettingsParams = {
    model_name: `${modelInfo}`,
    is_best: false
  };

  const handleReset = () => {
    // setNNtoggle(initialValues.nnToggle)
    setMetrics(initialValues.metric);
    setTime(initialValues.time);
    setEpoch(initialValues.epoch);
    setOptimizer(initialValues.optimizer);
    setBatchNorm(initialValues.batchNorm);
    setLearningRate(initialValues.learningRate);
    setLayers(initialValues.layers);
    setNodes(initialValues.nodes);
    setActivation(initialValues.activation);
    setDropout_rates(initialValues.dropout_rates);
    nnReset({});
  };

  useEffect(() => {
    setParams((prevParams: any) => ({
      ...prevParams,
      models: {
        ...prevParams.models,
        NeuralNets: {
          ...prevParams.models.NeuralNets,
          batch_norm: batchNorm
        }
      }
    }));
  }, [batchNorm]);

  useEffect(() => {
    if (!nnToggle) {
      dispatch(modelSettingsMiddleware.getModelSettings(getModelSettingsURL, modelSettingsParams));
    }
  }, [nnToggle]);

  useEffect(() => {
    if (nnToggle) {
      if (metrics === undefined || !metrics || isNaN(time)) {
        setDisableSave(true);
      } else {
        setDisableSave(false);
      }
    } else {
      if (
        isNaN(epoch) ||
        optimizer === undefined ||
        !optimizer ||
        isNaN(learningRate) ||
        nodes.some((value: any) => isNaN(value)) ||
        activation.some((value: any) => value === undefined || value === '') ||
        layers === null
      ) {
        setDisableSave(true);
      } else {
        setDisableSave(false);
      }
    }
  }, [metrics, time, epoch, optimizer, learningRate, nodes, activation, nnToggle]);

  const onModalClose = () => {
    dispatch(modalsMiddleware.closeModal(ModalName.NNInfoModal));
  };

  const handleNNToggle = () => {
    // setParams((prevParams: any) => ({
    //   ...prevParams,
    //   models: {
    //     ...prevParams.models,
    //     NeuralNets: {
    //       ...prevParams.models.NeuralNets,
    //       one_model: nnToggle
    //     }
    //   }
    // }));
    setNNtoggle(!nnToggle);
  };

  const handleNNMetrics = (e: any) => {
    const value = e.target.outerText;
    setMetrics(value);
  };

  const handleNNMaxTime = (e: any) => {
    const value = parseFloat(e.target.value);
    setTime(value);
  };

  const handleEpoch = (e: any) => {
    // let value = parseInt(e.target.value);
    // Ensure the value is within the desired range (0 to 300)
    let value = parseInt(e.target.value);
    value = Math.min(Math.max(value, 1), 256);
    // Update the input value if it was modified

    if (value !== parseInt(e.target.value, 10)) {
      e.target.value = value;
    }
    setEpoch(value);
  };

  const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
    if (e.key === '-' || e.key === 'e' || e.key === '+' || e.key === '.') {
      e.preventDefault();
    }
  };

  const handleOptimizer = (e: any) => {
    const value = e.target.outerText;
    setOptimizer(value);
  };

  const handleBatchNorm = (event: any) => {
    setBatchNorm(event.target.value === 'true');
  };

  const handleLearningRate = (e: any) => {
    let value = parseFloat(e.target.value);
    value = Math.min(Math.max(value, 0.00001), 0.99999);
    // Update the input value if it was modified

    if (value !== parseInt(e.target.value, 10)) {
      e.target.value = value;
    }
    // const number = parseInt(value, 10)
    setLearningRate(value);
  };

  const handleLayers = (e: any, value: any) => {
    if (value) {
      const numLayers = value.value;
      setNodes(Array(numLayers).fill(256));
      setActivation(Array(numLayers).fill('ReLU'));
      setDropout_rates(Array(numLayers).fill(0.1));
      setLayers(value);
    }
  };

  const handleNodes = (e: any, i: any) => {
    let value = parseInt(e.target.value);
    value = Math.min(Math.max(value, 1), 1024);
    if (value !== parseInt(e.target.value, 10)) {
      e.target.value = value;
    }
    const newNodes = [...nodes]; // Create a copy of the array
    newNodes[i] = value; // Update the value at the specified index
    setNodes(newNodes); // Update the state with the new array
  };

  const handleActivationFunction = (e: any, i: any) => {
    const value = e.target.outerText;
    const newActivation = [...activation]; // Create a copy of the array
    newActivation[i] = value; // Update the value at the specified index
    setActivation(newActivation); // Update the state with the new array
  };

  const handleSlider = (e: any, i: number) => {
    const value = e.target.value;
    const newDropout_rates = [...dropout_rates]; // Create a copy of the array
    newDropout_rates[i] = value; // Update the value at the specified index
    setDropout_rates(newDropout_rates); // Update the state with the new array
  };

  const optimizerOptionsWithIds = modelSettingsData?.model_settings?.optimizer.map((option: any, index: any) => ({
    id: index,
    label: option
  }));
  const activationsOptionsWithId = modelSettingsData?.model_settings?.activations.map((option: any, index: any) => ({
    id: index,
    label: option
  }));

  const onSaveClick = () => {
    setNNSaveObject({
      nnToggle: nnToggle,
      metric: typeof metrics === 'object' ? metrics.label : metrics,
      time: time,
      epoch: epoch,
      optimizer: optimizer,
      batchNorm: batchNorm,
      learningRate: learningRate,
      layers: layers,
      nodes: nodes,
      activation: activation,
      dropout_rates: dropout_rates
    });
    if (nnToggle) {
      setParams(() => ({
        ...params,
        models: {
          [modelInfo]: {
            one_model: !nnToggle,
            time: time,
            metric: typeof metrics === 'object' ? metrics.label : metrics,
            epochs: epoch,
            optimizer: optimizer,
            batch_norm: batchNorm,
            learning_rate: learningRate,
            layers: parseInt(layers.label),
            nodes: [...nodes],
            activations: [...activation],
            dropout_rates: [...dropout_rates]
          }
        }
      }));
    } else {
      setParams(() => ({
        ...params,
        models: {
          [modelInfo]: {
            one_model: !nnToggle,
            time: 1,
            metric:
              problemType === 'classification'
                ? classificationMetricsOptions[0].label
                : regressionMetricsOptions[0].label,
            epochs: epoch,
            optimizer: optimizer,
            batch_norm: batchNorm,
            learning_rate: learningRate,
            layers: parseInt(layers.label),
            nodes: [...nodes],
            activations: [...activation],
            dropout_rates: [...dropout_rates]
          }
        }
      }));
    }
    onModalClose();
  };

  return (
    <Dialog fullWidth maxWidth="lg" onClose={onModalClose} aria-labelledby="simple-dialog-title" open>
      <div style={{ maxHeight: '100%' }}>
        <DialogTitle style={{ textAlign: 'center', position: 'relative', cursor: 'move' }} id="draggable-dialog-title">
          {`${modelInfo} Model Config`}
          <IconButton
            onClick={onModalClose}
            sx={{
              position: 'absolute',
              right: 20,
              top: 10
            }}
          >
            <CloseIcon />
          </IconButton>
        </DialogTitle>
        <Card>
          <CardContent sx={{ pt: 0 }}>
            <div style={{ display: 'flex', alignItems: 'center', paddingBottom: '10px', paddingTop: '10px' }}>
              <Tooltip title="The One Model  trains a specific model based on the parameters specified by the user.">
                <div
                  style={{
                    marginTop: '-15px'
                  }}
                >
                  <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                </div>
              </Tooltip>
              <div style={{ color: '#000', fontSize: '22px', marginRight: '10px' }}>One Model</div>
              <div
                style={{
                  width: '70px',
                  height: '35px',
                  borderRadius: '35px',
                  backgroundColor: nnToggle ? 'green' : 'green',
                  display: 'flex',
                  alignItems: 'center',
                  justifyContent: 'flex-start',
                  position: 'relative',
                  cursor: 'pointer',
                  padding: '0 5px',
                  boxSizing: 'border-box'
                }}
                onClick={handleNNToggle}
              >
                <div
                  style={{
                    width: '25px',
                    height: '25px',
                    borderRadius: '50%',
                    backgroundColor: 'white',
                    transform: nnToggle ? 'translateX(38px)' : 'translateX(0)',
                    transition: 'transform 0.2s ease-in-out',
                    boxShadow: '0 2px 4px rgba(0, 0, 0, 0.2)',
                    cursor: 'pointer'
                  }}
                />
              </div>
              <div style={{ color: '#000', fontSize: '22px', marginLeft: '10px' }}>Best Model</div>
              <Tooltip title="The  Best Model gives the leading model determined by the algorithm based on the Metrics  and Time specified by the user.">
                <div
                  style={{
                    marginTop: '-15px'
                  }}
                >
                  <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                </div>
              </Tooltip>
            </div>
            {nnToggle ? (
              <div>
                <Grid
                  container
                  justifyContent="space-between"
                  alignItems="center"
                  spacing={2}
                  style={{ paddingBottom: '10px' }}
                >
                  <Grid item sm={6} xs={12} sx={{ paddingX: '16px' }}>
                    <div style={{ display: 'flex', alignItems: 'center' }}>
                      <p style={{ fontSize: '18px', fontWeight: '600' }}>Metrics</p>
                      <Tooltip title="Select one of this parameters, which should be taken into account for model convergence.">
                        <div>
                          <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                        </div>
                      </Tooltip>
                    </div>
                    <Autocomplete
                      style={{ width: '100%' }}
                      options={
                        problemType === 'classification' ? classificationMetricsOptions : regressionMetricsOptions
                      }
                      value={metrics}
                      onChange={handleNNMetrics}
                      renderInput={(params) => <TextField {...params} label="" />}
                    />
                  </Grid>
                  <Grid item sm={6} xs={12} sx={{ paddingX: '16px' }}>
                    <div style={{ display: 'flex', alignItems: 'center' }}>
                      <p style={{ fontSize: '18px', fontWeight: '600' }}>Time</p>
                      <Tooltip title="Select how long you are able to expect for completion of model training process in hours.">
                        <div>
                          <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                        </div>
                      </Tooltip>
                    </div>
                    <TextField
                      style={{ width: '100%' }}
                      value={time}
                      onChange={handleNNMaxTime}
                      placeholder="Select Time"
                      inputProps={{ type: 'number', min: 1 }}
                    />
                  </Grid>
                </Grid>
              </div>
            ) : (
              <div>
                <Grid
                  container
                  justifyContent="space-between"
                  alignItems="center"
                  spacing={2}
                  style={{ paddingBottom: '10px' }}
                >
                  <Grid item sm={6} xs={12} sx={{ paddingX: '16px' }}>
                    <div style={{ display: 'flex', alignItems: 'center' }}>
                      <p style={{ fontSize: '18px', fontWeight: '600' }}>Epoch</p>
                      <Tooltip title="Select the epoch to train the network.">
                        <div>
                          <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                        </div>
                      </Tooltip>
                    </div>
                    <TextField
                      style={{ width: '100%' }}
                      onChange={handleEpoch}
                      // onKeyDown={handleKeyDown}
                      inputProps={{ type: 'number', min: 1, max: 256, step: 1 }}
                      value={epoch}
                    />
                  </Grid>
                </Grid>

                <Grid
                  container
                  justifyContent="space-between"
                  alignItems="center"
                  spacing={2}
                  style={{ paddingBottom: '10px' }}
                >
                  <Grid item sm={6} xs={12} sx={{ paddingX: '16px' }}>
                    <div style={{ display: 'flex', alignItems: 'center' }}>
                      <p style={{ fontSize: '18px', fontWeight: '600' }}>Optimizer</p>
                      <Tooltip title="Select the optimizer to train the network.">
                        <div>
                          <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                        </div>
                      </Tooltip>
                    </div>
                    {modelSettingsStatus === 200 && (
                      <Autocomplete
                        style={{ width: '100%' }}
                        value={optimizer}
                        options={optimizerOptionsWithIds}
                        onChange={handleOptimizer}
                        renderInput={(params) => <TextField {...params} label="" />}
                      />
                    )}
                  </Grid>
                  <Grid
                    style={{ display: 'flex', flexDirection: 'column', alignItems: 'center' }}
                    item
                    sm={6}
                    xs={12}
                    sx={{ paddingX: '16px' }}
                  >
                    <div style={{ display: 'flex', alignItems: 'center' }}>
                      <p style={{ fontSize: '18px', fontWeight: '600' }}>Batch Norm</p>
                      <Tooltip title="Normalize every batch of data or not.">
                        <div>
                          <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                        </div>
                      </Tooltip>
                    </div>
                    <div key={params.models.NN.batch_norm}>
                      <FormControlLabel
                        value="true"
                        control={<Radio color="primary" />}
                        label="True"
                        labelPlacement="end"
                        checked={batchNorm === true}
                        onChange={handleBatchNorm}
                        name="batch-norm-radio-true"
                      />
                      <FormControlLabel
                        value="false"
                        control={<Radio color="primary" />}
                        label="False"
                        labelPlacement="end"
                        checked={batchNorm === false}
                        onChange={handleBatchNorm}
                        name="batch-norm-radio-false"
                      />
                    </div>
                  </Grid>
                </Grid>
                <Grid
                  container
                  justifyContent="space-between"
                  alignItems="center"
                  spacing={2}
                  style={{ paddingBottom: '10px' }}
                >
                  <Grid item sm={6} xs={12} sx={{ paddingX: '16px' }}>
                    <div style={{ display: 'flex', alignItems: 'center' }}>
                      <p style={{ fontSize: '18px', fontWeight: '600' }}>Learning Rate</p>
                      <Tooltip title="Hyperparameter,  that defines the adjustment in the weights of network with respect to the gradient descent.">
                        <div>
                          <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                        </div>
                      </Tooltip>
                    </div>
                    <TextField
                      style={{ width: '100%' }}
                      onChange={handleLearningRate}
                      // onKeyDown={handleKeyDown}
                      value={learningRate}
                      inputProps={{ type: 'number', min: 0.00001, max: 0.99999, step: 0.001 }}
                    />
                  </Grid>
                  <Grid item sm={6} xs={12} sx={{ paddingX: '16px' }}>
                    <div style={{ display: 'flex', alignItems: 'center' }}>
                      <p style={{ fontSize: '18px', fontWeight: '600' }}>Layers</p>
                      <Tooltip title="Add dense layers to my network.">
                        <div>
                          <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                        </div>
                      </Tooltip>
                    </div>
                    <Autocomplete
                      style={{ width: '100%' }}
                      options={layerOptions}
                      value={layers}
                      onChange={handleLayers}
                      renderInput={(params) => <TextField {...params} label="" type="number" />}
                    />
                  </Grid>
                </Grid>
                <div>
                  <Grid
                    container
                    justifyContent="space-between"
                    alignItems="center"
                    spacing={2}
                    style={{ paddingBottom: '10px' }}
                  >
                    <Grid item sm={4} xs={12} sx={{ paddingX: '16px' }}>
                      <div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
                        <p style={{ fontSize: '18px', fontWeight: '600' }}>Nodes</p>
                        <Tooltip title="How many nodes will have each layer of my network. Please select 1-2000">
                          <div>
                            <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                          </div>
                        </Tooltip>
                      </div>
                    </Grid>
                    <Grid item sm={4} xs={12} sx={{ paddingX: '16px' }}>
                      <div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
                        <p style={{ fontSize: '18px', fontWeight: '600' }}>Activation Function</p>
                        <Tooltip title="Select an activation function for each layer.">
                          <div>
                            <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                          </div>
                        </Tooltip>
                      </div>
                    </Grid>
                    <Grid item sm={4} xs={12} sx={{ paddingX: '16px' }}>
                      <div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
                        <p style={{ fontSize: '18px', fontWeight: '600' }}>Dropout Rate</p>
                        <Tooltip title="The proportion of disabled nodes of each layer.">
                          <div>
                            <IconInfoCircle size={16} style={{ color: '#000', fill: '#fff' }} />
                          </div>
                        </Tooltip>
                      </div>
                    </Grid>
                  </Grid>
                </div>
                {layers &&
                  parseInt(layers.label) > 0 &&
                  Array.from({ length: parseInt(layers.label) }, (_, i) => (
                    <div key={i}>
                      <Grid
                        container
                        justifyContent="space-between"
                        alignItems="center"
                        spacing={2}
                        style={{ paddingBottom: '10px' }}
                      >
                        <Grid item sm={4} xs={12} sx={{ paddingX: '16px' }}>
                          <div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
                            <TextField
                              style={{ width: '100%' }}
                              value={nodes[i]}
                              onChange={(e) => handleNodes(e, i)}
                              inputProps={{ type: 'number', min: 1, max: 1024, step: 1 }}
                            />
                          </div>
                        </Grid>
                        <Grid item sm={4} xs={12} sx={{ paddingX: '16px' }}>
                          <div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
                            {modelSettingsStatus === 200 && (
                              <Autocomplete
                                style={{ width: '100%' }}
                                options={activationsOptionsWithId}
                                value={activation[i]}
                                onChange={(e) => handleActivationFunction(e, i)}
                                renderInput={(params) => <TextField {...params} label="" />}
                              />
                            )}
                          </div>
                        </Grid>
                        <Grid item sm={4} xs={12} sx={{ paddingX: '16px' }}>
                          <div style={{ display: 'flex', alignItems: 'center', justifyContent: 'center' }}>
                            <Grid item xs={12} container spacing={2} alignItems="center" sx={{ mt: 2.5 }}>
                              <Grid item>
                                <Typography variant="h6" color="primary">
                                  0
                                </Typography>
                              </Grid>
                              <Grid item xs>
                                <Slider
                                  color="secondary"
                                  value={dropout_rates[i]}
                                  onChange={(e) => handleSlider(e, i)}
                                  valueLabelDisplay="on"
                                  aria-labelledby="discrete-slider-small-steps"
                                  marks
                                  step={0.1}
                                  min={0.1}
                                  max={1}
                                />
                              </Grid>
                              <Grid item>
                                <Typography variant="h6" color="primary">
                                  1
                                </Typography>
                              </Grid>
                            </Grid>
                          </div>
                        </Grid>
                      </Grid>
                    </div>
                  ))}
              </div>
            )}

            <Grid item xs={12} md={4} sx={{ paddingX: '16px' }}>
              <Box sx={{ display: 'flex', justifyContent: 'space-between' }}>
                <Button
                  variant="contained"
                  size="large"
                  sx={{
                    margin: '20px auto',
                    height: '52px',
                    background: theme.palette.warning.dark,
                    '&:hover': { background: theme.palette.warning.main },
                    color: 'grey.900'
                  }}
                  onClick={handleReset}
                >
                  Reset
                </Button>
                <Button
                  variant="contained"
                  size="large"
                  style={{
                    backgroundColor: disableSave ? '#CCCCCC' : '#202090',
                    color: disableSave ? '#666666' : '#fff'
                  }}
                  sx={{ margin: '20px auto', height: '52px' }}
                  onClick={onSaveClick}
                  disabled={disableSave}
                >
                  Save
                </Button>
              </Box>
            </Grid>
          </CardContent>
        </Card>
      </div>
    </Dialog>
  );
};

export default NNModelDialog;
