//
// This file is part of MorphoDynamX - http://www.MorphoDynamX.org
// Copyright (C) 2012-2021 Richard S. Smith and collaborators.
//
// If you use MorphoDynamX in your work, please cite:
//   http://dx.doi.org/10.7554/eLife.05864
//
// MorphoDynamX is free software, and is licensed under under the terms of the 
// GNU General (GPL) Public License version 2.0, http://www.gnu.org/licenses.
//

// Include libtorch headers
#undef slots
#include <torch/script.h>
#include <ATen/cuda/CUDAUtils.h>
#define slots Q_SLOTS
#define USE_CUDA

#include <CNNProcess.hpp>

//#include <StackProcessSystem.hpp>
//#include <StackProcessCanvas.hpp>
#include <QFileDialog>
#include <cuda/CudaExport.hpp>
#include <CImg.h>
using namespace cimg_library;
typedef CImg<float> CImgF;

namespace mgx 
{
  // Class to manage memory hold for cuda, need to release it before calling CNN, hopefully we get some back after.
  struct HoldMem
  {
    HoldMem() { freeMem(); }
    ~HoldMem() { holdMem(); }
  };

//  // Call process to save file in H5 format, requires HDF AddOn to be loaded
//  bool UNet3DPredictPython::saveHDF5(const QString &stackNumber, const QString &storeName, const QString &fileName)
//  {
//    auto *process = getProcess("Stack/System/Save HDF5");
//    if(!process)
//      throw "saveHDF5 Unable to create HDF5 process, is HDF AddOn installed?";
//    process->setParm("Stack Number", stackNumber);
//    process->setParm("Store", storeName);
//    process->setParm("File Name", fileName);
//
//    return process->run();
//  }
//
//  bool UNet3DPredictPython::run(const QString &stackName, const QString &inputStore, const QString &outputStore, 
//                                  QString condaPath, QString cnnPath, QString workPath, bool segment)
//  {
//    Stack *stack = getStack(stackName.toInt());
//    if(!stack)
//      throw QString("%1::run Unable to get stack: %2").arg(name()).arg(stackName);
//
//    Store *input = inputStore == "Main" ? stack->main() : stack->work();
//    if(!input)
//      throw QString("%1::run Unable to get input store: %2").arg(name()).arg(inputStore);
//
//    Store *output = outputStore == "Main" ? stack->main() : stack->work();
//    if(!output)
//      throw QString("%1::run Unable to get output store: %2").arg(name()).arg(outputStore);
//
//    // Temporary work dir
//    QTemporaryDir tmpDir(QDir::tempPath() + "/mgxXXXXXX");
//
//    if(workPath.isEmpty())
//      workPath = tmpDir.path();
//    else
//      QDir().mkpath(workPath);
//
//    // Adjust paths 
//    if(cnnPath[0] != "/")
//      cnnPath = QDir::homePath() + "/" + cnnPath;
//    if(condaPath[0] != "/")
//      condaPath = QDir::homePath() + "/" + condaPath;
//
//    // Input file
//    QString inFileName("mgxInput.h5");
//    saveHDF5(stackName, inputStore, workPath + "/" + inFileName);
//
//    // Python script to run CNN
//    QString cnnScriptName(resourceDir().path() + "/CNN/mgxCNN.sh");
//
//    // Create file list
//    QString fileListName("inFiles.csv");
//    QFile fileList(workPath + "/" + fileListName);
//    if(!fileList.open(QIODevice::WriteOnly | QIODevice::Text))
//      throw QString("%1 Unable to open file list: %2").arg(name()).arg(workPath + "/" + fileListName);
//    fileList.write(inFileName.toLocal8Bit());
//    fileList.close();
//
//    // Free MGX hold on cuda memory
//    HoldMem hm;
//
//    // Segment the cells after prediction
//    QString seg = segment ? "Yes" : "No";
//
//    if(!QProcess::execute(cnnScriptName + " " + condaPath + " " + cnnPath + " " + workPath + " " + fileListName + " " + seg) ==  0)
//      return false;
//
//    // Save the current step
//    auto step = stack->step();
//
//    // Open the prediction stack
//    if(segment) {
//      StackOpen(*this).run(stack, input, workPath + "/" + inFileName.left(inFileName.length()-3) + "Predict.tif");
//      StackOpen(*this).run(stack, output, workPath + "/" + inFileName.left(inFileName.length()-3) + "Segment.tif");
//      ReverseStack(*this).run(input, input, false, true, false);
//      output->setLabels(true);
//    } else
//      StackOpen(*this).run(stack, output, workPath + "/" + inFileName.left(inFileName.length()-3) + "Predict.tif");
//
//    ChangeVoxelSize(*this).run(stack, step);
//    ReverseStack(*this).run(output, output, false, true, false);
//    output->setFile("");
//    output->show();
//
//    return true;
//  }
//  REGISTER_PROCESS(UNet3DPredictPython);

