//
// This file is part of 3DCellAtlas. 
// Copyright (C) 2015 George W. Bassel and collaborators.
//
// If you use 3DCellAtlas in your work, please cite:
//   http://dx.doi.org/10.1105/tpc.15.00175
//
// 3DCellAtlas is an AddOn for MorphoGraphX - http://www.MorphoGraphX.org
// Copyright (C) 2012-2015 Richard S. Smith and collaborators.
//
// 3DCellAtlas and MorphoGraphX are free software, and are licensed under under the terms of the 
// GNU General (GPL) Public License version 2.0, http://www.gnu.org/licenses.
// 
#include "ClusterMap.hpp"
//#include "CellAtlasUtils.hpp"
#include "Triangulate.hpp"

namespace mgx
{

  Point2d minMaxFromMap(AttrMap<int,double>& inputMap, std::set<int>& keys)
  {
    Point2d res(0,0);

    double min = 1E20;
    double max = -1E20;


    forall(int k, keys){
      max = inputMap[k]>max? inputMap[k] : max;
      min = inputMap[k]<min? inputMap[k] : min;
    }

    double diff = max-min;

    if(min<1E20 and max>-1E20) res = Point2d(min-0.05*diff,max+0.05*diff);
    return res;

  }

  void ClusterMap::initDataArray()
  {
   std::vector<std::vector<double> > viewDataNew(gridSize);
   forall(std::vector<double>& v, viewDataNew){
     std::vector<double> vv(gridSize);
     v = vv;
   }
   viewData = viewDataNew;
  }



  void ClusterMap::createCellView()
  {
 cout << "cell view " << cellLabels.size() << "/" << cluster.size() << endl;
    gridSize = 480;
    gridFactor = 1;

    if(largeDots){
      gridSize = 160;
      gridFactor = 3;
    }
    initDataArray();
 

    // find measure1 and measure2 in idx map
    int idx1 = attrIdxMap[selectedX];
    int idx2 = attrIdxMap[selectedY];
    if(idx1==0 or idx2==0) return;

    // vector idx starts from 0
    idx1--;
    idx2--;

    // find min and max value in maps
    Point2d minMax1 = xMinMax;//minMaxFromMap(cellFeatures[idx1], cellLabels);
    Point2d minMax2 = yMinMax;//minMaxFromMap(cellFeatures[idx2], cellLabels);

    if(!customMinMax){
      minMax1 = minMaxFromMap(cellFeatures[idx1], cellLabels);
      minMax2 = minMaxFromMap(cellFeatures[idx2], cellLabels);
      xMinMax = minMax1;
      yMinMax = minMax2;
    }



    // bin the values

    // step size of the heatmap
    //double step1 = (minMax1.y()-minMax1.x())/(double)gridSize;
    //double step2 = (minMax2.y()-minMax2.x())/(double)gridSize;

    // reset viewdata array
    //for(int i = 0; i<gridSize; i++){
    //  for(int j = 0; j<gridSize; j++){
    //    viewData[i][j] = 0;
    //  }
    //}
 cout << "cell view " << endl;
    // fill array with new data
    forall(int label, cellLabels){
      int binX = interpolateArrayIndex(cellFeatures[idx1][label],minMax1.x(),minMax1.y(),gridSize-1);
      int binY = interpolateArrayIndex(cellFeatures[idx2][label],minMax2.x(),minMax2.y(),gridSize-1);

      viewData[binX][binY]++;
      //std::cout << "fill  " << binX <<"/" << binY << std::endl;
      imageLabelMap[std::make_pair(binX,binY)] = label;
      //std::cout << "f2  " << cellFeatures[idx1][label] <<"/" << minMax1.y() << "//" << minMax1.x() << std::endl;
    }

    if(!customMinMax) calcHeatMax();
  }

