//
// This file is part of MorphoGraphX - http://www.MorphoGraphX.org
// Copyright (C) 2012-2015 Richard S. Smith and collaborators.
//
// If you use MorphoGraphX in your work, please cite:
//   http://dx.doi.org/10.7554/eLife.05864
//
// MorphoGraphX is free software, and is licensed under under the terms of the 
// GNU General (GPL) Public License version 2.0, http://www.gnu.org/licenses.
// 
#include "MeshProcessCreation.hpp"

#include "Progress.hpp"
#include "SystemProcessLoad.hpp"
#include "BoundingBox.hpp"
#include "Geometry.hpp"
#include "Information.hpp"
#include "Polygonizer.hpp"

#include <algorithm>
#include <SetVector.hpp>

#include <unistd.h>
#include "Tie.hpp"

// Number of evals per progress check for marching cubes
const int EVAL_PROGRESS = 10000;

namespace mgx 
{
  
  typedef Vector<3, vertex> Vtx3;
  typedef Vector<4, vertex> Vtx4;
  typedef Vector<5, vertex> Vtx5;
  typedef Vector<7, vertex> Vtx7;
  
  bool evalCheckProgress()
  {
    static int count = 0;
    if(count++ >= EVAL_PROGRESS) {
      count = 0;
      if(!progressAdvance(0))
        return false;
    }
    return true;
  }
  
  template <size_t n> void makeNhbd(vvGraph& S, Vector<n, vertex> V)
  {
    S.insertEdge(V[0], V[1]);
    for(size_t i = 2; i < n; ++i)
      S.spliceAfter(V[0], V[i - 1], V[i]);
  }
  
  // 3D surface evaluation function for marching cubes
  int getStackData(const Stack* stack, const Store* store, Point3f& p)
  {
    Point3f pf = stack->worldToImagef(p);
    // Be careful of float->int conversion when we are below 0
    for(int i = 0; i < 3; i++)
      if(pf[i] < 0)
        pf[i] -= 1;
    int x = int(pf.x());
    int y = int(pf.y());
    int z = int(pf.z());
  
    // Decide which stack to use
    const HVecUS& data = store->data();
    if(!stack->boundsOK(x, y, z))
      return (0);
    else
      return (data[stack->offset(x, y, z)]);
  }
  
  // Implicit surface evaluator functions, return true if inside
  bool MarchingCubeSurface::eval(Point3f p)
  {
    if(!evalCheckProgress())
      userCancel();
    int d = getStackData(_stack, _store, p);
    if(d >= _threshold)
      return true;
    else
      return false;
  }
  
