#include <limits>
#include <exception>

class ESCFunction
{
  public:
    virtual std::vector<float> init() = 0;
    virtual float value(const std::vector<float> &state) const = 0;
};

class Gauss1D : public ESCFunction
{
  private:
    float a_, b_, c_, d_;

  public:
    Gauss1D(float a=0, float b=1, float c=0, float d=1) : a_(a), b_(b), c_(c), d_(d) { }
    
    std::vector<float> init()
    {
      std::vector<float> state;
      state.push_back(0);
      return state;
    }
    
    float value(const std::vector<float> &state) const
    {
      if (state.size() != 1)
        throw std::runtime_error("invalid state size");
        
      return a_ + b_*std::exp(-(state[0]-c_)*(state[0]-c_)/(2*d_*d_));
    }
};

class Gauss2D : public ESCFunction
{
  private:
    float a_, b_;
    std::vector<float> c_, d_;

  public:
    Gauss2D(float a, float b, std::vector<float> c, std::vector<float> d)
    {
      if (c.empty())
      {
        c_.push_back(0);
        c_.push_back(0);
      }
      else
        c_ = c;
        
      if (d.empty())
      {
        d_.push_back(1);
        d_.push_back(1);
      }
      else
        d_ = d;
    }
    
    std::vector<float> init()
    {
      std::vector<float> state;
      state.push_back(0);
      state.push_back(0);
      return state;
    }
    
    float value(const std::vector<float> &state) const
    {
      if (state.size() != 1)
        throw std::runtime_error("invalid state size");
      
      double exponent = (state[0]-c_[0])*(state[0]-c_[0])/(2*d_[0]*d_[0]) + 
                        (state[1]-c_[1])*(state[1]-c_[1])/(2*d_[1]*d_[1]);
      
      return a_ + b_*std::exp(-exponent);
    }
};


class ESCSystem
{
  protected:
    ESCFunction *function_;
    std::vector<float> state_;
        
  public:
    ESCSystem(ESCFunction *function) : function_(function)
    {
      if (!function)
        throw std::runtime_error("no function specified");
        
      reset();
    }
    
    void reset()
    {
      state_ = function_->init();
    }
  
    float step(const std::vector<float> &vel)
    {
      if (state_.size() != vel.size())
        throw std::runtime_error("invalid state size");
    
      for (size_t ii=0; ii < state_.size() && ii < vel.size(); ++ii)
        state_[ii] += vel[ii];
        
      return function_->value(state_);
    }
    
    float set(const std::vector<float> &pos)
    {
      if (state_.size() != pos.size())
        throw std::runtime_error("invalid state size");
    
      state_ = pos;
      return function_->value(state_);
    }
    
    float value() const
    {
      return function_->value(state_);
    }
    
    const std::vector<float> &state()
    {
      return state_;
    }
};