  void ClusterMap::createHeatMap2D() 
  {
    gridSize = 160;
    gridFactor = 3;
    initDataArray();

    // find measure1 and measure2 in idx map
    int idx1 = attrIdxMap[selectedX] -1;
    int idx2 = attrIdxMap[selectedY] -1;
    int idxHeat = attrIdxMap[selectedHeat] -1;
    if(idx1==-1 or idx2==-1) return;

    std::cout << "2d " << idx1 << "/" << idx2 << "//" << idxHeat << "///" << cellFeatures.size() << std::endl;


    // find min and max value in maps
    Point2d minMax1 = xMinMax;//minMaxFromMap(cellFeatures[idx1], cellLabels);
    Point2d minMax2 = yMinMax;//minMaxFromMap(cellFeatures[idx2], cellLabels);

    if(!customMinMax){
      minMax1 = minMaxFromMap(cellFeatures[idx1], cellLabels);
      minMax2 = minMaxFromMap(cellFeatures[idx2], cellLabels);
      xMinMax = minMax1;
      yMinMax = minMax2;
      //std::cout << "in" << std::endl;
    }


    // bin the values

    // step size of the heatmap
    double step1 = (minMax1.y()-minMax1.x())/(double)gridSize;
    double step2 = (minMax2.y()-minMax2.x())/(double)gridSize;

    //std::cout << "2d " << step1 << "/" << step2 << std::endl;

    // reset viewdata array
    //for(int i = 0; i<gridSize; i++){
    //  for(int j = 0; j<gridSize; j++){
    //    viewData[i][j] = 0;
    //  }
    //}

    //std::cout << "2d " << step1 << "/" << step2 << std::endl;

    // fill array with new data
    forall(int label, cellLabels){
      // calc center of gaussian
      int binX = interpolateArrayIndex(cellFeatures[idx1][label],minMax1.x(),minMax1.y(),gridSize-1);
      int binY = interpolateArrayIndex(cellFeatures[idx2][label],minMax2.x(),minMax2.y(),gridSize-1);

      // cut off radius around the center

// cut off radius around the center
      int radiusX = std::max(10.0,sigma*2);
      int radiusY = std::max(10.0,sigma*2);

      if(binX == 0 or binX == gridSize-1) radiusX = 0;
      if(binY == 0 or binY == gridSize-1) radiusY = 0;

      for(int i = -radiusX; i<=radiusX; i++){
        for(int j = -radiusY; j<=radiusY; j++){
          // calc gaussian
          if(binX+i>=0 and binX+i<gridSize and binY+j>=0 and binY+j<gridSize){
            double xValue = interpolateArrayIndex(cellFeatures[idx1][label]+i*step1,minMax1.x(),minMax1.y(),gridSize-1);
            double yValue = interpolateArrayIndex(cellFeatures[idx2][label]+j*step2,minMax2.x(),minMax2.y(),gridSize-1);
            if(idxHeat == -1)
              viewData[binX+i][binY+j] += gauss2D(xValue, yValue, binX, binY, sigma, sigma);
            else
              viewData[binX+i][binY+j] += cellFeatures[idxHeat][label] * gauss2D(xValue, yValue, binX, binY, sigma, sigma);
          }
        }
      }
    }
    if(!customMinMax) calcHeatMax();
    findMaximaHeatMap();
    relateCellsToMaxima();
    std::cout << "done createHeatMap2D" << std::endl;
  }