  bool MarchingCubeSurface::run(Mesh* mesh, const Store* store, float cubeSize, int threshold)
  {
    // Store threshold for eval
    if(threshold <= 0)
      threshold = 1;
    _threshold = threshold;
  
    // Setup progress bar
    progressStart(QString("Marching Cubes Surface on Stack %1").arg(mesh->userId()), 0);
  
    // Clear mesh
    vvGraph& S = mesh->graph();
    mesh->reset();
  
    // Find bounding box
    _store = store;
    _stack = store->stack();
    const HVecUS& data = store->data();
    Point3u size = _stack->size();
  
    // Count pixels and create bounding region(s)
    long pixCount = 0;
    BoundingBox3u bBox;
    progressStart(QString("Finding object bounds-%1").arg(mesh->userId()), size.z());
    const ushort* pdata = data.data();
    int i = 0;
    for(uint z = 0; z < size.z(); z++) {
      if(!progressAdvance(z))
        userCancel();
      for(uint y = 0; y < size.y(); y++) {
        for(uint x = 0; x < size.x(); x++, ++pdata, ++i) {
          // For normal march only count "inside" pixels
          if(*pdata < threshold)
            continue;
          pixCount++;
  
          // Find bounding box
          if(pixCount == 1) {
            bBox = BoundingBox3u(Point3u(x, y, z));
          } else {
            bBox |= Point3u(x, y, z);
          }
        }
      }
    }
  
    // Polygonizer object
    std::vector<Point3f> vertList;
    std::vector<Point3i> triList;
    Polygonizer pol(*this, cubeSize, _stack->step(), vertList, triList);
  
    // Get bounding box in world coordinates
    BoundingBox3f bBoxWorld = _stack->imageToWorld(bBox);
  
    SETSTATUS("Computing segmentation for threshold " << threshold << " with bBox of: " << bBoxWorld[0] 
      << " - " << bBoxWorld[1] << " cube size:" << cubeSize << " pixel count:" << pixCount);
  
    // Call polygonizer, catch any errors
    triList.clear();
    Point3f marchBBox[2] = { bBoxWorld.pmin(), bBoxWorld.pmax() };
    pol.march(marchBBox);
  
    // Points to VVe vertices
    std::vector<vertex> vertices;
    vertices.resize(vertList.size());
    for(size_t i = 0; i < vertList.size(); ++i) {
      vertex v;
      v->pos = Point3d(vertList[i]);
      v->saveId = i;
      v->label = 0;
      v->signal = 1.0;
      vertices[i] = v;
    }
  
    // Insert vertices and triangles into mesh
    if(!Mesh::meshFromTriangles(S, vertices, triList)) {
      setErrorMessage("Error, cannot add all the triangles");
      return false;
    }
  
    SETSTATUS("Triangles added:" << triList.size() << ", vertices:" << S.size());
  
    mesh->signalBounds() = Point2f(0.0, 1.0);
    mesh->updateAll();
    mesh->setShowLabel("Label");
  
    return true;
  }
  REGISTER_PROCESS(MarchingCubeSurface);
  
  bool MarchingCube3D::eval(Point3f p)
  {
    if(!evalCheckProgress())
      userCancel();
    int d = getStackData(_stack, _store, p);
    if(d == _label)
      return true;
    else
      return false;
  }
  
