#!/usr/bin/env python3
import sys,re,os,math,bisect
import subprocess
from argparse import ArgumentParser
__version__="0.1.0"
def get_option(add_help=True):
    argparser = ArgumentParser()
    argparser.add_argument('RM',help='real or model for oniom')
    argparser.add_argument('EIn',help='Gaussian EIn file')
    argparser.add_argument('EOu',help='Gaussian EOu file')
    #argparser.add_argument('Msg',help='Gaussian Msg file',)
    #argparser.add_argument('FChk',help='Gaussian Fchk file')
    #argparser.add_argument('Mat',help='Gaussian MatEl file')
    argparser.add_argument('-l', '--lmp',type=str,default="lmp_serial",
                           help='path for LAMMPS executable')
    argparser.add_argument('-f', '--ffield',type=str,default="ffield.reax",
                           help='reaxff parameter file')
    argparser.add_argument('-m', '--model_non_periodic',action="store_true",
                         help='model will be treated as isolated cluster')
    argparser.add_argument('-b', '--boundary',type=str,default="p p p",
                           help='boundary for lammps')
    argparser.add_argument('-k', '--keep_intermediate',action="store_true",
                         help='keep intermediate files unremoved')
    argparser.add_argument('-v', '--version', action='version',version=__version__)
    argparser.add_argument('--debug',action="store_true",
                         help='debug option')
    argparser.add_argument('--pair_style',type=str,default="reax/c NULL safezone 1.6 mincap 100",
                           help='pair_style for lammps')
    argparser.add_argument('--fix',type=str,default="1 all qeq/reax 1 0.0 10.0 1e-6 reax/c",
                           help='fix for lammps')
    argparser.add_argument('others', nargs='*')
    return argparser.parse_args()
def cross(x,y):
    return [x[1]*y[2]-x[2]*y[1],x[2]*y[0]-x[0]*y[2],x[0]*y[1]-x[1]*y[0]]
def dot(x,y):
    return x[0]*y[0]+x[1]*y[1]+x[2]*y[2]
def normalize(x):
    n=math.sqrt(x[0]**2+x[1]**2+x[2]**2)
    return [x[0]/n,x[1]/n,x[2]/n]
def getidxs(pat,lis):
    reg=re.compile(pat)
    return [i for i,item in enumerate(lis) if reg.search(item)]
atom_symbol = (
'Bq','H' ,'He','Li','Be','B' ,'C' ,'N' ,'O' ,'F' ,'Ne',
     'Na','Mg','Al','Si','P' ,'S' ,'Cl','Ar','K' ,'Ca',
     'Sc','Ti','V' ,'Cr','Mn','Fe','Co','Ni','Cu','Zn',
     'Ga','Ge','As','Se','Br','Kr','Rb','Sr','Y' ,'Zr',
     'Nb','Mo','Tc','Ru','Rh','Pd','Ag','Cd','In','Sn',
     'Sb','Te','I' ,'Xe','Cs','Ba','La','Ce','Pr','Nd',
     'Pm','Sm','Eu','Gd','Tb','Dy','Ho','Er','Tm','Yb',
     'Lu','Hf','Ta','W' ,'Re','Os','Ir','Pt','Au','Hg',
     'Tl','Pb','Bi','Po','At','Rn','Fr','Ra','Ac','Th',
     'Pa','U' ,'Np','Pu','Am','Cm','Bk','Cf','Es','Fm',
     'Md','No','Lr','Rf','Db','Sg','Bh','Hs','Mt','Ds',
     'Rg','Cn','Nh','Fl','Mc','Lv','Ts','Og','Uue','Ubn'
)

args=get_option()
args.boundary=args.boundary.replace("-"," ")
assert os.path.isfile("tv.txt"), "tv.txt not found"
assert os.path.isfile(args.ffield), f"{args.ffield} not found"
# define intermediate file names
prefix=os.path.split(args.EOu)[-1].replace(".EOu","")
datfile=prefix+".dat"
dmpfile=prefix+".dmp"
b2A=0.52917721092 # bohr to Angstrom
h2kcal=627.5095   # hartree to kcal/mol
# Read EIn file
with open(args.EIn) as f:
    lines=f.read().splitlines()
