//
// This file is part of MorphoGraphX - http://www.MorphoGraphX.org
// Copyright (C) 2012-2016 Richard S. Smith and collaborators.
//
// If you use MorphoGraphX in your work, please cite:
//   http://dx.doi.org/10.7554/eLife.05864
//
// MorphoGraphX is free software, and is licensed under under the terms of the 
// GNU General (GPL) Public License version 2.0, http://www.gnu.org/licenses.
// 
#include "MeshProcessSegmentation.hpp"
#include "MeshProcessSignal.hpp"

#include "Progress.hpp"

#include "ui_LoadHeatMap.h"

namespace mgx 
{
  bool SegmentMesh::run(Mesh* mesh, uint maxsteps, const std::vector<vertex>& to_segment)
  {
    typedef std::pair<float, vertex> cvpair;
    typedef std::multimap<float, vertex> cvmap;
  
    // Start progress bar
    vvGraph& S = mesh->graph();
    const std::vector<vertex>& vs = (to_segment.empty() ? mesh->activeVertices() : to_segment);
  
    bool all_active = vs.size() == S.size();
  
    vvGraph SS;
    size_t unlabeled = 0;
    if(all_active) {
      SS = S;
      forall(vertex v, S)
        if(v->label == 0)
          unlabeled++;
    } else {
      std::set<vertex> vset(vs.begin(), vs.end());
      // Add labeled neighbors of non-labeled vertices
      forall(vertex v, vs)
        if(v->label == 0) {
          unlabeled++;
          forall(vertex n, S.neighbors(v))
            if(n->label > 0 and vset.count(n) == 0)
              vset.insert(n);
        }
      SS = S.subgraph(vset);
    }
  
    progressStart(QString("Segmenting mesh-%1").arg(mesh->userId()), unlabeled);
    int count = 0;
    mesh->updateTriangles();
    mesh->updateLines();
  
    VtxBoolAttr& inQueue = mesh->attributes().attrMap<vertex, bool>("Vertex InQueue", false);
   
    cvmap Q;
    do {
      Q.clear();
      // Mark all vertices as not in queue
      forall(const vertex& v, SS)
        v->*inQueue = false;
  
      // Put all neighbors of labelled vertices into queue
      forall(const vertex& v, SS)
        if(v->label > 0) {
          forall(const vertex& n, SS.neighbors(v)) {
            if(n->label == 0 and !(n->*inQueue)) {
              n->*inQueue = true;
              Q.insert(cvpair(n->signal, n));
            }
          }
        }
  
      // Process queue
      long steps = (maxsteps > 0 ? maxsteps : LONG_MAX);
      if(Q.size() > 0)
        mesh->updateTriangles();
  
      while(Q.size() > 0 and steps-- > 0) {
        if(!progressAdvance(count))
          userCancel();
  
        cvmap::iterator cvmax = Q.begin();
        vertex vmax = cvmax->second;
  
        int label = 0;
        bool foundlabel = false, difflabels = false;
        forall(const vertex& n, SS.neighbors(vmax)) {
          if(n->label > 0) {
            if(!foundlabel) {
              label = n->label;
              foundlabel = true;
            } else if(label != n->label)
              difflabels = true;
          }
          // RSS not sure why this is here
          // else if(n->label == 0 and !n->*inQueue) {
          //  n->*inQueue = true;
          //  Q.insert(cvpair(n->signal, n));
          //}
        }
        if(foundlabel) {
          if(difflabels)
            vmax->label = -1;
          else
            vmax->label = label;
          count++;
  
          forall(const vertex& n, SS.neighbors(vmax))
            if(n->label == 0 and !(n->*inQueue)) {
              n->*inQueue = true;
              Q.insert(cvpair(n->signal, n));
            }
        }
  
        vmax->*inQueue = false;
        Q.erase(cvmax);
      }
  
      if(count > 0) {
        mesh->updateTriangles();
        mesh->updateLines();
      }
      if(Q.size() > 0) {
        updateState();
        updateViewer();
      }
  
      SETSTATUS("Segmenting mesh " << mesh->userId() << ", vertices left in queue:" << Q.size());
    } while(Q.size() > 0);
  
    return true;
  }
  REGISTER_PROCESS(SegmentMesh);
  
