#! /usr/bin/env python

# replica_exchange_protomol
#
# Copyright (C) 2011- The University of Notre Dame
# This software is distributed under the GNU General Public License.
# See the file COPYING for details.
#
# This program implements elastic replica exchange using
# the cctools work queue framework and the Protomol molecular
# dynamics package, as described in the following paper:
#
# Dinesh Rajan, Anthony Canino, Jesus A Izaguirre, and Douglas Thain,
# "Converting a High Performance Application to an Elastic Cloud Application",
# The 3rd IEEE International Conference on Cloud Computing Technology
# and Science", November 2011.

# From the cctools python bindings, use work queue.

from work_queue import *

# All other dependencies are standard python.

import time
import sys
import os
import pprint
import re
import getopt
import math
import random


#-------------------------------Constants-----------------------------------
DEFAULT_MONTE_CARLO_STEPS = 100
DEFAULT_OUTPUT_PATH = '.'
DEFAULT_MDSTEPS = 10000
DEFAULT_BOUNDARY_CONDITIONS = "Vacuum"
DEFAULT_OUTPUT_FREQ = 10000
DEFAULT_PHYSICAL_TEMP = 300

#-------------------------------Global Variables----------------------------
protomol_local_install = False
generate_xyz = False
generate_dcd = False
debug_mode = False
quart_temp_split = False
vel_file_input = False

mc_step_times = []
#-----------------------------Global Data-------------------------------------
pdb_file = ""
psf_file = ""
par_file = ""
xyz_file_name = ""
dcd_file_name = ""
proj_name = None
boundary_conditions = DEFAULT_BOUNDARY_CONDITIONS
monte_carlo_steps = DEFAULT_MONTE_CARLO_STEPS
replica_id = None
md_steps = DEFAULT_MDSTEPS
output_freq = DEFAULT_OUTPUT_FREQ
output_path = DEFAULT_OUTPUT_PATH

#------------------------Stat Collection Variables------------------------------
num_replica_exchanges = 0
total_monte_carlo_step_time = 0
num_task_resubmissions = 0
replica_temp_execution_list = []

#-------------------------Define Replica Object---------------------------------
class Replica(object):
	def __init__(self, id, temp):
		self.id	= id
		self.temp = temp
		self.exchgd_replica_id = -1
		self.potential_energy = None
		self.new_xyz_output_file = None
		self.new_dcd_output_file = None
		self.prev_temp = None
		self.pdb = None
	
	def __str__(self):
		return "ID: %d Temp %f" % (self.id, self.temp)


#Function to drop repetetive values and gather just the unique ones.
def unique(inlist, keep_val=True):
  typ = type(inlist)
  if not typ == list:
    inlist = list(inlist)
  i = 0
  while i < len(inlist):
    try:
      del inlist[inlist.index(inlist[i], i + 1)]
    except:
      i += 1
  if not typ in (str, unicode):
    inlist = typ(inlist)
  else:
    if keep_val:
      inlist = ''.join(inlist)
  return inlist


#Function to check if a given executable exists in PATH.
def locate(executable):
    def check_executable(prog):
        return os.path.exists(prog) and os.access(prog, os.X_OK)

    for path in os.environ["PATH"].split(os.pathsep):
        exe = os.path.join(path, executable)
        if check_executable(exe):
            return exe

    return None


#Function to parse the file name from a string holding its location.
def parse_file_name(file_name):
	split_name = file_name.split('/')
	return split_name[len(split_name)-1]


#Function to parse the file name and leave out its extension.
def remove_trailing_dots(file_name):
	split_name = file_name.split('.')
	return split_name[0]