  bool UNet3DPredict::initialize(QWidget *parent)
  {
    if(!parent)
      return true;

    modelPath = parm("Model Path");
    if(!modelPath.isEmpty() and !QFileInfo(modelPath).isDir())
      return true;
    if(modelPath.isEmpty())
      modelPath = QDir::currentPath();

    modelPath = QFileDialog::getOpenFileName(parent, 
        "Choose network model to load", modelPath, "Torchscript files (*.pt);;All files (*.*)");
    if(modelPath.isEmpty())
      return false;

    setParm("Model Path", modelPath);

    return true;
  }

  bool UNet3DPredict::run(Stack *stack, Store* input, Store* output, QString modelPath, Point3u patchSize,
                    double strideFactor, Point3u patchStride, Point3d resample, uint channel, bool standardize, bool useCuda)
  {
    // Free memory for cuda
    HoldMem hm;
    // see if cuda is available
    if(!at::cuda::is_available() and useCuda) {
      useCuda = false;
      Information::out << "Cuda is not available, using CPU" << endl;
    }

    torch::jit::script::Module torchModel;
    try {

      // load the torch model, deserialize the ScriptModule from the provided model file
      torchModel = torch::jit::load(modelPath.toStdString().c_str(), useCuda ? at::kCUDA : at::kCPU);

    } catch(const std::exception &e) {
      throw QString("%1::run Error loading torch model: %2").arg(name()).arg(e.what());
    } catch(...) {
      throw QString("%1::run Unknown error loading torch model").arg(name());
    }

    // Convert data to float
    Point3u imageSize = stack->size();
    size_t imageTotalSize = size_t(imageSize.x()) * imageSize.y() * imageSize.z();
    FloatVec inputDataF(imageTotalSize);
    auto &data = input->data();
    for(size_t i = 0; i < imageTotalSize; i++)
      inputDataF[i] = float(data[i])/65535.0;

    // Resample if required
    if(resample != Point3d(1.0,1.0,1.0)) {
      Point3u newSize = imageSize;
      auto newVoxelSize = stack->step();
      for(int i = 0; i < 3; i++) {
        if(resample[i] == 1.0)
          continue;
        if(resample[i] <= 0)
          throw QString("%1::run Resample size must be > 0").arg(name());

        newSize[i] = imageSize[i] * resample[i];
        newVoxelSize[i] /= float(newSize[i])/imageSize[i];
      }

      Information::out << "Resizing sample to (" << newSize << ") giving a voxel size of (" << newVoxelSize << ")" << endl;
      size_t newTotalSize = size_t(newSize.x()) * newSize.y() * newSize.z();

      CImgF image(inputDataF.data(), imageSize.x(), imageSize.y(), imageSize.z(), 1, false);
      image.resize(newSize.x(), newSize.y(), newSize.z(), 1, 5); // 5 = cubic
      inputDataF.resize(newTotalSize);
      memcpy(inputDataF.data(), image.data(), newTotalSize * sizeof(float));
      imageSize = newSize;
      imageTotalSize = newTotalSize;
    }
    
    // Apparently PlantSeg models require this
    if(standardize) {
      double avgSignal = 0;
      #pragma omp parallel for
      for(size_t i = 0; i < imageTotalSize; i++)
        avgSignal += inputDataF[i];
      avgSignal /= imageTotalSize;
      double std = 0;
      #pragma omp parallel for
      for(size_t i = 0; i < imageTotalSize; i++)
        std += fabs(inputDataF[i] - avgSignal);
      std /= imageTotalSize;
      #pragma omp parallel for
      for(size_t i = 0; i < imageTotalSize; i++)
        inputDataF[i] = (inputDataF[i] - avgSignal)/std;
    }

    // Get the initial patch stride
    if(patchStride == Point3u(0,0,0)) {
      if(strideFactor <= 0 or strideFactor > 1.0)
        throw QString("%1::run strideFactor must be between 0 and 1").arg(name());

      for(int i = 0; i < 3; i++)
        patchStride[i] = ceil(strideFactor * patchSize[i]);
    }

    // Restart from here if torch tosses a cuda memory error
    restart:
    { 
      // compute the strides and number of steps used for processing and patch center for distance computations
      size_t patchTotalSize = 1;
      Point3f patchCenter;
      Point3u patchSteps;
      uint steps = 1;
      for(int i = 0; i < 3; ++i) {
        patchSize[i] = (patchSize[i] / 8) * 8; // Make a multiple of 8
        if(patchSize[i] >= imageSize[i]) {
          patchSize[i] = imageSize[i]; // Clip patches to image size
          patchStride[i] = ceil(strideFactor * patchSize[i]);
	      }
        if(patchStride[i] > patchSize[i])
          patchStride[i] = patchSize[i];
        patchTotalSize *= patchSize[i];
        patchCenter[i] = float(patchSize[i]/2.0f);
        patchSteps[i] = (imageSize[i] - patchSize[i]) / patchStride[i] + 1;
        if((imageSize[i] - patchSize[i]) % patchStride[i] > 0)
          patchSteps[i]++;
        steps *= patchSteps[i];
      }
      float patchCenterDistance = norm(patchCenter);
      
      // print processing information 
      Information::out << "Image Size: " << imageSize << " Orig Size " << stack->size() 
                       << " Patch Size " << patchSize << " Patch Stride: " << patchStride << " Steps:" << steps << endl;
              
      // generate the weight map
      FloatVec weights(imageTotalSize, 0);
      FloatVec patchWeights(patchTotalSize);
      float *patchWeightP = patchWeights.data();
      for(uint z = 0; z < patchSize.z(); z++)
        for(uint y = 0; y < patchSize.y(); y++)
          for(uint x = 0; x < patchSize.x(); x++)
            *patchWeightP++ = (patchCenterDistance - norm(Point3f(x, y, z) - patchCenter)) / patchCenterDistance;
      
      // initialize counters and temporary patch
      FloatVec currentPatch(patchTotalSize);
    
      // Loop over the patches
      FloatVec outputDataF(imageTotalSize, 0);
      bool onceChannels = true;
      progressStart("Running patches through network", steps);
      uint step = 0;
      for(uint z = 0, k = 0; k < patchSteps.z(); k++, z += patchStride.z())
        for(uint y = 0, j = 0; j < patchSteps.y(); j++, y += patchStride.y())
          for(uint x = 0, i = 0; i < patchSteps.x(); i++, x += patchStride.x()) {
            if(!progressAdvance(step++))
              userCancel();
    
            // Get the patch offsets
            Point3u basePos(min(x, uint(imageSize.x() - patchSize.x())), min(y, uint(imageSize.y() - patchSize.y())), 
                                                                            min(z, uint(imageSize.z() - patchSize.z())));
            Information::out << "Processing region " << basePos << endl;
    
            // fill the current patch input
            float *currentP = currentPatch.data();
            for(uint zp = basePos.z(); zp < basePos.z() + patchSize.z(); zp++)
              for(uint yp = basePos.y(); yp < basePos.y() + patchSize.y(); yp++)
                for(uint xp = basePos.x(); xp < basePos.x() + patchSize.x(); xp++)
                  *currentP++ = inputDataF[offset(xp, yp, zp, imageSize.x(), imageSize.y())];
    
            // convert the float array to a torch tensor
            auto tensorImage = torch::from_blob(currentPatch.data(), {1, 1, patchSize.z(), patchSize.y(), patchSize.x()}, torch::kFloat).clone();
            
            tensorImage = tensorImage.toType(torch::kFloat);
            tensorImage.set_requires_grad(0);
                         
            // Create a vector of inputs.
            std::vector<torch::jit::IValue> inputs;
            if(useCuda)
              inputs.emplace_back(tensorImage.to(at::kCUDA));
            else
              inputs.emplace_back(tensorImage);
    
            // Execute the model and turn its output into a tensor.
            torch::jit::IValue networkPrediction;
            torch::Tensor outputTensor;
            std::vector<torch::jit::IValue> outputTuple;
            int nOutChan = 0;
    
            // Catch torch errors
            try {
              // Don't store gradient data
              torch::NoGradGuard noGrad;
  
              networkPrediction = torchModel.forward(inputs);
                                
              // Two possibilities for output
              if(networkPrediction.isTuple()) {
                outputTuple = networkPrediction.toTuple()->elements();
                nOutChan = outputTuple.size();
                outputTensor = outputTuple[0].toTensor();
              } else {
                outputTensor = networkPrediction.toTensor();
                nOutChan = outputTensor.size(1);
              }
              if(onceChannels) {
                onceChannels = false;
                Information::out << "Output channels found: "<< nOutChan << endl;
              }
              if(nOutChan <= 0)
                throw QString("%1::run No output channels returned").arg(name());
    
              if(nOutChan < int(channel)) {
                channel = 0;
                Information::out << "Channel requested is greater than number of channels, using channel: " << channel << endl;
              }
    
              // convert tensor output back to a float array
              outputTensor = outputTensor.toType(torch::kFloat).to(at::kCPU);
            } catch(const std::exception &e) {
              QString errorMsg(e.what());
              if(errorMsg.contains("CUDA out of memory") and patchSize.x() > 32 and patchSize.y() > 32 and patchSize.z() > 32) {
                patchSize *= .75;
                patchStride *= .75;
                Information::out << "Torch Error, trying to reduce patch size" << endl;
                goto restart;
              }
              throw QString("%1::run Error running torch model: %2").arg(name()).arg(e.what());
            } catch(...) {
              throw QString("%1::run Unknown error running torch model").arg(name());
            }
    
            // Create an array of pointers to the channels for the results
            std::vector<float *> results(nOutChan);
            for(int n = 0; n < nOutChan; n++)
              if(networkPrediction.isTuple()) {
                if(n > 0)
                  outputTensor = outputTuple[n].toTensor();
                results[n] = outputTensor.data_ptr<float>();
              } else
                results[n] = outputTensor.data_ptr<float>() + n * patchTotalSize;
    
            for(int n = 0; n < nOutChan; n++) {
              if(n != int(channel))
                continue;
    
              // Get the result pointer and patch weight pointers
              float *resultp = results[n];
              float *patchWeightP = patchWeights.data();
              // Process the patch result
              for(uint zp = basePos.z(); zp < basePos.z() + patchSize.z(); zp++)
                for(uint yp = basePos.y(); yp < basePos.y() + patchSize.y(); yp++)
                  for(uint xp = basePos.x(); xp < basePos.x() + patchSize.x(); xp++) {
                    size_t idx = offset(xp, yp, zp, imageSize.x(), imageSize.y());
                    if(idx > weights.size()) {
                      Information::out << "Stack offset:" <<  stack->offset(xp, yp, zp) << " wt size:" << weights.size() 
                        << "img sz: " << imageSize << " total sz:" << imageTotalSize << " store sz:" << input->size() << endl;
                      continue;
                    }
                    outputDataF[idx] += *patchWeightP * *resultp++;
                    // increment the weight iterator. As weights for all channels are identical, this only has to be done once.
                    weights[idx] +=  *patchWeightP++;
                  }
            }
          }

      // Normalize by the weights
      #pragma omp parallel for
      for(size_t i = 0; i < imageTotalSize; i++)
        outputDataF[i] /= weights[i];

      // Resample if required
      if(resample != Point3d(1.0,1.0,1.0)) {
        Point3u newSize = stack->size();
        size_t newTotalSize = size_t(newSize.x()) * newSize.y() * newSize.z();

        CImgF image(outputDataF.data(), imageSize.x(), imageSize.y(), imageSize.z(), 1, false);
        image.resize(newSize.x(), newSize.y(), newSize.z(), 1, 5); // 5 = cubic
        outputDataF.resize(newTotalSize);
        memcpy(outputDataF.data(), image.data(), newTotalSize * sizeof(float));
        imageSize = newSize;
        imageTotalSize = newTotalSize;
      }
     
      // Find the data range and scale output data to 16bit
      float minS = std::numeric_limits<float>::max();
      float maxS = std::numeric_limits<float>::lowest();
      for(size_t i = 0; i < imageTotalSize; i++) {
        float s = outputDataF[i];
        if(maxS < s)
          maxS = s;
        if(minS > s)
          minS = s;
      }
      Information::out << "Prediction stack data values, min:" << minS << " max:" << maxS << endl;

      float range = maxS - minS;
      auto &outputData = output->data();
      outputData.resize(input->data().size());
      if(range > 0)
        #pragma omp parallel for
        for(size_t i = 0; i < imageTotalSize; i++)
          outputData[i] = short((outputDataF[i] - minS)/range * 65535.0);
      else
        throw QString("%1::run Null output returned from network on channel: %2").arg(name()).arg(channel);
    }
    
    output->changed();
    output->copyMetaData(input);

    return true;
  }
  REGISTER_PROCESS(UNet3DPredict);
}