  void ClusterMap::calcHeatMax()
  {
    int size = gridSize;
    double maxValue = -HUGE_VAL;//body.high;
    double sum, count;
    std::map<int,Point2d> maxVec;// = body.maximaHeatMapAll;
  
    for(int i = 0; i < size; i++) {
      for(int j = 0; j < size; j++) {
        if(maxValue < viewData[i][j]) maxValue = viewData[i][j];
        if(viewData[i][j] > 0.1){
          sum+=viewData[i][j];
          count++;
        }
      }
    }
    heatMax = maxValue/5.;    
  }

void ClusterMap::setActiveMeasures(QString measure1, QString measure2, QString measureHeat)
  {

    std::cout << "set active " << std::endl;
    if(mode2D){
      Point2d xMinMaxOld = xMinMax;
      Point2d yMinMaxOld = yMinMax;
      forall(ManualCluster& c, cluster){
        if(selectedX != measure1){
          xMinMax = minMaxFromMap(cellFeatures[attrIdxMap[measure1]-1], cellLabels);
          double value1 = c.val[attrIdxMap[selectedX]-1];
          double posX  = realToImageCoord(xMinMaxOld, value1, gridSize);
          c.val[attrIdxMap[selectedX]-1] = 0;
          c.val[attrIdxMap[measure1]-1] = imageToRealCoord(xMinMax, posX, gridSize);
          //std::cout << "x new " << imageToRealCoord(xMinMax, posX, gridSize) << std::endl;
        }
        if(selectedY != measure2){
          yMinMax = minMaxFromMap(cellFeatures[attrIdxMap[measure2]-1], cellLabels);
          std::cout << "y " << std::endl;
          double value2 = c.val[attrIdxMap[selectedY]-1];
          double posY  = realToImageCoord(yMinMaxOld, value2, gridSize);
          c.val[attrIdxMap[selectedY]-1] = 0;
          c.val[attrIdxMap[measure2]-1] = imageToRealCoord(yMinMax, posY, gridSize);
        }
      }

    } else {
      std::cout << "nD " << std::endl;
    }

    selectedX = measure1;
    selectedY = measure2;
    selectedHeat = measureHeat;

  }

  QImage ClusterMap::createImage(bool heatMapMode, bool hide0, bool highlightNon0, int backgroundMode){

    // init variables
    int size = gridSize;
    std::map<int,Point2d> maxVec;// = body.maximaHeatMapAll;

    // init image and colours
    int factor = gridFactor; // scaling factor
    QImage image( size*factor, size*factor, QImage::Format_ARGB32 );
    QRgb value;
    QRgb white = qRgba(255,255,255,255);
    QRgb grey1 = qRgba(200,200,200,255);
    //QRgb purple = qRgba(120,0,120,255);
    QRgb black = qRgba(0,0,0,255);

    value = grey1;

    int yOffset = size*factor-1;

    setBackground();

    // draw the heatmap and (if option is checked) small points for the cells
    for(int i = 0; i < size; i++) {
      for(int j = 0; j < size; j++) {
        //std::cout << "i/j/data " << i << "/" << j << "/" << data[i][j] << std::endl;
        for(int k = 0; k < factor; k++) {
          for(int l = 0; l < factor; l++) {
            if(heatMapMode) value = calcRGB(heatMax,viewData[i][j]);
            else {
              if(showParentLabels){
                value = getColorFromLabel(parentLabels[imageLabelMap[std::make_pair(i,j)]]);
              } else {
                value = calcRGB(heatMax,viewData[i][j]);
              }
              if(viewData[i][j] == 0 or (hide0 and parentLabels[imageLabelMap[std::make_pair(i,j)]]==0)) value = black;
            }
            if(value == black and backgroundMode>0){
              value = getColorFromLabel(imageBackgroundMap[std::make_pair(i,j)]);
              value = qRgba(qRed(value)/4,qGreen(value)/4,qBlue(value)/4,255);
            }
            image.setPixel(i*factor+k, yOffset-(j*factor+l), value);
          }
        }
      }
    }

    if(!heatMapMode) return image;

    // draw a small cross on all heatmap maxima
    int maxSize = maximaHeatMap.size();
    for(int i = 0; i < maxSize; i++) {
      Point2d currentPoint = maximaHeatMap[i];
      drawCross(image, currentPoint[0]*factor+factor/2, yOffset-currentPoint[1]*factor-factor/2, size*factor, 2, white);
    }

    return image;
  }

  double imageToRealCoord(Point2d minMax, double value, double gridSize)
  {
    double diff = minMax.y()-minMax.x();
    return value/gridSize * diff + minMax.x();

  }

  double realToImageCoord(Point2d minMax, double value, double gridSize)
  {
    double diff = minMax.y()-minMax.x();
    if(diff == 0) return 0;
    return ((value - minMax.x())/diff * gridSize);
  }