#Function to generate a config file to send to workqueue. It returns the generated config file name.
def generate_config(replica_obj, monte_carlo_step, xyz_input_file = None, xyz_vel_input_file = None):
	#initialize the config file name based on the replica id.
	cfg_file_name = "%s/%s/%d.cfg" % ( output_path, "cfg", replica_obj.id)	
	cfg_file_stream = open(cfg_file_name, "w")

	#initialize string that will hold the config file values
	write_str = ""

	#Parse supplied files so only actual file name is passed, not full path of the file name
	input_pdb = "%s.%d-%d.pdb" % (remove_trailing_dots(parse_file_name(pdb_file)), replica_obj.id, monte_carlo_step)
	parsed_psf_file = parse_file_name(psf_file)
	parsed_par_file = parse_file_name(par_file)

	#Start writing the config file parameters and values
	write_str += "randomtype 1\n" 
	write_str += "numsteps %d\n" % md_steps
	write_str += "outputfreq %d\n" % output_freq

	write_str += "posfile %s\n" % input_pdb
	write_str += "psffile %s\n" % parsed_psf_file
	write_str += "parfile %s\n" % parsed_par_file
	
	if vel_file_input:
		write_str += "velfile %s.%d-%d.vel\n" % (remove_trailing_dots(parse_file_name(pdb_file)), replica_obj.id, monte_carlo_step)

	write_str += "dofinPDBPosFile true\n" 
	write_str += "finPDBPosFile %s.%d-%d.pdb\n" % (remove_trailing_dots(parse_file_name(pdb_file)), replica_obj.id, monte_carlo_step+1)
	write_str += "finXYZVelFile %s.%d-%d.vel\n" % (remove_trailing_dots(parse_file_name(pdb_file)), replica_obj.id, monte_carlo_step+1)

	write_str += "temperature %f\n" % replica_obj.temp
	write_str += "boundaryConditions %s\n" % boundary_conditions
	write_str += "cellManager Cubic\n"
	write_str += "cellsize 69\n"

	if generate_xyz:
		write_str += "XYZPosFile %d.xyz\n" % replica_obj.id
		write_str += "XYZPosFileOutputFreq %d\n" % md_steps 
	if generate_dcd:
		write_str += "DCDFile %d.dcd\n" % replica_obj.id
		write_str += "DCDFileOutputFreq %d\n" % output_freq

	write_str += "allEnergiesFile %d.eng\n" % replica_obj.id
	write_str += "allEnergiesFileOutputFreq %d\n" % output_freq
	
	write_str += "seed %d\n" % random.randint(1, 1000000)
	write_str += "shake on\n"

	write_str += "integrator {\n"
	write_str += "level 0 langevinImpulse {\n"
	write_str += "temperature %f\n" % replica_obj.temp
	write_str += "gamma 5\n"
	write_str += "timestep 2\n"
	write_str += "force bond\n"
	write_str += "force angle\n"
	write_str += "force dihedral\n"
	write_str += "force improper\n"
	write_str += "force LennardJones Coulomb\n"
	write_str += " -switchingFunction C2 -switchingFunction C1 -algorithm NonbondedCutoff\n"
	write_str += "   -switchon          10\n"
	write_str += "   -cutoff            12\n"
	write_str += "   -cutoff            12\n"
	write_str += "   -cutoff            12\n"
	write_str += "}\n}"

	#Write to the config file
	cfg_file_stream.write(write_str)
	
	return cfg_file_name

	
#Function that to perform metropolis criteria check for two replicas.
def metropolis( u_i, u_j, t_i, t_j ):
    # Metropolis for replica i with potential energy u_i, temp t_i
    #            and replica j with potential energy u_j, temp t_j
    K_b = 0.001987191 #Constants.boltzmann()
    deltaE = (1 / (K_b * t_i) - 1/ (K_b * t_j) ) - (u_j - u_i)

    if( deltaE < 0 ):
        return True

    acceptProb = math.exp(-deltaE) 
    randNum = random.random()

    if( randNum < acceptProb ):
        return True
    else:
	return False


