//
// This file is part of MorphoGraphX - https://www.MorphoGraphX.org  (@RichardSmithLab)
//
// MorphoGraphX development is led by the Richard S. Smith lab at the John Innes Centre, Norwich, UK
//
// If you use MorphoGraphX in your work, please cite:
//   https://doi.org/10.7554/eLife.72601
//
// For support please see the image.sc forum:
//   https://forum.image.sc/tag/MorphoGraphX
//
// MorphoGraphX is copyright by its authors, contributors, and/or their employers.
//
// MorphoGraphX is free software, and is licensed under the terms of the 
// GNU General Public License https://www.gnu.org/licenses/.
//
#ifndef DIST_OBJECT_HPP
#define DIST_OBJECT_HPP

/**
 * Distributed object library.
 *
 * This class allows data that is stored in the properties of a VVGraph to be 
 * transferred to/from the GPU. The data can either be in the vertex data itself,
 * or stored in the attributes keyed by vertices and edges. 
 */

#include <QString>
#include <Attributes.hpp>
#include <ThrustTypes.hpp>

namespace mgx
{
  template<typename T1, typename T2> 
  int copyGPU(T1 *src, T2 *dst);

  template<typename T> 
  int allocGPU(T **vec, size_t n);

  /**
   * \class DistNhbdT DistObject.hpp <DistObject.hpp>
   *
   * Distributed neighborhood object. This class mirrors the graph structure onto the GPU.
   */
  template <typename vvGraphT>
  class DistNhbd
  {
  public:
    // Type of the graph
    typedef vvGraphT vvGraph;
    // Type of a vertex
    typedef typename vvGraphT::vertex_t Vertex;
    // Type of vertex content
    typedef typename Vertex::content_t VertexContent;
    // Type of an edge
    typedef typename vvGraphT::edge_t Edge;
    // Type of vertex content
    typedef typename Edge::content_t EdgeContent;
  
  private:
    uint _n;
    uint _nbs;
    vvGraphT &_S;
    thrust::device_vector<uint> *_data;      // Neighborhood information
    std::unordered_map<Vertex, uint> _vNum;
  
  public:
    // Set offset of data item in vertex
    DistNhbd(vvGraphT &S) : _n(S.size()), _S(S), _data(0)  {}
  
    // Clean up
    virtual ~DistNhbd()  { allocGPU(&_data, 0); }
  
    // Return # of vertices
    uint n() { return _n; };
  
    // Return max # of neighbors
    uint nbs() { return _nbs; };
  
    // Return pointer to device buffer
    thrust::device_vector<uint> *data() 
    { 
      return _data; 
    };
  
    // Return graph
    vvGraphT &graph() { return _S; };
  
    // Return vertex number from our ordering map
    uint vNum(const Vertex &v) { return _vNum[v]; };
   
    // Allocate and send to GPU
    void write()
    {
      // Record vertex numbers and find max # of neighbors
      _vNum.clear();
      _n = _nbs = 0;
      forall(const Vertex &v, _S)  {
        _vNum[v] = _n++;
        if(_S.valence(v) > _nbs)
          _nbs = _S.valence(v);
      }
      if(_n * _nbs == 0)
        return;
      thrust::host_vector<uint> hData(_n * _nbs);
      uint d = 0;
      // Grab the neighborhood information 
      forall(const Vertex &v, _S) {
        // Put in edge data, 
        forall(const Vertex &w, _S.neighbors(v))
          hData[d++] = _vNum[w];
    
        // Use current index to indicate empty
        for(uint k = _S.valence(v); k < _nbs; k++)
          hData[d++] = _vNum[v];
      }
      // Write to device
      try {
        allocGPU(&_data, _n * _nbs) ;
      } catch (const std::exception &e) {
        throw(QString("Unable to allocate distributed neighborhood object of sizes %1 %2 %3, err:%4")
          .arg(_n).arg(_nbs).arg(sizeof(uint)).arg(e.what()));
      } catch(...) {
        throw(QString("Unable to allocate distributed neighborhood object of sizes %1 %2 %3")
          .arg(_n).arg(_nbs).arg(sizeof(uint)));
      }
      copyGPU(&hData, _data);
    }
  };
  
  /**
   * \class DistVertex DistObject.hpp <DistObject.hpp>
   *
   * Distributed vertex object. This class mirrors vertex data onto the GPU. Use this class directly
   * if there is no host side data, otherwise use one of the derived classes that associates data
   * to fields within the vertices themselves or fields in the attributes system.
   */
  template <typename DistNhbdT, typename DataT>
  class DistVertex
  {
    DistNhbdT &_nhbd; // Distributed neighborhood object
    thrust::device_vector<DataT> *_data;     // Distributed data
    bool _hasHostData;// Is there host-side storage? or is it GPU only?
  
  protected:
    // Constructor for object with host data, only from derived classes
    DistVertex(DistNhbdT &nhbd, bool hasData) : _nhbd(nhbd), _data(0), _hasHostData(hasData)  {}
  