  void ClusterMap::addCluster(QString measure1, QString measure2, double value1, double value2)
  {
    ManualCluster newCluster;

    newCluster.val[attrIdxMap[measure1]-1] = imageToRealCoord(xMinMax, value1, gridSize);
    newCluster.val[attrIdxMap[measure2]-1] = imageToRealCoord(yMinMax, value2, gridSize);

    newCluster.label = 0;

    cluster.push_back(newCluster);
    std::cout << "new cluster " << value1 << "/" << value2 << std::endl;



    std::cout << "new cluster " << imageToRealCoord(xMinMax, value1, gridSize) << "/" << imageToRealCoord(yMinMax, value2, gridSize) << std::endl;


  }
  void ClusterMap::updateCluster(int clusterIdx, QString measure, double value, bool xCoord)
  {
    if(xCoord) cluster[clusterIdx].val[attrIdxMap[measure]-1] = imageToRealCoord(xMinMax, value, gridSize);
    else cluster[clusterIdx].val[attrIdxMap[measure]-1] = imageToRealCoord(yMinMax, value, gridSize);
  }

// checks whether there exists a maximum in maxVec close to the pos coordinates
// writes the result in maxIdx
bool ClusterMap::nearbyManualCluster(QString measure1, QString measure2, Point2d pos, int& maxIdx)
  {
    double minDis = 1E20;

    for(size_t i =0; i<cluster.size(); i++){
      ManualCluster& c = cluster[i];
    //forall(const ManualCluster& c, cluster){
      double realX = c.val[attrIdxMap[measure1]-1];
      double realY = c.val[attrIdxMap[measure2]-1];
      double imageX = interpolateArrayIndex(realX, xMinMax.x(), xMinMax.y(), gridSize-1);
      double imageY = interpolateArrayIndex(realY, yMinMax.x(), yMinMax.y(), gridSize-1);
      Point2d posImage(imageX,imageY);
      double dis = norm(pos - posImage);
      if(dis < minDis){
        minDis = dis;
        maxIdx = i;
      }
    }

    if(minDis < 10){
      return true;
    } else {
      maxIdx = -1;
      return false;
    }
  }

  // return the attr value for a specific label and measure
  double ClusterMap::getValue(QString measure, int label)
  {

    int idx = attrIdxMap[measure];
    if(idx==0) return 0.;
    idx--;
    return cellFeatures[idx][label];
  }

  double ClusterMap::getClusterValue(QString measure, int clusterIdx)
  {

    int idx = attrIdxMap[measure];
    if(idx==0) return 0.;
    idx--;
    return cluster[clusterIdx].val[idx];
  }

  void ClusterMap::getNearestCluster(Point2d imageP, int idx1, int idx2, double& minDis, int& minIdx)
  {
    minDis = HUGE_VAL;
    minIdx = -1;

    for(size_t i =0; i<cluster.size(); i++){
      ManualCluster& c = cluster[i];
      double realCX = c.val[idx1];
      double realCY = c.val[idx2];
//cout << "cl " << realCX << "/" << realCY << endl;
      double imageCX = realToImageCoord(xMinMax, realCX, gridSize-1);
      double imageCY = realToImageCoord(yMinMax, realCY, gridSize-1);

      Point2d pCluster(imageCX,imageCY);
//cout << "cl " << pCluster << endl;
      double dis = norm(imageP - pCluster);
      if(dis<minDis){
        minDis = dis;
        minIdx = i;
      }
    }

  }


