#!/usr/bin/env python

''' 
@todo: get stress for LinearCombinationCalculator 
@todo: use the same form of Hyungjun's paper (dx.doi.org/10.1021/jz2016395) 
''' 

import os, sys
import numpy as np
from math import sqrt, pow
import socket
import argparse

from ase import Atoms, Atom
from ase.calculators.vasp import Vasp2
from ase.io import read, write, Trajectory
from ase.optimize import *
from ase.calculators.mixing import LinearCombinationCalculator
from ase.calculators.mixing import SumCalculator
from ase.calculators.emt import EMT
from ase.calculators.lj import LennardJones
from ase.neighborlist import NeighborList
from ase.calculators.calculator import Calculator, all_changes
from ase.calculators.calculator import PropertyNotImplementedError
from ase.calculators.loggingcalc import LoggingCalculator
from ase.constraints import ExpCellFilter

class Params():
    def __init__(self,):
        self.vasp_params = {}
        self.fname = ''
        self.atoms = Atoms()

def set_up_vasp_env(p):
    if 1:
        if sum(p.kpts) > 3:
            os.environ['ASE_VASP_COMMAND'] = 'mpirun -np 24 /home/chengtao/soft/vasp/vasp.5.4.4.solv.skx.tst/src/vasp.5.4.4/bin/vasp_std'
        elif sum(p.kpts) == 3:
            os.environ['ASE_VASP_COMMAND'] = 'mpirun -np 24 /home/chengtao/soft/vasp/vasp.5.4.4.solv.skx.tst/src/vasp.5.4.4/bin/vasp_gam'
        os.environ['VASP_PP_PATH'] = '/home/chengtao/soft/vasp/'

def read_incar(p):
    f = open('INCAR', 'r')
    for i in f:
        tokens = []
        if i.strip().startswith('#'):
            pass
        else:
            if '#' in i:
                tokens = i.split('#')[0].split('=')
            else:
                tokens = i.split('=')
        if len(tokens) == 2:
            keywords =tokens[0].strip().lower() 
            values = tokens[1].strip().split()
            try:
                values = [float(ii) for ii in tokens[1].strip().split()]
            except Exception as e:
                pass
                #print("error", e, 'on', tokens[0])
            if len(values) == 1:
                p.vasp_params[keywords] = values[0]

            if keywords == 'magmom':
                magmom = []
                for ii in values:
                    if '*' in ii:
                        tokens = ii.strip().split('*')
                        nn = int(tokens[0])
                        vv = float(tokens[1])
                        for jj in range(nn):
                            magmom.append(vv)
                    else:
                        tokens = ii.strip()
                        vv = float(tokens)
                        magmom.append(vv)
                print(len(magmom), magmom)

                if len(magmom) == len(p.atoms):
                    for ii in range(len(p.atoms)):
                        p.atoms[ii].magmom = magmom[ii]
    f.close()

