//
// This file is part of MorphoDynamX - http://www.MorphoDynamX.org
// Copyright (C) 2012-2015 Richard S. Smith and collaborators.
//
// If you use MorphoDynamX in your work, please cite:
//   http://dx.doi.org/10.7554/eLife.05864
//
// MorphoDynamX is free software, and is licensed under under the terms of the 
// GNU General (GPL) Public License version 2.0, http://www.gnu.org/licenses.
//
#ifndef CNNProcess_HPP
#define CNNProcess_HPP

#include <CNNConfig.hpp>
#include <Process.hpp>

namespace mgx
{
  ///\addtogroup Process
  ///@{
  /**
   * \class UNet3DPredict CNNProcess.hpp <CNNProcess.hpp>
   *
   * Use CNN cell wall predictions to (hopefully) improve the stack
   */
//  class CNN_EXPORT UNet3DPredictPython : public Process 
//  {
//  public:
//    UNet3DPredictPython(const Process &process) : Process(process) 
//    {
//      setName("Stack/CNN/UNet3D Prediction Python");
//      setDesc("Improve stack with CNN cell wall predictions.\n"
//              "This process is uses the tools developed in Eschweiler et al.\n"
//              "2019 IEEE 16th International Symposium on Biomedical Imaging (ISBI 2019)\n");
//      setIcon(QIcon(":/images/CNN.png"));
//      
//      addParm("CNN Path","Path to CNN software (relative path from your home directory)", "JStegmaierCellSegmentation");						
//      addParm("Conda Path","Path to Anaconda directory (relative path from your home directory)", "anaconda3");						
//      addParm("Work Path","Path for work files, empty for temporary directory", "");						
//      addParm("Segment Cells","Return segmentation in the work stack, and predictions in the main", "No", booleanChoice());						
//	  }
//  
//    bool run()
//    {
//      QString stackNumber = QString("%1").arg(currentStack()->id());
//      QString inputStore = currentStack()->currentStore() == currentStack()->main() ? "Main" : "Work";
//      return run(stackNumber, inputStore, "Work", parm("Conda Path"), parm("CNN Path"), parm("Work Path"), stringToBool(parm("Segment Cells")));
//    }
//  
//    bool run(const QString &stackName, const QString &inputStore, const QString &outputStore, 
//                                                                QString condaPath, QString cnnPath, QString workPath, bool segment);
//    
//    bool saveHDF5(const QString &stackNumber, const QString &storeName, const QString &fileName);
//
//  };

  class CNN_EXPORT UNet3DPredict : public Process 
  {
  public:
    UNet3DPredict(const Process &process) : Process(process) 
    {
      setName("Stack/CNN/UNet3D Prediction");
      setDesc("Improve stack with CNN cell wall prediction.\n"
              "If you use this process in your work, please cite:\n"
              "  Eschweiler et al. 2019 CNN-Based Preprocessing to Optimize Watershed-Based Cell Segmentation in 3D Confocal Microscopy Images\n"
              "  IEEE 16th International Symposium on Biomedical Imaging\n");
      setIcon(QIcon(":/images/CNN.png"));
      
      addParm("Model Path","Path to network model", QString("%1/CNN").arg(MGXRESPATH));						
      addParm("Patch Size","Maximum patch size for processing", "128 128 128");						
      addParm("Stride Factor", "Stride for patches, determines overlap", "0.5");						
      addParm("Patch Stride", "If not 0, specifies patch stride directly", "0 0 0");
      addParm("Resample", "Resample the image before feeding to the network (xyz)", "1.0 1.0 1.0");						
      addParm("Channel","Output channel to return", "2");						
      addParm("Standardize", "Standardize image, required for PlantSeg networks", "No", booleanChoice());						
      addParm("Use Cuda","Use cuda if available", "Yes", booleanChoice());						
	  }
  
    bool run()
    {
      Stack* stack = currentStack();
      Store* input = stack->currentStore();
      Store* output = stack->work();
      modelPath = parm("Model Path");
      if(modelPath.isEmpty())
        throw QString("%1::run No model path specified").arg(name());

      bool result = run(stack, input, output, modelPath, stringToPoint3u(parm("Patch Size")), parm("Stride Factor").toDouble(),
          stringToPoint3u(parm("Patch Stride")), stringToPoint3d(parm("Resample")), 
          parm("Channel").toUInt(), stringToBool(parm("Standardize")), stringToBool(parm("Use Cuda")));
      if(result) {
        output->show();
        output->setLabels(false);
      }
      return result;
    }
  
    bool initialize(QWidget *parent);
    bool run(Stack *stack, Store* input, Store *output, QString modelPath, 
        Point3u patchSize = Point3u(128,128,128), double strideFactor = 0.5, Point3u patchStride = Point3u(0,0,0), 
        Point3d resample = Point3d(1.0,1.0,1.0), uint channel = 2, bool stardardize = false, bool useCuda = true);

  protected:
    QString modelPath;
  };
  ///@}
}
#endif
