//
// This file is part of MorphoGraphX - https://www.MorphoGraphX.org  (@RichardSmithLab)
//
// MorphoGraphX development is led by the Richard S. Smith lab at the John Innes Centre, Norwich, UK
//
// If you use MorphoGraphX in your work, please cite:
//   https://doi.org/10.7554/eLife.72601
//
// For support please see the image.sc forum:
//   https://forum.image.sc/tag/MorphoGraphX
//
// MorphoGraphX is copyright by its authors, contributors, and/or their employers.
//
// MorphoGraphX is free software, and is licensed under the terms of the 
// GNU General Public License https://www.gnu.org/licenses/.
//
#ifndef IMPLICIT_DEFORMATION_HPP
#define IMPLICIT_DEFORMATION_HPP

#include "ImplicitDeformationConfig.hpp"

#include <Geometry.hpp>
#include <Process.hpp>
#include "Mesh.hpp"

#include "Triangulate.hpp"

using namespace std;

namespace mgx 
{


class ImplicitDeformation
{

  public:

  // needed for Attr Maps
  bool operator==(const ImplicitDeformation &other) const
  {
    if(Mat == other.Mat and bvec == other.bvec and xvec == other.xvec and
      data == other.data and values == other.values and rb_coeffs == other.rb_coeffs and poly_coeffs == other.poly_coeffs)
      return true;
    return false;
  }

  //Initial position
  std::vector<Point3d> data;
  //Final position
  std::vector<Point3d> values;
  //Radial basis function coefficients (3 for each point corresponding to the x, y and z deformation)
  std::vector<std::vector<double> > rb_coeffs;
  std::vector<std::vector<double> > poly_coeffs;

  // needed for solving
  std::vector<std::vector<double> > Mat;
  std::vector<double> bvec;
  std::vector<double> xvec;

  bool singularMatrix;

ImplicitDeformation(){
  //coeffs.initCoeffs();
  singularMatrix = false;
}

void initialize(std::vector<Point3d>& data, std::vector<Point3d>& values);

void clear();

void createCoeffs();

void solve(int cgStep);

void solveInv();

void load(std::vector<Point3d>& data, std::vector<Point3d>& values, std::vector<std::vector<double> >& rb_coeffs, std::vector<std::vector<double> >& poly_coeffs){
  this->data = data;
  this->values = values;
  this->rb_coeffs = rb_coeffs;
  this->poly_coeffs = poly_coeffs;
}

  //Initialize to identity
  void initCoeffs(){

    rb_coeffs.resize(3);
    poly_coeffs.resize(3);

    for(int i=0;i<3;i++){
      rb_coeffs[i].resize(data.size());
      for(uint j=0;j<data.size();j++){
        rb_coeffs[i][j] = 0;
      }

      poly_coeffs[i].resize(4);
      poly_coeffs[i][0] = 0;
      for(int j=1;j<4;j++)
        if(j == i+1)
          poly_coeffs[i][j] = 1;
        else
          poly_coeffs[i][j] = 0;
    }

  }


//Radial basis function used in the deformation
//Form of the appropriate RBF depends on dimension
//#define RBF_power 3 //Works best when set to the dimension of the data, but other powers can be tried
double RBF(Point3d Centre, Point3d X, int RBF_power = 3)
{

    double Rad = sqrt((X-Centre)*(X-Centre));
    
    if(RBF_power%2 != 0)
      return pow(Rad,RBF_power);
    else
      return Rad>1 ? pow(Rad,RBF_power)*log(Rad) : pow(Rad-1,RBF_power)*log(pow(Rad,Rad));
}

//Only for odd powers
Point3d RBFgrad(Point3d Centre, Point3d X, int RBF_power = 3){
	double Rad = sqrt((X-Centre)*(X-Centre));
	if(RBF_power%2 != 0){
		Point3d analytic = RBF_power*pow(Rad,RBF_power-2)*(X-Centre);
		return analytic;
	}
	return Point3d(1e20,1e20,1e20);
}

//Evaluates one-dimension of the deformation
double RBFcomponent(std::vector<Point3d> &data,std::vector<double> &rb_coeffs,std::vector<double> &poly_coeffs, Point3d pos){
    if(data.size()!=rb_coeffs.size() || poly_coeffs.size()!=4)
        return 1e20;
    double val = 0;
    for(uint i=0;i<data.size();i++)
        val+= rb_coeffs[i]*RBF(data[i],pos);

    val+=poly_coeffs[0];
    for(uint i=1;i<poly_coeffs.size();i++)
        val+=poly_coeffs[i]*pos[i-1];     

    return val;
}

Point3d GradientComponent(std::vector<Point3d> &data, std::vector<double> &rb_coeffs, std::vector<double> &poly_coeffs, Point3d pos){
   Point3d grad(0,0,0);
   for(uint i=0;i<data.size();i++){
     grad += rb_coeffs[i]*RBFgrad(data[i],pos);
   }
//   for(int i=1;i<poly_coeffs.size();i++)
   grad.x() += poly_coeffs[1];
   grad.y() += poly_coeffs[2];
   grad.z() += poly_coeffs[3];
   return grad;
}


