#!/usr/bin/env python
#@note: 09-22 fix bug in constraint atoms in lammps

import os
import numpy as np
import argparse
from ase.build import bcc100, bcc110
from ase.io import read, write
from ase.constraints import FixAtoms
from ase.neighborlist import NeighborList, natural_cutoffs

class Params():
    def __init__(self,):
        self.cwd = ''
        self.qm_fix_atoms = []
        self.vasp_run_cmd = 'mpirun -np 24 /opt/sourcecoude/vasp5.4.4/vasp.5.4.4/bin/vasp_gam > log'
        self.lammps_run_cmd = '/opt/sourcecoude/lammps/lammps-12Dec18/src/lammps -in lammps_input > log'
        self.simpy_path = '/opt/sourcecoude/simpy'

def gen_struct(p):
    os.chdir('inp')
    metal = bcc100('Li', size=(3,3,6), vacuum=10.0)
    metal.translate([0,0,-9.9])
    write('POSCAR', metal, vasp5=True, direct=True)
    os.system('python %s/tools/vasp/get_potcar.py POSCAR'%p.simpy_path)
    os.chdir('..')

def vasp_opt(p):
    folder = os.path.join(p.cwd, 'run')
    folder = os.path.join(folder, '0-vasp-opt')
    os.makedirs(folder, exist_ok=True)

    os.chdir(folder)
    os.system('cp %s/inp/POSCAR .'%p.cwd)
    os.system('cp %s/inp/POTCAR .'%p.cwd)
    os.system('cp %s/inp/KPOINTS .'%p.cwd)
    os.system('cp %s/inp/INCAR-opt INCAR'%p.cwd)
    if len(p.qm_fix_atoms) > 0:
        atoms = read('POSCAR')
        atoms.set_constraint(FixAtoms(indices=p.qm_fix_atoms))
        write('POSCAR', atoms, vasp5=True, direct=True)

    os.system(p.vasp_run_cmd)
    os.system('cp CONTCAR %s'%p.cwd)
    os.chdir(p.cwd)

def vasp_run(p, n_run):
    folder = os.path.join(p.cwd, 'run')
    folder = os.path.join(folder, '%d-vasp-run'%n_run)
    os.makedirs(folder, exist_ok=True)

    os.chdir(folder)
    os.system('cp %s/inp/POTCAR .'%p.cwd)
    os.system('cp %s/inp/KPOINTS .'%p.cwd)
    os.system('cp %s/CONTCAR POSCAR'%p.cwd)
    os.system('cp %s/inp/INCAR-md INCAR'%p.cwd)
    if len(p.qm_fix_atoms) > 0:
        atoms = read('POSCAR')
        atoms.set_constraint(FixAtoms(indices=p.qm_fix_atoms))
        write('POSCAR', atoms, vasp5=True, direct=True)

    os.system(p.vasp_run_cmd)
    os.system('perl %s/tools/vasp/xdat2xyz.pl'%p.simpy_path)
    os.system('cp CONTCAR %s'%p.cwd)
    os.chdir(p.cwd)

def reax_run(p, n_run):
    folder = os.path.join(p.cwd, 'run')
    folder = os.path.join(folder, '%d-lammps-run'%n_run)
    os.makedirs(folder, exist_ok=True)

    os.chdir(folder)
    os.system('cp %s/CONTCAR POSCAR'%p.cwd)
    os.system('cp %s/inp/lammps_input .'%p.cwd)
    os.system('cp %s/inp/ffield .'%p.cwd)
    os.system('cp %s/inp/constraint .'%p.cwd)
    os.system('cp %s/inp/control.reaxc .'%p.cwd)
    os.system('python %s/lib/e_2_contcar.py POSCAR'%p.simpy_path)
    atoms = read('POSCAR')

    c = []
    if len(p.qm_fix_atoms) > 0:
        for n in p.qm_fix_atoms:
            c.append(n)

    nl = NeighborList(natural_cutoffs(atoms, mult=1.2))
    nl.update(atoms)
    for n, i in enumerate(atoms):
        if i.symbol == 'Li':
            indices, offsets = nl.get_neighbors(n)
            n_li = 0
            for j in indices:
                if atoms[j].symbol == 'Li':
                    n_li += 1
            if n_li >= 8:
                if not n in c:
                    c.append(n)

    o = open('constraint', 'w')
    if len(c) > 0:
        o.write('group li0 id %s\n'%(' '.join([str(ii) for ii in c])))
        o.write('group sim subtract all li0\n')
        o.write('fix 909 li0 setforce 0.0 0.0 0.0\n')
    else:
        o.write('group sim id %s\n'%(' '.join([str(ii) for ii in range(len(atoms))])))

    o.close()

    os.system(p.lammps_run_cmd)
    atoms = read('lammps.xyz')
    poscar = read('POSCAR')
    cell = poscar.get_cell()
    atoms.set_cell(cell)
    atoms.set_pbc([True, True, True])

    c = []
    if len(p.qm_fix_atoms) > 0:
        for n in p.qm_fix_atoms:
            c.append(n)
    
    write('CONTCAR', atoms, vasp5=True, direct=True)
    os.system('cp CONTCAR %s'%p.cwd)
    os.chdir(p.cwd)

def set_up(p):
    p.cwd = os.getcwd()
    if os.path.exists('inp/qm-fix'):
        tokens = np.loadtxt('inp/qm-fix', dtype=int)
        p.qm_fix_atoms = tokens.flatten()

def results_analysis():
    pass

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", type=int, help="n cycles of QM + MM run")
    args = parser.parse_args()

    p = Params()
    set_up(p)

    if 0:
        gen_struct(p)
    vasp_opt(p)
    n_epoch = 2
    if  args.c:
        n_epoch = args.c

    for i in range(1,n_epoch):
        vasp_run(p,i)
        reax_run(p,i)
    if 0:
        results_analysis()