//
// This file is part of MorphoGraphX - http://www.MorphoGraphX.org
// Copyright (C) 2012-2015 Richard S. Smith and collaborators.
//
// If you use MorphoGraphX in your work, please cite:
//   http://dx.doi.org/10.7554/eLife.05864
//
// MorphoGraphX is free software, and is licensed under under the terms of the 
// GNU General (GPL) Public License version 2.0, http://www.gnu.org/licenses.
//
#include <CorrelationMaker.hpp>
#include <GraphUtils.hpp>
#include <MeshProcessHeatMap.hpp>
#include <math.h> 
#include <iostream>
#include <vector>
#include <fstream>
#include <limits>
#include <Progress.hpp>
#include <QFileDialog>
#include <QFileDialog>


namespace mgx
{

// return the attr map for a specified measure 
void CorrelationMaker::getAttrMap(AttrMap<int, double>& data, QString measure)
{
  Mesh* m1 = currentMesh();

  if(measure == "Current Heat Map"){
  	data.clear();
  	forall(auto p, m1->labelHeat()){
  	  data[p.first] = p.second;
  	}
  	return;
  } else if(measure == "Parent Label"){
  	data.clear();
  	forall(auto p, m1->parents()){
  	  data[p.first] = p.second;
  	}
  	return;
  }



  QString prefix = "Measure Label Double ";
  QString proName = "Mesh/Heat Map/Measures";
  QString measureAttr = measure;
  measureAttr.replace( " ", "" );

  data = m1->attributes().attrMap<int, double>(prefix+measureAttr);

  if(data.size() == 0){ // calculate it
    qDebug() << "Attr Map \t" << prefix+measure << " doesnt exist\n";
    QStringList parms;
    Process *pro = makeProcess(proName+measure);
    if(pro){
    qDebug() << "Process call\t" << proName+measure << "\n";
    getLastParms(proName+measure, parms);
    pro->run();
    data = m1->attributes().attrMap<int, double>(prefix+measureAttr);
    }
  }
 // clMap.addCellFeature(measure, data, selectedCells);
}


