//
// This file is part of MorphoGraphX - https://www.MorphoGraphX.org  (@RichardSmithLab)
//
// MorphoGraphX development is led by the Richard S. Smith lab at the John Innes Centre, Norwich, UK
//
// If you use MorphoGraphX in your work, please cite:
//   https://doi.org/10.7554/eLife.72601
//
// For support please see the image.sc forum:
//   https://forum.image.sc/tag/MorphoGraphX
//
// MorphoGraphX is copyright by its authors, contributors, and/or their employers.
//
// MorphoGraphX is free software, and is licensed under the terms of the 
// GNU General Public License https://www.gnu.org/licenses/.
//
#ifndef CNNProcess_HPP
#define CNNProcess_HPP

#include "CNNConfig.hpp"
#include <Process.hpp>

namespace mgx
{
  ///\addtogroup Process
  ///@{
  /**
   * \class UNet3DPredict CNNProcess.hpp <CNNProcess.hpp>
   *
   * Use CNN cell wall predictions to (hopefully) improve the stack
   */
//  class CNN_EXPORT UNet3DPredictPython : public Process 
//  {
//  public:
//    UNet3DPredictPython(const Process &process) : Process(process) 
//    {
//      setName("Stack/CNN/UNet3D Prediction Python");
//      setDesc("Improve stack with CNN cell wall predictions.\n"
//              "This process is uses the tools developed in Eschweiler et al.\n"
//              "2019 IEEE 16th International Symposium on Biomedical Imaging (ISBI 2019)\n");
//      setIcon(QIcon(":/images/CNN.png"));
//      
//      addParm("CNN Path","Path to CNN software (relative path from your home directory)", "JStegmaierCellSegmentation");            
//      addParm("Conda Path","Path to Anaconda directory (relative path from your home directory)", "anaconda3");            
//      addParm("Work Path","Path for work files, empty for temporary directory", "");            
//      addParm("Segment Cells","Return segmentation in the work stack, and predictions in the main", "No", booleanChoice());            
//    }
//  
//    bool run()
//    {
//      QString stackNumber = QString("%1").arg(currentStack()->id());
//      QString inputStore = currentStack()->currentStore() == currentStack()->main() ? "Main" : "Work";
//      return run(stackNumber, inputStore, "Work", parm("Conda Path"), parm("CNN Path"), parm("Work Path"), stringToBool(parm("Segment Cells")));
//    }
//  
//    bool run(const QString &stackName, const QString &inputStore, const QString &outputStore, 
//                                                                QString condaPath, QString cnnPath, QString workPath, bool segment);
//    
//    bool saveHDF5(const QString &stackNumber, const QString &storeName, const QString &fileName);
//
//  };

  class CNN_EXPORT UNet3DPredict : public Process 
  {
  public:
    UNet3DPredict(const Process &process) : Process(process) 
    {
      setName("Stack/CNN/UNet3D Prediction");
      setDesc("Improve stack with CNN cell wall prediction.\n"
              "If you use this process in your work, please cite:\n"
              "  Eschweiler et al. 2019 CNN-Based Preprocessing to Optimize Watershed-Based Cell Segmentation in 3D Confocal Microscopy Images\n"
              "  IEEE 16th International Symposium on Biomedical Imaging\n");
      setIcon(QIcon(":/images/CNN.png"));
      
      addParm("Model Path","Path to network model", QString("%1/CNN").arg(MGXRESPATH));            
      addParm("Patch Size","Maximum patch size for processing", "256 256 128");            
      addParm("Stride Factor", "Stride for patches, determines overlap", "0.5");            
      addParm("Patch Stride", "If not 0, specifies patch stride directly", "0 0 0");
      addParm("Resample", "Resample the image before feeding to the network (xyz)", "1.0 1.0 1.0");            
      addParm("Standardize", "Standardize image", "No", booleanChoice());            
      addParm("Use Cuda","Use cuda if available", "Yes", booleanChoice());            
    }

    bool initialize(QWidget *parent);
    bool run();
    bool run(Stack *stack, Store* input, std::vector<HVecUS> &outputData, QString modelPath, 
        Point3u patchSize = Point3u(96, 96, 48), double strideFactor = 0.5, Point3u patchStride = Point3u(0, 0, 0), 
        Point3d resample = Point3d(1.0, 1.0, 1.0), bool stardardize = false, bool useCuda = true);

    std::vector<FloatVec> &outputDataF() { return outputDataArrayF; }
    std::vector<HVecUS> &outputData() { return outputDataArray; }

  protected:
    std::vector<FloatVec> outputDataArrayF;
    std::vector<HVecUS> outputDataArray;
    QString modelPath;
  };

  class CNN_EXPORT UNetDisplay : public Process 
  {
  public:
    UNetDisplay(const Process &process) : Process(process) 
    {
      setName("Stack/CNN/UNet Display");
      setDesc("Display output from UNet prediction");
      setIcon(QIcon(":/images/CNN.png"));
      
      addParm("Channel","Output channel to display", "2");            
    }
    bool run();
  };

  ///@}
}
#endif
