//
// This file is part of MorphoGraphX - http://www.MorphoGraphX.org
// Copyright (C) 2012-2015 Richard S. Smith and collaborators.
//
// If you use MorphoGraphX in your work, please cite:
//   http://dx.doi.org/10.7554/eLife.05864
//
// MorphoGraphX is free software, and is licensed under under the terms of the 
// GNU General (GPL) Public License version 2.0, http://www.gnu.org/licenses.
//
#ifndef CellTypeRecognitionHPP
#define CellTypeRecognitionHPP

#include "Process.hpp"
#include "CellCluster.hpp"
#include "GraphUtils.hpp"
#include "SVMClassifier.hpp"
#include "MeshProcessCellMesh.hpp"
#include "MeshProcessHeatMap.hpp"
#include "MeshProcessMeasures.hpp"
#include "MeshProcessMeasures3D.hpp"
#include <MeshProcessLineage.hpp> // for selecting parents (specification process) 
#include <QDialog>
#include <memory>
#include <QMessageBox>
#include <QCheckBox>
#include <ui_SVMGUI.h>
#include <libsvm/svm.h>

#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <algorithm>
#include <iosfwd>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include <memory>

#include "ui_HeatMapClusterGUI.h"
#include "ClusterMap.hpp"

//using namespace caffe;
using std::string;

namespace mgx
{

static svm_model *model;
static CellCluster cc;
static SVMClassifier svmc;


  /**
   * \class CellTypeRecognitionClear TypeRecognition.hpp <TypeRecognition.hpp>
   *
   * Clear cc data structure and/or model for new test data
   */
  class CellTypeRecognitionClear : public Process
  { 
    public:
      CellTypeRecognitionClear(const Process &process): Process(process) 
      {
		setName("Mesh/Cell Types/Classification/Z Reset Data");
	    setDesc("Clears the internal input and model data.");
	    setIcon(QIcon(":/images/CellTypeRecognitionSpecification.jpeg"));
	    

		addParm("Clear Input","","Yes",booleanChoice());
		addParm("Clear Model","","Yes",booleanChoice());
		}
		
        bool run(){
        return run(stringToBool(parm("Clear Input")), stringToBool(parm("Clear Model"))); 
      }
      bool run(bool clearInput, bool clearModel);

  };




  class CellTypeRecognitionFeatureSelection : public Process
  { 
    Q_OBJECT
    public:
      Ui_SVMDialog ui;

      std::set<QString> selectedMeasures;

      CellTypeRecognitionFeatureSelection(const Process &process) : Process(process) 
      { 
	    setName("Mesh/Cell Types/Classification/A Select Measures");
	    setDesc("Select and View the Measure Data for Training");
	    setIcon(QIcon(":/images/CellTypeRecognitionSpecification.jpeg"));

		addParm("Mesh Type","","2D", QStringList() << "2D" << "3D");
		addParm("Selected Measures","","");	  
	  }

      bool findProcesses(Mesh *mesh);
      bool initialize(QWidget *parent);

      // reads the attr map tree widget and returns a string with all selected ones
      QString readSelectedMeasures();

      bool scaleData();
      bool calcSelectedMeasures(Mesh *mesh, bool forceRecalc, bool setScale);
      void populateMapTree(Mesh *mesh);
      //void selectFeaturesGUI();

      void setSelectedMeasures(std::set<QString> newMeasures);

      bool run();

      bool runMeasureProcess(Mesh* m, QString name, bool forceRecalc, bool justGrapAttrMap);


      public slots:
      void on_svmTreeWidget_itemClicked(QTreeWidgetItem *, int);
      void on_svmMapTreeWidget_itemClicked(QTreeWidgetItem *, int);
      void on_selectAllButton_clicked();
      void on_unselectAllButton_clicked();
      void changeMeshType(const QString& s);
  };

  class CellTypeRecognitionSimilarity : public Process
  { 
    public:

      CellTypeRecognitionSimilarity(const Process &process) : Process(process) 
      { 
		setName("Mesh/Cell Types/Classification/Tools/Similarity Heat Map");
		setDesc("Generate a Heat Map of a Similarity measure using the selected measures \n"
		"in the process A Select Measures based on the selected cells");
		setIcon(QIcon(":/images/CellTypeRecognitionSpecification.jpeg"));

		addParm("Distance","","Euclidean", QStringList() << "Euclidean");
		addParm("Mesh","","active", QStringList() << "active" << "other");
  
		  }

      bool run(){
        Mesh *m1 = currentMesh();
        Mesh *m2 = otherMesh();
        return run(m1, m2, parm("Distance"), parm("Mesh")); 
      }

      bool run(Mesh* m1, Mesh* m2, QString distanceType, QString meshNr);

  };

class CorrespondingVertices : public Process
  { 
    public:

      CorrespondingVertices(const Process &process) : Process(process) 
      { 
		setName("Mesh/Cell Types/Classification/Tools/Corresponding Vertices");
		setDesc("Find corresponding vertices in the other mesh based on the used attribute maps"); 
		setIcon(QIcon(":/images/CellTypeRecognitionSpecification.jpeg")); 
	  }

      bool run(){
        Mesh *m1 = currentMesh();
        Mesh *m2 = otherMesh();
        return run(m1, m2); 
      }

      bool run(Mesh* m1, Mesh* m2);

  };


  class CellTypeRecognitionSpecification : public SetParent
  {
    public:
      CellTypeRecognitionSpecification(const Process &process): SetParent(process) 
      {
		setName("Mesh/Cell Types/Classification/B Specify Cell Types"); 
		setDesc("Set the parent for selected cells"); 
	  }
      
  };

  class CellTypeRecognitionDataFile : public Process
  {
    public:
      CellTypeRecognitionDataFile(const Process &process): Process(process) 
      {
		setName("Mesh/Cell Types/Classification/C Write Training Data");
		setDesc("Write Training Data to a file based on the selected measures and cells");
		setIcon(QIcon(":/images/CellTypeRecognitionSpecification.jpeg")); 

		addParm("Filename","","");
		addParm("File Type","","New", QStringList() << "New" << "Append");
	  }

      bool initialize( QWidget* parent);

      bool run(){
        Mesh *m1 = currentMesh();
        return run(m1, parm("Filename"), parm("File Type"));
      }
      bool writeDataFile(Mesh *mesh, QString filename, QString choice);
      bool run(Mesh *m1, QString filename, QString choice);

  };
  class CellTypeRecognitionTrain : public Process
  {
    public:
      CellTypeRecognitionTrain(const Process &process): Process(process) 
      {
		setName("Mesh/Cell Types/Classification/SVM/E Train SVM Model");
		setDesc("Train the SVM model by loading a training file. Optimize SVM parameter using cross validaton.");
		setIcon(QIcon(":/images/CellTypeRecognitionSpecification.jpeg")); 

		addParm("Filename","Location of training data file","");
		addParm("SVM Type","Type of the SVM \n"
		"SVC: classification \n"
		"SVR: regression \n"
		"One class: distribution estimation","c_svc", QStringList() << "c_svc" << "nu_svc" << "one_class" << "epsilon_svr" << "nu_svr");
		addParm("Kernel Type","Kernel Type of the SVM \n"
		"rbf: non-linear mapping into a higher space \n"
		"linear: special case of rbf, use when many dimensions","rbf", QStringList() << "linear" << "polynomial" << "rbf" << "sigmoid" << "precomputed");
		addParm("Feature Selection","Feature reduction","No", QStringList() << "Genetic Feature Selection" << "Full Reduction" << "Quick Reduction" << "Full Inclusion" << "Quick Inclusion" << "No");
		addParm("Optimize Parms","Optimize gamma and C using a grid search and use best cross validation SVM","Yes",booleanChoice());
		addParm("Feature Threshold","gamma","0.001");
		addParm("Optimization Grid Size","C","6");
		addParm("gamma","k Fold","0.1");
		addParm("C","Min Features","1");
		addParm("k Fold","Max Features","3");
		addParm("Min Features","","1");
		addParm("Max Features","","10");		  
	  }