  bool SegmentClear::run(Mesh* mesh)
  {
    const std::vector<vertex>& vs = mesh->activeVertices();
    forall(const vertex& v, vs)
      v->label = 0;
    mesh->updateTriangles();
    mesh->updateLines();
    return true;
  }
  REGISTER_PROCESS(SegmentClear);
  
  bool MeshRelabel::run(Mesh* m, int start, int step)
  {
    if(step < 1)
      step = 1;
    if(start < 1)
      start = 1;
    const std::vector<vertex>& vs = m->activeVertices();
    progressStart(QString("Relabeling current mesh.").arg(start).arg(step), vs.size());
    std::unordered_map<int, int> relabel_map;
    int prog = vs.size() / 100;
    int i = 0;
    forall(const vertex& v, vs) {
      if(v->label > 0) {
        std::unordered_map<int, int>::iterator found = relabel_map.find(v->label);
        if(found != relabel_map.end())
          v->label = found->second;
        else {
          relabel_map[v->label] = start;
          v->label = start;
          start += step;
        }
      }
      if((i % prog) == 0 and !progressAdvance(i))
        userCancel();
      ++i;
    }
    m->setLabel(start);
    if(!progressAdvance(vs.size()))
      userCancel();
    m->updateTriangles();
    SETSTATUS(QString("Total number of labels created = %1. Last label = %2").arg(relabel_map.size()).arg(start));
    return true;
  }
  REGISTER_PROCESS(MeshRelabel);
  
  bool GrabLabelsSegment::run(Mesh* mesh1, Mesh* mesh2, float tolerence)
  {
    const vvGraph& S1 = mesh1->graph();
    const vvGraph& S2 = mesh2->graph();
  
    // Find center of all cells (should be closed surfaces)
    typedef std::pair<int, Point4f> IntPoint4fPair;
    std::map<int, Point4f> posMap1;
    forall(const vertex& v, S1)
      if(v->label > 0)
        posMap1[v->label] += Point4f(v->pos.x(), v->pos.y(), v->pos.z(), 1.0);
    forall(const IntPoint4fPair& p, posMap1)
      posMap1[p.first] /= posMap1[p.first].t();
  
    std::map<int, Point4f> posMap2;
    forall(const vertex& v, S2)
      if(v->label > 0)
        posMap2[v->label] += Point4f(v->pos.x(), v->pos.y(), v->pos.z(), 1.0);
    forall(const IntPoint4fPair& p, posMap2)
      posMap2[p.first] /= posMap2[p.first].t();
  
    // For each label in mesh 1, find closest in mesh 2.
    std::map<int, int> labelMap;
    forall(const IntPoint4fPair& p1, posMap1) {
      float minDist = 1e37;
      int minLabel = 0;
      forall(const IntPoint4fPair& p2, posMap2) {
        float dist = (p1.second - p2.second).norm();
        if(dist < minDist and dist < tolerence) {
          minDist = dist;
          minLabel = p2.first;
        }
      }
      if(minLabel > 0)
        labelMap[p1.first] = minLabel;
    }
    SETSTATUS(QString("Updating %1 matching cells.").arg(labelMap.size()));
  
    // Update mesh labels
    forall(const vertex& v, S1)
      if(labelMap[v->label] > 0)
        v->label = labelMap[v->label];
      else
        v->label = 0;
  
    mesh1->updateTriangles();
    return true;
  }
  REGISTER_PROCESS(GrabLabelsSegment);
  
