//
// This file is part of MorphoGraphX - https://www.MorphoGraphX.org  (@RichardSmithLab)
//
// MorphoGraphX development is led by the Richard S. Smith lab at the John Innes Centre, Norwich, UK
//
// If you use MorphoGraphX in your work, please cite:
//   https://doi.org/10.7554/eLife.72601
//
// For support please see the image.sc forum:
//   https://forum.image.sc/tag/MorphoGraphX
//
// MorphoGraphX is copyright by its authors, contributors, and/or their employers.
//
// MorphoGraphX is free software, and is licensed under the terms of the 
// GNU General Public License https://www.gnu.org/licenses/.
//
#ifndef SYMMETRIC_TENSOR_HPP
#define SYMMETRIC_TENSOR_HPP
//#include <Geometry.hpp>
#include <Vector.hpp>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_vector.h>
#include <gsl/gsl_eigen.h>


namespace mgx 
{
  // RSS Currently this is only floats, it should be templated
  typedef Vector<3, float> Point3f;
  typedef Vector<3, double> Point3d;
  typedef Matrix<3, 3, double> Matrix3d;
  //AR: This isn't the best place for these functions, they should be moved to their own hpp/cpp later
  inline void calcEigenVecsValues(Matrix3d& M, Point3d& p1, Point3d& p2, Point3d& p3, Point3d& ev)
  {

    gsl_matrix* mat = gsl_matrix_alloc(3, 3);
    gsl_vector* eval = gsl_vector_alloc(3);
    gsl_matrix* evec = gsl_matrix_alloc(3, 3);
    gsl_eigen_symmv_workspace* w = gsl_eigen_symmv_alloc(3);

    
    for(int i = 0; i < 3; ++i)
      for(int j = 0; j < 3; ++j)
        gsl_matrix_set(mat, i, j, M[i][j]);

    gsl_eigen_symmv(mat, eval, evec, w);
    gsl_eigen_symmv_sort(eval, evec, GSL_EIGEN_SORT_VAL_DESC);

    p1 = Point3d(gsl_matrix_get(evec, 0, 0), gsl_matrix_get(evec, 1, 0), gsl_matrix_get(evec, 2, 0));
    p2 = Point3d(gsl_matrix_get(evec, 0, 1), gsl_matrix_get(evec, 1, 1), gsl_matrix_get(evec, 2, 1));
    p3 = Point3d(gsl_matrix_get(evec, 0, 2), gsl_matrix_get(evec, 1, 2), gsl_matrix_get(evec, 2, 2));
    ev = Point3d(gsl_vector_get(eval, 0), gsl_vector_get(eval, 1), gsl_vector_get(eval, 2));

  }

  inline Matrix3d OuterProduct(Point3d p1,Point3d p2)
  {
    Matrix3d mat;
    for(int i=0;i<3;i++)
        for(int j=0;j<3;j++){
            mat[i][j] = p1[i]*p2[j];
        }
    return mat;
  }


  class SymmetricTensor 
  {
  public:
    SymmetricTensor() : _ev1(0, 0, 0), _ev2(0, 0, 0), _evals(0, 0, 0) {}
  
    SymmetricTensor(const Point3f &ev1, const Point3f &ev2, const Point3f &evals)
      : _ev1(ev1), _ev2(ev2), _evals(evals) {}
  
    SymmetricTensor(const SymmetricTensor &copy)
      : _ev1(copy._ev1), _ev2(copy._ev2), _evals(copy._evals) {}
  
    SymmetricTensor &operator=(const SymmetricTensor &other)
    {
      _ev1 = other._ev1;
      _ev2 = other._ev2;
      _evals = other._evals;
      return *this;
    }

    bool operator==(const SymmetricTensor &other) const
    {
      if(_ev1 == other._ev1 and _ev2 == other._ev2 and _evals == other._evals)
        return true;
      return false;
    }

    bool operator!=(const SymmetricTensor &other) const 
    {
      return (!((*this) == other));
    }

  
    Point3f &ev1() { return _ev1; }
    const Point3f &ev1() const { return _ev1; }

    Point3f &ev2() { return _ev2; }
    const Point3f &ev2() const { return _ev2; }

    Point3f ev3() const { return _ev1 ^ _ev2; }

    const Point3f &evals() const { return _evals; }

    Point3f &evals() { return _evals; }

    Matrix3d toMatrix() 
    {
      Matrix3d mat;
      Matrix3d e1 = _evals[0]*OuterProduct(Point3d(_ev1),Point3d(_ev1));
      Matrix3d e2 = _evals[1]*OuterProduct(Point3d(_ev2),Point3d(_ev2));
      Matrix3d e3 = _evals[2]*OuterProduct(Point3d(_ev1 ^ _ev2),Point3d(_ev1 ^ _ev2));
      mat = e1+e2+e3;

      return mat;
    }   

    bool fromMatrix(Matrix3d mat) 
    {
      Point3d p1,p2,p3,ev;

      calcEigenVecsValues(mat, p1, p2, p3, ev);
      _ev1 = Point3f(p1.x(), p1.y(), p1.z());
      _ev2 = Point3f(p2.x(), p2.y(), p2.z());
      _evals = Point3f(ev.x(), ev.y(), ev.z());

      //Should be some error checking here
      return true;
    }

  protected:
    Point3f _ev1, _ev2;
    Point3f _evals;
  };
}
#endif