  public:
    // Constructor for object with no host data (GPU only)
    DistVertex(DistNhbdT &nhbd) : _nhbd(nhbd), _data(0), _hasHostData(false)  {}
  
    // Clean up
    virtual ~DistVertex()  { allocGPU(&_data, 0); }
  
    // Get a data entry
    virtual DataT *vData(const typename DistNhbdT::Vertex &) { return 0; }
  
    // Return neighborhood object
    DistNhbdT &nhbd() { return _nhbd; };
  
    // Return pointer to device buffer
    thrust::device_vector<DataT> *data() 
    { 
      alloc();
      return _data;
    };

    // Allocate space on GPU
    void alloc() 
    { 
      try {
        allocGPU(&_data, _nhbd.n()); 
      } catch (const std::exception &e) {
        throw(QString("Unable to allocate distributed vertex object of sizes %1 %2, err:%3")
          .arg(_nhbd.n()).arg(sizeof(DataT)).arg(e.what()));
      } catch(...) {
        throw(QString("Unable to allocate distributed vertex object of sizes %1 %2")
          .arg(_nhbd.n()).arg(sizeof(DataT)));
      }
    }
  
    // Copy vertex data to GPU
    void write()
    {
      if(!_hasHostData)
        throw(QString("Error: Distributed vector has no host storage"));
      if(_nhbd.n() <= 0)
        return;
  
      thrust::host_vector<DataT> hData(_nhbd.n());
      forall(const typename DistNhbdT::Vertex &v, _nhbd.graph())
        hData[_nhbd.vNum(v)] = *vData(v);

      // Write to device
      alloc();
      copyGPU(&hData, _data);
    }
  
    // Copy vertex data to Host
    void read()
    {
      if(!_hasHostData)
        throw(QString("Error: Distributed vector has no host storage"));
      if(_nhbd.n() <= 0)
        return;
      if(_data->size() != _nhbd.n())
        throw(QString("Error: Distributed vector size mismatch on device buffer"));
  
      // Read data from device
      thrust::host_vector<DataT> hData(_nhbd.n());
      copyGPU(_data, &hData);
      forall(const typename DistNhbdT::Vertex &v, _nhbd.graph())
        *vData(v) = hData[_nhbd.vNum(v)];
    }
  };
  
  /**
   * \class DistVertexGraph DistObject.hpp <DistObject.hpp>
   *
   * Derived class from distributed vertex with storage in the vertices of the graph.
   */
  template <typename DistNhbdT, typename DataT>
  class DistVertexGraph : public DistVertex<DistNhbdT, DataT>
  {
    DataT DistNhbdT::VertexContent::*_off;
  
  public:
    // Constructor for object with memory in vertices of the graph
    DistVertexGraph(DistNhbdT &nhbd, DataT DistNhbdT::VertexContent::*off) 
      : DistVertex<DistNhbdT, DataT>(nhbd, true), _off(off) {}
  
    // Get data 
    DataT *vData(const typename DistNhbdT::Vertex &v)
    {
      return &((*v).*_off);
    }
  };
  
  /**
   * \class DistVertexAttr DistObject.hpp <DistObject.hpp>
   *
   * Derived class from distributed vertex with storage in an attribute map keyed by vertex.
   */
  template <typename DistNhbdT, typename ContentT, typename DataT>
  class DistVertexAttr : public DistVertex<DistNhbdT, DataT>
  {
  private:
    AttrMap<typename DistNhbdT::Vertex, ContentT> *_map;
    DataT ContentT::*_off;
  
  public:
    // Constructor for object with memory in vertices of the graph
    DistVertexAttr(DistNhbdT &nhbd, AttrMap<typename DistNhbdT::Vertex, ContentT> *map, 
                                                                         DataT ContentT::*off) 
      : DistVertex<DistNhbdT, DataT>(nhbd, true), _map(map), _off(off) {}
  
    // Get data 
    DataT *vData(const typename DistNhbdT::Vertex &v)
    {
      return &((v->*_map).*_off);
    }

    // Set (or reset) attribute map. I may change when mesh is reset.
    void setAttrMap(AttrMap<typename DistNhbdT::Vertex, ContentT> *map)
    {
      _map = map;
    }
  };
  
  /**
   * \class DistEdge DistObject.hpp <DistObject.hpp>
   *
   * Distributed edge object. This class mirrors edge data onto the GPU. Use this class directly
   * if there is no host side data, otherwise use one of the derived classes that associates data
   * to fields within the edges themselves or fields in the attributes system.
   */
  template <typename DistNhbdT, typename DataT>
  class DistEdge 
  {
  public:
    DistNhbdT &_nhbd; // Distributed neighborhood object
    bool _hasHostData;// Is there host-side storage? (or is it GPU only)
    thrust::device_vector<DataT> *_data;     // Distributed data

  public:
    // Constructor for object with no host data (only GPU)
    DistEdge(DistNhbdT &nhbd) : _nhbd(nhbd), _hasHostData(false), _data(0) {}

  protected:
    // Constructor for object with host data, only from derived classes
    DistEdge(DistNhbdT &nhbd, bool hasData) : _nhbd(nhbd), _hasHostData(hasData), _data(0) {}
  
