#!/usr/bin/env python3
import sys,re,os,math,bisect
import subprocess
from argparse import ArgumentParser
__version__="0.1.0"
b2A=0.52917721092 # bohr to Angstrom
h2kcal=627.5095   # hartree to kcal/mol
def get_option(add_help=True):
  argparser = ArgumentParser()
  argparser.add_argument('RM',help='real or model for oniom')
  argparser.add_argument('EIn',help='Gaussian EIn file')
  argparser.add_argument('EOu',help='Gaussian EOu file')
  #argparser.add_argument('Msg',help='Gaussian Msg file',)
  #argparser.add_argument('FChk',help='Gaussian Fchk file')
  #argparser.add_argument('Mat',help='Gaussian MatEl file')
  #argparser.add_argument('-l', '--level',type=str,default="b3lyp/6-31G(d)",
  #                       help='computational level')
  argparser.add_argument('-c0', '--charge0',type=str,default="0 1",
                         help='charge and spin multiplicity for F0')
  argparser.add_argument('-c1', '--charge1',type=str,default="0 1",
                         help='charge and spin multiplicity for F1')
  argparser.add_argument('-c2', '--charge2',type=str,default="0 1",
                         help='charge and spin multiplicity for F2')
  argparser.add_argument('-c12', '--charge12',type=str,default="0 1",
                         help='charge and spin multiplicity for F12')
  argparser.add_argument('-v', '--version', action='version',version=__version__)
  argparser.add_argument('--debug',action="store_true",
                       help='debug option')
  #argparser.add_argument('--pair_style',type=str,default="reax/c NULL safezone 1.6 mincap 100",
  #                       help='pair_style for lammps')
  #argparser.add_argument('--fix',type=str,default="1 all qeq/reax 1 0.0 10.0 1e-6 reax/c",
  #                       help='fix for lammps')
  argparser.add_argument('others', nargs='*')
  return argparser.parse_args()
def dvec(x1,x2):
  return [x1[0]-x2[0],x1[1]-x2[1],x1[2]-x2[2]]
def cross(x,y):
  return [x[1]*y[2]-x[2]*y[1],x[2]*y[0]-x[0]*y[2],x[0]*y[1]-x[1]*y[0]]
def dot(x,y):
  return x[0]*y[0]+x[1]*y[1]+x[2]*y[2]
def normalize(x,scale=1.0):
  n=math.sqrt(x[0]**2+x[1]**2+x[2]**2)
  return [x[0]/n*scale,x[1]/n*scale,x[2]/n*scale]
def calc_cap_pos(x1,x2,scale=1/b2A):
  dv=dvec(x1,x2)
  nv=normalize(dv,scale=scale)
  return [x2[0]+nv[0],x2[1]+nv[1],x2[2]+nv[2]]
def _pattern_to_energy(p):
    return float(p.group(1).replace("D","E"))
def _pattern_to_gradient(p):
    gradients=[]
    for l in p.group(1).splitlines():
        gradients.append(list(map(lambda x: -float(x),l.split()[-3:])))
    return gradients
atom_symbol = (
'Bq','H' ,'He','Li','Be','B' ,'C' ,'N' ,'O' ,'F' ,'Ne',
     'Na','Mg','Al','Si','P' ,'S' ,'Cl','Ar','K' ,'Ca',
     'Sc','Ti','V' ,'Cr','Mn','Fe','Co','Ni','Cu','Zn',
     'Ga','Ge','As','Se','Br','Kr','Rb','Sr','Y' ,'Zr',
     'Nb','Mo','Tc','Ru','Rh','Pd','Ag','Cd','In','Sn',
     'Sb','Te','I' ,'Xe','Cs','Ba','La','Ce','Pr','Nd',
     'Pm','Sm','Eu','Gd','Tb','Dy','Ho','Er','Tm','Yb',
     'Lu','Hf','Ta','W' ,'Re','Os','Ir','Pt','Au','Hg',
     'Tl','Pb','Bi','Po','At','Rn','Fr','Ra','Ac','Th',
     'Pa','U' ,'Np','Pu','Am','Cm','Bk','Cf','Es','Fm',
     'Md','No','Lr','Rf','Db','Sg','Bh','Hs','Mt','Ds',
     'Rg','Cn','Nh','Fl','Mc','Lv','Ts','Og','Uue','Ubn'
)

args=get_option()
# define intermediate file names
prefix=os.path.split(args.EOu)[-1].replace(".EOu","")
f0prefix="tempf0"
pid=prefix.replace("Gau-","") # process id

# Read head file
headfile=args.RM+".head"
if not os.path.exists(headfile):
  sys.exit(f"headfile {headfile} not found")
with open(headfile) as f:
  head=f.read()#.splitlines()
