/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

        Timstat3        varquot2test
        Timstat3        meandiff2test
*/

#include <cdi.h>

#include "functs.h"
#include "process_int.h"
#include "cdo_vlist.h"
#include "param_conversion.h"
#include "statistic.h"

#define NIN 2
#define NOUT 1
#define NFWORK 4
#define NIWORK 2

void *
Timstat3(void *process)
{
  CdoStreamID streamID[NIN];
  int vlistID[NIN], vlistID2 = -1;
  int64_t vdate = 0;
  int vtime = 0;
  int is;
  Varray3D<int> iwork[NIWORK];
  FieldVector2D fwork[NFWORK];
  int reached_eof[NIN];
  constexpr int n_in = NIN;

  cdoInitialize(process);

  // clang-format off
  const auto VARQUOT2TEST  = cdoOperatorAdd("varquot2test",  0, 0, nullptr);
  const auto MEANDIFF2TEST = cdoOperatorAdd("meandiff2test", 0, 0, nullptr);
  // clang-format on

  const auto operatorID = cdoOperatorID();

  operatorInputArg("constant and risk (e.g. 0.05)");
  operatorCheckArgc(2);
  const auto rconst = parameter2double(cdoOperatorArgv(0));
  const auto risk = parameter2double(cdoOperatorArgv(1));

  if (operatorID == VARQUOT2TEST)
    {
      if (rconst <= 0) cdoAbort("Constant must be positive!");

      if (risk <= 0 || risk >= 1) cdoAbort("Risk must be greater than 0 and lower than 1!");
    }

  for (is = 0; is < NIN; ++is)
    {
      streamID[is] = cdoOpenRead(is);
      vlistID[is] = cdoStreamInqVlist(streamID[is]);
      if (is > 0)
        {
          vlistID2 = cdoStreamInqVlist(streamID[is]);
          vlistCompare(vlistID[0], vlistID2, CMP_ALL);
        }
    }

  const auto vlistID3 = vlistDuplicate(vlistID[0]);

  const auto gridsizemax = vlistGridsizeMax(vlistID[0]);
  const auto nvars = vlistNvars(vlistID[0]);

  const auto maxrecs = vlistNrecs(vlistID[0]);
  std::vector<RecordInfo> recList(maxrecs);

  const auto taxisID1 = vlistInqTaxis(vlistID[0]);
  const auto taxisID3 = taxisDuplicate(taxisID1);

  vlistDefTaxis(vlistID3, taxisID3);
  const auto streamID3 = cdoOpenWrite(2);
  cdoDefVlist(streamID3, vlistID3);

  for (int i = 0; i < NIN; ++i) reached_eof[i] = 0;

  Field in[NIN], out[NOUT];
  for (int i = 0; i < NIN; ++i) in[i].resize(gridsizemax);
  for (int i = 0; i < NOUT; ++i) out[i].resize(gridsizemax);

  for (int iw = 0; iw < NFWORK; ++iw) fwork[iw].resize(nvars);
  for (int iw = 0; iw < NIWORK; ++iw) iwork[iw].resize(nvars);

  for (int varID = 0; varID < nvars; ++varID)
    {
      const auto gridID = vlistInqVarGrid(vlistID[0], varID);
      const auto gridsize = vlistGridsizeMax(vlistID[0]);
      const auto nlevels = zaxisInqSize(vlistInqVarZaxis(vlistID[0], varID));
      const auto missval = vlistInqVarMissval(vlistID[0], varID);

      for (int iw = 0; iw < NFWORK; ++iw) fwork[iw][varID].resize(nlevels);
      for (int iw = 0; iw < NIWORK; ++iw) iwork[iw][varID].resize(nlevels);

      for (int levelID = 0; levelID < nlevels; ++levelID)
        {
          for (int iw = 0; iw < NFWORK; ++iw)
            {
              fwork[iw][varID][levelID].grid = gridID;
              fwork[iw][varID][levelID].missval = missval;
              fwork[iw][varID][levelID].resize(gridsize);
            }

          for (int iw = 0; iw < NIWORK; ++iw) iwork[iw][varID][levelID].resize(gridsize, 0);
        }
    }

  int tsID = 0;
  while (true)
    {
      for (is = 0; is < NIN; ++is)
        {
          if (reached_eof[is]) continue;

          const auto nrecs = cdoStreamInqTimestep(streamID[is], tsID);
          if (nrecs == 0)
            {
              reached_eof[is] = 1;
              continue;
            }

          vdate = taxisInqVdate(taxisID1);
          vtime = taxisInqVtime(taxisID1);

          for (int recID = 0; recID < nrecs; recID++)
            {
              int varID, levelID;
              cdoInqRecord(streamID[is], &varID, &levelID);

              const auto gridsize = gridInqSize(vlistInqVarGrid(vlistID[is], varID));

              in[is].missval = vlistInqVarMissval(vlistID[is], varID);

              if (tsID == 0 && is == 0)
                {
                  recList[recID].varID = varID;
                  recList[recID].levelID = levelID;
                  recList[recID].lconst = vlistInqVarTimetype(vlistID[0], varID) == TIME_CONSTANT;
                }

              cdoReadRecord(streamID[is], in[is].vec_d.data(), &in[is].nmiss);

              for (size_t i = 0; i < gridsize; ++i)
                {
                  // if ( ( ! DBL_IS_EQUAL(array1[i], missval1) ) && ( ! DBL_IS_EQUAL(array2[i], missval2) ) )
                  {
                    fwork[NIN * is + 0][varID][levelID].vec_d[i] += in[is].vec_d[i];
                    fwork[NIN * is + 1][varID][levelID].vec_d[i] += in[is].vec_d[i] * in[is].vec_d[i];
                    iwork[is][varID][levelID][i]++;
                  }
                }
            }
        }

      for (is = 0; is < NIN; ++is)
        if (!reached_eof[is]) break;

      if (is == NIN) break;

      tsID++;
    }

  taxisDefVdate(taxisID3, vdate);
  taxisDefVtime(taxisID3, vtime);
  cdoDefTimestep(streamID3, 0);

  for (int recID = 0; recID < maxrecs; recID++)
    {
      const auto varID = recList[recID].varID;
      const auto levelID = recList[recID].levelID;

      const auto missval1 = fwork[0][varID][levelID].missval;
      const auto missval2 = missval1;

      if (operatorID == VARQUOT2TEST)
        {
          for (size_t i = 0; i < gridsizemax; ++i)
            {
              const auto fnvals0 = iwork[0][varID][levelID][i];
              const auto fnvals1 = iwork[1][varID][levelID][i];

              const auto temp0 = DIVMN(MULMN(fwork[0][varID][levelID].vec_d[i], fwork[0][varID][levelID].vec_d[i]), fnvals0);
              const auto temp1 = DIVMN(MULMN(fwork[2][varID][levelID].vec_d[i], fwork[2][varID][levelID].vec_d[i]), fnvals1);
              const auto temp2 = SUBMN(fwork[1][varID][levelID].vec_d[i], temp0);
              const auto temp3 = SUBMN(fwork[3][varID][levelID].vec_d[i], temp1);
              const auto statistic = DIVMN(temp2, ADDMN(temp2, MULMN(rconst, temp3)));

              double fractil_1 = missval1, fractil_2 = missval1;
              if (fnvals0 > 1 && fnvals1 > 1)
                cdo::beta_distr_constants((fnvals0 - 1) / 2, (fnvals1 - 1) / 2, 1 - risk, &fractil_1, &fractil_2, __func__);

              out[0].vec_d[i]
                = DBL_IS_EQUAL(statistic, missval1) ? missval1 : (statistic <= fractil_1 || statistic >= fractil_2) ? 1.0 : 0.0;
            }
        }
      else if (operatorID == MEANDIFF2TEST)
        {
          double mean_factor[NIN], var_factor[NIN];

          mean_factor[0] = 1.0;
          mean_factor[1] = -1.0;
          var_factor[0] = var_factor[1] = 1.0;

          for (size_t i = 0; i < gridsizemax; ++i)
            {
              double temp0 = 0.0;
              double deg_of_freedom = -n_in;
              for (int j = 0; j < n_in; j++)
                {
                  const auto fnvals = iwork[j][varID][levelID][i];
                  auto tmp = DIVMN(MULMN(fwork[2 * j][varID][levelID].vec_d[i], fwork[2 * j][varID][levelID].vec_d[i]), fnvals);
                  temp0 = ADDMN(temp0, DIVMN(SUBMN(fwork[2 * j + 1][varID][levelID].vec_d[i], tmp), var_factor[j]));
                  deg_of_freedom = ADDMN(deg_of_freedom, fnvals);
                }

              if (!DBL_IS_EQUAL(temp0, missval1) && temp0 < 0) // This is possible because of rounding errors 
                temp0 = 0;

              const auto stddev_estimator = SQRTMN(DIVMN(temp0, deg_of_freedom));
              auto mean_estimator = -rconst;
              for (int j = 0; j < n_in; j++)
                {
                  const auto fnvals = iwork[j][varID][levelID][i];
                  mean_estimator = ADDMN(mean_estimator, MULMN(mean_factor[j], DIVMN(fwork[2 * j][varID][levelID].vec_d[i], fnvals)));
                }

              double temp1 = 0.0;
              for (int j = 0; j < n_in; j++)
                {
                  const auto fnvals = iwork[j][varID][levelID][i];
                  temp1 = ADDMN(temp1, DIVMN(MUL(MUL(mean_factor[j], mean_factor[j]), var_factor[j]), fnvals));
                }

              const auto norm = SQRTMN(temp1);

              const auto temp2 = DIVMN(DIVMN(mean_estimator, norm), stddev_estimator);
              const auto fractil = (deg_of_freedom < 1) ? missval1 : cdo::student_t_inv(deg_of_freedom, 1 - risk / 2, __func__);

              out[0].vec_d[i]
                  = DBL_IS_EQUAL(temp2, missval1) || DBL_IS_EQUAL(fractil, missval1) ? missval1 : std::fabs(temp2) >= fractil;
            }
        }

      cdoDefRecord(streamID3, varID, levelID);
      cdoWriteRecord(streamID3, out[0].vec_d.data(), fieldNumMiss(out[0]));
    }

  cdoStreamClose(streamID3);
  for (is = 0; is < NIN; ++is) cdoStreamClose(streamID[is]);

  cdoFinish();

  return nullptr;
}