  bool MarchingCube3D::run(Mesh* mesh, const Store* store, float cubeSize, int minVoxels, int smooth,
                                  int singleLabel)
  {
    bool single = false;
    if(singleLabel > 0)
      single = true;
  
    // Setup progress bar
    progressStart(QString("Marching Cubes 3D on Stack %1").arg(mesh->userId()), 0);

    // Clear mesh
    vvGraph& S = mesh->graph();
    vvGraph D;
  
    // If re-marching a single label, delete the old mesh part, otherwise reset it
    if(single) {
      forall(const vertex& v, S)
        if(v->label == singleLabel)
          D.insert(v);
      mesh->getConnectedVertices(S, D);
      forall(const vertex& v, D)
        S.erase(v);
    } else
      mesh->reset();
  
    // Find bounding box
    _store = store;
    _stack = store->stack();
    const HVecUS& data = store->data();
    Point3u size = _stack->size();
  
    // Count pixels and create bounding region(s)
    if(minVoxels <= 0)
      minVoxels = 1;
    int labelCount = 0;
    std::vector<long> pixCnt(65536, 0);
    std::vector<BoundingBox3u> bBoxImg(65536, BoundingBox3u());
    const ushort* pdata = data.data();
    int i = 0;
    for(uint z = 0; z < size.z(); z++) {
      for(uint y = 0; y < size.y(); y++) {
        for(uint x = 0; x < size.x(); x++, ++pdata, ++i) {
          // Count pixels
          int label = *pdata;
          if(single and label != singleLabel)
            continue;
          pixCnt[label]++;
          if(pixCnt[label] == minVoxels)
            labelCount++;
  
          // Find bounding box
          if(pixCnt[label] == 1) {
            bBoxImg[label] = BoundingBox3u(Point3u(x, y, z));
          } else {
            bBoxImg[label] |= Point3u(x, y, z);
          }
        }
      }
    }
  
    // Polygonizer object
    std::vector<Point3f> vertList;
    std::vector<Point3i> triList;
    Polygonizer pol(*this, cubeSize, _stack->step(), vertList, triList);
  
    // Call marching cubes for each label
    int maxLabel = 0;
    int startLabel = 1;
    if(single)
      startLabel = singleLabel;
    int endLabel = 65535;
    if(single)
      endLabel = singleLabel;
  
    for(int label = startLabel; label <= endLabel; ++label) {
      if(pixCnt[label] < minVoxels)
        continue;
      if(label > maxLabel)
        maxLabel = label;
  
      // Get bounding box in world coordinates
      BoundingBox3f bBox = _stack->imageToWorld(bBoxImg[label]);
      Point3f bBoxWorld[2] = { bBox.pmin(), bBox.pmax() };
  
      SETSTATUS("Computing segmentation for label " << label << " with bbox of: " << bBoxWorld[0] << " - "
                                                    << bBoxWorld[1] << " cube size:" << cubeSize
                                                    << " pixels:" << pixCnt[label]);
  
      // Store label for eval
      _label = label;
  
      // Call polygonizer
      triList.clear();
      pol.march(bBoxWorld);
  
      // Points to VVe vertices
      std::vector<vertex> vertices;
      // For now generate new vertices each time through
      std::map<int, int> vertMap;
      std::map<int, int>::iterator vertIter;
      forall(Point3i& tri, triList) {
        for(int i = 0; i < 3; i++) {
          if((vertIter = vertMap.find(tri[i])) != vertMap.end()) {
            tri[i] = vertIter->second;
          } else {
            vertex v;
            v->pos = Point3d(vertList[tri[i]]);
            v->saveId = tri[i];
            v->label = label;
            v->signal = 1.0;
            vertMap[tri[i]] = vertices.size();
            tri[i] = vertices.size();
            vertices.push_back(v);
          }
        }
      }
  
      // Insert vertices and triangles into mesh
      if(!Mesh::meshFromTriangles(S, vertices, triList)) {
        setErrorMessage("Error, cannot add all the triangles");
        return false;
      }
  
      SETSTATUS("Label:" << label << ", triangles added:" << triList.size() << ", vertices:" << S.size());
    }
  
    // smooth points but keep them attached
    if(!single) {
      for(int i = 0; i < smooth; i++) {
        std::map<int, Point3d> posMap;
        std::map<int, int> cntMap;
        forall(const vertex& v, S)
          forall(const vertex& n, S.neighbors(v)) {
            posMap[v->saveId] += n->pos;
            cntMap[v->saveId]++;
          }
        forall(const vertex& v, S)
          v->pos = posMap[v->saveId] / cntMap[v->saveId];
      }
    }
  
    // Update next label if required
    if(maxLabel > mesh->viewLabel())
      mesh->setLabel(maxLabel + 1);
  
    mesh->signalBounds() = Point2f(0.0, 1.0);
    mesh->updateAll();
    mesh->setShowLabel("Label");
  
    return true;
  }
  REGISTER_PROCESS(MarchingCube3D);
  
