package in.ac.iisc.cds.dsl.cdgvendor.solver;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.jgrapht.UndirectedGraph;
import org.jgrapht.alg.BronKerboschCliqueFinder;
import org.jgrapht.alg.CliqueMinimalSeparatorDecomposition;
import org.jgrapht.alg.ConnectivityInspector;
import org.jgrapht.alg.NeighborIndex;
import org.jgrapht.graph.AbstractBaseGraph;
import org.jgrapht.graph.DefaultEdge;
import org.jgrapht.graph.SimpleGraph;

import in.ac.iisc.cds.dsl.cdgvendor.model.ViewInfo;
import in.ac.iisc.cds.dsl.cdgvendor.model.formal.FormalCondition;

public class CliqueFinder {

    private final List<String>    sortedColumns;
    private final List<boolean[]> allTrueBS;

    public CliqueFinder(ViewInfo viewInfo, List<boolean[]> allTrueBS) {
        sortedColumns = new ArrayList<>(viewInfo.getViewNonkeys());
        Collections.sort(sortedColumns);

        this.allTrueBS = allTrueBS;
    }

    /**
     * Returns cliques other than those with single attribute single split.
     * Cliques are ordered so that first can be merged with second and then third and so on...
     * @param conditions
     * @return
     */
    public List<Set<String>> getOrderedNonTrivialCliques(List<FormalCondition> conditions) {

        //STEP 1: Construct an undirected graph with each column as a vertex and edge representing that columns coappear somewhere
        UndirectedGraph<String, DefaultEdge> jgraph = new SimpleGraph<>(DefaultEdge.class);
        for (String column : sortedColumns) {
            jgraph.addVertex(column);
        }
        for (FormalCondition condition : conditions) {
            Set<String> appearingCols = new HashSet<>();
            AbstractCliqueFinder.getApppearingCols(appearingCols, condition);
            List<String> colList = new ArrayList<>(appearingCols);

            for (int i = 0; i < colList.size(); i++) {
                String colI = colList.get(i);
                for (int j = i + 1; j < colList.size(); j++) {
                    String colJ = colList.get(j);
                    jgraph.addEdge(colI, colJ);
                }
            }
        }

        //STEP 2: Check if it is chordal. Library method seems fast
        //printMatrix(jgraph);
        CliqueMinimalSeparatorDecomposition<String, DefaultEdge> cmsd = new CliqueMinimalSeparatorDecomposition<>(jgraph);
        if (!cmsd.isChordal()) {
            makeChordal(jgraph);
        }

        //STEP 3: Get all maximal cliques
        BronKerboschCliqueFinder<String, DefaultEdge> bkcf = new BronKerboschCliqueFinder<>(jgraph);
        Collection<Set<String>> cliques = bkcf.getAllMaximalCliques();

        //STEP 4: Removing cliques with single column single split
        List<Set<String>> nonTrivialCliques = new ArrayList<>();
        for (Set<String> clique : cliques) {
            if (clique.size() == 1) {
                String columnName = new ArrayList<>(clique).get(0);
                if (getSplitCount(columnName) == 1) {
                    continue;
                }
            }
            nonTrivialCliques.add(clique);
        }

        //STEP 5: Get cliques in merge order
        List<Set<String>> orderedCliques = getCliquesInMergeOrder(nonTrivialCliques, jgraph);

        return orderedCliques;
    }

    private UndirectedGraph<String, DefaultEdge> makeChordal(UndirectedGraph<String, DefaultEdge> jgraph) {
    	
    	CliqueMinimalSeparatorDecomposition<String, DefaultEdge> cmsd = new CliqueMinimalSeparatorDecomposition<>(jgraph);
    	jgraph = cmsd.getMinimalTriangulation();
    	cmsd = new CliqueMinimalSeparatorDecomposition<>(jgraph);
    
        
        
        //CliqueMinimalSeparatorDecomposition<String, DefaultEdge> cmsd = new CliqueMinimalSeparatorDecomposition<>(jgraph);
        if (!cmsd.isChordal()){
        	
            throw new RuntimeException("Should not be reaching here - Not Chordal");
        }
        
        return jgraph;
    }

