/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.compress.utils.IntArrayList;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public interface CLALibSelectionMult {
    public static final Log LOG = LogFactory.getLog((String)CLALibSelectionMult.class.getName());

    public static MatrixBlock leftSelection(CompressedMatrixBlock right, MatrixBlock left, MatrixBlock ret, int k) {
        try {
            if (right.getNonZeros() <= -1L) {
                right.recomputeNonZeros();
            }
            boolean sparseOut = right.getSparsity() < 0.3;
            ret = CLALibSelectionMult.allocateReturn(right, left, ret, sparseOut);
            List<AColGroup> preFilter = right.getColGroups();
            boolean shouldFilter = CLALibUtils.shouldPreFilter(preFilter);
            if (shouldFilter) {
                CLALibSelectionMult.filteredLeftSelection(left, ret, k, sparseOut, preFilter);
            } else {
                CLALibSelectionMult.normalLeftSelection(left, ret, k, sparseOut, preFilter);
            }
            ret.recomputeNonZeros(k);
            return ret;
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed left selection Multiplication", e);
        }
    }

    public static boolean isSelectionMatrix(MatrixBlock mb) {
        if (mb.isEmpty()) {
            return false;
        }
        if (mb.getNonZeros() <= (long)mb.getNumRows() && mb.isInSparseFormat()) {
            SparseBlock sb = mb.getSparseBlock();
            for (int i = 0; i < mb.getNumRows(); ++i) {
                if (sb.isEmpty(i)) continue;
                if (sb.size(i) != 1) {
                    return false;
                }
                if (sb instanceof SparseBlockCSR) continue;
                double[] values = sb.values(i);
                int spos = sb.pos(i);
                int sEnd = spos + sb.size(i);
                for (int j = spos; j < sEnd; ++j) {
                    if (values[j] == 1.0) continue;
                    return false;
                }
            }
            if (sb instanceof SparseBlockCSR) {
                for (double d : sb.values(0)) {
                    if (d == 1.0) continue;
                    return false;
                }
            }
            return true;
        }
        return false;
    }

    private static MatrixBlock allocateReturn(CompressedMatrixBlock right, MatrixBlock left, MatrixBlock ret, boolean sparseOut) {
        if (ret == null) {
            ret = new MatrixBlock();
        }
        ret.reset(left.getNumRows(), right.getNumColumns(), sparseOut);
        ret.allocateBlock();
        return ret;
    }

    private static void normalLeftSelection(MatrixBlock left, MatrixBlock ret, int k, boolean sparseOut, List<AColGroup> preFilter) throws Exception {
        int rowLeft = left.getNumRows();
        boolean pointsNeeded = CLALibSelectionMult.areSortedCoordinatesNeeded(preFilter);
        if (k <= 1 || rowLeft < 1000) {
            CLALibSelectionMult.leftSelectionSingleThread(preFilter, left, ret, rowLeft, pointsNeeded, sparseOut);
        } else {
            CLALibSelectionMult.leftSelectionParallel(preFilter, left, ret, k, rowLeft, pointsNeeded, sparseOut);
        }
    }

    private static void filteredLeftSelection(MatrixBlock left, MatrixBlock ret, int k, boolean sparseOut, List<AColGroup> preFilter) throws Exception {
        double[] constV = new double[ret.getNumColumns()];
        List<AColGroup> morphed = CLALibUtils.filterGroups(preFilter, constV);
        CLALibSelectionMult.normalLeftSelection(left, ret, k, sparseOut, morphed);
        double[] rowSums = left.rowSum(k).getDenseBlockValues();
        CLALibSelectionMult.outerProduct(rowSums, constV, ret, sparseOut);
    }

    private static void leftSelectionSingleThread(List<AColGroup> right, MatrixBlock left, MatrixBlock ret, int rowLeft, boolean pointsNeeded, boolean sparseOut) {
        ColGroupUtils.P[] points = pointsNeeded ? ColGroupUtils.getSortedSelection(left.getSparseBlock(), 0, rowLeft) : null;
        for (AColGroup g : right) {
            g.selectionMultiply(left, points, ret, 0, rowLeft);
        }
        if (sparseOut) {
            ret.getSparseBlock().sort();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void leftSelectionParallel(List<AColGroup> right, MatrixBlock left, MatrixBlock ret, int k, int rowLeft, boolean pointsNeeded, boolean sparseOut) throws InterruptedException, ExecutionException {
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            ArrayList tasks = new ArrayList();
            int blkz = Math.max(rowLeft / k, 1000);
            for (int i = 0; i < rowLeft; i += blkz) {
                int n = i;
                int end = Math.min(rowLeft, i + blkz);
                ColGroupUtils.P[] points = pointsNeeded ? ColGroupUtils.getSortedSelection(left.getSparseBlock(), n, end) : null;
                tasks.add(pool.submit(() -> {
                    for (AColGroup g : right) {
                        g.selectionMultiply(left, points, ret, start, end);
                    }
                    if (sparseOut) {
                        SparseBlock sb = ret.getSparseBlock();
                        for (int j = start; j < end; ++j) {
                            if (sb.isEmpty(j)) continue;
                            sb.sort(j);
                        }
                    }
                }));
            }
            for (Future future : tasks) {
                future.get();
            }
        }
        finally {
            pool.shutdown();
        }
    }

    private static boolean areSortedCoordinatesNeeded(List<AColGroup> right) {
        for (AColGroup g : right) {
            if (g.getCompType() != AColGroup.CompressionType.SDC) continue;
            return true;
        }
        return false;
    }

    private static void outerProduct(double[] rows, double[] cols, MatrixBlock ret, boolean sparse) {
        if (sparse) {
            CLALibSelectionMult.outerProductSparse(rows, cols, ret);
        } else {
            CLALibSelectionMult.outerProductDense(rows, cols, ret);
        }
    }

    private static void outerProductDense(double[] rows, double[] cols, MatrixBlock ret) {
        DenseBlock db = ret.getDenseBlock();
        for (int r = 0; r < rows.length; ++r) {
            double rv = rows[r];
            double[] dbV = db.values(r);
            int pos = db.pos(r);
            if (rv == 0.0) continue;
            for (int c = 0; c < cols.length; ++c) {
                int n = pos + c;
                dbV[n] = dbV[n] + rv * cols[c];
            }
        }
    }

    private static void outerProductSparse(double[] rows, double[] cols, MatrixBlock ret) {
        SparseBlock sb = ret.getSparseBlock();
        IntArrayList skipCols = new IntArrayList();
        for (int c = 0; c < cols.length; ++c) {
            if (cols[c] == 0.0) continue;
            skipCols.appendValue(c);
        }
        int skipSz = skipCols.size();
        if (skipSz == 0) {
            return;
        }
        int[] skipC = skipCols.extractValues();
        for (int r = 0; r < rows.length; ++r) {
            double rv = rows[r];
            if (rv == 0.0) continue;
            for (int ci = 0; ci < skipSz; ++ci) {
                int c = skipC[ci];
                sb.add(r, c, rv * cols[c]);
            }
        }
    }
}

