//
// This file is part of MorphoGraphX - http://www.MorphoGraphX.org
// Copyright (C) 2012-2016 Richard S. Smith and collaborators.
//
// If you use MorphoGraphX in your work, please cite:
//   http://dx.doi.org/10.7554/eLife.05864
//
// MorphoGraphX is free software, and is licensed under under the terms of the 
// GNU General (GPL) Public License version 2.0, http://www.gnu.org/licenses.
// 
#include <MeshProcessFibril.hpp>
#include <MeshProcessSignal.hpp>
#include <Information.hpp>
#include <Progress.hpp>
#include <PCA.hpp>
#include <GraphUtils.hpp>

namespace mgx 
{
  //  Adapted from an ImageJ plugin published in Boudaoud et al., Nature Protocols 2014
  //
  // 'FibrilTool, an ImageJ plug-in to quantify fibrillar structures in raw microscopy images'.
  bool FibrilOrientations::run(Mesh* mesh, double border, double minAreaRatio, double blurRadius, double backRadius)
  {
    IntPoint3fAttr Normals;
    IntPoint3fMap Orientations;
    std::map<int, Matrix3f> Correlations;
    IntFloatMap Areas;
    IntFloatMap WeightedAreas;
    IntFloatMap InsideAreas;
    IntFloatMap Signals;
    std::unordered_map<int, std::set<Triangle> > TriangleMap;
  
    const vvGraph& S = mesh->graph();

    VtxFloatMap signal, gSignal;

    // Backup original signal
    if(backRadius > 0 or blurRadius > 0) {
      forall(const vertex &v, S)
        signal[v] = v->signal;
    }

    // Remove background (difference of Gaussian)
    if(backRadius > 0) {
      MeshGaussianBlur mgb(*this);
      mgb.run(mesh, backRadius);

      // Get global blur and restore signal
      forall(const vertex &v, S) {
        double gsignal = gSignal[v];
        gSignal[v] = v->signal;
        v->signal = signal[v];
      }
    }

    // Blur signal
    if(blurRadius > 0) {
      MeshGaussianBlur mgb(*this);
      mgb.run(mesh, blurRadius);
    }
  
    mesh->updateCentersNormals();
    Normals = mesh->labelNormal(); 
  
    // Mark the vertices within this distance from the border.
    mesh->markBorder(border);
  
    // List all the internal triangles (non border) first. 
    forall(const vertex &v, S) {
      forall(const vertex &n, S.neighbors(v)) {
        const vertex& m = S.nextTo(v, n); 
        if(!S.uniqueTri(v, n, m))
          continue; 
        int label = getLabel(v, n, m);
        if(label <= 0)
          continue;
  
        float area = triangleArea(v->pos, n->pos, m->pos);
        Areas[label] += area;
        if(n->minb == 0 and v->minb == 0 and m->minb == 0) {
          InsideAreas[label] += area;  
  	      Triangle t(v,n,m);
          TriangleMap[label].insert(t);
  	    }
      }
    }
  
    // Clear data in MorphoGraphX
    mesh->clearCellAxis();
  
    // Calculate orientation of the perpendicular to the average gradient of the signal
    std::unordered_map<int, std::set<Triangle> >::iterator it;
    for(it = TriangleMap.begin() ; it != TriangleMap.end() ; ++it) {
      int label = it->first; 
      // if the total area of the triangles is too small, don't compute the PO
      if(InsideAreas[label] < Areas[label]*minAreaRatio)
        continue; 
      std::set<Triangle> Triangles = it->second; 
      std::set<Triangle>::iterator t;
      for(t = Triangles.begin() ; t != Triangles.end() ; ++t) {
        Triangle tri = *t;
        vertex v = tri.v[0];
        vertex n = tri.v[1];
        vertex m = tri.v[2];
  
        // Get center
        //Point3f nrml = normalized((v->pos - n->pos) ^ (v->pos - m->pos));
        Point3d nrml(Normals[label]);
        double area = triangleArea(v->pos, n->pos, m->pos);
  
        // Set signal to known gradient for debugging
        //v->signal = v->pos.y(); n->signal = n->pos.y(); m->signal = m->pos.y();

        // Calculate gradient
        Point3d grad = triangleGradient(v->pos, n->pos, m->pos, double(v->signal), double(n->signal), double(m->signal)); 
        if(backRadius > 0)
          grad -= triangleGradient(v->pos, n->pos, m->pos, double(gSignal[v]), double(gSignal[n]), double(gSignal[m]));
        // subtract any component in normal direction
        grad -= (grad * nrml) * nrml;

  // Original code adapted from ImageJ: 
  //     	// Reduce the problem to a triangle centered on v, and with dn on the X axis
  //      float ds_n = n->signal - v->signal;
  //      float ds_m = m->signal - v->signal;
  //
  //      Point3f dn = n->pos - v->pos;
  //      Point3f dm = m->pos - v->pos;
  //
  //      Point3f nrml = dn ^ dm;
  //      float area = norm(nrml);
  //      nrml /= area;
  //      area /= 2.f;
  //
  //      // Create the 2D triangle
  //      // Define the axis in the triangle plane
  //      float len_dn = norm(dn);
  //      Point3f x = dn / len_dn;
  //      Point3f y = nrml ^ x;
  //
  //      // In this 3D context, the normal to the place is (a,b,c) and the equation of the plane is:
  //      // ax + by + cs = 0
  //      // Then the function of signal is s = -a/c x - b/c y
  //      // so the gradient is (-a/c, -b/c)
  //      //Point3f n2(len_dn, 0, ds_n);
  //      //Point3f m2(dm*x, dm*y, ds_m);
  //			//ds_n = 1.0; ds_m = 1.0;// test without any gradient!
  //      Point3f n2(1, 0, ds_n);
  //			dm /= norm(dm);
  //      Point3f m2(dm*x, dm*y, ds_m);
  //      Point3f grad2d = n2 ^ m2;
  //      grad2d /= grad2d.z();
  //
  //      Point3f grad = -grad2d.x() * x - grad2d.y() * y;
  
  //      // Direction along the fibres
      grad = nrml ^ grad;
  
      Matrix3f corr;
      for(int i = 0; i < 3; i++)
        for(int j = 0; j < 3; j++)
          corr(i,j) = grad[i] * grad[j];
  
      Correlations[label] += corr * area;// big triangles should count more than small ones
      WeightedAreas[label] += area;
        
  //      // Arezki version: we normalize the sum of correlations by the norm of gradient square
  //      WeightedAreas[label] += area * normsq(grad); 
      }
    }

    // Restore signal
    if(blurRadius > 0) {
      forall(const vertex &v, S)
        v->signal = signal[v]; 
    }
 
    AttrMap<int, SymmetricTensor>& fibrilAttr = mesh->attributes().attrMap<int, SymmetricTensor>("Measure Label Tensor FibrilOrientations");
    fibrilAttr.clear();

    // Do the PCA
    forall(const IntFloatPair &p, WeightedAreas) {
      if(p.second == 0 )
        continue;
      // corr is an average of the correlations, weighted by triangles size
      Matrix3f corr = Correlations[p.first] / p.second;
  
      PCA pca;
      if(!pca(corr))
        return setErrorMessage("Error performing PCA");
      
      SymmetricTensor& tensor = mesh->cellAxis()[p.first];
      tensor.ev1() = pca.p1;
      tensor.ev2() = pca.p2;
      tensor.evals() = pca.ev;

      fibrilAttr[p.first] = tensor;
    }
  
    mesh->setCellAxisType("fibril");
    mesh->setCellAxisUnit("a.u.");
    mesh->updateLines();

  	// Run fibril display, with the parameters from the GUI 

  	
    DisplayFibrilOrientations *proc;
    if(!getProcess("Mesh/Cell Axis/Fibril Orientations/Display Fibril Orientations", proc))  
      throw(QString("MeshProcessHeatMap:: Unable to make DisplayFibrilOrientations process"));   	
  	proc->run(); 	
  
    return true;
  }
  REGISTER_PROCESS(FibrilOrientations);