head=re.sub(r"(#.*)",r"\1 force nosym test units=au",head)
head+="\ngenerated by extnfa\n\n"
head+=args.charge0+"\n"
# read computational level from head
level_key=(("oniom"   ,r"ONIOM: extrapolated energy = *"),
           ("eomccsd" ,r" E\(EOM-CCSD\) = *"),
           ("sac-ci"  ,r" Total energy= *"),
           ("ccsd(t)" ,r" CCSD\(T\)= *"),
           ("ccsd"    ,r"amplitudes converged. E\(Corr\)= *"),
           ("b2plyp"  ,r"[0-9] E\(B2PLYP\) = *"),
           ("mp2"     ,r" EUMP2 = *"),
           ("casscf"  ,r"\n TOTAL +"),
           ("external",r"\n Energy= +"),
           (" td"     ,r" E\(TD-HF/TD-(?:DFT|KS)\) = *"),
           (" tda"    ,r" E\(CIS/TDA\) = *"),
           ("amber"   ,r"\n Energy= +"),
           (" uff"    ,r"\n Energy= +"),
           ("dreiding",r"\n Energy= +"))
otherm=("SCF",r"SCF Done: +E\(.*\) *= *")
level,key=next(((m,k) for m,k in level_key
        if re.search(m+r"\W",head.lower())),otherm)
# define regex pattern
patE=re.compile(key+r"([-0-9.+DEe]+)")  # energy
patF=re.compile(r"Center +Atomic +Forces.+\n.*\n -{3,} *\n((?: +[0-9 +-.]+\n){1,}) -{3,}") # negative gradient
# Read EIn file
with open(args.EIn) as f:
  lines=f.read().splitlines()
istart=1
natoms=int(lines[0].split()[0])
iend=istart+natoms
atoms=lines[istart:iend]
connect=lines[iend:-1]
# read connectivity
pairs=[[] for i in range(natoms)]
#pairs=[[]]*natoms
for t in connect:
  line=t.split()
  targ=int(line[0])-1
  partners=[int(p)-1 for p in line[1::2]]
  for p in partners:
    pairs[targ].append(p)
    pairs[p].append(targ)
# read orients
orients=[]
for atom in atoms:
  at,*xyz = atom.split()
  x,y,z=map(float,xyz[:3])
  mm=xyz[-1]
  orients.append([atom_symbol[int(at)],x,y,z,mm])
f1_index=[i for i,xyzmm in enumerate(orients) if xyzmm[-1].upper()=="F1"] 
f2_index=[i for i,xyzmm in enumerate(orients) if xyzmm[-1].upper()=="F2"] 
f12_index=[i for i,xyzmm in enumerate(orients) if xyzmm[-1].upper()=="F12"] 
neighbors=[]
for i in f12_index:
    neighbors.extend(pairs[i])
neighbors=set(neighbors)-set(f12_index)
# capping hydrogen index
h1_index=list(neighbors.intersection(f1_index))
h2_index=list(neighbors.intersection(f2_index))
# boundary atom index in f12
b1_index=[list(set(pairs[i]).intersection(f12_index)) for i in h1_index]
b2_index=[list(set(pairs[i]).intersection(f12_index)) for i in h2_index]
h1_pos=[]
h2_pos=[]
f0_pos=[" ".join(map(str,ori[:4])) for ori in orients]
for i,hi in enumerate(h1_index):
    for bi in b1_index[i]:
        pos=calc_cap_pos(orients[hi][1:4],orients[bi][1:4])
        h1_pos.append(" ".join(["H-Bq"]+list(map(str,pos))))
for i,hi in enumerate(h2_index):
    for bi in b2_index[i]:
        pos=calc_cap_pos(orients[hi][1:4],orients[bi][1:4])
        h2_pos.append(" ".join(["H-Bq"]+list(map(str,pos))))
f0_pos.extend(h1_pos)
f0_pos.extend(h2_pos)
#with open(f12prefix+".inp","w") as f:
#  print(*head,sep="\n",file=f)
#  print(*f0_pos,sep="\n",file=f)
#  print(file=f)
f0_input=head
f0_input+="\n".join(f0_pos)
f0_input+="\n\n"
if args.debug:
    print("inpot for external (debug)")
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    print(f0_input)
    print("~~~~end of external input~~~~")
cp=subprocess.run("g16",shell=True,check=False,text=True,input=f0_input,capture_output=True)
if args.debug or cp.returncode!=0:
    print("output for external (debug)")
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    print(cp.stdout)
    print("~~~~end of external output~~~~")
stdout=cp.stdout
try:
    energy=[_pattern_to_energy(p) for p in patE.finditer(stdout)][-1]
    gradients=[_pattern_to_gradient(p) for p in patF.finditer(stdout)][-1]
except:
    sys.exit("canot read energy or gradient from external g16")
    if not args.debug and cp.returncode==0:
        print("output for external (debug)")
        print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        print(cp.stdout)
        print("~~~~end of external output~~~~")
zero=0.0
with open(args.EOu,"w") as f:
  print(f"{energy:20.12e}{zero:20.12e}{zero:20.12e}{zero:20.12e}",file=f)
  for gx,gy,gz in gradients:
    print(f"{gx:20.12e}{gy:20.12e}{gz:20.12e}",file=f)
#if not args.keep_intermediate:
#  os.remove(datfile)
#  os.remove(dmpfile)
