//
// This file is part of MorphoGraphX - https://www.MorphoGraphX.org  (@RichardSmithLab)
//
// MorphoGraphX development is led by the Richard S. Smith lab at the John Innes Centre, Norwich, UK
//
// If you use MorphoGraphX in your work, please cite:
//   https://doi.org/10.7554/eLife.72601
//
// For support please see the image.sc forum:
//   https://forum.image.sc/tag/MorphoGraphX
//
// MorphoGraphX is copyright by its authors, contributors, and/or their employers.
//
// MorphoGraphX is free software, and is licensed under the terms of the 
// GNU General Public License https://www.gnu.org/licenses/.
//
#ifndef CUDA_HPP
#define CUDA_HPP

#include <Config.hpp>

#include <thrust/version.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/device_ptr.h>

#define COMPILE_CUDA // RSS Cant' we make this depend on THRUST_CUDA_BACKEND?
#ifdef THRUST_BACKEND_CUDA
  #include <cuda_runtime.h>
#endif

#include <CudaGlobal.hpp>
#include <Geometry.hpp>
#include <ThrustTypes.hpp>

#ifdef WIN32
#  include <windows.h>
#else
#  include <unistd.h>
#endif

namespace mgx 
{
  // Progress bar
  //  #if ((defined(WIN32) || defined(WIN64)) && defined(MINGW))
  #if (defined(WIN32) || defined(WIN64))
  // Define fake functions for Windows
  inline bool progressAdvance() { return true; }
  inline bool progressAdvance(int step) { return true; }
  inline void progressStart(const std::string &msg, int steps) {}
  inline void progressSetMsg(const std::string & msg) {}
  inline bool progressCancelled() { return false; }
  inline void progStop() {}
  #else
  extern bool progressAdvance();
  extern bool progressAdvance(int step);
  extern void progressStart(const std::string &msg, int steps);
  extern void progressSetMsg(const std::string &msg);
  extern bool progressCancelled();
  extern void progressStop();
  #endif

  typedef thrust::device_ptr<int> DevicePi;
  typedef thrust::device_ptr<uint> DevicePui;
  typedef thrust::device_ptr<double> DevicePd;
  typedef thrust::device_ptr<Point3d> DeviceP3d;
  typedef thrust::device_ptr<Matrix3d> DevicePM3d;
  
  // This requires changing after thrust is updated everywhere
  #if THRUST_VERSION >= 100600
    typedef thrust::counting_iterator<int, thrust::device_system_tag> DCountIter;
  #else
    typedef thrust::counting_iterator<int, thrust::device_space_tag> DCountIter;
  #endif
  
  // Dimensions
  enum DimEnum { X, Y, Z, XY, YZ, XZ, XYZ };
  
  // Check for Cuda errors
  extern "C" cuda_EXPORT int checkCudaError(const std::string& msg);
  cuda_EXPORT uint getThreadCount(uint threadPos, uint totThreads);
  cuda_EXPORT int errMsg(const std::string& s);
  cuda_EXPORT size_t userMem();
  #define MByte size_t(1024 * 1024)
  
  // Get a pointer to first element from a device vector
  template <typename T> T* devP(thrust::device_vector<T> &DVec) { return (&DVec[0]).get(); }
  template <typename T> T* devP(thrust::device_vector<T> *DVec) { return (&(*DVec)[0]).get(); }

  // Atomic add for doubles
  #ifdef THRUST_BACKEND_CUDA 
    __device__ 
    inline double atomicAddMGX(double* address, double val)
    {
      unsigned long long int* address_as_ull = (unsigned long long int*)address;
      unsigned long long int old = *address_as_ull, assumed;
      do {
        assumed = old;
        old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
      } while (assumed != old);
      return __longlong_as_double(old);
    }
  #else
    inline double atomicAddMGX(double* address, double val)
    {
      #pragma omp atomic
      *address += val;
      return *address;
    }
  #endif

  // Class to lock memory
  extern void freeMem();
  extern void holdMem();
  extern size_t userMem();
  extern size_t memLeft(size_t memrq, size_t roomrq, size_t mem);
  extern int errMsg(const std::string &s);

  struct HoldMem
  {
    HoldMem() { freeMem(); }
    ~HoldMem() { holdMem(); }
  };
}
#endif