  // Display principal orientations (separatly from each other) and heatmaps of anisotropy etc.
  bool DisplayFibrilOrientations::run(Mesh *mesh,const QString displayHeatMap,const QColor& qaxisColor,
                                             float axisLineWidth, float axisLineScale, float axisOffset,
                                             float orientationThreshold)																	 
  {
    // if there is no cell axis stored, display error message
    const IntSymTensorAttr& cellAxis = mesh->cellAxis();
    if(cellAxis.size() == 0 or mesh->cellAxisType() != "fibril"){
      setErrorMessage(QString("No fibril axis stored in active mesh!"));
      return false;
    }
  
    // Check cell axis and prepare them for display. Populate cellAxisScaled in Mesh.  
    IntMatrix3fAttr &cellAxisVis = mesh->cellAxisVis();
    cellAxisVis.clear();
    IntVec3ColorbAttr &cellAxisColor = mesh->cellAxisColor();
    cellAxisColor.clear();
    mesh->setAxisWidth(axisLineWidth);
    mesh->setAxisOffset(axisOffset);
    Colorb axisColor(qaxisColor);
  
    // Update the normals and center of the cell for visualisation.
    if(mesh->labelCenterVis().empty() or mesh->labelNormalVis().empty())
      mesh->updateCentersNormals();
  
    // Check if there is a cell center for each cell axis
    IntPoint3fAttr labelCenterVis = mesh->labelCenterVis();
    int nProblems = 0;
    forall(const IntSymTensorPair &p, cellAxis) {
      int label = p.first;
      if(labelCenterVis.count(label) == 0)
        nProblems++;
    }
    if(nProblems != 0)
      SETSTATUS("Warning: non-existing cell center found for " << nProblems << " cell axis.");
  
    // Scale the cell axis for display
    forall(const IntSymTensorPair &p, cellAxis) {
      int cell = p.first;
      const SymmetricTensor& tensor = p.second;
      // Degree of orientation = axisMax/axisMin -1. Values always >= 0 
      //float orientation =  tensor.evals()[0]/tensor.evals()[1]-1;
      float orientation = (tensor.evals()[0] - tensor.evals()[1])/(tensor.evals()[0] + tensor.evals()[1]);
      cellAxisVis[cell][0] = (orientation > orientationThreshold ? orientation : 0.f) * tensor.ev1() * axisLineScale;
      cellAxisVis[cell][1] = Point3f(0,0,0);
      cellAxisVis[cell][2] = Point3f(0,0,0);
  
      cellAxisColor[cell][0] = axisColor;
    }
  
    // Heat map of orientation = axisMax/axisMin -1.
    if(displayHeatMap == "Orientation") {
      mesh->labelHeat().clear();
      IntFloatAttr labelHeatMap;
      float maxValue = 0;
      float minValue = 0;
      bool firstRun = true;
      forall(const IntSymTensorPair &p, cellAxis)
      {
        //float value = p.second.evals()[0] / p.second.evals()[1] -1;
        float value = (p.second.evals()[0] - p.second.evals()[1]) 
                                 / (p.second.evals()[0] + p.second.evals()[1]);
        labelHeatMap[p.first] = value;
  
        if(firstRun) {
          maxValue = value;
          minValue = value;
          firstRun = false;
        }
  
        if(value > maxValue)
          maxValue = value;
        if(value < minValue)
          minValue = value;
      }
  
      // Adjust heat map if run on parents
      if(mesh->parents().size() != 0){
        Information::out << "parent label loaded" << endl;
        mesh->labelHeat().clear();
        IntIntAttr & parentMap = mesh->parents();
        forall(const IntIntPair &p, parentMap)
          if(p.second > 0 and labelHeatMap.count(p.second) != 0)
            mesh->labelHeat()[p.first] = labelHeatMap[p.second];
      }
      else
        mesh->labelHeat() = labelHeatMap;
  
      mesh->heatMapBounds() = Point2f(minValue, maxValue);
      mesh->heatMapUnit() = QString("Orientation");
      mesh->setShowLabel("Label Heat");
      mesh->updateTriangles();
    }
  
    return true;
  }
  REGISTER_PROCESS(DisplayFibrilOrientations);