  void ClusterMap::setParentsCellView()
  {
    cout << "cell view " << cellLabels.size() << "/" << cluster.size() << endl;
    std::map<int,int> newParents;

    // find measure1 and measure2 in idx map
      int idx1 = attrIdxMap[selectedX];
      int idx2 = attrIdxMap[selectedY];
      if(idx1==0 or idx2==0) return;

      // vector idx starts from 0
      idx1--;
      idx2--;


    forall(int l, cellLabels){



//cout << "cell label " << l << "/" << attrIdxMap[measure1] << "/" << attrIdxMap[measure2] << "/" << cellFeatures.size() << endl;
      double realX = getValue(selectedX, l);//cellFeatures[idx1][l];
      double realY = getValue(selectedY, l);//cellFeatures[idx2][l];
//cout << "c " << realX << "/" << realY << endl;
      double imageX = realToImageCoord(xMinMax, realX, gridSize-1);
      double imageY = realToImageCoord(yMinMax, realY, gridSize-1);

      Point2d pCell(imageX,imageY);

      double minDis = HUGE_VAL;
      int minIdx = -1;
      getNearestCluster(pCell, idx1, idx2, minDis, minIdx);
      /*
//cout << "c " << pCell << endl;
      for(size_t i =0; i<cluster.size(); i++){
        ManualCluster& c = cluster[i];
        double realCX = c.val[idx1];
        double realCY = c.val[idx2];
//cout << "cl " << realCX << "/" << realCY << endl;
        double imageCX = realToImageCoord(xMinMax, realCX, gridSize-1);
        double imageCY = realToImageCoord(yMinMax, realCY, gridSize-1);

        Point2d pCluster(imageCX,imageCY);
//cout << "cl " << pCluster << endl;
        double dis = norm(pCell - pCluster);
        if(dis<minDis){
          minDis = dis;
          minIdx = i;
        }
      }
*/

      if(minIdx > -1)
        newParents[l] = cluster[minIdx].label+1;

        //cout << "parent " << minIdx << endl;

    }


    parentLabels = newParents;
  }

  void ClusterMap::setBackground()
  {
    // find measure1 and measure2 in idx map
      int idx1 = attrIdxMap[selectedX];
      int idx2 = attrIdxMap[selectedY];
      if(idx1==0 or idx2==0) return;

      // vector idx starts from 0
      idx1--;
      idx2--;

    for(int i = 0; i < gridSize; i++) {
      for(int j = 0; j < gridSize; j++) {
        Point2d pCell(i,j);

        double minDis = HUGE_VAL;
        int minIdx = -1;
        getNearestCluster(pCell, idx1, idx2, minDis, minIdx);
        if(minIdx>-1)
          imageBackgroundMap[std::make_pair(i,j)] = cluster[minIdx].label+1;
      }
    }

  }

