/* Simple molecular dynamics simulation of a Hamonic Oscillator

   Copyright 2010-2012 Cheng Zhang

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   A copy of the GNU General Public License can be found in
   <http://www.gnu.org/licenses/>.
*/
import java.awt.*;
import java.awt.event.*;
import java.awt.geom.*;
import javax.swing.*;
import static java.lang.Math.*;
import java.text.*;
import java.util.Random;

class MD {
  int step = -1;
  boolean dwell = false; // double well potential
  boolean thermostat = false; // use thermostat
  boolean vrescale = false; // use velocity rescaling
  double thermoamp = 0.5; // how much to change the temperature
  double T = 1.0; // temperature
  double m = 1.0;  // mass of the spring
  double k = 1.0;  // spring constant
  double dt = 0.002; // md integration step
  double thermdt = 0.002; // thermostate step;
  static final double x0 = 0.0, v0 = 1.0;
  double x, v, f; // position, velocity and force
  double Ek, Ep, E; // kinetic, potential, and total energy

  void init() {
    x = x0; v = v0;
    step = 0;
  }

  double getTime() { return step*dt; }

  /** integrate Newton's equation */
  void vv() {
    v += (f/m) * dt * .5;
    x += v * dt;
    if (dwell)
      forceDwell();
    else
      force();
    v += (f/m) * dt * .5;
    Ek = .5*m*v*v;
    if (thermostat) Ek = (vrescale) ? vrescale(thermdt) : mctherm(thermdt) ;
  }

  Random rng = new Random();
  double vrescale(double thdt) {
    double ek1 = .5*m*v*v, ekav = .5*k*T;
    double ek2 = ek1 + (ekav - ek1)*thdt + 2*sqrt(ek1*ekav*thdt)*rng.nextGaussian();
    if (ek2 < 0) ek2 = ek1;
    v *= sqrt(ek2/ek1);
    return ek2;
  }

  double mctherm(double thdt) {
    double v2 = v + (rng.nextDouble()*2 - 1)*thdt;
    double ek1 = .5*m*v*v, ek2 = .5*m*v2*v2;
    if (ek2 < ek1 || rng.nextDouble() < exp(-(ek2 - ek1)/T)) {
      v = v2;
      return ek2;
    } else return ek1;
  }

  /** calculate force and potential energy */
  void force() {
    Ep = 0.5 * k * x * x;
    f = - k * x;
  }

  /** force and potential energy for double-well potential */
  void forceDwell() {
    double x2 = x * x;
    Ep = k * x2 * (0.25 * x2 - 1.0);
    f = k * x * (1.0 - x2);
  }
}

public class Spring extends JApplet implements ActionListener {
  MD md = new MD();
  int delay = 50;
  Timer timer;

  MyCanvas canvas;
  JPanel spnl, cpnl;
  DecimalFormat df = new DecimalFormat("0.000");
  JTextField tPos   = new JTextField(" " + df.format(md.x0));
  JTextField tVel   = new JTextField(" " + df.format(md.v0));
  JTextField tMDdt  = new JTextField(" " + df.format(md.dt));
  JCheckBox  bDwell = new JCheckBox("Double well");
  JCheckBox  bTstat = new JCheckBox("Thermostat");
  JCheckBox  bVrscl = new JCheckBox("v-rescale");
  JTextField tTemp  = new JTextField("1.0");
  JTextField tThdt  = new JTextField(" " + df.format(md.thermdt));
  JButton    bReset = new JButton("Reset");
  JToggleButton bStart = new JToggleButton("Start");
  JLabel     lStatus = new JLabel("Status");