  // returns all neighobring triangles of a vertex that are within the radius
  std::set<Triangle> getNeighborTris(const vvGraph& S, const vertex& v, double radius)
  {
    std::set<Triangle> neighborTris;

    std::set<vertex> visitedVtxs;
    std::set<vertex> currentVtxs;

    visitedVtxs.insert(v);
    forall(const vertex& n, S.neighbors(v)){
      if(norm(v->pos - n->pos) > radius) 
        continue;
      currentVtxs.insert(n);
    }

    while(!currentVtxs.empty()){
      vertex n = *currentVtxs.begin();

      forall(const vertex &m, S.neighbors(n)) {
        const vertex& k = S.nextTo(n, m);

        double disM = norm(v->pos - m->pos);
        double disK = norm(v->pos - k->pos);

        if(visitedVtxs.find(m) == visitedVtxs.end() and disM <= radius) currentVtxs.insert(m);
        if(visitedVtxs.find(k) == visitedVtxs.end() and disK <= radius) currentVtxs.insert(k);

        if(!S.uniqueTri(n, m, k))
          continue; 
        if(disM > radius or disK > radius)
          continue;

        //if(n->minb == 0 and v->minb == 0 and m->minb == 0) {
          Triangle t(n,m,k);
          neighborTris.insert(t);
        //}
      }

      visitedVtxs.insert(n);
      currentVtxs.erase(n);
    }


    return neighborTris;
  } 

