#include <CellAtlasMeristem.hpp>

#include <Progress.hpp>
#include <Geometry.hpp>

#include <GraphUtils.hpp>

#include <CellAtlasUtils.hpp>

using namespace std;

namespace mgx
{


// this process is similar to analyze cells 3d, but a bit dumbed down, i.e., there is no bezier here and only
// centroids, volumes and radial distance are calculated
  bool AnalyzeMeristem::run(Mesh *m1, Mesh *m2, double volThreshold)
  {

    CellAtlasAttr *cellAtlasAttr;
    cellAtlasAttr = &m1->attributes().attrMap<int, CellAtlasData>("Cell Atlas 3D Data");

    CellAtlasConfigAttr *cellAtlasConfigAttr;
    cellAtlasConfigAttr = &m1->attributes().attrMap<int, CellAtlasConfig>("Cell Atlas 3D Config");

    //int numCells = (*cellAtlasAttr).size();

    RootCellProcessing rcp;

    progressStart("Analyzing Meristem", 0);
    // meshes
    const vvGraph& segmentedMesh = m1->graph();
    const vvGraph& surfaceMesh = m2->graph();

    VVGraphVec cellVertex;
    VtxIntMap vertexCell;

    progressStart("Analyze Cells 3D - Find Cells", 0);
    std::vector<int> uniqueLabels = findAllLabels(segmentedMesh);

    rcp.rootData.numCells = uniqueLabels.size();
    rcp.rootData.uniqueLabels = uniqueLabels;

    labelVertexMap lvMap;
    std::vector<vertex> vertexMap;

    // create label - vertex map
    progressStart("Analyze Cells 3D - Create label/vertex map", 0);
    forall(const vertex &v, segmentedMesh) {
      //if(!progress.advance(1)) userCancel();
      lvMap[v->label].push_back(v);
    }

    // analyze all cells

    std::map<int, triVector> cellTriangles;

    RootCellAnalyzing rca(cellAtlasAttr,cellAtlasConfigAttr);

    if(!rca.generateLabelTriangleMap(rcp.rootData.numCells, uniqueLabels, segmentedMesh, lvMap, cellTriangles)) return false;

    if(!rca.analyzeCellCentroidsVolumes(cellTriangles, rcp, volThreshold)) return false;


    std::map<int, double> surfaceArclengths;
    std::map<int, double> surfaceRadialDis;
    std::map<int, double> arclengths;

    std::map<int, Point3d> bMapDummy;

    rca.analyzeCellCalcDisSurface(rcp, surfaceMesh, false, surfaceArclengths, surfaceRadialDis, bMapDummy, true);

    if(!neighborhoodGraphLocal(segmentedMesh, 0.0001, rcp.rootData.cellWallArea, rcp.rootData.wallArea, rcp.rootData.outsideWallArea))
      userCancel();

    writeAttrMaps(rcp, cellAtlasAttr, cellAtlasConfigAttr);

    return true;
  }
  REGISTER_PROCESS(AnalyzeMeristem);


 bool getIncludedCellCenters(CellAtlasAttr *cellAtlasAttr, double coneAngle, std::map<int, std::vector<int> >& includedCells)
  {
   int numCells = (*cellAtlasAttr).size();
    // get L1
    //#pragma omp parallel for
    for(int i=0; i<numCells; i++){
      int currentLabel = (*cellAtlasAttr)[i].cellLabel; // rcp.rootData.uniqueLabels[i];

      Point3d currentCentroid = (*cellAtlasAttr)[i].centroid;  //rcp.rootData.cellCentroids[currentLabel];
      Point3d currentConeBase = (*cellAtlasAttr)[i].nearestSurfacePoint; //rcp.rootData.nearestSurfacePoint[currentLabel];

      Point3d axisVec = currentCentroid - currentConeBase;

      std::vector<int> included;

      for(int j=0; j<numCells; j++){
        int checkLabel = (*cellAtlasAttr)[j].cellLabel;

        if(pointInCone(axisVec, currentCentroid, currentConeBase, coneAngle, (*cellAtlasAttr)[j].centroid))
          included.push_back(checkLabel);
      }
      includedCells[currentLabel] = included;

    }
    return true;
  }