# Major replicaexchange and scheduling is handled here
def re_main(replica_list, num_replicas):

	#Stat collection variables
	global num_replica_exchanges
	global total_monte_carlo_step_time
	global num_task_resubmissions

	#Setup WorkQueue and initialize parameters 
	wq = WorkQueue()
	wq.specify_algorithm(WORK_QUEUE_SCHEDULE_FCFS)

	if proj_name:
		wq.specify_name(proj_name)

	wq.activate_fast_abort(3.0);

	#Iterate through the given number of monte carlo steps
	for i in range(monte_carlo_steps):
		#Initialize step time.
		step_time = time.time()
		
		print "Beginning Monte Carlo Step #%d" % i

		#Keep track of number of replicas running
		replicas_running = num_replicas
	
		#-------Perform computation for each replica at current monte carlo step--------
		'''Each computation is a task in the workqueue
		   Each replica will be run on one of the workers instantiated.'''
	
		#Iterate through the given number of replicas and start their
		#		 computation for the current monte carlo step.
		for j in range(num_replicas):
			
			replica_id = replica_list[j].id
			#Each replica does computation at its current temperature
			replica_temperature = replica_list[j].temp
			
			#Find pdb file for current replica
			replica_pdb = "%s.%d" % (remove_trailing_dots(parse_file_name(pdb_file)), j)
			
			#Find pdb file for replica that exchanged with current replica in last step
			if (replica_list[j].exchgd_replica_id > -1):
				exchgd_replica_pdb = "%s.%d" % (remove_trailing_dots(parse_file_name(pdb_file)), replica_list[j].exchgd_replica_id)

			if debug_mode:
				print "Replica %d at %f" % (replica_id, replica_temperature)

			'''Local_file - file that will be brought back and stored on the user computer.
			   Remote_file - file that will be sent to remote worker and used in execution there.'''

			#Assign local and remote config files.
			local_config_file = "%s/cfg/%d.cfg" % (output_path, replica_id)
			remote_config_file = "%d-%d.cfg" % (replica_id, i)
			
			if (i > 0):	
				vel_file_input = True
				#Assign local and remote velocity input files.
				local_velocity_input_file = "%s/cfg/%s/%s-%d.vel" % (output_path, replica_list[j].temp, exchgd_replica_pdb, i)
				remote_velocity_input_file = "%s-%d.vel" % (replica_pdb, i)


			#Assign local and remote (output) energies files.
			local_energies_file = "%s/cfg/eng/%d/%d.eng" % (output_path, replica_id, replica_id) 
			remote_energies_file = "%d.eng" % replica_id
			
			#Assign local and remote velocity output files.
			local_velocity_output_file = "%s/cfg/%s/%s-%d.vel" % (output_path, replica_list[j].temp, replica_pdb, i+1)
			remote_velocity_output_file = "%s-%d.vel" % (replica_pdb, i+1)
			
			#Assign local and remote xyz output files.
			if generate_xyz:
				local_xyz_output_file = "%s/%s/%s.%d-%d.xyz" % (output_path, replica_list[j].temp, xyz_file_name, replica_list[j].id, i) 
				remote_xyz_output_file = "%d.xyz" % (replica_id)

			#Assign local and remote dcd output files.
			if generate_dcd:
				local_dcd_output_file = "%s/%s/%s.%d-%d.dcd" % (output_path, replica_list[j].temp, dcd_file_name, replica_list[j].id, i) 
				remote_dcd_output_file = "%d.dcd" % (replica_id)

	        #Create a task using given task string for remote worker to execute.
			
			#Assign executable that will be run on remote worker to task string.
			if protomol_local_install:
				task_str = "ProtoMol %s" % remote_config_file
			else:
				local_executable = "%s" % ("ProtoMol")
				remote_executable = "%s" % ("ProtoMol")
				task_str = "./ProtoMol %s" % remote_config_file
		
			task = Task(task_str) 
			task.specify_tag('%s' % replica_id)
				
			if not protomol_local_install:
				task.specify_input_file(local_executable, remote_executable)
	        
			#Set the input files required by the task.
			task.specify_input_file(local_config_file, remote_config_file) 

			if (i > 0):
				local_pdb_input_file = "%s/cfg/%s/%s-%d.pdb" % (output_path, replica_list[j].temp, exchgd_replica_pdb, i)
			else:
				local_pdb_input_file = "%s/cfg/%s/%s-%d.pdb" % (output_path, replica_list[j].temp, replica_pdb, i)	
			
			remote_pdb_input_file = "%s-%d.pdb" % (replica_pdb, i)	
			task.specify_input_file(local_pdb_input_file, remote_pdb_input_file)
			task.specify_input_file(psf_file, parse_file_name(psf_file))
			task.specify_input_file(par_file, parse_file_name(par_file))

			if (i > 0):
				task.specify_input_file(local_velocity_input_file, remote_velocity_input_file)
			
			if generate_xyz:
				task.specify_output_file(remote_xyz_output_file, local_xyz_output_file)
			if generate_dcd:
				task.specify_output_file(remote_dcd_output_file, local_dcd_output_file)

			task.specify_output_file(local_energies_file, remote_energies_file)
			task.specify_output_file(local_velocity_output_file, remote_velocity_output_file)

			pdb_output_file = "%s/cfg/%s/%s-%d.pdb" % (output_path, replica_list[j].temp, replica_pdb, i+1)	
			task.specify_output_file(pdb_output_file, parse_file_name(pdb_output_file))
		
			#Call function to generate config file.
			generate_config(replica_list[j], i)

			#Submit the task to WorkQueue for execution at remote worker.
			wq.submit(task)


		#Wait for all replicas to finish execution,
		while not wq.empty():
			task = wq.wait(10)
			if (task):
				#Get replica id from finished task.
				replica_id = int(task.tag)
			
				# Check if task (replica) failed. If so, resubmit.
				if task.result != 0:
					num_task_resubmissions += 1
					print "Replica failed!"
					#sys.stdout.flush()
					time.sleep(3)
				
					#Resubmit the task.
					wq.submit(task) 
					continue

				#Task was succesful. Update all information.
				replicas_running -= 1
				
		        #Get potential energy value of the completed replica run.	
				energies_file =  "%s/cfg/eng/%d/%d.eng" % (output_path, replica_id, replica_id) 
				energies_stream = open(energies_file, "r")
				slist = (energies_stream.readline()).split()
				potential_energy = float(slist[1])
				replica_list[replica_id].potential_energy = potential_energy
			
				#Store temperature and exchanged replica id values from the current run.
				replica_list[replica_id].prev_temp = replica_list[replica_id].temp
				replica_list[replica_id].exchgd_replica_id = replica_id
	
		#Attempt and perform replica exchange.
		#Pick a random replica.
		cur = random.randint(0, num_replicas-1)
		next = cur + 1

		if debug_mode: print "Replica %d picked to check for exchange" % cur
		if (next == num_replicas):
			next = cur - 1

		#Check for metropolis criteria with the randomly chosen replica and its immediate neighbor.
		if (metropolis(replica_list[cur].potential_energy, replica_list[next].potential_energy, replica_list[cur].temp, replica_list[next].temp)):
			if debug_mode: print "Replicas %d and %d exchanged" % (cur, next)
			#Swap the fields of the two replicas being exchanged.
			T = replica_list[next].temp
			replica_list[next].temp = replica_list[cur].temp
			replica_list[cur].temp = T
			
			replica_list[cur].exchgd_replica_id = replica_list[next].id
			replica_list[next].exchgd_replica_id = replica_list[cur].id
			
			replica_temp_execution_list[cur].append(replica_list[cur].temp)
			replica_temp_execution_list[next].append(replica_list[next].temp)
			
			num_replica_exchanges += 1

		#Update time stats.
		step_time = time.time() - step_time
		total_monte_carlo_step_time += step_time
		mc_step_times.append(step_time)

		#Goes back to the top