  //  Vertex based version
  bool FibrilOrientationsVertex::run(Mesh* mesh, Mesh* mesh2, float border, float minAreaRatio, float radius, float blurRadius)
  {

    typedef std::unordered_map<vertex, float > VtxFloatMap;
    IntPoint3fAttr Normals;
    IntPoint3fMap Orientations;
    std::map<vertex, Matrix3f> Correlations;
    VtxFloatMap Areas;
    VtxFloatMap WeightedAreas;
    VtxFloatMap InsideAreas;
    VtxFloatMap Signals;

    std::unordered_map<vertex, std::set<Triangle> > TriangleMap;
  
    const vvGraph& S = mesh->graph();

    // update the normals
    mesh->updateCentersNormals();
 
    // First back up signal and blur
    VtxFloatMap signal;
    if(blurRadius > 0) {
      forall(const vertex &v, S)
        signal[v] = v->signal;
      MeshGaussianBlur mgb(*this);
      mgb.run(mesh, blurRadius); 
    }

    // Mark the vertices within this distance from the border.
    mesh->markBorder(border);

    int counter = 0;

    forall(const vertex &v, S) {
      if(!v->selected) continue;
      if(!progressAdvance(counter))
        userCancel();
      counter++;

      TriangleMap[v] = getNeighborTris(S, v, radius);

      std::set<Triangle> outsideTris;

      forall(Triangle t, TriangleMap[v]){
        float area = triangleArea(t.v[0]->pos, t.v[1]->pos, t.v[2]->pos);
        Areas[v] += area;
        if(t.v[0]->minb == 0 and t.v[1]->minb == 0 and t.v[2]->minb == 0) {
          InsideAreas[v] += area;  
        } else {
          outsideTris.insert(t);
        }
      }

      forall(Triangle t, outsideTris){
        TriangleMap[v].erase(t);
      }

      Normals[v] = Point3f(v->nrml);
    }

    if(Normals.empty()) return setErrorMessage("Select some vertices first!");


    // older slower code
    // forall(const vertex &vv, S) {

    //   if(!vv->selected) continue;
    //   //vertex nearest;
    //   //double minDis = 1E20;
    //   // List all the internal triangles (non border) first. 
    //   forall(const vertex &v, S) {
    //     if(norm(vv->pos - v->pos) > radius)
    //       continue;
    //     //double dis = norm(v->pos - vv->pos);
    //     //if(dis < minDis){
    //     //  minDis = dis;
    //     //  nearest = v;
    //     //}
    //     forall(const vertex &n, S.neighbors(v)) {
    //       const vertex& m = S.nextTo(v, n); 
    //       if(!S.uniqueTri(v, n, m))
    //         continue; 
    //       if(norm(vv->pos - n->pos) > radius or norm(vv->pos - m->pos) > radius)
    //         continue;


    //       float area = triangleArea(v->pos, n->pos, m->pos);
    //       Areas[vv] += area;

    //       if(n->minb == 0 and v->minb == 0 and m->minb == 0) {
    //         InsideAreas[vv] += area;  
    //         triangle t(v,n,m);
    //         TriangleMap[vv].insert(t);


    //       }
    //     }
    //   }
    //   std::cout << "old " << TriangleMap[vv].size() << "/" << Areas[vv] << "/" << InsideAreas[vv] << std::endl;
    //   std::cout << "new " << TriangleMap2[vv].size() << "/" << Areas2[vv] << "/" << InsideAreas2[vv] << std::endl;


    //   Normals[vv] = Point3f(vv->nrml);
    // }
  
    // Clear data in MorphoGraphX
    mesh->vertexAxis().clear();
  
    // Calculate orientation of the perpendicular to the average gradient of the signal
    std::unordered_map<vertex, std::set<Triangle> >::iterator it;
    for(it = TriangleMap.begin() ; it != TriangleMap.end() ; ++it) {
      vertex vv = it->first; 
      // if the total area of the triangles is too small, don't compute the PO
      if(InsideAreas[vv] < Areas[vv]*minAreaRatio)
        continue; 
      std::set<Triangle> Triangles = it->second; 
      std::set<Triangle>::iterator t;
      for(t = Triangles.begin() ; t != Triangles.end() ; ++t) {
        Triangle tri = *t;
        vertex v = tri.v[0];
        vertex n = tri.v[1];
        vertex m = tri.v[2];
  
        // Get center
        //Point3f nrml = normalized((v->pos - n->pos) ^ (v->pos - m->pos));
        Point3d nrml = vv->nrml;
        double area = triangleArea(v->pos, n->pos, m->pos);
  
        // Set signal to known gradient for debugging
        //v->signal = v->pos.y(); n->signal = n->pos.y(); m->signal = m->pos.y();

        // Calculate gradient
        Point3d grad = triangleGradient(v->pos, n->pos, m->pos, double(v->signal), double(n->signal), double(m->signal));
        // subtract any component in normal direction
        grad -= (grad * nrml) * nrml;

      // Direction along the fibres
      grad = nrml ^ grad;
  
      Matrix3f corr;
      for(int i = 0; i < 3; i++)
        for(int j = 0; j < 3; j++)
          corr(i,j) = grad[i] * grad[j];
  
      Correlations[vv] += corr * area;// big triangles should count more than small ones
      WeightedAreas[vv] += area;
        
  //      // Arezki version: we normalize the sum of correlations by the norm of gradient square
  //      WeightedAreas[label] += area * normsq(grad); 
      }
    }

    // Restore signal
    if(blurRadius > 0) {
      forall(const vertex &v, S)
        v->signal = signal[v];   
    }

    // Do the PCA
    typedef std::pair<vertex,float> VtxFloatPair;
    forall(const VtxFloatPair &p, WeightedAreas) {
      if(p.second == 0 )
        continue;
      // corr is an average of the correlations, weighted by triangles size
      Matrix3f corr = Correlations[p.first] / p.second;
  
      PCA pca;
      if(!pca(corr))
        return setErrorMessage("Error performing PCA");
      
      SymmetricTensor& tensor = mesh->vertexAxis()[p.first];
      tensor.ev1() = pca.p1;
      tensor.ev2() = pca.p2;
      tensor.evals() = pca.ev;
    }
  
    mesh->setAxisView("Vertex Axis");
    mesh->setCellAxisType("fibril");
    mesh->setCellAxisUnit("a.u.");
    mesh->updateLines();
    mesh->updateTriangles();
  
    // Run fibril display, with the parameters from the GUI 
    
    
    DisplayFibrilOrientationsVertex *proc;
    if(!getProcess("Mesh/Cell Axis/Fibril Orientations/Vertex Display Orientations", proc))  
      throw(QString("MeshProcessHeatMap:: Unable to make DisplayFibrilOrientationsVertex process"));   	
  	proc->run(); 	    
  
    return true;
  }
  REGISTER_PROCESS(FibrilOrientationsVertex);