  void initgui() {
    Container box = getContentPane();
    box.setLayout(new BorderLayout());

    cpnl = new JPanel(); // create a child panel for controls
    box.add(cpnl, BorderLayout.EAST);

    // add stuff to the panel
    cpnl.setLayout(new GridLayout(16, 2));

    cpnl.add(bStart);
    bStart.addActionListener(this);

    cpnl.add(bReset);
    bReset.addActionListener(this);

    cpnl.add(new JLabel(" Position:"));
    tPos.addActionListener(this);
    cpnl.add(tPos);

    cpnl.add(new JLabel(" Velocity:"));
    tVel.addActionListener(this);
    cpnl.add(tVel);

    cpnl.add(new JLabel(" MD dt:"));
    tMDdt.addActionListener(this);
    cpnl.add(tMDdt);

    cpnl.add(new JLabel(" Temperature:"));
    tTemp.addActionListener(this);
    cpnl.add(tTemp);

    bDwell.addActionListener(this);
    cpnl.add(bDwell);

    bTstat.addActionListener(this);
    cpnl.add(bTstat);

    bVrscl.addActionListener(this);
    cpnl.add(bVrscl);

    cpnl.add(new JLabel(" Therm. dt:"));
    tThdt.addActionListener(this);
    cpnl.add(tThdt);

    canvas = new MyCanvas();
    canvas.setBackground(getBackground());
    box.add(canvas, BorderLayout.CENTER);

    spnl = new JPanel();
    box.add(spnl, BorderLayout.SOUTH);
    lStatus.setFont(new Font("Courier", 0, 12));
    spnl.add(lStatus);
  }

  public void init() {
    try {
      SwingUtilities.invokeAndWait(new Runnable() {
        public void run() {
          md.init();
          initgui();
        }
      });
    } catch (Exception e) {
      System.err.println("failed to init gui: " + e.toString());
    }

    timer = new Timer(delay, this);
    timer.start(); timer.stop();
  }

  public void paint(Graphics g) {
    // write the current position:
    tPos.setText(" " + df.format(md.x));
    tVel.setText(" " + df.format(md.v));
    lStatus.setText("t: " + df.format(md.getTime()) + "; E: " + df.format(md.Ek + md.Ep)
        + "; Ek: " + df.format(md.Ek) + "; Ep: " + df.format(md.Ep)
        + "; T: " + df.format(md.T));
    canvas.setxv(md.x, md.v);
    canvas.repaint();
    spnl.repaint();
    cpnl.repaint();
  }
  /** handle timer and control events */
  public void actionPerformed(ActionEvent e) {
    Object src = e.getSource();

    if (src == timer) {
      md.step++;
      for (int i = 0; i < 10; i++) // integrate a few steps
        md.vv();
      repaint();
      return;
    }

    if (src == bDwell) md.dwell = !md.dwell;
    if (src == bTstat) md.thermostat = !md.thermostat;
    if (src == bVrscl) md.vrescale = !md.vrescale;
    md.x = Double.parseDouble(tPos.getText().trim());
    md.v = Double.parseDouble(tVel.getText().trim());
    if (src == tMDdt) {
      double mdt = Double.parseDouble(tMDdt.getText().trim());
      if (mdt < 1e-8) { mdt = 1e-8; tMDdt.setText(" " + mdt); }
      md.dt = mdt;
    }
    if (src == tTemp) {
      double T = Double.parseDouble(tTemp.getText().trim());
      if (T < 1e-8) { T = 1e-8; tTemp.setText(" " + T); }
      md.T = T;
    }
    if (src == tThdt) {
      double tdt = Double.parseDouble(tThdt.getText().trim());
      if (tdt < 0) { tdt = 0; tThdt.setText(" " + tdt); }
      md.thermdt = tdt;
    }
    if (src == bStart) {
      boolean on = bStart.isSelected();
      if (on) {
        timer.restart();
        bStart.setText("Pause");
      } else {
        timer.stop();
        bStart.setText("Resume");
      }
    }
    if (src == bReset) {
      if (timer.isRunning()) timer.stop();
      bStart.setSelected(false);
      bStart.setText("Start");
      md.init();
      canvas.onceTraj = false;
    }
    repaint();
  }
  public void start() { if (bStart.isSelected()) timer.restart(); }
  public void stop() { if (timer.isRunning()) timer.stop(); }
}

