import os
import numpy as np
from ase.calculators.calculator import FileIOCalculator
from ase.calculators.calculator import SCFError, ReadError
import ase.units
import shutil
import pickle

class Jaguar(FileIOCalculator):
    """
    Jaguar calculator
    """
    name = 'jaguar'

    implemented_properties = ['energy', 'forces']

    command = ""

    # define the defaults
    default_parameters = {'basis': '6-31G**'}

    def __init__(self, restart=None,
                 ignore_bad_restart_file=FileIOCalculator._deprecated,
                 label='jaguar', scratch=None, tpp=32, atoms=None, **kwargs):
        """
        tpp: int
             number of threads for the command line flag

        """

        FileIOCalculator.__init__(self, restart, ignore_bad_restart_file,
                                  label, atoms, **kwargs)


        SCHRODINGER = os.environ['SCHRODINGER']
        if SCHRODINGER == None:
            raise Exception('The environment variable SCHRODINGER has not been defined.')

        # Augment the command by various flags
        self.command = SCHRODINGER + "/jaguar run "

        if tpp > 1:
            self.command += '-TPP %d ' % tpp

        self.command += 'PREFIX.in'
        self.command += ' -WAIT'

    def read(self, label):
        raise NotImplementedError

    def read_results(self):
        #
        # convert energy from Eh. to eV.
        #
        energy_convert = ase.units.Hartree

        #
        # convert the forces from hartrees/bohr to eV / Angstrom
        #
        force_convert = ase.units.Hartree / ase.units.Bohr

        filename = self.label + '.out'
        if not os.path.isfile(filename):
            raise ReadError('Could not find Jaguar output file: ' + filename)

        with open(filename, 'r') as fileobj:
            lineiter = iter(fileobj)
            for line in lineiter:
                if 'SCF failed to converge' in line:
                    raise SCFError()

                elif 'SCFE: SCF energy:' in line:
                    self.results['energy'] = float(line.split()[4]) * energy_convert

                elif 'forces (hartrees/bohr) : total' in line:
                    #
                    # read forces[natom][3] 
                    #
                    forces = []

                    #
                    # Skip 3 lines
                    #
                    next(lineiter)
                    next(lineiter)
                    next(lineiter)

                    #
                    # Loop over the atomic forces
                    #
                    line = next(lineiter)
                    
                    while True:
                        items = line.split()
                        af = (float(items[2]), float(items[3]), float(items[4]))
                        forces.append(af)
                        
                        line = next(lineiter)
                        if '----' in line:
                            self.results['forces'] = np.array(forces) * force_convert
                            break

    def write_input(self, atoms, properties=None, system_changes=None):
        FileIOCalculator.write_input(self, atoms, properties, system_changes)

        #
        # get initial guess
        #
        restart_file = self.label + '.01.in'
        initial_guess = []
        found_initial_guess = 0
        if os.path.isfile(restart_file):
            with open(restart_file, 'r') as restart_fileobj:
                lineiter = iter(restart_fileobj)
                for line in lineiter:
                    line = line.rstrip('\n').strip().lower()
                    if line[0:6] == "$guess" or line[0:6] == "&guess":
                        initial_guess.append(line)
                        while True:
                            line = next(lineiter)
                            line = line.rstrip('\n').strip()
                            initial_guess.append(line)
                            if line == "&" or line == "$":
                                found_initial_guess = 1
                                break

        #
        # get the xyz coordinates from previous run
        #

        cutoff_dist = 1

        files = [f for f in os.listdir(self.label[:-4]) if os.path.isfile(os.path.join(self.label[:-4], f))]
        num_list = []
        for i in files:
            if 'input' in i and '.swp' not in i:
                num_list.append(int(i.split("_")[1]))
        num_list.sort()
        
        try:
            shutil.copy(self.label+'.out', self.label+'_pre.out')
            f = open(self.label+'.out', 'r')
            data = f.readlines()
            f.close()
            w = open("energy_tracking", 'a')
            for i in data:
                if "SCFE" in i:
                    w.write(i)
            w.close()
        except:
            pass

        remove_wfns = False
        if os.path.isfile(self.label+".in") == False:
            found_initial_guess = 0
        elif os.path.isfile(self.label+".in") and len(num_list) == 0:
            remove_wfns = True
            found_initial_guess = 0
        else:
            last_num = num_list[-1]
    
            len_qmregion = len(atoms)
            previous_xyz = np.zeros([len_qmregion,3])
            f = open(self.label[:-4]+"input_%s"%last_num, 'r')
            f = open(self.label+".in", 'r')
            data = f.readlines()
            f.close()
            index = data.index('&zmat\n')
            for i in range(len_qmregion):
                xyz = data[index+1+i].split()[1:]
                for j in range(3):
                    previous_xyz[i,j] = xyz[j]

            # check if any atomic displacement < cutoff distance, otherwise do not read previous wavefunction.
            j=0
            for a in atoms:
                new_xyz = [a.x, a.y, a.z]
                for coord in range(3):
                    if previous_xyz[j,coord] - new_xyz[coord] < cutoff_dist:
                        pass
                    else:
                        remove_wfns = True
                        found_initial_guess = 0
                j+=1

        if remove_wfns == True:
            os.remove(self.label+".01.in")


        filename = self.label + '.in'
        with open(filename, 'w') as fileobj:
            #
            # write &gen section
            #
            fileobj.write('&gen\n')
            if 'forces' in properties:
                fileobj.write('igeopt=-1\n')
                found_initial_guess = 1
   
            for prm in self.parameters:
                if self.parameters[prm] is not None:
                    fileobj.write('%s = %s\n' % (prm, self.parameters[prm]))

            fileobj.write('&\n')

            #
            # write &zmat section
            #
            fileobj.write('&zmat\n')
            for a in atoms:
                fileobj.write('   %s  %f  %f  %f\n' % (a.symbol,a.x, a.y, a.z))
            fileobj.write('&\n')

            #
            # write initial guess
            #
            if found_initial_guess == 1:
                for line in initial_guess:
                    fileobj.write(line + "\n")