#Function to create directories to hold files from the simulations.
def make_directories(output_path, temp_list, num_replicas):
	count = 0
	#For each temperature value, create a directory with its name and cfg (config) dirs.	
	for i in temp_list:
		newdir1 = "%s/cfg/%s" % (output_path, i)
	
		if not os.path.exists(newdir1):
			os.makedirs(newdir1)

		command = "cp %s %s/cfg/%s/%s.%d-%d.pdb" % (pdb_file, output_path, i, remove_trailing_dots(parse_file_name(pdb_file)), count, 0)
		print command
		os.system(command)
		count += 1

	#For each replica, create eng (energies) dirs.	
	for j in range(num_replicas):
		newdir2 = "%s/cfg/%s/%s" % (output_path, "eng", j)
		if not os.path.exists(newdir2):
			os.makedirs(newdir2)


#Main function.
if __name__ == "__main__":

	#Create help string for user.
	usage_str = "Usage: %s <PDB_FILE> <PSF_FILE> <PAR_FILE> <MIN_TEMP> <MAX_TEMP> <NUM_OF_REPLICAS>" % sys.argv[0]
	help_str  = "-h		-	help\n"
	help_str += "-n		-	specify a project name for using exclusive workers\n"
	help_str += "-x		-	specify the name of the xyz file for output\n"
	help_str += "-d		-	specify the name of the dcd file for output\n"
	help_str += "-l		-	print debuging information\n"
	help_str += "-m		-	specify the number of monte carlo steps\n"
	help_str += "-s		-	specify the number of mdsteps\n"
	help_str += "-p		-	specify the output_path for the output files generated\n"
	help_str += "-i		-	assume ProtoMol is installed and available in PATH on worker site\n"
	help_str += "-q		-	assign closer temperature values to the first and last quartile of the replicas.\n"

	#Check to see if there is error in the given command line arguments.
	try:
		opts, args = getopt.getopt(sys.argv[1:], "m:s:n:p:x:d:ihlq", ["--help", "--proj", "dir", "--xyz", "--dcd"])
	except getopt.GetoptError, err:
		print str(err) 
		print usage_str
		sys.exit(1)

	#Parse command line arguments.
	for o, a in opts:
		if o in ("-h", "--help"):
			print help_str
			sys.exit(0)
		elif o == "-l":
			debug_mode = True
		elif o in ("-x", "--xyz"):
			generate_xyz = True
			xyz_file_name = a
		elif o in ("-d", "--dcd"):
			generate_dcd = True
			dcd_file_name = a
		elif o in ("-n", "--proj"):
			proj_name = a
		elif o in ("-p", "--dir"):
			output_path = a
		elif o in ("-m"):
			monte_carlo_steps = int(a)
		elif o in ("-s"):
			md_steps = int(a)
		elif o == "-q":
			quart_temp_split = True
		elif o == "-i":
			protomol_local_install = True

	if debug_mode:
		print "dmode on"
		print "mdsteps: %d" % md_steps
		set_debug_flag("wq")
		set_debug_flag("debug")

	if len(args) != 6:
		print usage_str
		sys.exit(1)

	protomol_path = locate('ProtoMol')
	if protomol_path:
		command = "cp %s ." % protomol_path
		os.system(command)
	else:
		print "Error: Please add ProtoMol to your PATH."
		sys.exit(1)
		
	#Begin timing after parsing command line arguments.
	total_run_time = time.time()

	#Parse command line to get input file names,
    #	   temperature range for simulation,
	#      number of replicas to run.
	pdb_file = args[0]
	psf_file = args[1]
	par_file = args[2]
	min_temp = int(args[3])
	max_temp = int(args[4])
	num_replicas = int(args[5])
	replica_list = []

	#Split up the temperature range for assigning to each replica.
	inc = float(( max_temp -  min_temp)) / float(num_replicas-1)
	temp_list = []

	#Assign temperature to each replica.
	for x in range(num_replicas):
		#Quart split assigns closer temperature values
		#	to the top and bottom 25% of replicas.
		if quart_temp_split:
			if x < math.ceil(0.25 * num_replicas):
                        	replica_temp =  min_temp + (x * inc / 3)

        	        elif x >= math.ceil(0.75 * num_replicas):
			        replica_temp =  max_temp - (((num_replicas-1) - x) * inc / 3)

	                else:
        	                replica_temp =  min_temp + (x * inc)
		
		#If not quart split, split temperature range uniformaly 
		#		                    among all replicas.	
		else:
			replica_temp =  min_temp + (x * inc)
	
		#Store the temperature values and replicas.	
		temp_list.append(str(replica_temp))
		replica_list.append(Replica(x, replica_temp))

		#Initialize list for maintaining replica exchange matrix.
		replica_temp_execution_list.append([])
		
	#Add the initial temperature value to the replica exchange matrix.
	for repl in range(num_replicas):
		replica_temp_execution_list[repl].append(replica_list[repl].temp)

	#Create directories for storing data from the run.
	make_directories(output_path, temp_list, num_replicas)

	#Start the run.
	re_main(replica_list, num_replicas)  	
	
	#Keep track of the total run time.
	total_run_time = (time.time() - total_run_time)

	#Print stats on completion.
	print "Total Run Time:                  %f"  % total_run_time
	print "Total Monte Carlo Step Time:     %f"  % total_monte_carlo_step_time
	print "Average Monte Carlo Step Time:   %f"  % (total_monte_carlo_step_time /  monte_carlo_steps)
	print "Number of failures:   			%d"  % num_task_resubmissions
	print "Replica Exchanges:               %d"  % num_replica_exchanges
	print "Acceptance Rate:                 %f" % ((num_replica_exchanges * 100) / monte_carlo_steps)

	#Write stats to a stats file
	stat_file_name = "%s/%s.stat" % (output_path,  remove_trailing_dots(parse_file_name(pdb_file)))
	stat_file_stream = open(stat_file_name, "w")
	stat_file_stream.write("%s\n" % "Printing replica temperature execution matrix:")
	
	#Sort and format the replica exchange matrix.
	for itr in range(num_replicas):
		 replica_temp_execution_list[itr].sort()
		 unique(replica_temp_execution_list[itr])
		 stat_file_stream.write("Replica %d: %s\n" % (itr, replica_temp_execution_list[itr]))

	#Write run time for each step to stats file
	stat_file_stream.write("\n\n%s\n" % "Printing run times for each monte carlo step:")
	count = 1
	for i in mc_step_times:
		stat_file_stream.write("%d %f %f\n" % (count, i))
		count += 1
	
	stat_file_stream.close()

	sys.exit(0)
