import { useState, useEffect } from "react";
import {
  TextField, Typography, RadioGroup,
  Box, Radio, FormControlLabel, Grid,
  Select, MenuItem, FormControl, Paper
} from "@mui/material";
import ConfigInput from "./ConfigInput";
import { apiUrl } from "../secrets";
import Alert from '@mui/material/Alert';
import { storeConfig, getInferenceConfigFromFirebase, isStringInLatestVersionStatus } from "../utils/versionManagement";
import FloatingButton from "./UIElements/FloatingButton";
import { verifyService } from "../utils/serviceManagement";
import { useNavigate } from "react-router-dom";
import getTokens from '../utils/auth';

const StyledTextarea = ({ value, onChange, name }) => {
  return (
    <Grid item xs={12} sm={6}>
      <Paper elevation={1} sx={{ p: 2 }}>
        <Typography variant="h6" sx={{ mb: 2 }}>{name}</Typography>
        <TextField
          value={value}
          onChange={onChange}
          fullWidth
          multiline
          variant="outlined"
          minRows={4}
          maxRows={6}
        />
      </Paper>
    </Grid>
  );
};

const url = apiUrl + '/' + localStorage.getItem('userId');

const metricsMap = {
  "binary": ["f1_score", "accuracy", "precision", "recall"],
  "multiClass": ["accuracy", "precision_weighted", "recall_weighted", "f1_weighted",
    "precision_macro", "recall_macro", "f1_macro",
    "precision_micro", "recall_micro", "f1_micro"],
  "similarity": ["similarity"],
  "exactMatch": ["exact"],
  "summaryQuality": ["summary_quality"],
  "gpt4Evaluator": ["score"],
  "custom": ["validation_loss"]
}