 // Jacobian = deformation gradient
Matrix3d DefJacobian2(Point3d pos){
  Matrix3d Jacobian;
  //Jacobian.resize(3);
  //Jacobian[0].resize(3);
  //Jacobian[1].resize(3);
  //Jacobian[2].resize(3);

  Point3d grad_x = GradientComponent(data,rb_coeffs[0],poly_coeffs[0],pos);
  Point3d grad_y = GradientComponent(data,rb_coeffs[1],poly_coeffs[1],pos);
  Point3d grad_z = GradientComponent(data,rb_coeffs[2],poly_coeffs[2],pos);
  Jacobian[0][0] = grad_x.x();
  Jacobian[0][1] = grad_x.y();
  Jacobian[0][2] = grad_x.z();

  Jacobian[1][0] = grad_y.x();
  Jacobian[1][1] = grad_y.y();
  Jacobian[1][2] = grad_y.z();

  Jacobian[2][0] = grad_z.x();
  Jacobian[2][1] = grad_z.y();
  Jacobian[2][2] = grad_z.z();

  return Jacobian;
}


 // Jacobian = deformation gradient
std::vector<std::vector<double> > DefJacobian(Point3d pos){
	std::vector<std::vector<double> > Jacobian;
	Jacobian.resize(3);
	Jacobian[0].resize(3);
	Jacobian[1].resize(3);
	Jacobian[2].resize(3);

	Point3d grad_x = GradientComponent(data,rb_coeffs[0],poly_coeffs[0],pos);
	Point3d grad_y = GradientComponent(data,rb_coeffs[1],poly_coeffs[1],pos);
	Point3d grad_z = GradientComponent(data,rb_coeffs[2],poly_coeffs[2],pos);
	Jacobian[0][0] = grad_x.x();
	Jacobian[0][1] = grad_x.y();
	Jacobian[0][2] = grad_x.z();

	Jacobian[1][0] = grad_y.x();
	Jacobian[1][1] = grad_y.y();
	Jacobian[1][2] = grad_y.z();

	Jacobian[2][0] = grad_z.x();
	Jacobian[2][1] = grad_z.y();
	Jacobian[2][2] = grad_z.z();

	return Jacobian;
}

std::vector<std::vector<double> > GrowthTensor(Point3d pos){
	std::vector<std::vector<double> > GT;
	GT = DefJacobian(pos);
	GT[0][0]-=1.0;
	GT[1][1]-=1.0;
	GT[2][2]-=1.0;
	return GT;
}

Matrix3d calcGrowthTensor(Point3d pos){
  Matrix3d tensor;

  Matrix3d J = DefJacobian2(pos);
  Matrix3d JT = transpose(J);
  
  //tensor = 0.5 * (JT + J);
  tensor = JT * J;

  //tensor[0] = tensor[0] - Point3d(1,0,0);
  //tensor[1] = tensor[1] - Point3d(0,1,0);
  //tensor[2] = tensor[2] - Point3d(0,0,1);
  

  return tensor;
}

Matrix3d calcProjectedGrowthTensorOnPos(Point3d pos, Point3d planeNrml){

  Matrix3d rotMat;
  Matrix3d tensor = calcGrowthTensor(pos);

  return calcProjectedGrowthTensor(tensor, planeNrml);
}


double JacDet(Point3d pos){
	std::vector<std::vector<double> > J = DefJacobian(pos);
	return J[0][0]*J[1][1]*J[2][2] + J[0][1]*J[1][2]*J[2][0] + J[0][2]*J[1][0]*J[2][1]
	      -J[0][0]*J[1][2]*J[2][1] - J[0][1]*J[1][0]*J[2][2] - J[0][2]*J[1][1]*J[2][0];
}


//Evaluates the 3D deformation
Point3d RBFdeformation(Point3d pos){
    Point3d def;
    def.x() = RBFcomponent(data,rb_coeffs[0],poly_coeffs[0],pos);
    def.y() = RBFcomponent(data,rb_coeffs[1],poly_coeffs[1],pos);
    def.z() = RBFcomponent(data,rb_coeffs[2],poly_coeffs[2],pos);
    return def;
}

//Following functions create A, x and b for the system Ax=b to be solved
void CreateMatrix()
{
  std::vector<std::vector<double> > Mat;

  int dSize = data.size();
  int mSize = 3*dSize+3*4;

  Mat.resize(mSize);
  for(int i=0;i<mSize;i++){
    Mat[i].resize(mSize);
  }
  for(int i=0;i<mSize;i++){
    for(int j=i;j<mSize;j++){
      if(j<3*dSize){
        if(i%3 == j%3){
          Mat[i][j] = RBF(data[j/3],data[i/3]);
          Mat[j][i] = RBF(data[i/3],data[j/3]);;
        }
      }
      else if(i<3*dSize && (i%3 == j%3)){
        //int sample_ind_j = (j-3*dSize);//j/3;
        //int sample_ind_i = (i-3*dSize);

        if(j/3 == (int)(dSize) or i/3==(int)(dSize))
          Mat[i][j] = Mat[j][i] = 1.0;
        else if(j/3 == (int)(dSize+1) or i/3==(int)(dSize+1))
          Mat[i][j] = Mat[j][i] = data[i/3][0];
        else if(j/3 == (int)(dSize+2) or i/3==(int)(dSize+2))
          Mat[i][j] = Mat[j][i] = data[i/3][1];
        else
          Mat[i][j] = Mat[j][i] = data[i/3][2];
      }
    }
  }

  this->Mat = Mat;
}

void CreateBVector()
{
  std::vector<double> bvec;

  uint dSize = data.size();
  uint mSize = 3*dSize+3*4;

  bvec.resize(mSize);
  for(uint i=0;i<bvec.size();i++)
    if(i<3*dSize)
       bvec[i] = values[i/3][i%3];
    else
       bvec[i] = 0;

  this->bvec = bvec;
}

void CreatePVector(uint dSize)
{
  std::vector<double> lvec;
  int mSize = 3*dSize+3*4;
  lvec.resize(mSize);
  for(uint i=0;i<lvec.size();i++)
    if(i<3*dSize)
      lvec[i]=0.0;
    else if(i<(3*dSize+3))
      lvec[i]=0.0;
    lvec[lvec.size()-1] = 0.1;
    lvec[lvec.size()-5] = 0.1;
    lvec[lvec.size()-9] = 0.1;

  xvec = lvec;
}




//Functions to perform the CG solve (basic matrix vector ops + CG solve)
std::vector<double> MatVecMult(std::vector<std::vector<double> > &M, std::vector<double> &v){

  std::vector<double> result;
  int mSize = v.size();
  //std::cout << "matvec" << v.size() << "/" << M.size() << "//" << M[0].size() << std::endl;
  result.resize(mSize);

    for(int i=0;i<mSize;i++){
      result[i] =0;
      for(int j=0;j<mSize;j++)
        result[i]+=M[i][j]*v[j];  
    }
  return result;
}

double Dot(std::vector<double> &v1,std::vector<double> &v2){
  double accum = 0;
  for(uint i=0;i<v1.size();i++)
    accum+=v1[i]*v2[i];
  return accum;
}

std::vector<double> BinOp(std::vector<double> &v1,std::vector<double> &v2, double scalar){
  vector<double> result;
  result.resize(v1.size());
  for(uint i=0; i<v1.size(); i++)
    result[i]=v1[i]+scalar*v2[i];
  return result;
}

//CG solve
  std::vector<double> CGsolve(std::vector<double> &b,std::vector<std::vector<double> > &A, std::vector<double> x){

	  std::vector<double> Ax = MatVecMult(A,x);
	  std::vector<double> r = BinOp(b,Ax,-1);
	  std::vector<double> p = r;
	  std::vector<double> tx = x;
	  double rsold = Dot(r,r);
	  int mSize = b.size();
	  double minres = 1e20;

	  for(int i=0;i<mSize;i++){
	    std::vector<double> Ap = MatVecMult(A,p);
	    double alpha = rsold/ Dot(p,Ap);
	    x = BinOp(x,p,alpha);
	    r = BinOp(r,Ap,-1*alpha);
	    double rsnew = Dot(r,r);

		if(rsnew<minres){
			minres = rsnew;
			tx = x;
			std::cout<<rsnew<<"!"<<std::endl;
		} //Early exit - works best if just continue iterating

	    //if(rsnew<1e-60) //Early exit - works best if just continue iterating
	    //  break; // might be better to break here?
	    p = BinOp(r,p,rsnew/rsold);
	    rsold = rsnew;
	  }
	  return tx;
	}


};

