#include <cmath>
#include <iterator>
#include <list>
#include <random>
#include <string>

namespace armspp {

class exception : public std::exception {
public:
  explicit exception(const char* message_)
    : message(message_) {}

  exception(const char* message_, const char*, int)
    : message(message_) {}

  virtual ~exception() throw() {}
  virtual const char* what() const throw() {
    return message.c_str();
  }

private:
  std::string message;
};

template<typename Scalar, typename Functor, typename Iterator>
class ARMS {
public:
  ARMS(
    Functor& logPdf,
    Scalar lower,
    Scalar upper,
    Scalar convex,
    Iterator xInitial,
    int nInitial,
    int maxPoints,
    bool metropolis,
    Scalar xPrevious
  ) : logPdf_(logPdf),
      lower_(lower),
      upper_(upper),
      convex_(convex),
      maxPoints_(maxPoints),
      metropolis_(metropolis),
      xPrevious_(xPrevious) {

    if (nInitial < 3) {
      throw exception("Too few initial points");
    }

    // Set up the initial envelope
    int nEnvelopeInitial = 2 * nInitial + 1;

    if (nEnvelopeInitial > maxPoints_) {
      throw exception("Too many initial points");
    }

    if (xInitial[0] <= lower_ || xInitial[nInitial - 1] >= upper_) {
      throw exception("Initial points do not satisfy bounds");
    }

    for (int i = 1; i < nInitial; ++i) {
      if (xInitial[i] <= xInitial[i - 1]) {
        throw exception("Initial points are not ordered");
      }
    }

    if (convex_ < 0) {
      throw exception("Convexity parameter is negative");
    }

    // Left bound
    points_.emplace_back(lower, 0, 0, 0, false);
    for (int i = 1; i < nEnvelopeInitial - 1; ++i) {
      if (i % 2 == 1) {
        // Point on log density
        points_.emplace_back(
          *xInitial, logPdf(*xInitial),
          0, 0, true
        );
        ++xInitial;
      } else {
        // Intersection point
        points_.emplace_back(0, 0, 0, 0, false);
      }
    }
    // Right bound
    points_.emplace_back(upper, 0, 0, 0, false);

    auto q = points_.begin();
    for (int i = 0; i < nEnvelopeInitial; i += 2) {
      updateIntersection_(q);
      ++q;
      ++q;
    }

    accumulate_();

    if (metropolis_) {
      yPrevious_ = logPdf_(xPrevious_);
    }
  }

  template<typename RNG>
  Scalar operator()(RNG& rng) {
    using std::prev;
    using std::next;

    for (int iteration = 0; iteration < MAX_ITERATIONS; ++iteration) {
      WorkingPoint w = invert_(uniform_(rng));

      // For rejection test
      Scalar u = uniform_(rng) * w.p.expY;
      Scalar y = logShift_(u);

      if (
        !metropolis_
        && w.left != points_.begin()
        && next(w.right) != points_.end()
      ) {
        // Squeeze test
        typename std::list<Point>::iterator ql, qr;

        if (w.left->isEvaluated) {
          ql = w.left;
        } else {
          ql = prev(w.left);
        }
        if (w.right->isEvaluated) {
          qr = w.right;
        } else {
          qr = next(w.right);
        }
        Scalar ySqueeze = (
          (qr->y * (w.p.x - ql->x) + ql->y * (qr->x - w.p.x))
          / (qr->x - ql->x)
        );
        if (y <= ySqueeze) {
          /* accept point at squeezing step */
          return w.p.x;
        }
      }

      Scalar yNew = logPdf_(w.p.x);

      /* perform rejection test */
      if (!metropolis_ || (metropolis_ && y >= yNew)) {
        // Either not metropolis, or we will reject
        w.p.y = yNew;
        w.p.expY = expShift_(w.p.y);
        w.p.isEvaluated = true;
        addPoint(w);
        if (y < w.p.y) {
          return w.p.x;
        }
      } else {
        // Metropolis step
        typename std::list<Point>::iterator qPrevious;
        qPrevious = points_.begin();
        while (next(qPrevious)->x < xPrevious_) ++qPrevious;
        Scalar zPrevious = (
          qPrevious->y
          + (next(qPrevious)->y - qPrevious->y) * (xPrevious_ - qPrevious->x) / (next(qPrevious)->x - qPrevious->x)
        );
        Scalar zNew = w.p.y;
        if (yPrevious_ < zPrevious) zPrevious = yPrevious_;
        if (yNew < zNew) zNew = yNew;

        Scalar logRatio = yNew - zNew - yPrevious_ + zPrevious;
        Scalar ratio;
        if (logRatio > 0) {
          ratio = 1;
        } else if (logRatio > -Y_CEILING) {
          ratio = std::exp(logRatio);
        } else {
          ratio = 0;
        }

        if (uniform_(rng) > ratio) {
          // Rejected by Metropolis step
          return xPrevious_;
        } else {
          // Accepted by Metropolis step; update internal state
          xPrevious_ = w.p.x;
          yPrevious_ = yNew;
          return w.p.x;
        }
      }
    }

    throw exception("Maximum number of iterations exceeded");
  }