  bool CuttingSurfMesh::run(Mesh* mesh)
  {
    CuttingSurface* cutSurf = cuttingSurface();
    if(!cutSurf) {
      setErrorMessage("You need an active cutting surface to convert it into a mesh");
      return false;
    }
    std::vector<Point3f> points;
    int uSize = 0, vSize = 0;
    cutSurf->getSurfPoints(&mesh->stack()->frame(), points, uSize, vSize);
    if(points.empty() or uSize < 2 or vSize < 2)
      return true;
  
    // Create a new graph
    vvGraph S;
  
    std::vector<vertex> empty_vector(vSize, vertex(0));   // VC++ bug requires this
    std::vector<std::vector<vertex> > vtx(uSize, empty_vector);
  
    // Create the vertices and fill in the position
    for(int u = 0; u < uSize; u++) {
      for(int v = 0; v < vSize; v++) {
        vertex a;
        S.insert(a);
        vtx[u][v] = a;
        a->pos = Point3d(points[u * vSize + v]);
        a->signal = 1.0;
      }
    }
  
    // Connect neighborhoods
    for(int u = 0; u < uSize; u++)
      for(int v = 0; v < vSize; v++) {
        if(u == 0 and v == vSize - 1)       // top left corner;
          makeNhbd(S, Vtx3(vtx[u][v], vtx[u][v - 1], vtx[u + 1][v]));
        else if(u == 0 and v == 0)       // bottom left corner
          makeNhbd(S, Vtx3(vtx[u][v], vtx[u + 1][v], vtx[u][v + 1]));
        else if(u == uSize - 1 and v == 0)       // bottom right corner
          makeNhbd(S, Vtx3(vtx[u][v], vtx[u][v + 1], vtx[u - 1][v]));
        else if(u == uSize - 1 and v == vSize - 1)       // top right corner
          makeNhbd(S, Vtx3(vtx[u][v], vtx[u - 1][v], vtx[u][v - 1]));
        else if(v == vSize - 1)       // top edge
          makeNhbd(S, Vtx4(vtx[u][v], vtx[u - 1][v], vtx[u][v - 1], vtx[u + 1][v]));
        else if(u == 0)       // left edge
          makeNhbd(S, Vtx4(vtx[u][v], vtx[u][v - 1], vtx[u + 1][v], vtx[u][v + 1]));
        else if(v == 0)       // bottom edge
          makeNhbd(S, Vtx4(vtx[u][v], vtx[u + 1][v], vtx[u][v + 1], vtx[u - 1][v]));
        else if(u == uSize - 1)       // right edge
          makeNhbd(S, Vtx4(vtx[u][v], vtx[u][v + 1], vtx[u - 1][v], vtx[u][v - 1]));
        else       // Interior vertex
          makeNhbd(S, Vtx5(vtx[u][v], vtx[u + 1][v], vtx[u][v + 1], vtx[u - 1][v], vtx[u][v - 1]));
      }
  
    // At this point we have a quad mesh, triangulate it (nicely)
    vvGraph T = S;
    forall(const vertex& v, T)
      forall(const vertex& n, T.neighbors(v)) {
        vertex m = S.nextTo(v, n);
        vertex w = S.nextTo(m, v);
        // Unique quad found
        if(w == S.prevTo(n, v) and v > w and n > m) {
          vertex a;
          a->pos = (v->pos + w->pos + n->pos + m->pos) / 4.0f;
          S.insert(a);
          makeNhbd(S, Vtx5(a, v, n, w, m));
  
          S.spliceAfter(v, n, a);
          S.spliceAfter(n, w, a);
          S.spliceAfter(w, m, a);
          S.spliceAfter(m, v, a);
        }
      }
  
    mesh->graph().swap(S);
  
    SETSTATUS("Mesh " << mesh->userId() << " made from cutting surface");
  
    mesh->setMeshType("MGXM");
    mesh->signalBounds() = Point2f(0.0, 1.0);
    mesh->setShowLabel("Label");
    mesh->updateAll();
    return true;
  }
  REGISTER_PROCESS(CuttingSurfMesh);