 // fill the combo boxes with all available measures and/or attr maps
void CorrelationMaker::fillMeasures(){

    ui.xDimCmb->clear();
    ui.yDimCmb->clear();
    ui.corrCmb->clear();
    QStringList defaultEntries;
    defaultEntries << "none";

    int idx = 0;
    int selIdxX = 0;
    int selIdxY = 0;
    int selIdxCorr = 0;
    // find measure processes
    QStringList procs = mgx::listProcesses();

    QStringList addProcs;
    addProcs << "Current Heat Map";
    std::set<QString> procsSet;
    QStringList addAttr;
    forall(const QString& name, procs) {
      mgx::ProcessDefinition* def = mgx::getProcessDefinition(name);
      QStringList list = def->name.split("/");

      QString subfolder = "Measures";

      if(list[0] == "Mesh" and list[1] == "Heat Map" and list[2] == subfolder){
        QString newProc = "/" + list[3] + "/" + list[4];
        addProcs << newProc;
        procsSet.insert(newProc);
        if(newProc == selectedX) selIdxX = idx;
        if(newProc == selectedY) selIdxY = idx;
        if(newProc == selectedCor) selIdxCorr = idx;
        idx++;
      }
      
    }

    // now the attr maps
    Mesh* m = currentMesh();
    Attributes *attributes;
    attributes = &m->attributes();
    QStringList attr = attributes->getAttrList(); 
    QStringList measureAttr;

    forall(const QString &text, attr) {
      QStringList list = text.split(" ");

      if(list.size() < 4) continue;
      if(list[0] != "Measure") continue;
      QString name;
      for(int i = 3; i<list.size(); i++){
        name = name + list[i] + " ";
      }
      name = name.trimmed();
      if(procsSet.find(name) == procsSet.end()){
        addProcs << name;
          if(name == selectedX) selIdxX = idx;
          if(name == selectedY) selIdxY = idx;
          if(name == selectedCor) selIdxCorr = idx;
          idx++;
      }
    }


    ui.corrCmb->addItems(addProcs);
    ui.xDimCmb->addItems(addProcs);
    ui.yDimCmb->addItems(addProcs);

    ui.corrCmb->setCurrentIndex(selIdxCorr);
    ui.xDimCmb->setCurrentIndex(selIdxX);
    ui.yDimCmb->setCurrentIndex(selIdxY);

  }


void CorrelationMaker::changeValue(){

  bin_num =  ui.spinBoxBins->value();
  bin_width = ui.spinBoxBinWidth->value();
  x_max =  ui.spinBoxXMax->value();
  x_min =  ui.spinBoxXMin->value();
  y_max =  ui.spinBoxYMax->value();
  y_min =  ui.spinBoxYMin->value();
  selectedX = ui.xDimCmb->currentText();
  selectedY = ui.yDimCmb->currentText();
  selectedCor = ui.corrCmb->currentText();

}


bool CorrelationMaker::initialize(QWidget* parent)
  {

    dlg = new QDialog(parent);
    ui.setupUi(dlg);
    this->dlg = dlg;
//    if(ui.radioButton2D->isChecked()) clMap.mode2D = true;
//    else clMap.mode2D = false;
    fillMeasures();

//    if(gaussianMode) ui.imageCellsButton->setChecked(true);
//    else ui.imageHeatButton->setChecked(true);

    //ui.preselectButton->setEnabled(false);

//    changeHeatmap();

    // check which is selected
    connect(ui.yDimCmb, SIGNAL(currentIndexChanged(QString)), this, SLOT(changeValue()));
    connect(ui.xDimCmb, SIGNAL(currentIndexChanged(QString)), this, SLOT(changeValue()));
    connect(ui.corrCmb, SIGNAL(currentIndexChanged(QString)), this, SLOT(changeValue()));


    // calculate measure & update heatmap
    ui.spinBoxBins->setValue(bin_num);
    ui.spinBoxBinWidth->setValue(bin_width);
    ui.spinBoxXMax->setValue(x_max);
    ui.spinBoxXMin->setValue(x_min);
    ui.spinBoxYMax->setValue(y_max);
    ui.spinBoxYMin->setValue(y_min);

    connect(ui.spinBoxBins, SIGNAL(valueChanged(double)), this, SLOT(changeValue()));
    connect(ui.spinBoxBinWidth, SIGNAL(valueChanged(double)), this, SLOT(changeValue()));
    connect(ui.spinBoxXMax, SIGNAL(valueChanged(double)), this, SLOT(changeValue()));
    connect(ui.spinBoxXMin, SIGNAL(valueChanged(double)), this, SLOT(changeValue()));
    connect(ui.spinBoxYMax, SIGNAL(valueChanged(double)), this, SLOT(changeValue()));
    connect(ui.spinBoxYMin, SIGNAL(valueChanged(double)), this, SLOT(changeValue()));
    

    if(dlg->exec() == QDialog::Accepted && parent){
      return OutputHistogram();
    }
    else
      return false;

  }

  // CorrelationMaker GUI process for HeatMap clustering
  bool CorrelationMaker::run(Mesh *m, QString parmsString)
  {
//    updateScreen();

    return true;
  
  }
  REGISTER_PROCESS(CorrelationMaker);