istart=1
natoms=int(lines[0].split()[0])
iend=istart+natoms
atoms=lines[istart:iend]
if os.path.exists(dmpfile):
   os.replace(dmpfile,dmpfile+".old")
if os.path.exists("tv.txt"):
   with open("tv.txt") as f:
       tvs=f.read().splitlines()
tv0=[float(t) for t in tvs[0].split()[-3:]]
tv1=[float(t) for t in tvs[1].split()[-3:]]
tv2=[float(t) for t in tvs[2].split()[-3:]]
# abc is old basis set
a=normalize(tv0)  
c=normalize(cross(a,tv1)) 
b=normalize(cross(c,a))
atomT={}
lines=[]
elements=[]
for i,atom in enumerate(atoms):
    at,*xyz = atom.split()
    if not atomT.get(at):
        atomT[at]=len(atomT)+1
        elements.append(atom_symbol[int(at)])
    xyz=list(map(float,xyz[:3]))
    x = dot(a,xyz)*b2A
    y = dot(b,xyz)*b2A
    z = dot(c,xyz)*b2A
    lines.append(f"{i+1:6d}{atomT[at]:3d} 0 {x:20.12f}{y:20.12f}{z:20.12f}")
lx=dot(a,tv0)#a[0]*tv0[0]+a[1]*tv0[1]+a[2]*tv0[2]
ly=dot(b,tv1)#b[0]*tv1[0]+b[1]*tv1[1]+b[2]*tv1[2]
xy=dot(a,tv1)
xz=dot(a,tv2)
yz=dot(b,tv2)
lz=dot(c,tv2)
if (args.RM == 'M' and args.model_non_periodic) or args.boundary == 's s s': 
    input_text=f'''units    real
atom_style  charge
boundary s s s
region box prism -2000.00 2000.00 -2000.00 2000.00 -2000.00 2000.00 0.0 0.0 0.0
create_box {len(atomT)} box
read_data   {datfile} add merge
'''
else:
    input_text=f'''units    real
atom_style  charge
boundary    {args.boundary}
read_data   {datfile} 
'''
input_text+=f'''pair_style  {args.pair_style}
pair_coeff  * * {args.ffield} {" ".join(elements)}
fix         {args.fix}
dump        1 all custom 1 {dmpfile} fx fy fz
dump_modify 1 format float %20.12f
run         0
variable   a equal pe
print     $a append {dmpfile}
'''
head=["title","",f"{natoms} atoms",f"{len(atomT)} atom types",""]
head+=[f"0.0 {lx} xlo xhi",f"0.0 {ly} ylo yhi",f"0.0 {lz} zlo zhi",
        f"{xy} {xz} {yz} xy xz yz","","Masses",""]
head+=[f"{i+1} 1.0" for i in range(len(atomT))]
head+=["","Atoms",""]
with open(datfile,"w") as f:
    print(*head,sep="\n",file=f)
    print(*lines,sep="\n",file=f)
try:
    subprocess.run(args.lmp,shell=True,check=True,stderr=subprocess.STDOUT,text=True,input=input_text)
except subprocess.CalledProcessError as e:
    print(e)
    sys.exit(1)
with open(dmpfile) as f:
    lines=f.read().splitlines()
idx=getidxs(r"fx fy fz",lines)[-1]
forces=lines[idx+1:-1]
energy=float(lines[-1])
zero=0.0
with open(args.EOu,"w") as f:
    print(f"{energy/h2kcal:20.12e}{zero:20.12e}{zero:20.12e}{zero:20.12e}",file=f)
    for fxyz in forces:
        fa,fb,fc=list(map(float,fxyz.split()))
        fx = a[0]*fa + b[0]*fb + c[0]*fc
        fy = a[1]*fa + b[1]*fb + c[1]*fc
        fz = a[2]*fa + b[2]*fb + c[2]*fc
        print(f"{-fx/h2kcal*b2A:20.12e}{-fy/h2kcal*b2A:20.12e}{-fz/h2kcal*b2A:20.12e}",file=f)
if not args.keep_intermediate:
    os.remove(datfile)
    os.remove(dmpfile)