  public:
    // Clean up
    virtual ~DistEdge()  { allocGPU(&_data, 0); }

    // Return neighborhood object
    DistNhbdT &nhbd() { return _nhbd; }; 

    // Get edge data
    virtual DataT *eData(const typename DistNhbdT::Vertex &v1, const typename DistNhbdT::Vertex &v2) { return 0; }

    // Return pointer to device buffer
    thrust::device_vector<DataT> *data() 
    { 
      alloc();
      return _data;
    };

    // Allocate space on GPU
    void alloc() 
    { 
      try {
        allocGPU(&_data, _nhbd.n() * _nhbd.nbs());
      } catch (const std::exception &e) {
        throw(QString("Unable to allocate distributed neighborhood object of sizes %1 %2 %3, err:%4")
          .arg(_nhbd.n()).arg(_nhbd.nbs()).arg(sizeof(DataT)).arg(e.what()));
      } catch(...) {
        throw(QString("Unable to allocate distributed edge object of sizes %1 %2 %3")
          .arg(_nhbd.n()).arg(_nhbd.nbs()).arg(sizeof(DataT)));
      }
    }

    // Send edge data to GPU
    void write()
    {
      if(!_hasHostData)
        throw(QString("Error: DistEdge is temp (no graph storage)"));
      // No data, just return
      if(_nhbd.n() * _nhbd.nbs() <= 0)
        return;
  
      thrust::host_vector<DataT> hData(_nhbd.n() * _nhbd.nbs());
      // Grab data
      uint d = 0;
      forall(const typename DistNhbdT::Vertex &v, _nhbd.graph()) {
        forall(const typename DistNhbdT::Vertex &w, _nhbd.graph().neighbors(v))
          hData[d++] = *eData(v,w);

        // Pad rest of nhbd with 0s
        for(uint k = _nhbd.graph().valence(v); k < _nhbd.nbs(); k++)
          hData[d++] = DataT();
      }
      // Write to GPU
      alloc();
      copyGPU(&hData, _data);
    }
    
    // Copy edge data to host
    void read()
    {
      if(!_hasHostData)
        throw(QString("Warning DistEdge is temp (no graph storage)"));
      // Return if no data
      if(_nhbd.n() * _nhbd.nbs() <= 0)
        return;
  
      // Read from GPU
      thrust::host_vector<DataT> hData(_nhbd.n() * _nhbd.nbs());
      copyGPU(_data, &hData);
  
      // Grab data
      uint s = 0;
      forall(const typename DistNhbdT::Vertex &v, _nhbd.graph()) {
        forall(const typename DistNhbdT::Vertex &w, _nhbd.graph().neighbors(v))
          *eData(v,w) = hData[s++];

        // Skip 0s
        for(uint k = _nhbd.graph().valence(v); k < _nhbd.nbs(); k++)
          s++;
      }
    }
  };

  /**
   * \class DistEdgeGraph DistObject.hpp <DistObject.hpp>
   *
   * Derived class from distributed edge with storage in the edges of the graph.
   */
  template <typename DistNhbdT, typename DataT>
  class DistEdgeGraph : public DistEdge<DistNhbdT, DataT>
  {
    DataT DistNhbdT::EdgeContent::*_off;
  
  public:
    // Constructor for object with memory in edges of the graph
    DistEdgeGraph(DistNhbdT &nhbd, DataT DistNhbdT::EdgeContent::*off) 
      : DistEdge<DistNhbdT, DataT>(nhbd, true), _off(off) {}

    // Get data 
    DataT *eData(const typename DistNhbdT::Vertex &v1, const typename DistNhbdT::Vertex &v2) 
    {
      return &((*DistEdge<DistNhbdT, DataT>::nhbd().graph().edge(v1, v2)).*_off);
    }
  };
  
  /**
   * \class DistEdgeAttr DistObject.hpp <DistObject.hpp>
   *
   * Derived class from distributed edge with storage in an attribute map keyed by a pair of edges
   */
  template <typename DistNhbdT, typename ContentT, typename DataT>
  class DistEdgeAttr : public DistEdge<DistNhbdT, DataT>
  {
  public:
    typedef std::pair<typename DistNhbdT::Vertex, typename DistNhbdT::Vertex> EdgePair;
    AttrMap<EdgePair, ContentT> *_map;
    DataT ContentT::*_off;
  
  public:
    // Constructor for object with memory in the attributes system
    DistEdgeAttr(DistNhbdT &nhbd, AttrMap<EdgePair, ContentT> *map, DataT ContentT::*off) 
      : DistEdge<DistNhbdT, DataT>(nhbd, true), _map(map), _off(off) {}

    // Get data 
    DataT *eData(const typename DistNhbdT::Vertex &v1, const typename DistNhbdT::Vertex &v2) 
    {
      return &((EdgePair(v1, v2)->*_map).*_off);
    }

    // Set (or reset) attribute map. I may change when mesh is reset.
    void setAttrMap(AttrMap<EdgePair, ContentT> *map)
    {
      _map = map;
    }
  };
}
#endif