  void ClusterMap::setParentsHeatMapView()
  {
    
    std::map<int,int> newParents;

    // relate max in heatmap to manual clusters

    for(size_t i =0; i<maximaHeatMap.size(); i++){
      Point2d maxHM = maximaHeatMap[i];

      double minDis = HUGE_VAL;
      int minIdx = -1;


      for(size_t j =0; j<cluster.size(); j++){
        //ManualCluster& c = cluster[j];
        Point2d clusterPos(getClusterValue(selectedX,j),getClusterValue(selectedY,j));

        double imageCX = realToImageCoord(xMinMax, clusterPos.x(), gridSize-1);
        double imageCY = realToImageCoord(yMinMax, clusterPos.y(), gridSize-1);
        Point2d clusterPos2(imageCX,imageCY);

        double dis = norm(clusterPos2 - maxHM);
        if(dis<minDis){
          minDis = dis;
          minIdx = j;
        }
      }
      maximaIdxManualClusterIdxMap[i]=minIdx;
      //std::cout << "cl " << i << "/" << minIdx <<  "/" << maximaHeatMap[minIdx] << std::endl;

    }


    // relate single cells to manual cluster
    forall(int l, cellLabels){
      //std::cout << "la " << l << "/" << cellLabelMaximaIdxMap[l] <<  "/" << maximaIdxManualClusterIdxMap[cellLabelMaximaIdxMap[l]] << "/" << cluster[maximaIdxManualClusterIdxMap[cellLabelMaximaIdxMap[l]]].label << std::endl;
      newParents[l] = cluster[maximaIdxManualClusterIdxMap[cellLabelMaximaIdxMap[l]]].label+1;
    }
    parentLabels = newParents;
  }


// find all local maxima in a cellMorphLandscape (=heatmap) and save their positions
void ClusterMap::findMaximaHeatMap()
{
  std::vector<Point2d> maxima;

  int size = gridSize;

  // check every point of the landscape
  for(int i=0; i<size; i++){
    for(int j=0; j<size; j++){
      bool isMax = true;
      for(int in=-1; in<=1; in++){
        for(int jn=-1; jn<=1; jn++){
          if(i+in >= 0 and i+in < size and j+jn >= 0 and j+jn < size and (jn!=0 or in!=0))
            if(viewData[i][j] <= viewData[i+in][j+jn]) isMax = false;
        }
      }
      if(isMax){
        Point2d maxPos(i,j);
        maxima.push_back(maxPos);
      }
    }
  }

  maximaHeatMap = maxima;
  }



// relate each cell in the root to a maxima in the heatmap by following the steepest gradient upwards until a maximum is reached
void ClusterMap::relateCellsToMaxima()
{
 
  std::map<int,double> nearestMaximum;
  //std::set<int> rootArea = body.cells;
  int numMax = maximaHeatMap.size();
  //double minX = xMinMax.x();/// body.heatMapMinMax[0];
  //double maxX = xMinMax.y();//body.heatMapMinMax[1];
  //double minY = yMinMax.x();//body.heatMapMinMax[2];
  //double maxY = yMinMax.y();//body.heatMapMinMax[3];

  //int numCells = (*data).size();

  //for(int i = 0; i<numCells; i++){
  forall(int currentLabel, cellLabels){
   // int currentLabel = labelMap[i];//(*data)[i].cellLabel;//rootData.uniqueLabels[i];
    //if(rootArea.find(currentLabel) != rootArea.end()){
      Point2d currentPos;
      double minDis = 1E20;
      int minNum = 0;

      double imageX = realToImageCoord(xMinMax, getValue(selectedX,currentLabel), gridSize-1);
      double imageY = realToImageCoord(yMinMax, getValue(selectedY,currentLabel), gridSize-1);

      currentPos[0] = imageX;//interpolateArrayIndex(,minX,maxX,gridSize-1);
      currentPos[1] = imageY;//interpolateArrayIndex(getValue(measure1,currentLabel),minY,maxY,gridSize-1);

      for(int j = 0; j<numMax; j++){ // find nearest maximum
        Point2d clusterPos = maximaHeatMap[j];

        if(minDis > norm(currentPos - clusterPos)){
          minDis = norm(currentPos - clusterPos);
          minNum = j;
        }
      }

      int stopCounter = 0;

      // if nearest maximum too far, then follow gradient
      while(minDis > sigma and stopCounter<100){
        stopCounter++;
        double maxValue = -HUGE_VAL;
        int changeX = 0;
        int changeY = 0;
        // find highest value in neighborhood and go there
        for(int i = -1; i<=1; i++){
          for(int j = -1; j<=1; j++){
            if(currentPos[0]+i >= 0 and currentPos[0]+i < gridSize and currentPos[1]+j >=0 and currentPos[1]+j < gridSize){
              if(maxValue < viewData[currentPos[0]+i][currentPos[1]+j]){
                maxValue = viewData[currentPos[0]+i][currentPos[1]+j];
                changeX = i;
                changeY = j;
              }
            }
          }
        }
        // find nearest max again
        minDis = HUGE_VAL;
        minNum = 0;
        currentPos[0] += changeX;
        currentPos[1] += changeY;
        for(int j = 0; j<numMax; j++){ // find nearest maximum
          Point2d clusterPos = maximaHeatMap[j];

          if(minDis > norm(currentPos - clusterPos)){
            minDis = norm(currentPos - clusterPos);
            minNum = j;
          }
        }
      }
      cellLabelMaximaIdxMap[currentLabel] = minNum;
    }
  //}

  //maximaIdxManualClusterIdxMap = cellLabelMaximaIdxMap;
  //return nearestMaximum;
}


/*
void ClusterMap::convertClusters2DToND()
{
  mode2D = false;

  

}


// take current GUI position of manual clusters and save them to idx 0 and 1
void ClusterMap::convertClustersNDTo2D()
{
  resetCellFeatures();
  mode2D = true;

  // save current features
  AttrMap<int, double>& featureX = cellFeatures[attrIdxMap[selectedX]-1];
  AttrMap<int, double>& featureY = cellFeatures[attrIdxMap[selectedX]-1];

  addCellFeature(selectedX, featureX);
  addCellFeature(selectedY, featureY);
}
*/
}