  bool VoxelFaceMesh::run(Stack *stack, Store *store, Mesh *mesh)
  {
    // Create a new graph
    vvGraph S;

    // Map of triangle lists
    typedef Vector<3, ulong> Point3ul;
    typedef std::vector<Point3ul> TriVec;
    std::map<int, TriVec> cellTris;

    Point3u size = stack->size() + Point3u(2, 2, 2);
    HVecUS data(size.x() * size.y() * size.z());

    // Make a padded copy of the data
    for(size_t i = 0; i < data.size(); ++i)
      data[i] = 0;

    for(uint z = 1; z < size.z() - 1; ++z)
      for(uint y = 1; y < size.y() - 1; ++y)
        for(uint x = 1; x < size.x() - 1; ++x)
          data[getOffset(x, y, z, size)] = store->data()[getOffset(x-1, y-1, z-1, stack->size())];

    // Loop through the voxels
    for(uint z = 1; z < size.z() - 1; ++z)
      for(uint y = 1; y < size.y() - 1; ++y)
        for(uint x = 1; x < size.x() - 1; ++x) {
          ulong off = getOffset(x, y, z, size);
          int label = data[off];
          // Skip background
          if(label == 0)
            continue;

          // Offsets for neighbors
          ulong offx = getOffset(x-1, y, z, size);
          ulong offX = getOffset(x+1, y, z, size);
          ulong offy = getOffset(x, y-1, z, size);
          ulong offY = getOffset(x, y+1, z, size);
          ulong offz = getOffset(x, y, z-1, size);
          ulong offZ = getOffset(x, y, z+1, size);

          ulong offXY = getOffset(x+1, y+1, z, size);
          ulong offXZ = getOffset(x+1, y, z+1, size);
          ulong offYZ = getOffset(x, y+1, z+1, size);
          ulong offXYZ = getOffset(x+1, y+1, z+1, size);

          // Check all neighbors
          if(label != data[offx] and y < size.y() - 1 and z < size.z() - 1) {
            cellTris[label].push_back(Point3ul(off, offZ, offYZ));
            cellTris[label].push_back(Point3ul(off, offYZ, offY));
          }
          if(label != data[offX]) {
            cellTris[label].push_back(Point3ul(offX, offXYZ, offXZ));
            cellTris[label].push_back(Point3ul(offX, offXY, offXYZ));
          }
          if(label != data[offy]) {
            cellTris[label].push_back(Point3ul(off, offXZ, offZ));
            cellTris[label].push_back(Point3ul(off, offX, offXZ));
          }
          if(label != data[offY]) {
            cellTris[label].push_back(Point3ul(offY, offYZ, offXYZ));
            cellTris[label].push_back(Point3ul(offY, offXYZ, offXY));
          }
          if(label != data[offz]) {
            cellTris[label].push_back(Point3ul(off, offXY, offX));
            cellTris[label].push_back(Point3ul(off, offY, offXY));
          }
          if(label != data[offZ]) {
            cellTris[label].push_back(Point3ul(offZ, offXZ, offXYZ));
            cellTris[label].push_back(Point3ul(offZ, offXYZ, offYZ));
          }
        }

    // Loop though triangle lists and create cells
    typedef std::pair<int, TriVec> IntTriPair;
    forall(const IntTriPair &pr, cellTris) {
      // Create vertices
      std::map<ulong, vertex> vMap;
      for(uint i = 0; i < pr.second.size(); ++i) {
        for(int j = 0; j < 3; ++j) {
          ulong off = pr.second[i][j];
          if(vMap.count(off) == 0) {
            vertex v;
            v->label = pr.first;
            Point3u img(off % size.x(), off / size.x() % size.y(), off / size.x() / size.y());
            v->pos = Point3d(stack->imageToWorld(img) - stack->step() * (3.0/2.0));
            vMap[off] = v;
          }
        }
      }
      // Create vertex list
      std::vector<vertex> vVec;
      typedef std::pair<ulong, vertex> ULongVtxPair;
      ulong saveId = 0;
      forall(const ULongVtxPair &pr, vMap) {
        pr.second->saveId = saveId++;
        vVec.push_back(pr.second);
      }

      // Create triangle list
      std::vector<Point3i> tVec;
      for(uint i = 0; i < pr.second.size(); ++i)
        tVec.push_back(Point3i(vMap[pr.second[i].x()]->saveId, vMap[pr.second[i].y()]->saveId, vMap[pr.second[i].z()]->saveId));

      if(!Mesh::meshFromTriangles(S, vVec, tVec))
        for(uint i = 0; i < vVec.size(); i++)
          vVec[i]->selected = true;

      Mesh::meshFromTriangles(S, vVec, tVec);
    }
  
    // Update the graph
    mesh->graph().swap(S);

    SETSTATUS("Mesh " << mesh->userId() << " created with " << mesh->graph().size() << " vertices.");
  
    mesh->setMeshType("MGXM");
    mesh->signalBounds() = Point2f(0.0, 1.0);
    mesh->updateAll();

    return true;
  }
  REGISTER_PROCESS(VoxelFaceMesh); 

