//
// This file is part of MorphoGraphX - http://www.MorphoGraphX.org
// Copyright (C) 2012-2016 Richard S. Smith and collaborators.
//
// If you use MorphoGraphX in your work, please cite:
//   http://dx.doi.org/10.7554/eLife.05864
//
// MorphoGraphX is free software, and is licensed under under the terms of the 
// GNU General (GPL) Public License version 2.0, http://www.gnu.org/licenses.
//

#ifndef SOLVER_HPP
#define SOLVER_HPP

#include <Mesh.hpp>
#include <DistMatrix.hpp>

using namespace std;

namespace mgx
{

  // Type of distributed neighborhood object
  typedef DistNhbd<cellGraph> DistNhbdC;

  typedef DVector<DistNhbdC,Matrix2d,Point2d,double> DVector2d;
  typedef DMatrix<DistNhbdC,Matrix2d,Point2d,double> DMatrix2d;

  typedef thrust::host_vector<Matrix2d> DVM2f;
  typedef std::vector<Matrix2d> HVM2f;

  enum SolvingMethod
  {
    // Forward euler method
    Euler,
    // Backward euler method
    BackwardEuler
  };


  class Solver 
  {
  public:
    cellGraph &C;
    SolvingMethod method;
    int variables;
    double eulerDt;

    // CG parms
    double cgTolerance, cgMaxIter;
    int maxNewtonSteps;
    double newtonErrTolerance;

    bool debugOutput;

    // Distributed object
    DistNhbdC* nhbd;

    struct VtxData 
    {
      Point2d x,b;
      Matrix2d j;

      bool operator==(const VtxData &other) const
      {
        if(x == other.x and b == other.b and j == other.j)
          return true;
        return false;
      }
    };

    struct EdgData
    {
      Matrix2d j;

      bool operator==(const EdgData &other) const
      {
        if(j == other.j)
          return true;
        return false;
      }
    };

    // Distributed matrix vectors
    DistVertexAttr<DistNhbdC, VtxData, Matrix2d> *vJ;
    DistEdgeAttr<DistNhbdC, EdgData, Matrix2d> *eJ;
    DistVertexAttr<DistNhbdC, VtxData, Point2d> *vX;
    DistVertexAttr<DistNhbdC, VtxData, Point2d> *vB;

    DMatrix2d *J;
    DVector2d *X;
    DVector2d *B;

    // Attr Maps
    typedef AttrMap<cell, VtxData> CellAttr2;
    CellAttr2 *cellVars;

    typedef std::pair<cell, cell> CellCellPair;

    typedef AttrMap<CellCellPair, EdgData> MyEdgeAttr2;
    MyEdgeAttr2 *cellEdges;


    Solver(Mesh* m, cellGraph& pC, SolvingMethod method, int variables, double eulerDt, double _Dx=0.001, bool debug = true)
      : C(pC), method(method), variables(variables), eulerDt(eulerDt), debugOutput(debug)
      {
        nhbd = &m->nhbdC();

        cellVars = &m->attributes().attrMap<cell, VtxData>("VtxData");
        cellEdges = &m->attributes().attrMap<CellCellPair, EdgData>("EdgData");
      }

    // Parameters
    enum ParmNames { pSolvingMethod, pDt, pCGTolerance, pCGMaxIter, pMaxNewtonSteps, pErrTolerance, pDebug, pNumParms }; 

    void processParms(const QStringList &parms);

    QStringList parmNames() const 
    {
      QVector <QString> vec(pNumParms);

      vec[pSolvingMethod] = "SolvingMethod";
      vec[pDt] = "Dt";
      vec[pDebug] = "Debug";

      return vec.toList();
    }

    QStringList parmDescs() const 
    {
      QVector <QString> vec(pNumParms);

      vec[pSolvingMethod] = "SolvingMethod";
      vec[pDt] = "Dt";
      vec[pDebug] = "Debug";

      return vec.toList();
    }

    QStringList parmDefaults() const 
    {
      QVector <QString> vec(pNumParms);

      vec[pSolvingMethod] = "Euler";
      vec[pDt] = "0.005";
      vec[pDebug] = "Yes";

      return vec.toList();
    }


    // set values and derivatives
    void setInput(std::vector<double>& values, std::vector<double>& derivatives);

    // solve according to parms and values
    void solve();

   virtual std::vector<double> getValues(const cell &c) {std::vector<double> d; return d;}
   virtual void setValues(const cell& c, std::vector<double> values) {}
   virtual void updateDerivatives(const cell& c) {}
   virtual std::vector<double> getDerivatives(const cell& c) {std::vector<double> d; return d;}

  private:

    // init for backward euler
    void initialize();

    // create jacobian for backward euler
    void buildJacobian(const cell& c);

    // solver forward euler
    void solveEuler();

    // solver backward euler
    void solveBackwardEuler();

  };
// Read/write Edge data
  bool inline readAttr(Solver::VtxData &m, const QByteArray &ba, size_t &pos) 
  {
    return readChar((char *)&m, sizeof(Solver::VtxData), ba, pos);
  }
  bool inline writeAttr(const Solver::VtxData &m, QByteArray &ba) 
  {
    return writeChar((char *)&m, sizeof(Solver::VtxData), ba);
  }

  // Read/write Edge data
  bool inline readAttr(Solver::EdgData &m, const QByteArray &ba, size_t &pos) 
  {
    return readChar((char *)&m, sizeof(Solver::EdgData), ba, pos);
  }
  bool inline writeAttr(const Solver::EdgData &m, QByteArray &ba) 
  {
    return writeChar((char *)&m, sizeof(Solver::EdgData), ba);
  }
}

#endif