  // Display principal orientations (separatly from each other) and heatmaps of anisotropy etc.
  bool DisplayFibrilOrientationsVertex::run(Mesh *mesh,const QString displayHeatMap,const QColor& qaxisColor,
                                             float axisLineWidth, float axisLineScale, float axisOffset,
                                             float orientationThreshold)                                   
  {
    // if there is no cell axis stored, display error message
    const VtxSymTensorAttr& vertexAxis = mesh->vertexAxis();
    if(vertexAxis.size() == 0 or mesh->cellAxisType() != "fibril"){
      setErrorMessage(QString("No fibril axis stored in active mesh!"));
      return false;
    }
  
    // Check cell axis and prepare them for display. Populate cellAxisScaled in Mesh.  
    VtxMatrix3fAttr &vtxAxisVis = mesh->vertexAxisVis();
    vtxAxisVis.clear();
    VtxVec3ColorbAttr &vtxAxisColor = mesh->vertexAxisColor();
    vtxAxisColor.clear();
    mesh->setAxisWidth(axisLineWidth);
    mesh->setAxisOffset(axisOffset);
    Colorb axisColor(qaxisColor);
  
    // Scale the cell axis for display
    typedef std::pair<vertex, SymmetricTensor> VtxSymTensorPair;
    forall(const VtxSymTensorPair &p, vertexAxis) {
      vertex vv = p.first;
      const SymmetricTensor& tensor = p.second;
      // Degree of orientation = axisMax/axisMin -1. Values always >= 0 
      //float orientation =  tensor.evals()[0]/tensor.evals()[1]-1;
      float orientation = (tensor.evals()[0] - tensor.evals()[1])/(tensor.evals()[0] + tensor.evals()[1]);
      vtxAxisVis[vv][0] = (orientation > orientationThreshold ? orientation : 0.f) * tensor.ev1() * axisLineScale;
      vtxAxisVis[vv][1] = Point3f(0,0,0);
      vtxAxisVis[vv][2] = Point3f(0,0,0);
  
      vtxAxisColor[vv][0] = axisColor;
    }
    mesh->setShowAxis("Vertex Axis");
    mesh->updateTriangles();
  
    return true;
  }
  REGISTER_PROCESS(DisplayFibrilOrientationsVertex);