 std::vector<int> getIncludedCellCenters(int label, double coneAngle, AttrMap<int, Point3d>& centroids, AttrMap<int, Point3d>& nearestSurfacePoint)
  {
    std::vector<int> includedCells;

    Point3d currentCentroid = centroids[label];
    Point3d currentConeBase = nearestSurfacePoint[label];

    Point3d axisVec = currentCentroid - currentConeBase;

    forall(auto p, centroids){
      int checkLabel = p.first;
      if(checkLabel == label) continue;

      if(pointInCone(axisVec, currentCentroid, currentConeBase, coneAngle, p.second))
        includedCells.push_back(checkLabel);
    }

    return includedCells;
  }

/*
  bool getIncludedVertices(RootCellProcessing& rcp, double coneAngle, const vvGraph& S, std::map<int, std::vector<int> >& includedVtxs)
  {
    int numCells = rcp.rootData.uniqueLabels.size();
    // get L1
    //#pragma omp parallel for
    for(int i=0; i<numCells; i++){
      int currentLabel = rcp.rootData.uniqueLabels[i];

      Point3d currentCentroid = rcp.rootData.cellCentroids[currentLabel];
      Point3d currentConeBase = rcp.rootData.nearestSurfacePoint[currentLabel];

      Point3d axisVec = currentCentroid - currentConeBase;

      //int inCounter = 0;
      std::vector<int> included;

      for(u_int i = 0; i<S.size();i++){
        const vertex& v = S[i];

        if(pointInCone(axisVec, currentCentroid, currentConeBase, coneAngle, v->pos))
          included.push_back(v->label);
      }
      includedVtxs[currentLabel] = included;

    }

  }
*/
// detect L1 L2 and L3 by checking how many other cell centers lie in a cone from cell centers and the surface mesh
  bool DetectLayers::run(Mesh *m1, double coneAngle, bool detectByNeighborhood)
  {

  CellAtlasAttr *cellAtlasAttr;
  cellAtlasAttr = &m1->attributes().attrMap<int, CellAtlasData>("Cell Atlas 3D Data");

  CellAtlasConfigAttr *cellAtlasConfigAttr;
  cellAtlasConfigAttr = &m1->attributes().attrMap<int, CellAtlasConfig>("Cell Atlas 3D Config");

  int numCells = (*cellAtlasAttr).size();

  if(numCells == 0 or (*cellAtlasConfigAttr).size() == 0){
    return setErrorMessage("The Cell Atlas Data Attribute Map is empty. Please run Analyze Cells first.");
  }

    //RootCellProcessing rcp;

    progressStart("Detecting Layers", 0);
    m1->parents().clear();
    //const vvGraph& S = m1->graph();

    //int numCells = rcp.rootData.uniqueLabels.size();

    std::map<int, std::vector<int> > includedVertices;

    //if(useVtxs)
    //  getIncludedVertices(rcp, coneAngle, S, includedVertices);
    //else // use cell centers only
    // find cells included in the cones
    getIncludedCellCenters(cellAtlasAttr, coneAngle, includedVertices);

    // detect L1
    for(int i=0; i<numCells; i++){
      int currentLabel = (*cellAtlasAttr)[i].cellLabel;//rcp.rootData.uniqueLabels[i];
      if(includedVertices[currentLabel].size() == 0)
        m1->parents()[currentLabel] = 1;

    }


    if(detectByNeighborhood){ // detect L2 and L3 by neighborhood

      double threshold = 0.001;
      // L2 is all cells that neighbor L1 and are not L1
      forall(const IntIntDouPair& p, (*cellAtlasConfigAttr)[0].wallArea){
        if(m1->parents()[p.first.first] == 1 and m1->parents()[p.first.second] != 1 and p.second>threshold)
          m1->parents()[p.first.second] = 2;
      }
      // L3 is all cells that neighbor L2 and are not L1 or L2
      forall(const IntIntDouPair& p, (*cellAtlasConfigAttr)[0].wallArea){
        if(m1->parents()[p.first.first] == 2 and m1->parents()[p.first.second] != 1 and m1->parents()[p.first.second] != 2 and p.second>threshold)
          m1->parents()[p.first.second] = 3;
      }

    } else { // old method, check included cells
      // get L2
      // as the map label -> centroids in cone (includedVertices) is already created, just go through the centroids
      // and check whether they are L1 or not
      //#pragma omp parallel for
      for(int i=0; i<numCells; i++){
        int currentLabel = (*cellAtlasAttr)[i].cellLabel;//rcp.rootData.uniqueLabels[i];

        if(m1->parents()[currentLabel] == 1) continue;
        std::vector<int> included = includedVertices[currentLabel];

        bool isL2 = true;
        forall(int j, included)
          if(m1->parents()[j] != 1) isL2 = false;

        if(isL2)
          m1->parents()[currentLabel] = 2;

        //std::cout << "here " << currentLabel << "/" << m1->parents()[currentLabel] << "/" << included.size() << std::endl;
      }

      // get L3
      // same as L2, but now also check for L2 in the cone
      //#pragma omp parallel for
      for(int i=0; i<numCells; i++){
        int currentLabel = (*cellAtlasAttr)[i].cellLabel;//rcp.rootData.uniqueLabels[i];
        if(m1->parents()[currentLabel] == 1 or m1->parents()[currentLabel] == 2) continue;

        std::vector<int> included = includedVertices[currentLabel];

        bool isL3 = true;
        forall(int i, included)
          if(m1->parents()[i] != 1 and m1->parents()[i] != 2) isL3 = false;

        if(isL3)
          m1->parents()[currentLabel] = 3;
      }
    }
    m1->updateTriangles();
    m1->setUseParents(true);
    return true;
  }
  REGISTER_PROCESS(DetectLayers);

// detect cell layers
  bool DetectCellLayers::run(Mesh *m1, bool useSurface, double coneAngle, int layers, double minWall, double minVolume)
  {

  if(layers < 1) return true;

  CellAtlasAttr *cellAtlasAttr;
  cellAtlasAttr = &m1->attributes().attrMap<int, CellAtlasData>("Cell Atlas 3D Data");

  CellAtlasConfigAttr& cellAtlasConfigAttr = m1->attributes().attrMap<int, CellAtlasConfig>("Cell Atlas 3D Config");

  int numCells = (*cellAtlasAttr).size();

  std::map<IntInt, double> wallArea = cellAtlasConfigAttr[0].wallArea;

  if(numCells == 0 or cellAtlasConfigAttr.empty()){
    return setErrorMessage("The Cell Atlas Data Attribute Map is empty. Please run Analyze Cells first.");
  }

    progressStart("Detecting Layers", 0);
    m1->parents().clear();

    std::map<int, std::vector<int> > includedVertices;

    getIncludedCellCenters(cellAtlasAttr, coneAngle, includedVertices);

    // detect L1
    for(int i=0; i<numCells; i++){
      int currentLabel = (*cellAtlasAttr)[i].cellLabel;//rcp.rootData.uniqueLabels[i];
      if(includedVertices[currentLabel].size() == 0)
        m1->parents()[currentLabel] = 1;

    }

    int parentLabel = 1;

    while(layers > 1){
      layers--;
      parentLabel++;

      std::set<int> neighbors;

      // detect the next layer
      for(int i=0; i<numCells; i++){
        int currentLabel = (*cellAtlasAttr)[i].cellLabel;
        if(m1->parents()[currentLabel] != parentLabel-1) continue;
        // go through neighbors and label them
        forall(IntIntDouPair p, wallArea){
          if(p.first.first == currentLabel and p.second > minWall and m1->parents()[p.first.second] < 1){
            neighbors.insert(p.first.second);
          }
        }

      }

      forall(int l, neighbors){
        m1->parents()[l] = parentLabel;
      }
    }

    m1->updateTriangles();
    m1->setUseParents(true);
    return true;
  }
  REGISTER_PROCESS(DetectCellLayers);


// detect cell layers
  bool DetectCellLayers2::run(Mesh *m1, bool useSurface, double coneAngle, int layers, double minWall, double minVolume)
  {

  if(layers < 1) return setErrorMessage("Layers must be >=1");

  AttrMap<int, Point3d>& centroids = m1->attributes().attrMap<int, Point3d>("Measure Label Vector CellCentroids");
  AttrMap<int, Point3d>& nearestSurfacePoint = m1->attributes().attrMap<int, Point3d>("Measure Label Vector SurfacePoint");

  //int numCells = (*cellAtlasAttr).size();

  //std::map<IntInt, double> wallArea = cellAtlasConfigAttr[0].wallArea;
  AttrMap<IntInt, double>& wallArea = m1->attributes().attrMap<IntInt, double>("Shared Wall Areas");

  if(centroids.empty()){
    return setErrorMessage("Please run Cell Analysis 3D first.");
  }

    progressStart("Detecting Layers", 0);
    m1->parents().clear();

    std::map<int, std::vector<int> > includedVertices;



    //getIncludedCellCenters(cellAtlasAttr, coneAngle, includedVertices);

    // detect L1
    forall(auto p, centroids){ //for(int i=0; i<numCells; i++){
      int currentLabel = p.first; //(*cellAtlasAttr)[i].cellLabel;//rcp.rootData.uniqueLabels[i];

      std::vector<int> includedVertices = getIncludedCellCenters(currentLabel, coneAngle, centroids, nearestSurfacePoint);
 
      if(includedVertices.empty())
        m1->parents()[currentLabel] = 1;
    }

    int parentLabel = 1;

    while(layers > 1){
      layers--;
      parentLabel++;

      std::set<int> neighbors;

      // detect the next layer
      forall(auto p, centroids){ //for(int i=0; i<numCells; i++){
        int currentLabel = p.first; //(*cellAtlasAttr)[i].cellLabel;
        if(m1->parents()[currentLabel] != parentLabel-1) continue;
        // go through neighbors and label them
        forall(IntIntDouPair p, wallArea){
          if(p.first.first == currentLabel and p.second > minWall and m1->parents()[p.first.second] < 1){
            neighbors.insert(p.first.second);
          }
        }

      }

      forall(int l, neighbors){
        m1->parents()[l] = parentLabel;
      }
    }

    m1->updateTriangles();
    m1->setUseParents(true);
    return true;
  }
  REGISTER_PROCESS(DetectCellLayers2);