    // Read/write attr map
    bool inline readAttr(ImplicitDeformation &data, const QByteArray &ba, size_t &pos) 
    {

      std::cout << "r 0 " << std::endl;
  
      data.data.clear();
      data.values.clear();
      data.rb_coeffs.clear();
      data.rb_coeffs.clear();

      uint sz, sz2;
  
      readAttr(sz, ba, pos);

      if(sz==0) return true;

      std::cout << "r 1 " << sz << std::endl;
      for(uint i = 0; i<sz; i++){
        Point3d p(0,0,0);
        readAttr(p, ba, pos);
        data.data.push_back(p);
      }

      readAttr(sz, ba, pos);
      std::cout << "r 2 " << sz << std::endl;
      for(uint i = 0; i<sz; i++){
        Point3d p(0,0,0);
        readAttr(p, ba, pos);
        data.values.push_back(p);
      }

      readAttr(sz, ba, pos);
      readAttr(sz2, ba, pos);

      std::cout << "r 3 " << sz << "/" << sz2 << std::endl;

      data.rb_coeffs.resize(sz);

      for(uint i = 0; i < sz; i++){
        for(uint j = 0; j < sz2; j++){
          double p = 0;
          readAttr(p, ba, pos);
          data.rb_coeffs[i].push_back(p);
        }
      }

      readAttr(sz, ba, pos);
      readAttr(sz2, ba, pos);

      std::cout << "r 4 " << sz << "/" << sz2 << std::endl;
      data.poly_coeffs.resize(sz);

      for(uint i = 0; i < sz; i++){
        for(uint j = 0; j < sz2; j++){
          double p = 0;
          readAttr(p, ba, pos);
          data.poly_coeffs[i].push_back(p);
        }
      }

      return true;
    }
    bool inline writeAttr(const ImplicitDeformation &data, QByteArray &ba) 
    {

      std::cout << "bbb0 " << std::endl;

      uint sz, sz2;

      if(data.data.empty()){
        sz = 0;
        writeChar((char *)&sz, sizeof(uint), ba);
        return true;
      } 

      sz = data.data.size();
      writeChar((char *)&sz, sizeof(uint), ba);
      for(uint i = 0; i < sz; i++){
        writeAttr(data.data[i], ba);
      }

      sz = data.values.size();
      writeChar((char *)&sz, sizeof(uint), ba);
      for(uint i = 0; i < sz; i++){
        writeAttr(data.values[i], ba);
      }

      std::cout << "bbb " << sz << std::endl;

      sz = data.rb_coeffs.size();
      sz2 = data.rb_coeffs[0].size();
      writeChar((char *)&sz, sizeof(uint), ba);
      writeChar((char *)&sz2, sizeof(uint), ba);
      for(uint i = 0; i < sz; i++){
        for(uint j = 0; j < sz2; j++){
          writeAttr(data.rb_coeffs[i][j], ba);
        }
      }

      std::cout << "bbb2 " << sz << "/" << sz2 << std::endl;

      sz = data.poly_coeffs.size();
      sz2 = data.poly_coeffs[0].size();
      writeChar((char *)&sz, sizeof(uint), ba);
      writeChar((char *)&sz2, sizeof(uint), ba);
      for(uint i = 0; i < sz; i++){
        for(uint j = 0; j < sz2; j++){
          writeAttr(data.poly_coeffs[i][j], ba);
        }
      }

      std::cout << "bbb3 " << sz << "/" << sz2 << std::endl;

      return true;
    }  

}

#endif