  // Display principal orientations (separatly from each other) and heatmaps of anisotropy etc.
  bool SelectVerticesOfCell::run(Mesh* m, double minPointDis,  double minBorderDis, bool ignoreNoParents)                                   
  {
    vvGraph& S = m->graph();

    //if(!meshHasEdges(S)) return setErrorMessage("Error: A surface mesh as active mesh is required");

    std::vector<Point3d> outputPoints;

    std::vector<vertex> av = m->activeVertices();

    std::map<int, std::set<vertex> > labelVtxMap;
    std::map<int, std::set<vertex> > labelVtxSelMap;

    forall(vertex v, av){
      if(ignoreNoParents and m->parents().find(v->label) == m->parents().end()) continue;
      labelVtxMap[v->label].insert(v);
    }

    forall(const vertex& v, S){
      v->selected = false;
    }

    int progr = 0;

    // foreach cell separately
    forall(auto p, labelVtxMap){

      if(!progressAdvance(progr))
        userCancel();

      int label = p.first;
      if(label < 0) continue;

      std::set<vertex> cellBorder;

      forall(vertex v, p.second){
        forall(vertex n, S.neighbors(v)){
          if(n->label == v->label) continue;
          cellBorder.insert(n);
        }
      }

      forall(vertex v, p.second){
        double minDis = 1E20;
        double borderDis = 1E20;
        // check distance to nearest found vertex
        forall(vertex n, labelVtxSelMap[label]){
          double dis = norm(v->pos - n->pos);
          if(dis < minDis) minDis = dis;
        }

        if(minBorderDis > 0){
          // check distance to nearest border vertex
          forall(vertex n, cellBorder){
            double dis = norm(v->pos - n->pos);
            if(dis < borderDis) borderDis = dis;
          }
        }

        // only add the point if both criteria are satisfied
        if(minDis > minPointDis and borderDis > minBorderDis) labelVtxSelMap[label].insert(v);
      }

      progr++;
    }

    forall(auto p, labelVtxSelMap){

      forall(vertex v, p.second){
        v->selected = true;
      }
    }
  
    m->updateAll();
  
    return true;
  }
  REGISTER_PROCESS(SelectVerticesOfCell);

}