  bool LabelSelected::run(Mesh* mesh, int label)
  {
    const std::vector<vertex>& vs = mesh->activeVertices();
    if(label <= 0) {
      forall(const vertex& v, vs)
        if(v->label > 0) {
          if(label == 0)
            label = v->label;
          else if(label != v->label)
            return setErrorMessage("At most one label can be contained in selected vertices");
        }
      if(label <= 0)
        label = mesh->nextLabel();
    }
    forall(const vertex& v, vs)
      v->label = label;
    mesh->updateTriangles();
    mesh->updateLines();
    return true;
  }
  REGISTER_PROCESS(LabelSelected);
  
  // Relabel 3D cells
  bool RelabelCells3D::run(Mesh* mesh, int label_start, int label_step)
  {
    // Get the vertices
    const vvGraph& S = mesh->graph();
  
    // Get cells by connected regions
    VVGraphVec cellVertex;
    VtxIntMap vertexCell;
    mesh->getConnectedRegions(S, cellVertex, vertexCell);
  
    if(label_start == -1)
      label_start = mesh->nextLabel();
  
    // Relabel vertices
    forall(const vvGraph& S, cellVertex) {
      forall(const vertex& v, S)
        v->label = label_start;
      label_start += label_step;
    }
  
    if(label_start >= mesh->viewLabel())
      mesh->setLabel(label_start + 1);
  
    // Tell the system the mesh color has changed
    mesh->updateTriangles();
    mesh->updateSelection();
  
    return true;
  }
  REGISTER_PROCESS(RelabelCells3D);
  
  bool MeshCombineRegions::run(Mesh* mesh, float borderDist, float threshold)
  {
    vvGraph& S = mesh->graph();
    const std::vector<vertex>& vs = mesh->activeVertices();
    progressStart(QString("Combining Regions on Mesh %1").arg(mesh->userId()), S.size());
    typedef std::set<vertex> Set;
    Set VNew;
    std::set<int> VContain;
    typedef std::pair<int, int> SigPair;
    SigPair pr;
    std::set<SigPair> pair, PairUpdated;
    std::map<int, int> VCont, labelCount;
    std::map<int, double> labelAvg;
    typedef std::pair<SigPair, double> to_loop;
    std::map<SigPair, double> mysig, mywt, PairSet;
    mesh->markBorder(borderDist);
    typedef std::map<int, Set> junctionVertices;
    junctionVertices junction;
    std::set<vertex> VJunction;
  
    // Removing the vertices at the junctions
    forall(const vertex& v, vs) {
      if(v->label == -1) {
        VContain.clear();
        forall(const vertex& n, S.neighbors(v)) {
          if(n->label == -1) {
            vertex m = S.nextTo(v, n);
            VContain.insert(m->label);
          }
        }
        if(VContain.size() < 3)
          VNew.insert(v);
        else {
          // VJunction.insert(v);
          forall(int m, VContain) {
            junction[m].insert(v);
          }
        }
      }
    }
  
    // Average of v->signal for each cell (respective to each label)
    int cnt = 0;
    forall(const vertex& v, vs) {
      if(!progressAdvance(cnt++))
        userCancel();
      if(v->label <= 0 or v->minb != 0)
        continue;
      labelAvg[v->label] += v->signal;
      labelCount[v->label]++;
    }
    int c = 0;
    forall(SigPair pr, labelCount) {
      c++;
      if(labelCount[pr.first] != 0) {
        labelAvg[pr.first] /= labelCount[pr.first];
      }
    }
  
    // Calculation of the average signal at the border
    forall(const vertex& v, VNew) {
      if(v->label == -1) {
        forall(const vertex& n, S.neighbors(v)) {
          if(n->label == -1) {
            vertex m = S.nextTo(v, n);
            vertex k = S.prevTo(v, n);
            if(m->label > 0 and k->label > 0 and m->label < k->label and labelAvg.count(m->label) == 1
               and labelAvg.count(k->label) == 1) {
              pr = std::make_pair(m->label, k->label);
              Point3d vn_avg = (v->pos + n->pos) / 2.0;
              bool flag = true;
              forall(const vertex k, junction[m->label]) {
                if(norm(k->pos - vn_avg) < borderDist) {
                  flag = false;
                  break;
                }
              }
              if(flag == true) {
                mywt[pr] += norm(v->pos - n->pos);
                mysig[pr] += v->signal * norm(v->pos - n->pos);
              }
            }
          }
        }
      }
    }
  
    // Set of pairs of labels to be updated if the avg. border signal is not too much higher
    // than the average signals of the two adjacent cells
    forall(to_loop loop, mysig) {
      if(float((loop.second / mywt[loop.first]) / ((labelAvg[loop.first.first] + labelAvg[loop.first.second]) / 2))
         < threshold) {
        pr = std::make_pair(loop.first.first, loop.first.second);
        pair.insert(pr);
        PairSet[pr] = 1;
      }
    }
    int count = 0;
    // Updating the mesh with the changed labels and updating those changed labels in
    // respective pairs further on in the set
    do {
      if(!progressAdvance(count++))
        userCancel();
      pr = *(pair.begin());
      forall(const vertex& n, vs) {
        if(n->label == pr.first) {
          n->label = pr.second;
        }
      }
      PairUpdated.clear();
      forall(SigPair pr2, pair) {
        if(pr.first == pr2.first and pr.second == pr2.second) {
          PairSet[pr2]++;
        }
        if(PairSet[pr2] == 1) {
          if(pr.first == pr2.first and pr.second < pr2.second) {
            pr2.first = pr.second;
            PairUpdated.insert(std::make_pair(pr2.first, pr2.second));
            PairSet[pr2] = 1;
          } else if(pr.first == pr2.second) {
            pr2.second = pr.second;
            PairUpdated.insert(std::make_pair(pr2.first, pr2.second));
            PairSet[pr2] = 1;
          } else {
            PairUpdated.insert(std::make_pair(pr2.first, pr2.second));
          }
        }
      }
      pair = PairUpdated;
    } while(pair.size() > 0);
    mesh->updateAll();
    return true;
  }
  REGISTER_PROCESS(MeshCombineRegions);
  