      std::set<QString> fileFeatures;
      svm_node *x_space;
      svm_parameter param;     
      svm_problem prob;

      bool initialize(QWidget* parent);

      bool run(){
        Mesh *m1 = currentMesh();

        bool featureReduction = (parm("Feature Selection") == "Full Reduction" or parm("Feature Selection") == "Quick Reduction")? true : false;
        bool fullParameterSearch = (parm("Feature Selection") == "Full Reduction" or parm("Feature Selection") == "Full Inclusion")? true : false;
        bool featureInclusion = (parm("Feature Selection") == "Full Inclusion" or parm("Feature Selection") == "Quick Inclusion")? true : false;
        bool genetic = (parm("Feature Selection") == "Genetic Feature Selection")? true : false;

        return run(m1, parm("Filename"), stringToBool(parm("Optimize Parms")), parm("gamma").toDouble(), parm("C").toDouble(), parm("SVM Type"), parm("Kernel Type"), featureReduction,
          parm("Feature Threshold").toDouble(), parm("Optimization Grid Size").toInt(), fullParameterSearch, featureInclusion, genetic,parm("k Fold").toInt());
      }
      //bool backwardSelection(QString filename);
      //double do_cross_validation();
      //void optimizeSVMParameter();
      bool run(Mesh *m1, QString filename, bool optimizeParms, double pGamma, double pC, QString svmType, QString kernelType, bool featureReduction,
        double featureThreshold, int optGridSize, bool fullParameterSearch, bool featureInclusion, bool genetic, int kFold);

  };

  class LoadModelSVM : public Process
  {
    public:
      LoadModelSVM(const Process &process): Process(process) 
      {
		setName("Mesh/Cell Types/Classification/SVM/F Load Model");
		setDesc("Load an SVM model from a .model file");
		setIcon(QIcon(":/images/CellTypeRecognitionSpecification.jpeg")); 

		addParm("Filename","",""); 
	  }

      bool initialize(QWidget* parent);

      bool run(){
        return run(parm("Filename"));
      }

      bool run(QString filename);

  };


  class CellTypeRecognitionClassification : public Process
  {
    public:
      CellTypeRecognitionClassification(const Process &process): Process(process) 
      {
		setName("Mesh/Cell Types/Classification/SVM/G Classification");
		setDesc("Classify all cells of the current mesh by overwriting their parent label. Requires a trained SVM");
		setIcon(QIcon(":/images/CellTypeRecognitionSpecification.jpeg")); 

		addParm("Classification Threshold","Threshold for classification, lower values express more uncertainty","0.7");
		addParm("Recalculate Measures","Recalculate all necessary measures and overwrite the attr maps","No",booleanChoice());
	  }
      bool run(){
        Mesh *m1 = currentMesh();
        return run(m1, parm("Classification Threshold").toDouble(), stringToBool(parm("Recalculate Measures")));
      }
      bool run(Mesh *m1, double classifyThreshold, bool forceRecalc);

  };


  struct ClusterData{

  int trainingCount;
  std::vector<double> position;

  int associatedParent;

};

   // 
   static std::vector<std::vector<double> > kMeansClusters;
   static std::vector<ClusterData> clusterD;

  class KMeansTraining : public Process
  {
    public:
      KMeansTraining(const Process& process) : Process(process) 
      { 
		setName("Mesh/Cell Types/Classification/KMeans/E Training");
		setDesc("Train K Means Cluster");
		setIcon(QIcon(":/images/CellAtlas.png")); 
		
		addParm("Steps","Number of Optimization steps (set to -1 for steps until convergence)","-1");
		addParm("Convergence Threshold","Max percentage of changed cells to accept convergence","0.01");
		addParm("Use Scaled Measures","Use Scaled Measures","Yes",booleanChoice());
	  }