  Scalar envelopeQuantile(Scalar probability) {
    return invert_(probability).p.x;
  }

private:
  struct Point {
    Point(
      Scalar x_,
      Scalar y_,
      Scalar expY_,
      Scalar cumulative_,
      bool isEvaluated_
    ) : x(x_),
        y(y_),
        expY(expY_),
        cumulative(cumulative_),
        isEvaluated(isEvaluated_) {}

    Scalar x;
    Scalar y;
    Scalar expY;
    Scalar cumulative;
    bool isEvaluated;
  };

  struct WorkingPoint {
    WorkingPoint(
      Point p_,
      typename std::list<Point>::iterator left_,
      typename std::list<Point>::iterator right_
    ) : p(p_),
        left(left_),
        right(right_) {}

    Point p;
    typename std::list<Point>::iterator left;
    typename std::list<Point>::iterator right;
  };

  Functor& logPdf_;
  Scalar lower_;
  Scalar upper_;
  Scalar convex_;
  int maxPoints_;
  bool metropolis_;

  std::uniform_real_distribution<Scalar> uniform_;

  // Internal state
  std::list<Point> points_;
  Scalar yMaximum_;
  Scalar xPrevious_;
  Scalar yPrevious_;

  const Scalar X_EPSILON = 0.00001;
  const Scalar Y_EPSILON = 0.1;
  const Scalar EXP_Y_EPSILON = 0.001;
  const Scalar Y_CEILING = 50.;
  const int MAX_ITERATIONS = 10000;

  void addPoint(WorkingPoint w) {
    using std::prev;
    using std::next;

    if (points_.size() > static_cast<unsigned int>(maxPoints_ - 2)) {
      // No further room for points
      return;
    }

    typename std::list<Point>::iterator q, ql, qr;

    // The new point
    q = points_.insert(w.right, w.p);

    if (prev(q)->isEvaluated) {
      /* left end of piece is on log density; right end is not */
      /* set up new intersection */
      points_.insert(q, Point(0, 0, 0, 0, false));
    } else {
      /* left end of interval is not on log density; right end is */
      /* set up new intersection */
      points_.insert(next(q), Point(0, 0, 0, 0, false));
    }

    /* now adjust position of q within interval if too close to an endpoint */
    ql = prev(q, prev(q) != points_.begin() ? 2 : 1);
    qr = next(q, next(q, 2) != points_.end() ? 2 : 1);

    if (q->x < (1. - X_EPSILON) * ql->x + X_EPSILON * qr->x){
      /* q too close to left end of interval */
      q->x = (1. - X_EPSILON) * ql->x + X_EPSILON * qr->x;
      q->y = logPdf_(q->x);
    } else if (q->x > X_EPSILON * ql->x + (1. - X_EPSILON) * qr->x){
      /* q too close to right end of interval */
      q->x = X_EPSILON * ql->x + (1. - X_EPSILON) * qr->x;
      q->y = logPdf_(q->x);
    }

    updateIntersection_(prev(q));
    updateIntersection_(next(q));
    if (prev(q, 2) != points_.begin()) {
      updateIntersection_(prev(q, 3));
    }
    if (next(q, 3) != points_.end()) {
      updateIntersection_(next(q, 3));
    }
    accumulate_();
  }

  void updateIntersection_(typename std::list<Point>::iterator q) {
    using std::prev;
    using std::next;

    Scalar gl = 0, gr = 0, grl = 0;
    bool il = false, ir = false, irl = false;
    if (q != points_.begin() && prev(q, 2) != points_.begin()) {
      gl = (prev(q)->y - prev(q, 3)->y) / (prev(q)->x - prev(q, 3)->x);
      il = true;
    }
    if (next(q) != points_.end() && next(q, 3) != points_.end()) {
      /* chord gradient can be calculated at right end of interval */
      gr = (next(q)->y - next(q, 3)->y) / (next(q)->x - next(q, 3)->x);
      ir = true;
    }
    if (q != points_.begin() && next(q) != points_.end()) {
      /* chord gradient can be calculated across interval */
      grl = (next(q)->y - prev(q)->y) / (next(q)->x - prev(q)->x);
      irl = true;
    }

    if (irl && il && gl < grl) {
      /* convexity on left exceeds current threshold */
      if (!metropolis_) throw exception("Envelope violation");
      /* adjust left gradient */
      gl = gl + (1.0 + convex_) * (grl - gl);
    }

    if (irl && ir && gr > grl) {
      /* convexity on right exceeds current threshold */
      if (!metropolis_) throw exception("Envelope violation");
      /* adjust right gradient */
      gr = gr + (1.0 + convex_) * (grl - gr);
    }

    Scalar dl = 0, dr = 0;
    if (il && irl) {
      dr = (gl - grl) * (next(q)->x - prev(q)->x);
      if (dr < Y_EPSILON) {
        /* adjust dr to avoid numerical problems */
        dr = Y_EPSILON;
      }
    }
    if (ir && irl) {
      dl = (grl - gr) * (next(q)->x - prev(q)->x);
      if (dl < Y_EPSILON) {
        /* adjust dl to avoid numerical problems */
        dl = Y_EPSILON;
      }
    }

    if (il && ir && irl) {
      /* gradients on both sides */
      q->x = (dl * next(q)->x + dr * prev(q)->x) / (dl + dr);
      q->y = (dl * next(q)->y + dr * prev(q)->y + dl * dr) / (dl + dr);
    } else if (il && irl) {
      /* gradient only on left side, but not right hand bound */
      q->x = next(q)->x;
      q->y = next(q)->y + dr;
    } else if (ir && irl) {
      /* gradient only on right side, but not left hand bound */
      q->x = prev(q)->x;
      q->y = prev(q)->y + dl;
    } else if (il) {
      /* right hand bound */
      q->y = prev(q)->y + gl * (q->x - prev(q)->x);
    } else if (ir) {
      /* left hand bound */
      q->y = next(q)->y - gr * (next(q)->x - q->x);
    } else {
      /* gradient on neither side - should be impossible */
      throw exception("arms error 31");
    }
    if(((q != points_.begin()) && (q->x < prev(q)->x)) ||
       ((next(q) != points_.end()) && (q->x > next(q)->x))) {
      /* intersection point outside interval (through imprecision) */
      throw exception("intersection point outside interval");
    }
  }