class myLowGradient(Calculator):
    implemented_properties = ['energy', 'forces']
    default_parameters = {'clg': 1.0, 're': 3.0, 'rc': None}
    nolabel = True

    def __init__(self, **kwargs):
        Calculator.__init__(self, **kwargs)

    def init_tap(self, ri, ro):
        ri2 = ri**2
        ri3 = ri**3

        ro2 = ro**2
        ro3 = ro**3

        tap = np.zeros(8)
        dtap = np.zeros(8)

        d7 = pow(ro-ri, 7.0)
        tap[7] = 20.0/d7
        tap[6] = -70.0*(ri+ro)/d7
        tap[5] = 84.0*(ri2+3.0*ri*ro+ro2)/d7
        tap[4] = -35.0*(ri3*ro+9.0*ri2*ro+9.0*ri*ro2+ro3)/d7
        tap[3] = 140*(ri3*ro+3.0*ri2*ro2+ri*ro3)/d7
        tap[2] = -210*(ri3*ro2+ri2*ro3)/d7
        tap[1] = 140*ri3*ro3/d7
        tap[0] = (-35.0*ri3*ro2*ro2+21.0*ri2*ro3*ro2+7.0*ri*ro3*ro3+ro3*ro3*ro)/d7
        
        dtap[7] = 0.0
        dtap[6] = 7.0*tap[7]
        dtap[5] = 6.0*tap[6]
        dtap[4] = 5.0*tap[5]
        dtap[3] = 4.0*tap[4]
        dtap[2] = 3.0*tap[3]
        dtap[1] = 2.0*tap[2]
        dtap[0] = 1.0*tap[1]
        return tap, dtap

    def calculate(self, atoms=None, properties=['energy'], system_changes=all_changes):
        Calculator.calculate(self, atoms, properties, system_changes)

        natoms = len(self.atoms)
        
        params = {}
        f = open('ffield', 'r')
        for i in f:
            if i.strip().startswith('#'):
                pass
            else:
                tokens = i.strip().split()
                if len(tokens) == 3:
                    params[tokens[0]] = [float(tokens[1]), float(tokens[2])]
        f.close()

        ri = 0.0
        ro = 10.0
        tap, dtap = self.init_tap(ri, ro)

        rc = self.parameters.rc

        clg = np.zeros((natoms, natoms))
        re6 = np.zeros((natoms, natoms))

        atps = self.atoms.get_chemical_symbols()
        for i in range(natoms):
            for j in range(i, natoms):
                clg[i][j] = sqrt(params[atps[i]][0] * params[atps[j]][0])
                re6[i][j] = pow(params[atps[i]][1] * params[atps[j]][1], 3)

        if rc is None:
            rc = 10.0

        self.nl = NeighborList([rc/2] * natoms, self_interaction=False)
        self.nl.update(self.atoms)

        positions = self.atoms.positions
        cell = self.atoms.cell

        energy = 0.0
        forces = np.zeros((natoms, 3))

        for i in range(natoms):
            neighbors, offsets = self.nl.get_neighbors(i)
            cells = np.dot(offsets, cell)
            e_lg = 0.0
            for n in range(len(neighbors)):
                j = neighbors[n]
                rvec = self.atoms.get_distance(i,j, vector=True) + cells[n]
                #rvec = self.atoms.get_distance(i,j, mic=True, vector=True)
                r2 = (rvec**2).sum()
                r = sqrt(r2)
                r6 = r2**3
                if r > ri:
                    taper = tap[7]
                    dtaper = dtap[7]
                    for ii in range(7):
                        taper = taper*r+tap[6-ii]
                        dtaper = dtaper*r + dtap[6-ii]
                else:
                    taper = 1.0
                    dtaper = 0.0
                
                if r > ro:
                    taper = 0.0
                    dtaper = 0.0

                e_lg = -clg[i][j]/(r6+re6[i][j])
                energy += taper*e_lg
                f = dtaper*(-6.0*e_lg*r6/(r6+re6[i][j])/r2)*rvec
                forces[i] += f
                forces[j] += -f

        self.results['energy'] = energy
        self.results['free_energy'] = energy
        self.results['forces'] = forces

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-k', type=int, nargs=3, help='kpoints')
    parser.add_argument('-run', nargs=1, help='runtype: lg, qm, qm-lg')
    parser.add_argument('-steps', type=int, nargs=1, help='max steps of opt')
    parser.add_argument('-fmax', type=float, nargs=1, help='max force in opt')
    parser.add_argument('-opt', nargs=1, 
        help='optimizer: BFGS, BFGSLineSearch, LBFGS, LBFGSLineSearch, GPMin, MDMin, FIRE')
    parser.add_argument('-opttype', nargs=1, help='nvt or npt')
    
    args = parser.parse_args()

    opt = {'BFGS': BFGS, 'BFGSLineSearch': BFGSLineSearch, 'LBFGS': LBFGS,
           'LBFGSLineSearch': LBFGSLineSearch, 'GPMin': GPMin, 'MDMin': MDMin, 'FIRE':FIRE}
    opt_fun = 'BFGS'

    p = Params()
    p.fname = 'POSCAR'
    p.atoms = read(p.fname)

    # set up vdw correction
    p.vdw = myLowGradient()

    vasp_folder = 'vasp-run'
    read_incar(p)
    p.kpts = (1,1,1)
    if args.k:
        p.kpts=args.k
    else:
        print('Use default gamma point (1 1 1)')
    set_up_vasp_env(p)
    p.qm = Vasp2(kpts=p.kpts, directory=vasp_folder, pp='pbe')
    p.qm.set(**p.vasp_params)

    run_type = 'lg'
    opt_type = 'nvt'
    fmax = 0.001
    n_steps = 200

    if args.run:
        run_type = args.run[0]
    if args.opt:
        print(args.opt)
        opt_fun = args.opt[0]
    else:
        print('Using default BFGS')
    if args.steps:
        n_steps = args.steps[0]
    if args.fmax:
        fmax = args.fmax[0]
    if args.opttype:
        opt_type = args.opttype[0]

    if run_type == 'lg':
        p.atoms.set_calculator(p.vdw)
        #dyn = BFGS(p.atoms, trajectory='lg.traj', restart='lg-restart.pckl')
        #dyn.run(fmax=fmax, steps=n_steps)
        pe = p.atoms.get_potential_energy()
        print(pe)

    elif run_type == 'qm':
        p.atoms.set_calculator(p.qm)
        if opt_type == 'nvt':
            dyn = opt[opt_fun](p.atoms, trajectory='vasp.traj', restart='vasp-restart.pckl')
        elif opt_type == 'npt':
            ecf = ExpCellFilter(p.atoms)
            dyn = opt[opt_fun](ecf)
            traj = Trajectory('vasp.traj', 'w', p.atoms)
            dyn.attach(traj)
        dyn.run(fmax=fmax, steps=n_steps)
        pe = p.atoms.get_potential_energy()
        print(pe)

    elif run_type == 'qm-lg':
        LinearCombinationCalculator([p.vdw, p.qm], [1.0, 1.0], p.atoms)
        if opt_type == 'nvt':
            dyn = opt[opt_fun](p.atoms, trajectory='vasp.traj', restart='vasp-restart.pckl')
        elif opt_type == 'npt':
            print('Not implemented yet')
            '''
            ecf = ExpCellFilter(p.atoms)
            dyn = opt[opt_fun](ecf)
            traj = Trajectory('vasp.traj', 'w', p.atoms)
            dyn.attach(traj)
            '''
        dyn.run(fmax=fmax, steps=n_steps)
        pe = p.atoms.get_potential_energy()
        print(pe)

    else:
        print('Warning: undefined run type. Available types are:')
        print('lg, qm, qm-lg')