function InferenceConfiguration({ onNext, onBack, subscriptionPlan, config,
                                projectName, project, versionId, version, type, functionCalling}) {
  const [metricType, setMetricType] = useState( type !== "chat" ? "binary" : "gpt4Evaluator");
  const [classificationNCls, setClassificationNCls] = useState(3);
  const [classificationPositiveClass, setClassificationPositiveClass] = useState("");
  const [task, setTask] = useState(""); // for gpt4Evaluator
  const [classificationBetas, setClassificationBetas] = useState({ type: "chip", values: [] });
  const [metricGoal, setMetricGoal] = useState("maximize");
  const [metricName, setMetricName] = useState(type !== "chat" ? "f1_score" : "score");
  const [customLossConfig, setCustomLossConfig] = useState("{}");
  const [customInferenceConfig, setInferenceConfig] = useState("{}");
  const [maxTokensConfig, setMaxTokensConfig] = useState("{}");
  const [errorMessage, setErrorMessage] = useState("");
  const [metricList, setMetricList] = useState(type !== "chat" ?  metricsMap["binary"] : metricsMap["gpt4Evaluator"])

  const navigate = useNavigate();

  const allMetrics = [
    { value: "binary", label: "Binary Classification" },
    { value: "multiClass", label: "MultiClass Classification" },
    { value: "similarity", label: "Text Similarity" },
    { value: "exactMatch", label: "Exact Match" },
    { value: "summaryQuality", label: "Summary Quality" },
    { value: "gpt4Evaluator", label: "GPT-4 Evaluator" },
    { value: "custom", label: "Custom" },
  ];

  const chatMetrics = [
    { value: "similarity", label: "Text Similarity" },
    { value: "gpt4Evaluator", label: "GPT-4 Evaluator" },
    { value: "custom", label: "Custom" }
  ];

  const functionCallingMetrics = [
    { value: "gpt4Evaluator", label: "GPT-4 Evaluator" },
    { value: "custom", label: "Custom" }
  ];

  let metrics = allMetrics;
  if(type === 'chat'){
    if(functionCalling){
      metrics = functionCallingMetrics;
    }
    else{
      metrics = chatMetrics;
    }
  }

  // Use the getConfigFromFirebase function in a useEffect hook
  useEffect(() => {
    getInferenceConfigFromFirebase(project, versionId, version)
      .then((config) => {
        if (config) {
          setMetricType(config.metricType);
          setMetricList(metricsMap[config.metricType]);
          setClassificationNCls(config.classificationNCls);
          setClassificationPositiveClass(config.classificationPositiveClass);
          setTask(config.task);
          setClassificationBetas(config.classificationBetas);
          setMetricGoal(config.metricGoal);
          setMetricName(config.metricName);
          setCustomLossConfig(config.customLossConfig);
          setInferenceConfig(config.customInferenceConfig);
          setMaxTokensConfig(config.maxTokensConfig);
        }
      });
  }, []);


  const verifyJSON = (json) => {
    try {
      JSON.parse(json);
    } catch (e) {
      return false;
    }
    return true;
  };

  const getBetaValues = (config) => {
    if (config.type === "chip") {
      if (config.values.length === 0) {
        return null;
      }
      return config.values;
    } else {
      const range = config.values;
      const arr = [];
      for (let i = range.min; i <= range.max; i++) {
        arr.push(i);
      }
      return arr;
    }
  }

  const handleCreateFineTune = () => {
    isStringInLatestVersionStatus(project,"dataset created").then((datasetCreated)=>{
      if (!datasetCreated) {
        isStringInLatestVersionStatus(project, "failed to create dataset").then((datasetCreationFailed) => {
          if (datasetCreationFailed) {
            setErrorMessage("dataset creation failed for this version. delete the current version and retry from beginning. Contact support if needed")
          }
          else {
            setErrorMessage("dataset is getting created. Please retry after some time")
          }
        });
      }
  });

    if(subscriptionPlan === "free" && metricType === "gpt4Evaluator"){
      setErrorMessage("Upgrade to Starter or Pro plan to unlock GPT-4 Evaluator");
      return;
    }

    if((subscriptionPlan === "starter" || subscriptionPlan === "free") && metricType === "custom"){
      setErrorMessage("Upgrade to Pro plan to unlock Custom Evaluator");
      return;
    }

    let inferenceConfig = {}
    let metricConfig = {}

    switch(metricType){
      case "binary":
        metricConfig["metric"] = "BinaryClassification";
        break;
      case "multiClass":
        metricConfig["metric"] = "MultiClassClassification";
        break;
      case "similarity":
        metricConfig["metric"] = "TextSimilarity";
        break;
      case "gpt4Evaluator":
        metricConfig["metric"] = "GPT4Evaluator";
        break;
      case "exactMatch":
        metricConfig["metric"] = "ExactMatch";
        break;
      case "summaryQuality":
          metricConfig["metric"] = "SummaryQuality";
          break;
      case "custom":
        metricConfig["metric"] = "Custom";
        break;
      default:
        metricConfig["metric"] = "BinaryClassification";
        break;
    }

    if (metricType === "binary" || metricType === "multiClass") {
      metricConfig["config"] = {};
      if (metricType === "multiClass") {
        metricConfig["config"]["classification_n_classes"] = classificationNCls;
      }
      if (metricType === "binary") {
        if (classificationPositiveClass !== "") {
          metricConfig["config"]["positive_class"] = classificationPositiveClass;
        }
        else {
          setErrorMessage("Provide positive class for binary classification");
          return;
        }
      }
      let precision_weights = getBetaValues(classificationBetas)
      metricConfig["config"]["precision_weights"] =  precision_weights != null?precision_weights:[] ;
    }

    if (metricType === "gpt4Evaluator") {
      metricConfig["config"] = {};
      if (task !== "") {
        metricConfig["config"]["task"] = task;
      }
      else {
        setErrorMessage("Provide task description for GPT-4 Evaluator");
        return;
      }
    }

    if(metricType === 'custom'){
      if (verifyJSON(customLossConfig)) {
        if (customLossConfig !== "{}") {
          metricConfig["config"]= JSON.parse(customLossConfig);
        }
      } else {
        setErrorMessage("Invalid JSON for custom loss config");
        return;
      }
    }

    inferenceConfig["metric_config"] = metricConfig;

    if (metricType !== "binary" && metricType !== "multiClass") {
      if (verifyJSON(customInferenceConfig)) {
        if (customInferenceConfig !== "{}")
          inferenceConfig["inference_config"] = JSON.parse(customInferenceConfig);
      } else {
        setErrorMessage("Invalid JSON for inference config");
        return;
      }

      if (verifyJSON(maxTokensConfig)) {
        if (maxTokensConfig !== "{}")
          inferenceConfig["max_tokens_config"] = JSON.parse(maxTokensConfig);
      } else {
        setErrorMessage("Invalid JSON for max tokens config");
        return;
      }
    }

    inferenceConfig["metric_configuration"] = {
      "goal": metricGoal,
      "name": metricName === "validation_loss" ? metricName : metricName + '/validation'
    }

    // console.log(config)

    inferenceConfig["project"] = projectName;
    inferenceConfig["firebase_project_id"] = project;
    if (versionId) {
      inferenceConfig["firebase_version_id"] = versionId;
    }
    else {
      setErrorMessage("Error in creating version. Close this wizard and try again");
      throw new Error();
    }
    inferenceConfig["total_rows"] = 75;

    inferenceConfig["type"] = type;

    const newConfig = { ...config, ...inferenceConfig };

    // console.log(newConfig);

    // console.log(JSON.stringify(newConfig));

    verifyService(navigate).then((canContinue) => {
      if(canContinue){
        getTokens().then((headers) => {
        fetch(`${url}/createFineTune`, {
          "method": "POST", "body": JSON.stringify(newConfig),
           headers: {...{
            "Content-Type": "application/json",
          }, ...headers},
        })
        .then((res) => {
          if (res.status === 200) {
            return res.json();
          }
          else {
            setErrorMessage("Error creating finetune");
          }
        })
        .then((data) => {
          if (data['error']) {
            setErrorMessage(data['error']);
          }
          else if (data['message']) {
            if (data['message'] === "Request accepted. Processing finetune creation in the background.") {
              //Store inference config to firebase
              const createConfigObject = () => {
                return {
                  metricType: metricType,
                  classificationNCls: classificationNCls,
                  classificationPositiveClass: classificationPositiveClass,
                  task: task,
                  classificationBetas: classificationBetas,
                  metricGoal: metricGoal,
                  metricName: metricName,
                  customLossConfig: customLossConfig,
                  customInferenceConfig: customInferenceConfig,
                  maxTokensConfig: maxTokensConfig,
                };
              };

              const config = createConfigObject();
              storeConfig(project, versionId, config, "inference").then(() => {
              //add version details to firebase
              window.location.href = `/project?id=${project}&name=${projectName}`;
              });
            }
            else {
              setErrorMessage(data['message']);
            }
          }
          else {
            setErrorMessage("Error creating finetune");
          }
        })
        .catch((err) => {
          setErrorMessage("Error creating finetune");
        });
      });
      }
    });
  }

  const handleMetricTypeChange = (e) => {
    const value = e.target.value;
    if (value !== metricType) {
      const metricMap = {
        binary: { name: "f1_score", goal: "maximize" },
        multiClass: { name: "f1_weighted", goal: "maximize" },
        similarity: { name: "similarity", goal: "maximize" },
        gpt4Evaluator: { name: "score", goal: "maximize" },
        exactMatch: { name: "exact", goal: "maximize" },
        summaryQuality: { name: "summary_quality", goal: "maximize" },
        custom: { name: "validation_loss", goal: "minimize" },
      };
      const { name, goal } = metricMap[value];

      setMetricList(metricsMap[value]);
      setMetricName(name);
      setMetricGoal(goal);
      setMetricType(value);
    }
  }

  let subscriptionAlert = "";
  if (subscriptionPlan === "free") {
    subscriptionAlert = "Upgrade to Starter or Pro plan to unlock GPT-4 Evaluator";
  }
  else if (subscriptionPlan === "starter") {
    subscriptionAlert = "Upgrade to Pro plan to unlock Custom Evaluator";
  }

  return (
    <Box display={"flex"} flexDirection={"column"} gap={3} py={3}>
      {subscriptionPlan !== "pro" && (<Alert severity="info">{subscriptionAlert}</Alert>)}
      <Box pb={2} className="form-help-text" sx={{ display: { xs: 'none', md: 'block' } }}>
        Docs: <a href="https://www.easyllm.tech/docs/loss-and-metrics-configuration.html" target="_blank" className="form-help-hypertext" rel="noreferrer"> Configuration </a> and <a href="https://www.easyllm.tech/docs/blog/crafting-custom-metrics-for-measuring-performance-of-finetuned-large-language-models.html" className="form-help-hypertext" target="_blank" rel="noreferrer">Tutorial</a>
      </Box>
      <Grid container spacing={2}>
        <Grid item xs={12} sm={6}>
          <Paper elevation={1} sx={{ padding: 4 }}> {/* Metric Paper*/}
            <Grid container spacing={2}>
              <Grid item xs={12} sm={6}> {/* Metric Selector*/}
                <Typography variant="subtitle1"  >
                  Metric
                </Typography>
                <FormControl>
                  <Select
                    labelId="metric-select-label"
                    id="metric-select"
                    value={metricType}
                    onChange={(e) => { handleMetricTypeChange(e); }}
                  >
                    {
                      metrics.map((menuItem) => (
                        <MenuItem key={menuItem.value} value={menuItem.value}>
                          {menuItem.label}
                        </MenuItem>
                      ))
                    }
                  </Select>
                </FormControl>
              </Grid>
              <Grid item xs={12} sm={6}> {/* Number of Classes or Positive Class*/}
                {(metricType === "binary" || metricType === "multiClass") && (
                  <>
                    {(metricType === "multiClass") && (
                      <>
                        <Typography variant="subtitle1"  >
                          Number of Classes
                        </Typography>
                        <TextField
                          fullWidth
                          size="small"
                          value={classificationNCls}
                          type="number"
                          onChange={(e) => {
                            const value = parseInt(e.target.value);
                            (isNaN(value)) ? setClassificationNCls(0) : setClassificationNCls(value)
                          }}
                        />
                      </>
                    )}
                    {(metricType === "binary") && (
                      <>
                        <Typography variant="subtitle1"  >
                          Classification Positive Class
                        </Typography>
                        <TextField
                          fullWidth
                          size="small"
                          value={classificationPositiveClass}
                          onChange={(e) => setClassificationPositiveClass(e.target.value)}
                        />
                      </>
                    )}
                  </>
                )}
              </Grid>
              <Grid item xs={12} sm={6}> {/* Metric Goal*/}
                <Box display={"flex"} flexDirection={"column"} gap={1} py={1}>
                  <Typography variant="subtitle1">
                    Metric Goal
                  </Typography>
                  <RadioGroup
                    row
                    aria-label="metric-goal"
                    name="metric-goal"
                    value={metricGoal}
                    onChange={(e) => {
                      const value = e.target.value;
                      if (value === "minimize") {
                        setMetricGoal("minimize");
                      } else {
                        setMetricGoal("maximize");
                      }
                    }}
                  >
                    <FormControlLabel value="minimize" control={<Radio />}
                      disabled={(["custom"].indexOf(metricType) === -1)}
                      label="Minimize" />
                    <FormControlLabel value="maximize" control={<Radio />}
                      disabled={(["custom"].indexOf(metricType) === -1)}
                      label="Maximize" />
                  </RadioGroup>
                </Box>
              </Grid>
              <Grid item xs={12} sm={6}> {/* Metric Name*/}
                <Box display={"flex"} flexDirection={"column"} gap={1} py={1}>
                  <Typography variant="subtitle1"  >
                    Metric Name
                  </Typography>
                  {(metricType === 'custom') && (
                    <TextField
                      labelId="metric-name-select-label"
                      id="metric-name-select"
                      value={metricName}
                      onChange={(e) => setMetricName(e.target.value)}
                    />
                  )}
                  {(metricType !== 'custom') && (
                    <Select
                      labelId="metric-name-select-label"
                      id="metric-name-select"
                      value={metricName}
                      onChange={(e) => setMetricName(e.target.value)}
                    >
                      {
                        metricList.map((menuItem) => (
                          <MenuItem key={menuItem} value={menuItem}>
                            {menuItem}
                          </MenuItem>
                        ))
                      }
                    </Select>
                  )}
                </Box>
              </Grid>
            </Grid>
          </Paper>
        </Grid>
        {(metricType === "binary" || metricType === "multiClass") && (
          <Grid item xs={12} sm={6}> {/* F-Score Precision Weights*/}
            <ConfigInput config={classificationBetas} setConfig={setClassificationBetas} name="F-Score Precision Weights"
              inputKey="classification-betas" maxListLength={5} allowedRange={{ "min": 0.0, "max": 1.0 }} helperText="Allowed range 0.0 to 1.0" />
          </Grid>
        )}
        {/* Describe the Task*/}
        {(metricType === 'gpt4Evaluator') && (
          <StyledTextarea  
            value={task}
            onChange={(e) => setTask(e.target.value)}
            name="Describe the Task"
          />
        )}
        {/* Custom Metric Config*/}
        {metricType === "custom" && (
          <StyledTextarea
            value={customLossConfig}
            onChange={(e) => setCustomLossConfig(e.target.value)}
            name="Custom Metric Config"
          />
        )}
        {(["custom", "similarity", "exactMatch", "summaryQuality", "gpt4Evaluator"].indexOf(metricType) !== -1) && (
          <>
            {/* Inference Config*/}
            <StyledTextarea
              value={customInferenceConfig}
              onChange={(e) => setInferenceConfig(e.target.value)}
              name="Inference Config"
            />
            {/* Max Tokens Config*/}
            <StyledTextarea
              value={maxTokensConfig}
              onChange={(e) => setMaxTokensConfig(e.target.value)}
              name="Max Tokens Config"
            />
          </>
        )}
      </Grid>
      <FloatingButton onClick={handleCreateFineTune} text="Create FineTune" hideIcon />
      <Box sx={{ mb: 2 }}>
        {errorMessage && (
          <Alert severity="error">{errorMessage}</Alert>
        )}
      </Box>
    </Box>
  );
}

export default InferenceConfiguration;