  void accumulate_() {
    using std::prev;

    auto q = points_.begin();

    // Update yMaximum_
    yMaximum_ = points_.begin()->y;
    for (; q != points_.end(); ++q) {
      if (q->y > yMaximum_) yMaximum_ = q->y;
    }

    // Update expY
    for (q = points_.begin(); q != points_.end(); ++q) {
      q->expY = expShift_(q->y);
    }

    /* integrate exponentiated envelope */
    q = points_.begin();
    q->cumulative = 0;
    for(++q; q != points_.end(); ++q) {
      q->cumulative = prev(q)->cumulative + area_(q);
    }
  }

  WorkingPoint invert_(Scalar probability) {
    using std::next;

    auto q = prev(points_.end());
    Scalar u = probability * q->cumulative;
    while (prev(q)->cumulative > u) q = prev(q);

    Point p(0, 0, 0, u, false);

    Scalar proportion = (u - prev(q)->cumulative) / (q->cumulative - prev(q)->cumulative);

    /* get the required x-value */
    if (prev(q)->x == q->x){
      /* interval is of zero length */
      p.x = q->x;
      p.y = q->y;
      p.expY = q->expY;
    } else {
      Scalar xl = prev(q)->x;
      Scalar xr = q->x;
      Scalar yl = prev(q)->y;
      Scalar yr = q->y;
      Scalar eyl = prev(q)->expY;
      Scalar eyr = q->expY;
      if (std::abs(yr - yl) < Y_EPSILON) {
        /* linear approximation was used in integration in function cumulate */
        if (std::abs(eyr - eyl) > EXP_Y_EPSILON * std::abs(eyr + eyl)) {
          p.x = xl + ((xr - xl) / (eyr - eyl)) * (-eyl + std::sqrt(
            (1. - proportion) * eyl * eyl + proportion * eyr * eyr
          ));
        } else {
          p.x = xl + (xr - xl) * proportion;
        }
        p.expY = ((p.x - xl) / (xr - xl)) * (eyr - eyl) + eyl;
        p.y = logShift_(p.expY);
      } else {
        /* piece was integrated exactly in function cumulate */
        p.x = xl + ((xr - xl) / (yr - yl)) * (
          -yl + logShift_((1. - proportion) * eyl + proportion * eyr)
        );
        p.y = ((p.x - xl) / (xr - xl)) * (yr - yl) + yl;
        p.expY = expShift_(p.y);
      }

      /* guard against imprecision yielding point outside interval */
      if ((p.x < xl) || (p.x > xr)) {
        throw exception("generated point outside interval");
      }
    }

    return WorkingPoint(p, prev(q), q);
  }

  // Utility functions
  Scalar expShift_(Scalar y) const {
    if (y - yMaximum_ > -2.0 * Y_CEILING) {
      return std::exp(y - yMaximum_ + Y_CEILING);
    }
    return 0;
  }

  Scalar logShift_(Scalar expY) const {
    return std::log(expY) + yMaximum_ - Y_CEILING;
  }

  Scalar area_(typename std::list<Point>::iterator q) const {
    using std::prev;

    if (prev(q)->x == q->x) {
      /* interval is zero length */
      return 0;
    } else if (std::abs(q->y - prev(q)->y) < Y_EPSILON) {
      /* integrate straight line piece */
      return 0.5 * (q->expY + prev(q)->expY) * (q->x - prev(q)->x);
    }
    /* integrate exponential piece */
    return ((q->expY - prev(q)->expY) / (q->y - prev(q)->y)) * (q->x - prev(q)->x);
  }
};

};