      bool run(){
        Mesh *m = currentMesh();
        return run(m, parm("Steps").toInt(), parm("Convergence Threshold").toDouble(), stringToBool(parm("Use Scaled Measures")));
      }
      bool run(Mesh *m, int steps, double threshold, bool useScaledMeasures);

  };

  class KMeansClassification : public Process
  {
    public:
      KMeansClassification(const Process& process) : Process(process) 
      {
		setName("Mesh/Cell Types/Classification/KMeans/E Training");
		setDesc("Train K Means Cluster");
		setIcon(QIcon(":/images/CellAtlas.png"));

		addParm("Steps","Number of Optimization steps (set to -1 for steps until convergence)","-1");
		addParm("Convergence Threshold","Max percentage of changed cells to accept convergence","0.01");
		addParm("Use Scaled Measures","Use Scaled Measures","Yes",booleanChoice());  
	  }

      bool run(){
        Mesh *m = currentMesh();
        return run(m);
      }
      
      bool run(Mesh *m);

  };

  class AugmentTrainingFile : public Process
  {
    public:
      AugmentTrainingFile(const Process& process) : Process(process) 
      { 
		setName("Mesh/Cell Types/Classification/D Training File Augmentation");
		setDesc("Balances the training data file by duplicating entries of underrepresented classes");
		setIcon(QIcon(":/images/CellAtlas.png"));

		addParm("Filename","Location of the training file","");
		addParm("Entry Repeats","Number of repeats of existing entries","0");
		addParm("Noise Factor","Amount of noise","0.05");
		addParm("Balance File","duplicate entries of underrepresented classes","Yes",booleanChoice());
		addParm("Shuffle File Entries","Shuffle File Entries","Yes",booleanChoice());
		addParm("Max Entries per Class","Max Entries per Class","100");  
	  }

      bool initialize(QWidget* parent);

      bool run(){
        Mesh *m = currentMesh();
        return run(m, parm("Filename"), parm("Entry Repeats").toInt(), parm("Noise Factor").toDouble(), stringToBool(parm("Balance File")), stringToBool(parm("Shuffle File Entries")), parm("Max Entries per Class").toInt());
      }
      bool run(Mesh *m, QString filename, int repeats, double noise, bool balance, bool shuffle, int maxEntries);

  };



  class ClusterCells : public Process
  {
    Q_OBJECT
    public:
      ClusterMap clMap;
      Ui_HeatMapClusterDialog ui;

      QString selectedX, selectedY, selectedHeat;
      bool redraw;
      double sigma;
      bool gaussianMode;

      std::map<int, int> backupParents;
      std::set<int> selectedCells;
      QString subfolder;

    ClusterCells(const Process& process) : Process(process) 
    {
	  setName("Mesh/Cell Types/Classification/Tools/Cell Property Map 2D");
	  setDesc("Launches the Cell Property Map GUI to cluster cells based on a 2D plot");
	  setIcon(QIcon(":/images/Cluster.png"));

	  addParm("Parms String","Parms String","");	
		
	}
      bool initialize(QWidget* parent);
      bool run(){
        Mesh *m1 = currentMesh();
        return run(m1);
      }
      bool run(Mesh *m);


      void getAttrMap(AttrMap<int, double>& data, QString measure);
      //bool updateScreen();

    protected slots:
    
      void setImage();
      void changeHeatmap();
      void changeSigma(double sigma);
      void setPosition(const QPoint& p);
      void setClusterLabel(QString label);
      void setReleasePosition(const QPoint& p);
      void setMousePosition(const QPoint& p);
      void fillMeasures();
      bool parseParmsString(QString parmsString);
      QString createParmsString();
      //void kMeans();
      void kMeansN();
      void resetClusters();
      void clustersFromParents();
      void changeMinMax();
      void restoreParents();
      void updateScreen();

    protected:
      QDialog* dlg;
      Point2d mousePos;
  };




}
#endif