  bool MeshAutoSegment::run(Mesh* mesh, bool updateView, bool normalize, float gaussianRadiusCell,
                                   float localMinimaRadius, float gaussianRadiusWall, float normalizeRadius,
                                   float borderDist, float threshold)
  {
    const VtxVec& vs = mesh->activeVertices();
    FloatVec MeshSignal(vs.size());
  
    for(size_t i = 0; i < vs.size(); ++i)
      MeshSignal[i] = vs[i]->signal;
  
    // Blur the mesh, normally with same radius as local minima
    MeshGaussianBlur blurCells(*this);
    blurCells.run(mesh, gaussianRadiusCell);
    if(updateView) {
      mesh->setShowLabel("Label");
      updateState();
      updateViewer();
    }
  
    // Find local minima
    MeshLocalMinima minima(*this);
    minima.run(mesh, localMinimaRadius);
    if(updateView) {
      updateState();
      updateViewer();
    }
  
    // Restore mesh signal
    for(size_t i = 0; i < vs.size(); ++i)
      vs[i]->signal = MeshSignal[i];
  
    // Blur the mesh for segmentation and merging over-segmented regions
    MeshGaussianBlur blurWalls(*this);
    blurWalls.run(mesh, gaussianRadiusWall);
    if(updateView) {
      updateState();
      updateViewer();
    }
  
    // Normalize Mesh Signal if required
    //if(normalize) {
      //MeshNormalize normal(*this);
      //normal.run(mesh, normalizeRadius);
      //if(updateView) {
        //updateState();
        //updateViewer();
      //}
    //}
  
    // Segment Mesh
    SegmentMesh segMesh(*this);
    segMesh.run(mesh, 0, vs);
    if(updateView) {
      mesh->setShowMeshLines(true);
      mesh->setMeshView("Cells");
      updateState();
      updateViewer();
    }
  
    // Merge over-segmented regions
    MeshCombineRegions combine(*this);
    combine.run(mesh, borderDist, threshold);
  
    return true;
  }
  REGISTER_PROCESS(MeshAutoSegment);
}