class MyCanvas extends JPanel {
  double x, v;

  // graphics detail
  Image img, imgTraj; // buffered image
  Graphics gi, giTraj; // context associated with the image
  Dimension szi; // size of the image
  int wh, hh; // half width, half height

  double amp = 0.4;
  int radius = 20; // radius of the ball

  public MyCanvas() { super(); }

  void setxv(double x1, double v1) { x = x1; v = v1; }

  /** specific drawing goes here */
  void drawAll(Graphics g) {
    Graphics2D g2 = (Graphics2D) g;

    // draw the trajectory first
    drawTraj(giTraj);
    g.drawImage(imgTraj, 0, 0, null);

    // draw the coordinates
    g2.setPaint(Color.black);
    g2.setStroke(new BasicStroke(1)); // use a thicker pen
    g2.drawLine(0, hh, wh*2, hh);
    g2.drawLine(wh, 0, wh, hh*2);

    // put labels
    g2.drawString("X", wh*2 - 50, hh - 5);
    g2.drawString("V", wh + 10, 10);

    int xpos = (int)( wh * (1 + amp * x) );
    drawSpring(g2,  xpos - radius, hh);
    drawBall(g2, xpos, hh);
  }

  private final int blend(int fg, int bg, float fgfactor) {
    return (int) (bg + (fg - bg) * fgfactor);
  }

  /** draw the ball */
  void drawBall(Graphics2D g2, int x, int y) {
    for (int r = radius; r > 0; r--) {
      int hx = (radius - r) /3;
      Shape circle = new Ellipse2D.Double(x-r-hx, y-r-hx, r*2, r*2);
      float f = 1.f*r/radius;
      g2.setPaint(new Color(blend(25, 200, f), blend(25, 200, f), blend(255, 240, f) ));
      g2.fill(circle);
    }
  }

  /** draw the spring as a wave from wall to ball */
  void drawSpring(Graphics2D g2, int imax, int hh) {
    int cnt = 10;
    int i, ixp = 0, iyp = 0, ix, iy;
    g2.setPaint(Color.black);
    g2.setStroke(new BasicStroke(3)); // use a thicker pen
    for (i = 0; i <= imax; i++) {
      double x = 2.0 * PI * cnt * i / imax;
      double y = sin(x);
      ix = i;
      iy = (int) (y * radius) + hh;
      if (i > 0)
        g2.drawLine(ixp, iyp, ix, iy);
      ixp = ix;
      iyp = iy;
    }
  }

  int xpTraj = 0, vpTraj = 0; // previous coordiates for trajectory
  boolean onceTraj = false; // first trajectories

  /** draw trajectory in the phase space (phase-plot) */
  void drawTraj(Graphics g) {
    // first update imgTraj
    int ix = (int) (wh + wh * amp * x);
    int iv = (int) (hh + wh * amp * v);
    if (onceTraj) {
      g.drawLine(xpTraj, vpTraj, ix, iv);
    } else {
      g.setColor(getBackground());
      g.fillRect(0, 0, 2*wh, 2*hh);
      g.setColor(Color.black);
      onceTraj = true;
    }
    xpTraj = ix;
    vpTraj = iv;
  }


  protected void paintComponent(Graphics g) {
    super.paintComponent(g);

    Dimension sz = getSize();
    // create a off-screen context gi,
    if (gi == null || sz.width != szi.width || sz.height != szi.height) {
      szi = sz;
      wh = sz.width / 2;
      hh = sz.height / 2;
      // create an image
      img = createImage(sz.width, sz.height);
      gi  = img.getGraphics();
      // create a separate image for drawing trajectories
      imgTraj = createImage(sz.width, sz.height);
      giTraj = imgTraj.getGraphics();
    }

    // draw on gi instead of directly on g
    gi.setColor(getBackground());
    gi.fillRect(0, 0, sz.width, sz.height);
    drawAll(gi);

    // put image to screen
    g.drawImage(img, 0, 0, null);
  }
}

