/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.IContainDefaultTuple;
import org.apache.sysds.runtime.compress.colgroup.IFrameOfReferenceGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.encoding.ConstEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.DenseEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.EmptyEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.estim.encoding.SparseEncoding;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.compress.utils.HashMapLongInt;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public final class CLALibCombineGroups {
    protected static final Log LOG = LogFactory.getLog((String)CLALibCombineGroups.class.getName());

    private CLALibCombineGroups() {
    }

    public static List<AColGroup> combine(CompressedMatrixBlock cmb, CompressedSizeInfo csi, ExecutorService pool, int k) throws InterruptedException, ExecutionException {
        if (pool == null || csi.getInfo().size() == 1) {
            return CLALibCombineGroups.combineSingle(cmb, csi, pool, k);
        }
        return CLALibCombineGroups.combineParallel(cmb, csi, pool);
    }

    private static List<AColGroup> combineSingle(CompressedMatrixBlock cmb, CompressedSizeInfo csi, ExecutorService pool, int k) throws InterruptedException, ExecutionException {
        double[] c;
        List<AColGroup> input = cmb.getColGroups();
        int nRow = cmb.getNumRows();
        boolean filterFor = CLALibUtils.shouldFilterFOR(input);
        double[] dArray = c = filterFor ? new double[cmb.getNumColumns()] : null;
        if (filterFor) {
            input = CLALibUtils.filterFOR(input, c);
        }
        List<CompressedSizeInfoColGroup> csiI = csi.getInfo();
        csiI.sort((a, b) -> a.getNumVals() < b.getNumVals() ? -1 : (a.getNumVals() == b.getNumVals() ? 0 : 1));
        ArrayList<AColGroup> ret = new ArrayList<AColGroup>(csiI.size());
        for (CompressedSizeInfoColGroup gi : csiI) {
            List<AColGroup> groupsToCombine = CLALibCombineGroups.findGroupsInIndex(gi.getColumns(), input);
            AColGroup combined = CLALibCombineGroups.combineN(groupsToCombine, nRow, pool, k - csiI.size());
            combined = combined.morph(gi.getBestCompressionType(), nRow);
            combined = filterFor ? combined.addVector(c) : combined;
            ret.add(combined);
        }
        return ret;
    }

    private static List<AColGroup> combineParallel(CompressedMatrixBlock cmb, CompressedSizeInfo csi, ExecutorService pool) throws InterruptedException, ExecutionException {
        double[] c;
        List<AColGroup> input = cmb.getColGroups();
        int nRow = cmb.getNumRows();
        boolean filterFor = CLALibUtils.shouldFilterFOR(input);
        double[] dArray = c = filterFor ? new double[cmb.getNumColumns()] : null;
        if (filterFor) {
            input = CLALibUtils.filterFOR(input, c);
        }
        List<AColGroup> filteredGroups = input;
        List<CompressedSizeInfoColGroup> csiI = csi.getInfo();
        ArrayList<Future> tasks = new ArrayList<Future>();
        for (CompressedSizeInfoColGroup gi : csiI) {
            Future fcg = pool.submit(() -> {
                List<AColGroup> groupsToCombine = CLALibCombineGroups.findGroupsInIndex(gi.getColumns(), filteredGroups);
                AColGroup combined = CLALibCombineGroups.combineN(groupsToCombine, nRow, pool, nRow);
                combined = combined.morph(gi.getBestCompressionType(), nRow);
                combined = filterFor ? combined.addVector(c) : combined;
                return combined;
            });
            tasks.add(fcg);
        }
        ArrayList<AColGroup> ret = new ArrayList<AColGroup>(csiI.size());
        for (Future fcg : tasks) {
            ret.add((AColGroup)fcg.get());
        }
        return ret;
    }

    public static List<AColGroup> findGroupsInIndex(IColIndex idx, List<AColGroup> groups) {
        ArrayList<AColGroup> ret = new ArrayList<AColGroup>();
        for (AColGroup g : groups) {
            if (!g.getColIndices().containsAny(idx)) continue;
            ret.add(g);
        }
        return ret;
    }

    public static AColGroup combineN(List<AColGroup> groups, int nRows, ExecutorService pool, int k) throws InterruptedException, ExecutionException {
        if (k > 0 && groups.size() > 3) {
            return CLALibCombineGroups.combineNMergeTree(groups, nRows, pool, k);
        }
        return CLALibCombineGroups.combineNSingleAtATime(groups, nRows);
    }

    public static AColGroup combineNMergeTree(List<AColGroup> groups, int nRows, ExecutorService pool, int k) throws InterruptedException, ExecutionException {
        Future[] tree = new Future[groups.size() / 2 + groups.size() % 2];
        for (int i = 0; i < groups.size(); i += 2) {
            int c1 = i;
            int c2 = i + 1;
            tree[i / 2] = pool.submit(() -> CLALibCombineGroups.combine((AColGroup)groups.get(c1), (AColGroup)groups.get(c2), nRows));
        }
        if (groups.size() % 2 != 0) {
            tree[tree.length - 1] = pool.submit(() -> (AColGroup)groups.get(groups.size() - 1));
        }
        while (tree.length > 1) {
            Future[] treeTmp = new Future[tree.length / 2 + tree.length % 2];
            Future[] curTree = tree;
            for (int i = 0; i < curTree.length; i += 2) {
                int c1 = i;
                int c2 = i + 1;
                treeTmp[i / 2] = pool.submit(() -> CLALibCombineGroups.combine((AColGroup)curTree[c1].get(), (AColGroup)curTree[c2].get(), nRows));
            }
            if (curTree.length % 2 != 0) {
                treeTmp[treeTmp.length - 1] = curTree[curTree.length - 1];
            }
            tree = treeTmp;
        }
        return (AColGroup)tree[0].get();
    }

    public static AColGroup combineNSingleAtATime(List<AColGroup> groups, int nRows) {
        AColGroup base = groups.get(0);
        for (int i = 1; i < groups.size(); ++i) {
            base = CLALibCombineGroups.combine(base, groups.get(i), nRows);
        }
        return base;
    }

    public static AColGroup combine(AColGroup a, AColGroup b, int nRow) {
        if (a instanceof IFrameOfReferenceGroup || b instanceof IFrameOfReferenceGroup) {
            throw new DMLCompressionException("Invalid call with frame of reference group to combine");
        }
        IColIndex combinedColumns = ColIndexFactory.combine(a, b);
        if (!(a instanceof ColGroupUncompressed) || !(b instanceof ColGroupUncompressed)) {
            if (a instanceof ColGroupUncompressed) {
                a = a.recompress();
            } else if (b instanceof ColGroupUncompressed) {
                b = b.recompress();
            }
        }
        long maxEst = (long)a.getNumValues() * (long)b.getNumValues();
        AColGroup ret = a instanceof AColGroupCompressed && b instanceof AColGroupCompressed && Integer.MAX_VALUE > maxEst ? CLALibCombineGroups.combineCompressed(combinedColumns, (AColGroupCompressed)a, (AColGroupCompressed)b) : CLALibCombineGroups.combineUC(combinedColumns, a, b);
        try {
            double sumIndividualB;
            double sumIndividualA;
            if (!CompressedMatrixBlock.debug) {
                return ret;
            }
            double sumCombined = ret.getSum(nRow);
            if (Math.abs(sumCombined - (sumIndividualA = a.getSum(nRow)) - (sumIndividualB = b.getSum(nRow))) > Math.abs((sumIndividualA + sumIndividualB) / 1000000.0)) {
                throw new DMLCompressionException("Invalid combine... not producing same sum: " + sumCombined + " vs  " + sumIndividualA + " : " + sumIndividualB + "  abs error: " + Math.abs(sumCombined - sumIndividualA - sumIndividualB));
            }
            return ret;
        }
        catch (NotImplementedException e) {
            throw e;
        }
        catch (Exception e) {
            StringBuilder sb = new StringBuilder();
            sb.append("Failed to combine:\n\n");
            String as = a.toString();
            if (as.length() < 10000) {
                sb.append(as);
            } else {
                sb.append(as.substring(0, 10000));
                sb.append("...");
            }
            sb.append("\n\n");
            String bs = b.toString();
            if (bs.length() < 10000) {
                sb.append(bs);
            } else {
                sb.append(bs.substring(0, 10000));
                sb.append("...");
            }
            String rets = ret.toString();
            if (rets.length() < 10000) {
                sb.append(rets);
            } else {
                sb.append(rets.substring(0, 10000));
                sb.append("...");
            }
            throw new DMLCompressionException(sb.toString(), e);
        }
    }

    private static AColGroup combineCompressed(IColIndex combinedColumns, AColGroupCompressed ac, AColGroupCompressed bc) {
        IEncode ae = ac.getEncoding();
        IEncode be = bc.getEncoding();
        Pair<IEncode, HashMapLongInt> cec = ae.combineWithMap(be);
        IEncode ce = (IEncode)cec.getLeft();
        HashMapLongInt filter = (HashMapLongInt)cec.getRight();
        if (ce instanceof EmptyEncoding) {
            return new ColGroupEmpty(combinedColumns);
        }
        if (ce instanceof ConstEncoding) {
            IDictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter);
            return ColGroupConst.create(combinedColumns, cd);
        }
        if (ce instanceof DenseEncoding) {
            DenseEncoding ced = (DenseEncoding)ce;
            IDictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter);
            return ColGroupDDC.create(combinedColumns, cd, ced.getMap(), null);
        }
        if (ce instanceof SparseEncoding) {
            SparseEncoding sed = (SparseEncoding)ce;
            IDictionary cd = DictionaryFactory.combineDictionariesSparse(ac, bc, filter);
            double[] defaultTuple = CLALibCombineGroups.constructDefaultTuple(ac, bc);
            return ColGroupSDC.create(combinedColumns, sed.getNumRows(), cd, defaultTuple, sed.getOffsets(), sed.getMap(), null);
        }
        throw new NotImplementedException("Not implemented combine for " + ac.getClass().getSimpleName() + " - " + bc.getClass().getSimpleName());
    }

    private static AColGroup combineUC(IColIndex combineColumns, AColGroup a, AColGroup b) {
        int nRow = 0;
        if (a instanceof ColGroupUncompressed) {
            nRow = ((ColGroupUncompressed)a).getData().getNumRows();
        } else if (b instanceof ColGroupUncompressed) {
            nRow = ((ColGroupUncompressed)b).getData().getNumRows();
        } else if (a instanceof ColGroupDDC) {
            nRow = ((ColGroupDDC)a).getMapToData().size();
        } else if (b instanceof ColGroupDDC) {
            nRow = ((ColGroupDDC)b).getMapToData().size();
        } else {
            throw new NotImplementedException();
        }
        return CLALibCombineGroups.combineUC(combineColumns, a, b, nRow);
    }

    private static AColGroup combineUC(IColIndex combinedColumns, AColGroup a, AColGroup b, int nRow) {
        double sparsityCombined = (a.getSparsity() * (double)a.getNumCols() + b.getSparsity() * (double)b.getNumCols()) / (double)combinedColumns.size();
        if (sparsityCombined < 0.4) {
            return CLALibCombineGroups.combineUCSparse(combinedColumns, a, b, nRow);
        }
        return CLALibCombineGroups.combineUCDense(combinedColumns, a, b, nRow);
    }

    private static AColGroup combineUCSparse(IColIndex combinedColumns, AColGroup a, AColGroup b, int nRow) {
        MatrixBlock target = new MatrixBlock(nRow, combinedColumns.size(), true);
        target.allocateBlock();
        SparseBlock db = target.getSparseBlock();
        IColIndex aTempCols = ColIndexFactory.getColumnMapping(combinedColumns, a.getColIndices());
        a.copyAndSet(aTempCols).decompressToSparseBlock(db, 0, nRow, 0, 0);
        IColIndex bTempCols = ColIndexFactory.getColumnMapping(combinedColumns, b.getColIndices());
        b.copyAndSet(bTempCols).decompressToSparseBlock(db, 0, nRow, 0, 0);
        target.recomputeNonZeros();
        return ColGroupUncompressed.create(combinedColumns, target, false);
    }

    private static AColGroup combineUCDense(IColIndex combinedColumns, AColGroup a, AColGroup b, int nRow) {
        MatrixBlock target = new MatrixBlock(nRow, combinedColumns.size(), false);
        target.allocateBlock();
        DenseBlock db = target.getDenseBlock();
        IColIndex aTempCols = ColIndexFactory.getColumnMapping(combinedColumns, a.getColIndices());
        a.copyAndSet(aTempCols).decompressToDenseBlock(db, 0, nRow, 0, 0);
        IColIndex bTempCols = ColIndexFactory.getColumnMapping(combinedColumns, b.getColIndices());
        b.copyAndSet(bTempCols).decompressToDenseBlock(db, 0, nRow, 0, 0);
        target.recomputeNonZeros();
        return ColGroupUncompressed.create(combinedColumns, target, false);
    }

    public static double[] constructDefaultTuple(AColGroupCompressed ac, AColGroupCompressed bc) {
        double[] ret = new double[ac.getNumCols() + bc.getNumCols()];
        IIterate ai = ac.getColIndices().iterator();
        IIterate bi = bc.getColIndices().iterator();
        double[] defa = ((IContainDefaultTuple)((Object)ac)).getDefaultTuple();
        double[] defb = ((IContainDefaultTuple)((Object)bc)).getDefaultTuple();
        int i = 0;
        while (ai.hasNext() && bi.hasNext()) {
            if (ai.v() < bi.v()) {
                ret[i++] = defa[ai.i()];
                ai.next();
                continue;
            }
            ret[i++] = defb[bi.i()];
            bi.next();
        }
        while (ai.hasNext()) {
            ret[i++] = defa[ai.i()];
            ai.next();
        }
        while (bi.hasNext()) {
            ret[i++] = defb[bi.i()];
            bi.next();
        }
        return ret;
    }
}