  bool CorrelationMaker::OutputHistogram(){
 
       AttrMap<int, double> dataX, dataY, dataZ;
        getAttrMap(dataX,selectedX); 
        getAttrMap(dataY,selectedY);
        getAttrMap(dataZ,selectedCor);
	std::cout<<"Creating histogram with: "<<selectedX.toUtf8().constData()<<" "<<selectedY.toUtf8().constData()<<" corr with "<<selectedCor.toUtf8().constData()<<std::endl;
	std::cout<<"And binning parameters: "<<x_min<<"-"<<x_max<<", "<<y_min<<"-"<<y_max<<" in "<<bin_num<<" bins with a window of "<<bin_width<<" bins"<<std::endl;

	HistoData histoXY = ComputeHistoInformation(dataX,dataY);
	HistoData histoXZ = ComputeHistoInformation(dataX,dataZ);
	CorrData corr = ComputeCorrelationInformation(dataX,dataY,dataZ);



	    QString filename;
//	    if(/*filename.isEmpty() and*/ parent)
	      filename = QFileDialog::getSaveFileName(0, "Choose spreadsheet file to save", QDir::currentPath(), "CSV files (*.csv)");
	    if(filename.isEmpty())
	      return false;
	    if(!filename.endsWith(".csv", Qt::CaseInsensitive))
	      filename += ".csv";
//	    parms[0] = filename;

	    QFile file(filename);
	    if(!file.open(QIODevice::WriteOnly)) {
	      setErrorMessage(QString("File '%1' cannot be opened for writing").arg(filename));
	      return false;
	    }
	    QTextStream out(&file);
	    bool success = PrintHisto(histoXY,&out,selectedX,selectedY);
	    out<<endl;
	    success = PrintHisto(histoXZ,&out,selectedX,selectedCor) && success;
	    out<<endl;
	    out<<"Binwise correlations and ratios of "<<selectedY.toUtf8().constData()<<" and "<<selectedCor.toUtf8().constData()<<" following binning by "<<selectedX.toUtf8().constData()<<endl;
	    out<<"Bin number, Correlation, Avgerage ratio, Count, Interval_start, Interval_end"<<endl;
	    for(int i=0;i<(int)corr.bins.size();i++){
			int lower_bin = i-bin_width <0 ? 0 : i-bin_width;
			int upper_bin = i+bin_width+1 >= bin_num ? bin_num : i+bin_width+1;	    	
	    	out<<i<<","<<corr.bins[i]<<","<<corr.ratios[i]<<","<<corr.counts[i]<<","<<corr.bin_bounds[lower_bin]<<","<<corr.bin_bounds[upper_bin]<<endl;
	    }


	    return success;
}

bool CorrelationMaker::PrintHisto(HistoData h,QTextStream *out_p,QString binned_by, QString binned){
	//QTextStream out = (*outp);
	std::vector<double> bins = h.bins;
//	std::vector<std::vector<int>> bin_labels;
	std::vector<double> bin_bounds = h.bin_bounds;
	std::vector<double> bin_variance = h.bin_variance;
//	std::vector<double> bin_corXY;
//	std::vector<double> bin_corYZ;	
//	std::vector<double> bin_corXZ;	
	std::vector<double> counts = h.counts;
	double under_bin = h.under_bin;
	double under_count= h.under_count;
	double over_bin= h.over_bin;
	double over_count= h.over_count;
	double under_y= h.under_y;
	double under_y_count = h.under_y_count;
	double over_y= h.over_y;
	double over_y_count = h.over_y_count;
	int in_x_not_y = h.in_first_not_second;

	//(*out_p)put CSV file here
	(*out_p)<<"Values of "<<binned.toUtf8().constData()<<" binned by "<<binned_by.toUtf8().constData()<<endl;//<<" and corrrelated with "<<selectedCor.toUtf8().constData()<<endl;
	(*out_p)<<"Bin number, Accumulated value, Average, Std-dev, Count, Interval start, Interval end"<<endl;
	for(int i=0;i<bin_num;i++){
		int lower_bin = i-bin_width <0 ? 0 : i-bin_width;
		int upper_bin = i+bin_width+1 >= bin_num ? bin_num : i+bin_width+1;
		(*out_p)<<i<<","<<bins[i]<<","<<(counts[i]>0? bins[i]/counts[i] : 0)<<","<<(counts[i]>0? sqrt(bin_variance[i]) : 0)<<","<<counts[i]<<","<<bin_bounds[lower_bin]<<","<<bin_bounds[upper_bin]<<endl;
	}
	(*out_p)<<"Below min bin"<<endl;
	(*out_p)<<" "<<","<<under_bin<<","<<(under_count>0? under_bin/under_count : 0)<<","<<"NA ,"<<under_count<<","<<"-inf"<<","<<bin_bounds[0]<<endl;

	(*out_p)<<"Over max bin"<<endl;
	(*out_p)<<" "<<","<<over_bin<<","<<(over_count>0? over_bin/over_count : 0)<<","<<"NA ,"<<over_count<<","<<bin_bounds[bin_num]<<","<<"+inf"<<endl;

	(*out_p)<<"Below min thresh"<<endl;
	(*out_p)<<" "<<","<<under_y<<","<<(under_y_count>0? under_y/under_y_count : 0)<<","<<"NA ,"<<under_y_count<<","<<"-inf"<<","<<y_min<<endl;

	(*out_p)<<"Over max thresh"<<endl;
	(*out_p)<<" "<<","<<over_y<<","<<(over_y_count>0? over_y/over_y_count : 0)<<","<<"NA ,"<<over_y_count<<","<<y_max<<","<<"+inf"<<endl;

	(*out_p)<<"Num of cells in "<<selectedX.toUtf8().constData()<<" not binned,"<<in_x_not_y<<endl;
	return true;
  }



HistoData CorrelationMaker::ComputeHistoInformation(AttrMap<int, double> dataX, AttrMap<int, double> dataY){


	std::vector<double> bins;
//	std::vector<std::vector<int>> bin_labels;
	std::vector<double> bin_bounds;
	std::vector<double> bin_variance;
//	std::vector<double> bin_corXZ;	
	std::vector<double> counts;

	double under_bin=0;
	double under_count=0;
	double over_bin=0;
	double over_count=0;
	double under_y=0;
	double under_y_count = 0;
	double over_y=0;
	double over_y_count =0;
	int in_x_not_y =0;
	bins.resize(bin_num);
	bin_variance.resize(bin_num);
	bin_bounds.resize(bin_num+1);
	counts.resize(bin_num);
	if(x_min == x_max){
		x_min = 1e20;
		x_max = -1e20;
		forall(IntDouble x,dataX){
			if(x.second>x_max)
				x_max = x.second;
			if(x.second<x_min)
				x_min = x.second;
		}
		//this will miss max and min value sometimes
		x_min -= 0.01*(x_min);
		x_max += 0.01*(x_max);
	}
	if(y_min == y_max){
		y_min = -1e20;
		y_max = 1e20;
	}

	for(int i=0;i<bin_num;i++){
	  float cur_bound = ((float)i/(float)bin_num)*(x_max-x_min)+x_min;
	  bin_bounds[i]=cur_bound;
	  counts[i]=0;	
	  bins[i] = 0;
	  bin_variance[i] = 0;
	}
	bin_bounds[bin_num] = x_max;

	//NEED TO COUNT PRESENT IN Y BUT NOT X
	//OPTION FOR SUM OR AVERAGE
	//OPTION FOR WHAT TO DO WITH BIN_BOUNDS FOR FIRST AND LAST SAMPLE
	//NEED TO GET CSV FILE NAME!
//	for(int i=0;i<NUM_X;i++){
	forall(IntDouble x,dataX){
		float x_val = x.second;
		if(dataY.find(x.first) == dataY.end() || std::isnan(dataY[x.first]) || std::isinf(dataY[x.first])){
			in_x_not_y++;
			continue;
		}
		float y_val =dataY[x.first] ;
		if(y_val < y_min){
			under_y+=y_val;
			under_y_count++;
			continue;
		}
		if(y_val > y_max){
			over_y+=y_val;
			over_y_count++;
			continue;
		}

		if(x_val<bin_bounds[0]){
			under_bin+=y_val;
			under_count++;
		}
		else if(x_val>=bin_bounds[bin_num]){
			over_bin+=y_val;
			over_count++;
		}
		else{
			for(int j=0;j<bin_num;j++){
				int lower_bin = j-bin_width <0 ? 0 : j-bin_width;
				int upper_bin = j+bin_width+1 >= bin_num ? bin_num : j+bin_width+1;
				if(x_val>=bin_bounds[lower_bin] && x_val<bin_bounds[upper_bin]){
					bins[j]+=y_val;
					//Construct first term of variance
					bin_variance[j]+=y_val*y_val;
					counts[j]++;
					//bin_labels[j].push_back(x.first);
				}
			}
		}

	}

	//Compute the variance of each bin
	for(int i=0;i<bin_num;i++){			
		if(counts[i]<=0)
			continue;
		double average = bins[i]/counts[i];
		bin_variance[i]/=counts[i];
		bin_variance[i]-=average*average;
	}
	HistoData to_return;
	to_return.bins = bins;
	to_return.bin_bounds = bin_bounds;
	to_return.bin_variance = bin_variance;
	to_return.counts = counts;
	to_return.under_bin=under_bin;
	to_return.under_count=under_count;
	to_return.over_bin=over_bin;
	to_return.over_count=over_count;
	to_return.under_y=under_y;
	to_return.under_y_count = under_y_count;
	to_return.over_y=over_y;
	to_return.over_y_count =over_y_count;
	to_return.in_first_not_second =in_x_not_y;
	return to_return;
}



//NEED INDEPENDENT MAX/MIN FOR Z-DATA
//APPLIED TO BOTH Y AND Z FOR THE MOMENT
CorrData CorrelationMaker::ComputeCorrelationInformation(AttrMap<int, double> dataX, AttrMap<int, double> dataY, AttrMap<int, double> dataZ){

	std::vector<double> bins;
	std::vector<double> binsY;
	std::vector<double> binsZ;
	std::vector<double> ratios;
//	std::vector<std::vector<int>> bin_labels;
	std::vector<double> bin_bounds;
	std::vector<double> bin_varianceY;
	std::vector<double> bin_varianceZ;
//	std::vector<double> bin_corXZ;	
	std::vector<double> counts;

	double under_bin=0;
	double under_count=0;
	double over_bin=0;
	double over_count=0;
	double under_y=0;
	double under_y_count = 0;
	double over_y=0;
	double over_y_count =0;
	int in_x_not_y =0;
	bins.resize(bin_num);
	binsY.resize(bin_num);
	binsZ.resize(bin_num);
	bin_varianceY.resize(bin_num);
	bin_varianceZ.resize(bin_num);
	bin_bounds.resize(bin_num+1);
	counts.resize(bin_num);
	ratios.resize(bin_num);
	if(x_min == x_max){
		x_min = 1e20;
		x_max = -1e20;
		forall(IntDouble x,dataX){
			if(x.second>x_max)
				x_max = x.second;
			if(x.second<x_min)
				x_min = x.second;
		}
		//this will miss max and min value sometimes
		x_min -= 0.01*(x_min);
		x_max += 0.01*(x_max);
	}
	if(y_min == y_max){
		y_min = -1e20;
		y_max = 1e20;
	}
	double z_min = y_min;
	double z_max = y_max;

	for(int i=0;i<bin_num;i++){
	  float cur_bound = ((float)i/(float)bin_num)*(x_max-x_min)+x_min;
	  bin_bounds[i]=cur_bound;
	  counts[i]=0;	
	  bins[i] = 0;
	  binsY[i] = 0;
	  binsZ[i] = 0;
	  ratios[i] = 0;
	  bin_varianceY[i] = 0;
	  bin_varianceZ[i] = 0;
	}
	bin_bounds[bin_num] = x_max;

	//NEED TO COUNT PRESENT IN Y BUT NOT X
	forall(IntDouble x,dataX){
		float x_val = x.second;
		if(dataY.find(x.first) == dataY.end() || std::isnan(dataY[x.first]) || std::isinf(dataY[x.first]) ||
			dataZ.find(x.first) == dataZ.end() || std::isnan(dataZ[x.first]) || std::isinf(dataZ[x.first])
			){
			in_x_not_y++;
			continue;
		}
		float y_val =dataY[x.first];
		float z_val =dataZ[x.first];
		if(y_val < y_min || z_val < z_min){
			under_y+=y_val;
			under_y_count++;
			continue;
		}
		if(y_val > y_max || z_val > z_max){
			over_y+=y_val;
			over_y_count++;
			continue;
		}

		if(x_val<bin_bounds[0]){
			under_bin+=y_val;
			under_count++;
		}
		else if(x_val>=bin_bounds[bin_num]){
			over_bin+=y_val;
			over_count++;
		}
		else{
			for(int j=0;j<bin_num;j++){
				int lower_bin = j-bin_width <0 ? 0 : j-bin_width;
				int upper_bin = j+bin_width+1 >= bin_num ? bin_num : j+bin_width+1;
				if(x_val>=bin_bounds[lower_bin] && x_val<bin_bounds[upper_bin]){
					bins[j]+=z_val*y_val;
					binsY[j]+=y_val;
					binsZ[j]+=z_val;
					ratios[j]+=y_val/z_val;
					//Construct first term of variance
					bin_varianceY[j]+=y_val*y_val;
					bin_varianceZ[j]+=z_val*z_val;
					counts[j]++;
					//bin_labels[j].push_back(x.first);
				}
			}
		}

	}

	//Compute the variance of each bin
	for(int i=0;i<bin_num;i++){			
		if(counts[i]<=0){
			bins[i]=0;
			continue;
		}
		double averageY = binsY[i]/counts[i];
		double averageZ = binsZ[i]/counts[i];
		double corr = bins[i]-counts[i]*averageY*averageZ;
		bin_varianceY[i]/=counts[i];
		bin_varianceY[i]-=averageY*averageY;
		bin_varianceZ[i]/=counts[i];
		bin_varianceZ[i]-=averageZ*averageZ;
		//Compute final correlation
		bins[i] = corr/sqrt(bin_varianceZ[i]*bin_varianceY[i])/(counts[i]);
		ratios[i]/=counts[i];
	}

	CorrData to_return;
	to_return.bins = bins;
	to_return.ratios = ratios;
	to_return.bin_bounds = bin_bounds;
	//to_return.bin_variance = bin_variance;
	to_return.counts = counts;
	to_return.under_bin=under_bin;
	to_return.under_count=under_count;
	to_return.over_bin=over_bin;
	to_return.over_count=over_count;
	to_return.under_y=under_y;
	to_return.under_y_count = under_y_count;
	to_return.over_y=over_y;
	to_return.over_y_count =over_y_count;
	to_return.in_first_not_second =in_x_not_y;
	return to_return;
}



}