    private void printMatrix(UndirectedGraph<String, DefaultEdge> jgraph) {

        ConnectivityInspector<String, DefaultEdge> ins = new ConnectivityInspector<>(jgraph);
        List<Set<String>> connectedSets = ins.connectedSets();

        List<Set<String>> connectedSets2 = new ArrayList<>();
        for (Set<String> connectedSet : connectedSets) {
            if (connectedSet.size() > 1) {
                connectedSets2.add(connectedSet);
            }
        }

        for (Set<String> connectedSet : connectedSets2) {

            List<String> list = new ArrayList<>(connectedSet);
            Collections.sort(list);

            System.out.println("\n\n" + list);
            int n = list.size();
            int mat[][] = new int[n][n];

            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    mat[i][j] = jgraph.containsEdge(list.get(i), list.get(j)) ? 1 : 0;
                }
            }
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    System.out.print(mat[i][j] + " ");
                }
                System.out.println();
            }
        }

    }

    /**
     * @cite Dias, Elisângela Silva, et al. "Efficient enumeration of chordless cycles." arXiv preprint arXiv:1309.1051 (2013). Page 6
     * @param jgraph
     * @return
     */
    private List<List<String>> getChordlessNonTriangleCycles(UndirectedGraph<String, DefaultEdge> jgraph) {

        List<List<String>> T = new ArrayList<>();
        NeighborIndex<String, DefaultEdge> neighborIndex = new NeighborIndex<>(jgraph);
        for (String u : jgraph.vertexSet()) {
            List<String> adjVertices = neighborIndex.neighborListOf(u);
            //Collections.shuffle(adjVertices);
            for (String x : adjVertices) {
                for (String y : adjVertices) {
                    if (u.compareTo(x) < 0 && x.compareTo(y) < 0 && !jgraph.containsEdge(x, y)) {
                        List<String> triple = new ArrayList<>();
                        triple.add(x);
                        triple.add(u);
                        triple.add(y);
                        T.add(triple);
                    }
                }
            }
        }

        List<List<String>> C = new ArrayList<>();
        while (!T.isEmpty()) {

            List<String> p = T.get(0);
            T.remove(0);

            String ulast = p.get(p.size() - 1);
            String u0 = p.get(0);
            String u1 = p.get(1);
            List<String> adjVertices = neighborIndex.neighborListOf(ulast);

            for (String v : adjVertices) {
                if (v.compareTo(u1) > 0) {
                    int j;
                    for (j = 1; j < p.size() - 1; j++) {
                        if (jgraph.containsEdge(v, p.get(j))) {
                            break;
                        }
                    }
                    if (j < p.size() - 1) {
                        continue;
                    }

                    p.add(v);
                    if (jgraph.containsEdge(v, u0)) {
                        C.add(p);
                    } else {
                        T.add(p);
                    }
                }
            }
        }

        return C;
    }

    private List<Set<String>> getCliquesInMergeOrder(List<Set<String>> unorderedCliques, UndirectedGraph<String, DefaultEdge> jgraph) {

        List<Set<String>> orderedCliques = new ArrayList<>();

        List<String> visitedCols = new ArrayList<>();
        boolean[] visited = new boolean[unorderedCliques.size()];

        orderedCliques.add(unorderedCliques.get(0));
        visitedCols.addAll(unorderedCliques.get(0));
        visited[0] = true;

        while (orderedCliques.size() < unorderedCliques.size()) {

            boolean adjacentCliqueExists = false;
            int foundCliqueIdx = -1;
            for (int i = 1; i < visited.length; i++) {
                if (visited[i]) {
                    continue;
                }
                Set<String> sharedCols = getColsAlreadyVisited(unorderedCliques.get(i), visitedCols);
                if (sharedCols.isEmpty()) {
                    continue;
                }

                adjacentCliqueExists = true;

                if (everyPathHasBlockade(unorderedCliques.get(i), new HashSet<>(visitedCols), sharedCols, jgraph)) {
                    foundCliqueIdx = i;
                    break;
                }
            }

            if (adjacentCliqueExists && foundCliqueIdx == -1)
                throw new RuntimeException("We have adjacent unvisted clique(s) but couldn't choose next clique");

            //Picking any unvisitedClique
            if (foundCliqueIdx == -1) {
                for (foundCliqueIdx = 0; foundCliqueIdx < visited.length && visited[foundCliqueIdx]; foundCliqueIdx++) {
                    ;
                }
                if (foundCliqueIdx == visited.length)
                    throw new RuntimeException("Couldn't choose next clique although some are left");
            }

            orderedCliques.add(unorderedCliques.get(foundCliqueIdx));
            visitedCols.addAll(unorderedCliques.get(foundCliqueIdx));
            visited[foundCliqueIdx] = true;

        }

        return orderedCliques;

    }

    /**
     * No side effects method
     * Checks if all paths connecting setA and setB has a police vertex in it
     * @param setA
     * @param setB
     * @param policeSet
     * @param jgraph
     * @return
     */
    private boolean everyPathHasBlockade(Set<String> setA, Set<String> setB, Set<String> policeSet, UndirectedGraph<String, DefaultEdge> jgraph) {

        setA = new HashSet<>(setA);
        setA.removeAll(policeSet);

        setB = new HashSet<>(setB);
        setB.removeAll(policeSet);

        if (setA.isEmpty() || setB.isEmpty())
            return true;

        UndirectedGraph<String, DefaultEdge> jgraphCopy = (UndirectedGraph<String, DefaultEdge>) ((AbstractBaseGraph<String, DefaultEdge>) jgraph).clone();
        jgraphCopy.removeAllVertices(policeSet);

        ConnectivityInspector<String, DefaultEdge> inspector = new ConnectivityInspector<>(jgraphCopy);
        
        for(String vertexB: setB){
    		if (inspector.pathExists(setA.iterator().next(), vertexB)){
    			//DebugHelper.printInfo("see the path exists!");
    			return false;
    		}
    	}
        
//        return !inspector.pathExists(setA.iterator().next(), setB.iterator().next());

        return true;
    }

    /**
     * Returns those columns of clique which are already visited
     * @param clique
     * @param visitedCols
     * @return
     */
    private Set<String> getColsAlreadyVisited(Set<String> clique, List<String> visitedCols) {

        Set<String> tempSet = new HashSet<>(clique);
        tempSet.removeAll(visitedCols);

        Set<String> resSet = new HashSet<>(clique);
        resSet.removeAll(tempSet);
        return resSet;
    }

    private int getSplitCount(String columnName) {

        int colIdx;
        for (colIdx = 0; colIdx < sortedColumns.size() && !sortedColumns.get(colIdx).equals(columnName); colIdx++) {
            ;
        }

        if (colIdx == sortedColumns.size())
            throw new RuntimeException("Not found coumnName: " + columnName + " in sortedColumns: " + sortedColumns);

        return allTrueBS.get(colIdx).length;

    }

    public long getReducedVariableCount(List<Set<String>> cliques) {
        //STEP 4: Count number of variables going clique by clique
        Map<String, Integer> columnsToBucketCount = new HashMap<>();
        for (int i = 0; i < sortedColumns.size(); i++) {
            columnsToBucketCount.put(sortedColumns.get(i), allTrueBS.get(i).length);
        }

        long varcount = 0;
        for (Set<String> clique : cliques) {
            long cliqueVarcount = 1;
            for (String attribute : clique) {
                cliqueVarcount *= columnsToBucketCount.get(attribute);
                //DebugHelper.printDebug(" attribute " + attribute + " vars " + columnsToBucketCount.get(attribute));
            }
            //DebugHelper.printDebug(clique.toString() + " vars " + cliqueVarcount);
            varcount += cliqueVarcount;
        }

        return varcount;
    }

}