  bool MarkMeristem::run(Mesh *m1, double disDown, double radius)
  {

  CellAtlasAttr *cellAtlasAttr;
  cellAtlasAttr = &m1->attributes().attrMap<int, CellAtlasData>("Cell Atlas 3D Data");

  CellAtlasConfigAttr *cellAtlasConfigAttr;
  cellAtlasConfigAttr = &m1->attributes().attrMap<int, CellAtlasConfig>("Cell Atlas 3D Config");

  int numCells = (*cellAtlasAttr).size();

  if(numCells == 0 or (*cellAtlasConfigAttr).size() == 0){
    return setErrorMessage("The Cell Atlas Data Attribute Map is empty. Please run Analyze Cells first.");
  }

    int selectedLabel1 = 0;

    vvGraph& m1Graph = m1->graph();

    // find selected top cell
    progressStart("Label Meristem", 0);

    selectedLabel1 = findSelectedLabel(m1Graph);

    cout << "Selected Cell: " << selectedLabel1 << endl;
    if(selectedLabel1 == 0){
      setErrorMessage("Warning: No top cell selected");
      return false;
    }
    int labelFirstCell = selectedLabel1;
    (*cellAtlasConfigAttr)[0].labelFirstCell = labelFirstCell;

    int idxSelectedLabel1 = (*cellAtlasConfigAttr)[0].labelIdxMap[labelFirstCell];


    Point3d topCentroid = (*cellAtlasAttr)[idxSelectedLabel1].centroid;//  rcp.rootData.cellCentroids[selectedLabel1];
    Point3d topDownDir = topCentroid - (*cellAtlasAttr)[idxSelectedLabel1].nearestSurfacePoint;//rcp.rootData.nearestSurfacePoint[selectedLabel1];
    topDownDir/=norm(topDownDir);

    // Find nearest L3 cell
    double minDis = 1E20;

    std::map<double, int> distanceLabelMap;
    typedef std::pair<double, int> DoubleInt;

    for(int i=0; i<numCells; i++){
      int currentLabel = (*cellAtlasAttr)[i].cellLabel;//rcp.rootData.uniqueLabels[i];

      if(m1->parents()[currentLabel] == 3){
        double currentDis = norm(topCentroid - (*cellAtlasAttr)[i].centroid);
        distanceLabelMap[currentDis] = currentLabel;
        if(minDis > currentDis){
          minDis = currentDis;
        }
      }
    }

    int counter = 0;
    int nrOfCellForAvg = 4;

    double meanDistance = 0;
    Point3d dirDown (0.0,0.0,0.0);

    forall(DoubleInt p, distanceLabelMap){
      //m1->parents()[p.second] = 4;
      meanDistance+=p.first;
      dirDown+=(*cellAtlasAttr)[p.second].centroid - topCentroid;
      counter++;
      if(counter > nrOfCellForAvg-1) break;
    }
    meanDistance/=nrOfCellForAvg;
    dirDown/=nrOfCellForAvg;
    dirDown/=norm(dirDown);

    double disDownWeighted = disDown*meanDistance;
    double radiusWeighted = radius*meanDistance;

    Point3f orgCenter(topCentroid + dirDown * (disDownWeighted + radiusWeighted));

    for(int i=0; i<numCells; i++){

      int currentLabel = (*cellAtlasAttr)[i].cellLabel;

      if(m1->parents()[currentLabel] == 0)
        m1->parents()[currentLabel] = 3;

      double currentDisOrg = norm(orgCenter - Point3f((*cellAtlasAttr)[i].centroid));
      double currentDisTop = norm(topCentroid - (*cellAtlasAttr)[i].centroid);
      if(currentDisOrg <= radiusWeighted){
        m1->parents()[currentLabel] = 4;
      }
      if(currentDisTop < radiusWeighted){
        if(m1->parents()[currentLabel] == 1)
          m1->parents()[currentLabel] = 5;
        if(m1->parents()[currentLabel] == 2)
          m1->parents()[currentLabel] = 6;
        if(m1->parents()[currentLabel] == 3)
          m1->parents()[currentLabel] = 7;
      }
    }

    m1->updateTriangles();
    m1->setUseParents(true);

    return true;
  }
  REGISTER_PROCESS(MarkMeristem);


bool LabelMeristem::run(Mesh *m1, Mesh *m2, double minVolume, double disDown, double radius, double coneAngle)
  {

    AnalyzeMeristem am(*this);
    am.run(m1, m2, minVolume);

    DetectLayers dl(*this);
    dl.run(m1, coneAngle, false);

    MarkMeristem mm(*this);
    mm.run(m1, disDown, radius);

    return true;
  }
  REGISTER_PROCESS(LabelMeristem);


