// Copyright (C) 2010, Guy Barrand. All rights reserved.
// See the file tools.license for terms.

#ifndef toolx_mpi_wrmpi
#define toolx_mpi_wrmpi

#include "wait_buffer"

#include <tools/impi>
#include <tools/vdata>
#include <tools/realloc>
#include <tools/mnmx>

#ifdef TOOLS_MEM
#include <tools/mem>
#endif

// the below must be in sync with tools/typedefs
#if defined(_MSC_VER) || defined(__MINGW32__)
//typedef __int64 int64;
//typedef unsigned __int64 uint64;
#define TOOLX_MPI_UINT64 MPI_UNSIGNED_LONG_LONG
#define TOOLX_MPI_INT64  MPI_LONG_LONG
#elif defined(_LP64)
// 64 Bit Platforms
//typedef long int64;
//typedef unsigned long uint64;
#define TOOLX_MPI_UINT64 MPI_UNSIGNED_LONG
#define TOOLX_MPI_INT64  MPI_LONG
#else
// 32-Bit Platforms
//typedef long long int64;
//typedef unsigned long long uint64;
#define TOOLX_MPI_UINT64 MPI_UNSIGNED_LONG_LONG
#define TOOLX_MPI_INT64  MPI_LONG_LONG
#endif

namespace toolx {
namespace mpi {

class wrmpi : public virtual tools::impi {
  typedef tools::impi parent;
public:
  typedef unsigned int num_t;
protected:
  static const std::string& s_class() {
    static const std::string s_v("toolx::mpi::wrmpi");
    return s_v;
  }
public: //tools::impi
  virtual bool pack(char a_val) {
    tools::uint32 sz = tools::uint32(sizeof(char));
    if(m_pos+sz>m_max) {if(!expand2(m_size+sz)) return false;}
    if(::MPI_Pack(&a_val,1,MPI_CHAR,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Pack(char) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool pack(short a_val) {
    tools::uint32 sz = tools::uint32(sizeof(short));
    if(m_pos+sz>m_max) {if(!expand2(m_size+sz)) return false;}
    if(::MPI_Pack(&a_val,1,MPI_SHORT,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Pack(short) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool pack(int a_val) {
    tools::uint32 sz = tools::uint32(sizeof(int));
    if(m_pos+sz>m_max) {if(!expand2(m_size+sz)) return false;}
    if(::MPI_Pack(&a_val,1,MPI_INT,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Pack(int) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool pack(unsigned int a_val) {
    tools::uint32 sz = tools::uint32(sizeof(unsigned int));
    if(m_pos+sz>m_max) {if(!expand2(m_size+sz)) return false;}
    if(::MPI_Pack(&a_val,1,MPI_UNSIGNED,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Pack(unsigned int) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool pack(tools::uint64 a_val) {
    tools::uint32 sz = tools::uint32(sizeof(tools::uint64));
    if(m_pos+sz>m_max) {if(!expand2(m_size+sz)) return false;}
    if(::MPI_Pack(&a_val,1,TOOLX_MPI_UINT64,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Pack(uint64) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool pack(tools::int64 a_val) {
    tools::uint32 sz = tools::uint32(sizeof(tools::int64));
    if(m_pos+sz>m_max) {if(!expand2(m_size+sz)) return false;}
    if(::MPI_Pack(&a_val,1,TOOLX_MPI_INT64,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Pack(int64) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool pack(float a_val) {
    tools::uint32 sz = tools::uint32(sizeof(float));
    if(m_pos+sz>m_max) {if(!expand2(m_size+sz)) return false;}
    if(::MPI_Pack(&a_val,1,MPI_FLOAT,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Pack(float) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool pack(double a_val) {
    tools::uint32 sz = tools::uint32(sizeof(double));
    if(m_pos+sz>m_max) {if(!expand2(m_size+sz)) return false;}
    if(::MPI_Pack(&a_val,1,MPI_DOUBLE,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Pack(double) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool bpack(bool a_val) {
    tools::uint32 sz = tools::uint32(sizeof(unsigned char));
    if(m_pos+sz>m_max) {if(!expand2(m_size+sz)) return false;}
    unsigned char val = a_val?1:0;
    if(::MPI_Pack(&val,1,MPI_UNSIGNED_CHAR,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Pack(bool) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool spack(const std::string& a_s) {
    if(!pack((num_t)a_s.size())) return false;
    tools::uint32 sz = (tools::uint32)a_s.size();
    if((m_pos+sz)>m_max) {if(!expand2(m_size+sz)) return false;}
#if defined(TOOLX_USE_MPI_PACK_NOT_CONST) || defined(TOOLS_USE_MPI_PACK_NOT_CONST)
    if(::MPI_Pack(const_cast<char*>(a_s.c_str()),a_s.size(),MPI_CHAR,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#else
    if(::MPI_Pack(a_s.c_str(),a_s.size(),MPI_CHAR,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#endif
      m_out << "toolx::mpi::wrmpi : MPI_Pack(std::string) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool vpack(const std::vector<unsigned int>& a_v) {
    if(!pack((num_t)a_v.size())) return false;
    tools::uint32 sz = (tools::uint32)(a_v.size()*sizeof(unsigned int));
    if((m_pos+sz)>m_max) {if(!expand2(m_size+sz)) return false;}
#if defined(TOOLX_USE_MPI_PACK_NOT_CONST) || defined(TOOLS_USE_MPI_PACK_NOT_CONST)
    if(::MPI_Pack(const_cast<unsigned int*>(tools::vec_data(a_v)),a_v.size(),
                  MPI_UNSIGNED,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#else
    if(::MPI_Pack(tools::vec_data(a_v),a_v.size(),MPI_UNSIGNED,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#endif
      m_out << "toolx::mpi::wrmpi : MPI_Pack(std::vector<unsigned int>) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool vpack(const std::vector<int>& a_v) {
    if(!pack((num_t)a_v.size())) return false;
    tools::uint32 sz = (tools::uint32)(a_v.size()*sizeof(int));
    if((m_pos+sz)>m_max) {if(!expand2(m_size+sz)) return false;}
#if defined(TOOLX_USE_MPI_PACK_NOT_CONST) || defined(TOOLS_USE_MPI_PACK_NOT_CONST)
    if(::MPI_Pack(const_cast<int*>(tools::vec_data(a_v)),a_v.size(),MPI_INT,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#else
    if(::MPI_Pack(tools::vec_data(a_v),a_v.size(),MPI_INT,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#endif
      m_out << "toolx::mpi::wrmpi : MPI_Pack(std::vector<int>) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool vpack(const std::vector<double>& a_v) {
    if(!pack((num_t)a_v.size())) return false;
    tools::uint32 sz = (tools::uint32)(a_v.size()*sizeof(double));
    if((m_pos+sz)>m_max) {if(!expand2(m_size+sz)) return false;}
#if defined(TOOLX_USE_MPI_PACK_NOT_CONST) || defined(TOOLS_USE_MPI_PACK_NOT_CONST)
    if(::MPI_Pack(const_cast<double*>(tools::vec_data(a_v)),a_v.size(),MPI_DOUBLE,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#else
    if(::MPI_Pack(tools::vec_data(a_v),a_v.size(),MPI_DOUBLE,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#endif
      m_out << "toolx::mpi::wrmpi : MPI_Pack(std::vector<double>) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
  virtual bool pack(tools::uint32 a_size,const char* a_buffer) {
    if(!pack((num_t)a_size)) return false;
    tools::uint32 sz = (tools::uint32)(a_size*sizeof(char));
    if((m_pos+sz)>m_max) {if(!expand2(m_size+sz)) return false;}
#if defined(TOOLX_USE_MPI_PACK_NOT_CONST) || defined(TOOLS_USE_MPI_PACK_NOT_CONST)
    if(::MPI_Pack(const_cast<char*>(a_buffer),a_size,MPI_CHAR,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#else
    if(::MPI_Pack(a_buffer,a_size,MPI_CHAR,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#endif
      m_out << "toolx::mpi::wrmpi : MPI_Pack(char*) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }

  virtual bool pack(tools::uint32 a_size,const int* a_buffer) {
    if(!pack((num_t)a_size)) return false;
    tools::uint32 sz = (tools::uint32)(a_size*sizeof(int));
    if((m_pos+sz)>m_max) {if(!expand2(m_size+sz)) return false;}
#if defined(TOOLX_USE_MPI_PACK_NOT_CONST) || defined(TOOLS_USE_MPI_PACK_NOT_CONST)
    if(::MPI_Pack(const_cast<int*>(a_buffer),a_size,MPI_INT,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#else
    if(::MPI_Pack(a_buffer,a_size,MPI_INT,m_buffer,m_size,&m_ipos,m_comm)!=MPI_SUCCESS) {
#endif
      m_out << "toolx::mpi::wrmpi : MPI_Pack(int*) failed." << std::endl;
      return false;
    }
    m_pos += sz;
    return true;
  }
public: //tools::impi
  virtual bool unpack(char& a_val) {
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,&a_val,1,MPI_CHAR,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(char) failed." << std::endl;
      a_val = 0;
      return false;
    }
    return true;
  }
  virtual bool unpack(short& a_val) {
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,&a_val,1,MPI_SHORT,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(short) failed." << std::endl;
      a_val = 0;
      return false;
    }
    return true;
  }
  virtual bool unpack(int& a_val) {
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,&a_val,1,MPI_INT,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(int) failed." << std::endl;
      a_val = 0;
      return false;
    }
    return true;
  }
  virtual bool unpack(unsigned int& a_val) {
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,&a_val,1,MPI_UNSIGNED,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(unsigned int) failed." << std::endl;
      a_val = 0;
      return false;
    }
    return true;
  }
  virtual bool unpack(tools::uint64& a_val) {
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,&a_val,1,TOOLX_MPI_UINT64,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(uint64) failed." << std::endl;
      a_val = 0;
      return false;
    }
    return true;
  }
  virtual bool unpack(tools::int64& a_val) {
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,&a_val,1,TOOLX_MPI_INT64,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(int64) failed." << std::endl;
      a_val = 0;
      return false;
    }
    return true;
  }
  virtual bool unpack(float& a_val) {
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,&a_val,1,MPI_FLOAT,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(float) failed." << std::endl;
      a_val = 0;
      return false;
    }
    return true;
  }
  virtual bool unpack(double& a_val) {
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,&a_val,1,MPI_DOUBLE,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(double) failed." << std::endl;
      a_val = 0;
      return false;
    }
    return true;
  }
  virtual bool bunpack(bool& a_val) {
    typedef unsigned char bool_t;
    bool_t val;
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,&val,1,MPI_UNSIGNED_CHAR,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(bool) failed." << std::endl;
      a_val = false;
      return false;
    }
    a_val = val==1?true:false;
    return true;
  }
  virtual bool vunpack(std::vector<unsigned int>& a_v) {
    num_t num;
    if(!unpack(num)) {a_v.clear();return false;}
    a_v.resize(num);
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,tools::vec_data(a_v),a_v.size(),MPI_UNSIGNED,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(std::vector<unsigned int>) failed." << std::endl;
      a_v.clear();
      return false;
    }
    return true;
  }
  virtual bool vunpack(std::vector<int>& a_v) {
    num_t num;
    if(!unpack(num)) {a_v.clear();return false;}
    a_v.resize(num);
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,tools::vec_data(a_v),a_v.size(),MPI_INT,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(std::vector<int>) failed." << std::endl;
      a_v.clear();
      return false;
    }
    return true;
  }
  virtual bool vunpack(std::vector<double>& a_v) {
    num_t num;
    if(!unpack(num)) {a_v.clear();return false;}
    a_v.resize(num);
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,tools::vec_data(a_v),a_v.size(),MPI_DOUBLE,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(std::vector<double>) failed." << std::endl;
      a_v.clear();
      return false;
    }
    return true;
  }
  virtual bool sunpack(std::string& a_s) {
    num_t num;
    if(!unpack(num)) {a_s.clear();return false;}
    a_s.resize(num);
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,const_cast<char*>(a_s.c_str()),a_s.size(),MPI_CHAR,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(std::string) failed." << std::endl;
      a_s.clear();
      return false;
    }
    return true;
  }
  virtual bool unpack(tools::uint32& a_size,char*& a_buffer) {
    num_t num;
    if(!unpack(num)) {a_size = 0;a_buffer = 0;return false;}
    a_buffer = new char[num];
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,a_buffer,num,MPI_CHAR,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(char*) failed." << std::endl;
      delete [] a_buffer;
      a_size = 0;
      a_buffer = 0;
      return false;
    }
    a_size = num;
    return true;
  }
  virtual bool unpack(tools::uint32& a_size,int*& a_buffer) {
    num_t num;
    if(!unpack(num)) {a_size = 0;a_buffer = 0;return false;}
    a_buffer = new int[num];
    if(::MPI_Unpack(m_buffer,m_size,&m_ipos,a_buffer,num,MPI_INT,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi : MPI_Unpack(int*) failed." << std::endl;
      delete [] a_buffer;
      a_size = 0;
      a_buffer = 0;
      return false;
    }
    a_size = num;
    return true;
  }

  virtual void pack_reset() {
    delete [] m_buffer;
    m_size = 128;
    m_buffer = new char[m_size];
    //if(!m_buffer) {}
    m_max = m_buffer+m_size;
    m_pos = m_buffer;
    m_ipos = 0; //IMPORTANT
  }

  virtual bool send_buffer(int a_dest,int a_tag) { // used in tools/mpi/pntuple.
    if(::MPI_Send(m_buffer,m_ipos,MPI_CHAR,a_dest,a_tag,m_comm)!=MPI_SUCCESS) {
      m_out << "toolx::mpi::wrmpi::send_buffer : MPI_Send() failed for rank destination " << a_dest << "." << std::endl;
      return false;
    }
    return true;
  }

  virtual bool wait_buffer(int a_rank,int a_src,int a_tag,int& a_probe_src,bool a_verbose = false) {
    int buffer_size = 0;
    char* _buffer = 0;
    if (!mpi::wait_buffer(m_out,a_rank,a_src,a_tag,m_comm,buffer_size,_buffer,a_probe_src,a_verbose)) {
      m_out << "toolx::mpi::wrmpi::wait_buffer : failed for rank " << a_rank << " and source " << a_src <<"." << std::endl;
      return false;
    }
    // we take ownership of buffer :
    delete [] m_buffer;
    m_size = buffer_size;
    m_buffer = _buffer;
    m_max = m_buffer+m_size;
    m_pos = m_buffer;
    m_ipos = 0; //IMPORTANT
    return true;
  }

  virtual bool wait_buffer(int a_rank,int a_tag,int& a_probe_src,bool a_verbose = false) {
    return wait_buffer(a_rank,MPI_ANY_SOURCE,a_tag,a_probe_src,a_verbose);
  }

public:
  wrmpi(std::ostream& a_out,const MPI_Comm& a_comm,tools::uint32 a_size = 128) // we expect a_size!=0
  :m_out(a_out)
  ,m_comm(a_comm)
  ,m_size(0)
  ,m_buffer(0)
  ,m_max(0)
  ,m_pos(0)
  ,m_ipos(0)
  {
#ifdef TOOLS_MEM
    tools::mem::increment(s_class().c_str());
#endif
    m_size = a_size;
    m_buffer = new char[m_size];
    //if(!m_buffer) {}
    m_max = m_buffer+m_size;
    m_pos = m_buffer;
  }
   wrmpi(std::ostream& a_out,const MPI_Comm& a_comm,tools::uint32 a_size,char* a_buffer) //we take ownership of a_buffer.
  :m_out(a_out)
  ,m_comm(a_comm)
  ,m_size(a_size)
  ,m_buffer(a_buffer)
  ,m_max(0)
  ,m_pos(0)
  ,m_ipos(0)
  {
#ifdef TOOLS_MEM
    tools::mem::increment(s_class().c_str());
#endif
    m_max = m_buffer+m_size;
    m_pos = m_buffer;
  }
  virtual ~wrmpi(){
    delete [] m_buffer;
#ifdef TOOLS_MEM
    tools::mem::decrement(s_class().c_str());
#endif
  }
protected:
  wrmpi(const wrmpi& a_from)
  :parent(a_from)
  ,m_out(a_from.m_out)
  ,m_comm(a_from.m_comm)
  ,m_size(0)
  ,m_buffer(0)
  ,m_max(0)
  ,m_pos(0)
  ,m_ipos(0)
  {
#ifdef TOOLS_MEM
    tools::mem::increment(s_class().c_str());
#endif
  }
  wrmpi& operator=(const wrmpi&){return *this;}
public:
  int ipos() const {return m_ipos;}
#if defined(TOOLX_USE_MPI_PACK_NOT_CONST) || defined(TOOLS_USE_MPI_PACK_NOT_CONST)
  char* buffer() {return m_buffer;}
#else
  const char* buffer() const {return m_buffer;}
#endif

protected:
  bool expand2(tools::uint32 a_new_size) {return expand(tools::mx<tools::uint32>(2*m_size,a_new_size));} //CERN-ROOT logic.

  bool expand(tools::uint32 a_new_size) {
    tools::diff_pointer_t len = m_pos-m_buffer;
    if(!tools::realloc<char>(m_buffer,a_new_size,m_size)) {
      m_out << "toolx::mpi::wrmpi::expand :"
            << " can't realloc " << a_new_size << " bytes."
            << std::endl;
      m_size = 0;
      m_max = 0;
      m_pos = 0;
      //m_wb.set_eob(m_max);
      return false;
    }
    m_size = a_new_size;
    m_max = m_buffer + m_size;
    m_pos = m_buffer + len;
    return true;
  }

protected:
  std::ostream& m_out;
  const MPI_Comm& m_comm;
  tools::uint32 m_size;
  char* m_buffer;
  char* m_max;
  char* m_pos;
  //
  int m_ipos;
};

}}

#ifdef TOOLX_MPI_UINT64
#undef TOOLX_MPI_UINT64
#endif

#ifdef TOOLX_MPI_INT64
#undef TOOLX_MPI_INT64
#endif

#endif
