from ase.io.trajectory import Trajectory
import numpy as np


class Resampling:

    def min_vec(self, vec_target, vec_ref, cell_par):

        self.vec_target = vec_target
        self.vec_ref = vec_ref
        self.cell_par = cell_par

        net_vec = self.vec_target - self.vec_ref
        new_vec = np.array([0.0, 0.0, 0.0])
        for i in range(3):
            dump = net_vec[i] % self.cell_par[i]
            if dump > self.cell_par[i] / 2.0:
                new_vec[i] = dump - self.cell_par[i] + self.vec_ref[i]
            elif dump < -1.0 * self.cell_par[i] / 2.0:
                new_vec[i] = dump + self.cell_par[i] + self.vec_ref[i]
            else:
                new_vec[i] = dump + vec_ref[i]
        dump = new_vec - self.vec_ref
        dist = np.sqrt(np.sum(dump*dump))
        
        return new_vec, dist

    def ion_dyn_matrix(self, atoms, qmregion, molid, num_H2O_1, num_H2O_2, atom1, atom2):

    """
    Resampling QM region
    """

        self.atoms = atoms
        self.qmregion = qmregion
        self.atom1 = atom1  # Oxygen ID of H3O+
        self.atom2 = atom2  # Oxygen ID in OH-
        self.molid = molid
        self.num_H2O_1 = num_H2O_1 # The number of O in in H3O+ cluster
        self.num_H2O_2 = num_H2O_2 # The number of O in in HH- cluster

        self.symbols = atoms.get_chemical_symbols()
        cell_par = self.atoms.cell.cellpar()

        ### Get list of "total" and QM atoms from atoms object
        O_list = []
        H_list = []
        QM_O_list = []
        QM_H_list = []
        for i in range(len(self.atoms)):
            if self.symbols[i] == "O" or self.symbols[i] == "Ow":
                O_list.append(i)
                if self.symbols[i] == "O":
                    QM_O_list.append(i)
            else:
                H_list.append(i)
                if self.symbols[i] == "H":
                    QM_H_list.append(i)

        pos = self.atoms.get_positions()
        pos_QM_O_array = pos[QM_O_list]
        pos_O_array = pos[O_list]
        pos_QM_H_array = pos[QM_H_list]
        pos_H_array = pos[H_list]

        mass = {'H': 1.007940, 'O': 15.999400}
        COM = (np.sum(pos_O_array * mass['O'], axis=0) + np.sum(pos_H_array * mass['H'], axis=0))/(len(O_list)*mass['O']+len(H_list)*mass['H'])

        ### Get new_QM_O_list
        dist_array = np.sqrt(np.sum(np.square(pos_O_array - pos[self.atom1]), axis=1))
        new_QM_O_list_1 = []
        for i in np.where(dist_array <= np.sort(dist_array)[self.num_H2O_1-1])[0]:
            new_QM_O_list_1.append(O_list[i])
        dist_array = np.sqrt(np.sum(np.square(pos_O_array - pos[self.atom2]), axis=1))
        new_QM_O_list_2 = []
        for i in np.where(dist_array <= np.sort(dist_array)[self.num_H2O_2-1])[0]:
            new_QM_O_list_2.append(O_list[i])
        new_QM_O_list = new_QM_O_list_1 + new_QM_O_list_2

        if len(list(set(new_QM_O_list))) == self.num_H2O_1 + self.num_H2O_2:
            pass
        elif len(list(set(new_QM_O_list))) > self.num_H2O_1 + self.num_H2O_2:
            raise Exception("len(list(set(new_QM_O_list))) > self.num_H2O_1 + self.num_H2O_2")
        else:  # Sample QM O based on an ellipsoid for recombination
            dist_array = np.sqrt(np.sum(np.square(pos_O_array - pos[self.atom1]), axis=1)) + np.sqrt(np.sum(np.square(pos_O_array - pos[self.atom2]), axis=1))
            ## Add QM_O based on the ellipsoid sampling
            i = 0
            while len(list(set(new_QM_O_list))) < self.num_H2O_1 + self.num_H2O_2:
                idx = np.where(dist_array == np.sort(dist_array)[i])[0][0]
                if O_list[idx] in new_QM_O_list:
                    pass
                else:
                    new_QM_O_list.append(O_list[idx])
                i += 1


        ### get new_QM_H_list and assign new_molid
        new_molid = self.molid.copy()
        new_QM_H_list = []
        for i in new_QM_O_list:
            dist_array = np.sqrt(np.sum(np.square(pos_H_array - pos[i]), axis=1)) # Get relative position array of H from target O[i]
            for j in np.where(dist_array <= np.sort(dist_array)[2])[0]:  # Among the 3 nearest H from O[i]
                dist_array_2 = np.sqrt(np.sum(np.square(pos_O_array - pos[H_list[j]]), axis=1)) # Get distance array of O from H[j]
                near_O = O_list[np.where(dist_array_2 == np.sort(dist_array_2)[0])[0][0]]  # pick the nearest O for H[j]
                if near_O in new_QM_O_list:
                    new_QM_H_list.append(H_list[j])
                    if self.molid[H_list[j]] == self.molid[near_O]:  # No proton transfer
                        pass
                    else:  # in case of proton transfer
                        proton = H_list[j]
                        new_molid[proton] = new_molid[near_O]
                else: # case where near_O is not in new_QM_O_list
                    if self.molid[H_list[j]] == self.molid[near_O]:  # No proton transfer
                        pass
                    else: # case where H is assigned to O(MM)
                        proton = H_list[j]
                        new_molid[proton] = new_molid[near_O]

        new_QM_H_list = sorted(list(set(new_QM_H_list)))  ## To prevent same H ID
        new_QM_list = new_QM_O_list + new_QM_H_list
        
        ### define the protonated O and O in hydroxide based on count(molid)
        proto_O, hydro_O = self.atom1, self.atom2
        list_proto = []
        list_hydro = []
        for i in range(len(QM_O_list)):
            if np.count_nonzero(new_molid == new_molid[QM_O_list[i]]) == 4:  # H3O+
                list_proto.append(QM_O_list[i])
            elif np.count_nonzero(new_molid == new_molid[QM_O_list[i]]) == 2:  # OH-
                list_hydro.append(QM_O_list[i])
            else:
                pass
        
        ### Prioritize the identity of H3O+ and OH- if there are multiples.
        if len(list_proto) == 0 and len(list_hydro) == 0: # pure water case
            pass
        else:
            if len(list_proto) == 1 and len(list_hydro) == 1 and len(new_QM_list) == (self.num_H2O_1 + self.num_H2O_2) * 3:
                proto_O = list_proto[0]
                hydro_O = list_hydro[0]
            else: # For stable dynamics
                print("!!! Proceed with previous QM_region for stable dynamics for this time step", flush=True)
                new_QM_list = self.qmregion
                new_molid = self.molid
                print("# of H3O+: %s, # of OH-: %s, len(new_QM_list): %s"%(len(list_proto), len(list_hydro), len(new_QM_list)), flush=True)
                if len(list_proto) == 1 and len(list_hydro) == 1:
                    proto_O = list_proto[0]
                    hydro_O = list_hydro[0]

        ## Keep original qmregion if O-H distance is smaller than OH_cutoff for stable dynamics
        OH_cutoff = 0.92
        removed_atoms = list(set(self.qmregion) - set(new_QM_list))
        for i in removed_atoms:
            if self.symbols[i] == "O":
                for j in removed_atoms:
                    if self.symbols[j] == "H" and self.molid[j] == self.molid[i]:
                        vec, dist = self.min_vec(pos[i], pos[j], cell_par)
                        if dist < OH_cutoff:
                            new_QM_list = self.qmregion
                            new_molid = self.molid
                            print("ReQM set to new_QM_list = self.qmregion since d(%s,%s)=%0.2f A < %s"%(i, j, dist, OH_cutoff), flush=True)
       
        ### Set new_chemical_symbols
        new_chemical_symbols = self.symbols.copy()
        if sorted(new_QM_list) != sorted(qmregion):
            for i in range(len(new_chemical_symbols)):
                if 'O' in new_chemical_symbols[i] and i in new_QM_list:
                    new_chemical_symbols[i] = 'O'
                elif 'O' in new_chemical_symbols[i] and i not in new_QM_list:
                    new_chemical_symbols[i] = 'Ow'
                elif 'H' in new_chemical_symbols[i] and i in new_QM_list:
                    new_chemical_symbols[i] = 'H'
                else: 
                    new_chemical_symbols[i] = 'Hw'

        return new_QM_list, new_molid, new_chemical_symbols, proto_O, hydro_O, COM