 int lastDigit(int numberToTest)
{
  int last = 0;
  while((numberToTest - last) % 10 != 0){
    last++;
  }
  return last;
}


bool MarkPrim::run(Mesh *m1, double ratioParm, double absDisParm, bool primLabelSame, int labelSaddle, int labelPrim)
  {

  CellAtlasAttr *cellAtlasAttr;
  cellAtlasAttr = &m1->attributes().attrMap<int, CellAtlasData>("Cell Atlas 3D Data");

  CellAtlasConfigAttr *cellAtlasConfigAttr;
  cellAtlasConfigAttr = &m1->attributes().attrMap<int, CellAtlasConfig>("Cell Atlas 3D Config");

  int numCells = (*cellAtlasAttr).size();

  if(numCells == 0 or (*cellAtlasConfigAttr).size() == 0){
    return setErrorMessage("The Cell Atlas Data Attribute Map is empty. Please run Analyze Cells first.");
  }

    //RootCellProcessing rcp;

   // int numCells = rcp.rootData.uniqueLabels.size();

    vvGraph& m1Graph = m1->graph();

    // find selected top cell
    progressStart("Mark Meristem - Find Selected Cell", 0);

    int selectedLabel1 = 0;
    int selectedLabel2 = 0;
    findTwoSelectedLabels(m1Graph, selectedLabel1, selectedLabel2);

    cout << "Selected Cells: " << selectedLabel1 << " and " << selectedLabel2 << endl;
    if((selectedLabel1 == 0 or selectedLabel2 == 0) and (labelSaddle == 0 or labelPrim == 0)){
      setErrorMessage("Warning: No cells selected and specified. Please select or specifiy the saddle and primordium center cells.");
      return false;
    }
    //if(rcp.rootData.labelFirstCell == 0){
    //  setErrorMessage("Error: This process requires the Mark Meristem process.");
    //  return false;
    //}

      int idxSelectedLabel1;
      int idxSelectedLabel2;

    if((selectedLabel1 == 0 or selectedLabel2 == 0)){
      idxSelectedLabel1 = (*cellAtlasConfigAttr)[0].labelIdxMap[labelPrim];
      idxSelectedLabel2 = (*cellAtlasConfigAttr)[0].labelIdxMap[labelSaddle];
    } else {
      idxSelectedLabel1 = (*cellAtlasConfigAttr)[0].labelIdxMap[selectedLabel1];
      idxSelectedLabel2 = (*cellAtlasConfigAttr)[0].labelIdxMap[selectedLabel2];
    }

    Point3d primCoord = (*cellAtlasAttr)[idxSelectedLabel1].centroid;
    Point3d saddleCoord = (*cellAtlasAttr)[idxSelectedLabel2].centroid;
    int primLabel;

    if(primLabelSame){
      primLabel = 0;

    } else {
      (*cellAtlasConfigAttr)[0].primCounter++;
      //rcp.rootData.primCounter++;
      primLabel = 2000*(*cellAtlasConfigAttr)[0].primCounter;
    }

    int labelFirstCell = (*cellAtlasConfigAttr)[0].labelFirstCell;
    int idxFirstCell = (*cellAtlasConfigAttr)[0].labelIdxMap[labelFirstCell];

    Point3d meriCoord = (*cellAtlasAttr)[idxFirstCell].centroid;//rcp.rootData.cellCentroids[rcp.rootData.labelFirstCell];

    if(norm(primCoord - meriCoord) < norm(saddleCoord - meriCoord)){
      primCoord = saddleCoord;
      saddleCoord = (*cellAtlasAttr)[idxSelectedLabel1].centroid;//rcp.rootData.cellCentroids[selectedLabel1];
    }

    double disMeristemSaddle = norm(meriCoord - saddleCoord);
    double disPrimSaddle = norm(primCoord - saddleCoord);
    double ratioMP = disMeristemSaddle / disPrimSaddle;

    double disLowerBound = ratioParm * ratioMP;

    for(int i=0; i<numCells; i++){
      int currentLabel = (*cellAtlasAttr)[i].cellLabel; //rcp.rootData.uniqueLabels[i];

      double currentDisM = norm(meriCoord - (*cellAtlasAttr)[i].centroid);
      double currentDisP = norm(primCoord - (*cellAtlasAttr)[i].centroid);
      double currentRatio = currentDisM / currentDisP;

      if(currentRatio < ratioMP){ // meristem side

        if(currentRatio > disLowerBound and currentDisM < absDisParm * disMeristemSaddle)
          m1->parents()[currentLabel] = 9 + primLabel; // boundary
        else {
          // no boundary, meristem, everything stays as it is
        }
      } else { // primordium side
        if(currentRatio < disLowerBound and currentDisP < absDisParm * disPrimSaddle)
          m1->parents()[currentLabel] = 9 + primLabel; // boundary
        else {
          if(lastDigit(m1->parents()[currentLabel]) == 1){ //if(m1->parents()[currentLabel] == 1){
            m1->parents()[currentLabel] = 11 + primLabel;
          } else if(lastDigit(m1->parents()[currentLabel]) == 2){ //} else if(m1->parents()[currentLabel] == 2){
            m1->parents()[currentLabel] = 12 + primLabel;
          } else { //if(m1->parents()[currentLabel] == 3)
            m1->parents()[currentLabel] = 13 + primLabel;
          }
        }
      }
    }

    m1->updateTriangles();
    m1->setUseParents(true);

    return true;
  }
  REGISTER_PROCESS(MarkPrim);

}