  bool MeshFromLocalMaxima::run(Mesh* mesh, const Store* store, float radius)
  {
    const Stack* stack = mesh->stack();
    const HVecUS& data = store->data();
  
    vvGraph& S = mesh->graph();
    S.clear();
  
    for(uint z = 0; z < stack->size().z(); z++)
      for(uint y = 0; y < stack->size().y(); y++)
        for(uint x = 0; x < stack->size().x(); x++) {
          uint label = data[stack->offset(x, y, z)];
          if(label == 0)
            continue;
  
          Point3d pos(stack->imageToWorld(Point3i(x, y, z)));
  
          // top, bottom, left, right, near(front), far(back)
          vertex t, b, l, r, n, f;
          vertex ntl, nbl, nbr, ntr, ftl, fbl, fbr, ftr;
          S.insert(t);
          S.insert(b);
          S.insert(l);
          S.insert(r);
          S.insert(n);
          S.insert(f);
          S.insert(ntl);
          S.insert(nbl);
          S.insert(nbr);
          S.insert(ntr);
          S.insert(ftl);
          S.insert(fbl);
          S.insert(fbr);
          S.insert(ftr);
  
          t->label = b->label = l->label = r->label = n->label = f->label = label;
          ntl->label = nbl->label = nbr->label = ntr->label = label;
          ftl->label = fbl->label = fbr->label = ftr->label = label;
  
          t->pos = pos + Point3d(0.0f, radius, 0.0f);
          b->pos = pos + Point3d(0.0f, -radius, 0.0f);
          l->pos = pos + Point3d(-radius, 0.0f, 0.0f);
          r->pos = pos + Point3d(radius, 0.0f, 0.0f);
          n->pos = pos + Point3d(0.0f, 0.0f, radius);
          f->pos = pos + Point3d(0.0f, 0.0f, -radius);
  
          float radf = radius * pow(1.0 / 3.0, .5);
          ntl->pos = pos + Point3d(-radf, radf, radf);
          nbl->pos = pos + Point3d(-radf, -radf, radf);
          nbr->pos = pos + Point3d(radf, -radf, radf);
          ntr->pos = pos + Point3d(radf, radf, radf);
          ftl->pos = pos + Point3d(-radf, radf, -radf);
          fbl->pos = pos + Point3d(-radf, -radf, -radf);
          fbr->pos = pos + Point3d(radf, -radf, -radf);
          ftr->pos = pos + Point3d(radf, radf, -radf);
  
          makeNhbd(S, Vtx5(t, ntl, ntr, ftr, ftl));
          makeNhbd(S, Vtx5(b, nbr, nbl, fbl, fbr));
          makeNhbd(S, Vtx5(l, nbl, ntl, ftl, fbl));
          makeNhbd(S, Vtx5(r, ntr, nbr, fbr, ftr));
          makeNhbd(S, Vtx5(n, ntl, nbl, nbr, ntr));
          makeNhbd(S, Vtx5(f, ftr, fbr, fbl, ftl));
  
          makeNhbd(S, Vtx7(ntl, ntr, t, ftl, l, nbl, n));
          makeNhbd(S, Vtx7(nbl, nbr, n, ntl, l, fbl, b));
          makeNhbd(S, Vtx7(nbr, ntr, n, nbl, b, fbr, r));
          makeNhbd(S, Vtx7(ntr, nbr, r, ftr, t, ntl, n));
          makeNhbd(S, Vtx7(ftl, ftr, f, fbl, l, ntl, t));
          makeNhbd(S, Vtx7(fbl, ftl, f, fbr, b, nbl, l));
          makeNhbd(S, Vtx7(fbr, fbl, f, ftr, r, nbr, b));
          makeNhbd(S, Vtx7(ftr, ftl, t, ntr, r, fbr, f));
        }
  
    SETSTATUS(S.size() << " total vertices");
    mesh->updateAll();
    //mesh->setCells(false);
    return true;
  }
  REGISTER_PROCESS(MeshFromLocalMaxima);

}
