/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.APreAgg;
import org.apache.sysds.runtime.compress.colgroup.dictionary.AIdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.lib.CLALibSelectionMult;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.utils.stats.Timing;

public final class CLALibLeftMultBy {
    private static final Log LOG = LogFactory.getLog((String)CLALibLeftMultBy.class.getName());

    private CLALibLeftMultBy() {
    }

    public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock right, MatrixBlock left, MatrixBlock ret, int k) {
        if (left.isEmpty() || right.isEmpty()) {
            return CLALibLeftMultBy.prepareEmptyReturnMatrix(right, left, ret, true);
        }
        if (left.getNumColumns() > 1) {
            LOG.warn((Object)"Transposing matrix block for transposed left matrix multiplication");
        }
        MatrixBlock transposed = new MatrixBlock(left.getNumColumns(), left.getNumRows(), false);
        LibMatrixReorg.transpose(left, transposed, k);
        ret = CLALibLeftMultBy.leftMultByMatrix(right, transposed, ret, k);
        return ret;
    }

    public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock right, CompressedMatrixBlock left, MatrixBlock ret, int k) {
        try {
            if (left.isEmpty() || right.isEmpty()) {
                return CLALibLeftMultBy.prepareEmptyReturnMatrix(right, left, ret, true);
            }
            ret = CLALibLeftMultBy.prepareReturnMatrix(right, left, ret, true);
            CLALibLeftMultBy.leftMultByCompressedTransposedMatrix(right, left, ret, k);
            return ret;
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed CLA Compressed Transposed LMM", e);
        }
    }

    public static MatrixBlock leftMultByMatrix(CompressedMatrixBlock right, MatrixBlock left, MatrixBlock ret, int k) {
        try {
            if (left.isEmpty() || right.isEmpty()) {
                return CLALibLeftMultBy.prepareEmptyReturnMatrix(right, left, ret, false);
            }
            if (CLALibSelectionMult.isSelectionMatrix(left)) {
                return CLALibSelectionMult.leftSelection(right, left, ret, k);
            }
            ret = CLALibLeftMultBy.prepareReturnMatrix(right, left, ret, false);
            ret = CLALibLeftMultBy.LMM(right.getColGroups(), left, ret, k, right.isOverlapping());
            return ret;
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed CLA LMM", e);
        }
    }

    private static MatrixBlock prepareEmptyReturnMatrix(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean doTranspose) {
        int numRowsOutput = doTranspose ? m2.getNumColumns() : m2.getNumRows();
        int numColumnsOutput = m1.getNumColumns();
        if (ret == null) {
            ret = new MatrixBlock(numRowsOutput, numColumnsOutput, true, 0L);
        } else if (ret.getNumColumns() != numColumnsOutput || ret.getNumRows() != numRowsOutput || !ret.isAllocated()) {
            ret.reset(numRowsOutput, numColumnsOutput, true, 0L);
        }
        return ret;
    }

    private static MatrixBlock prepareReturnMatrix(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean doTranspose) {
        int numRowsOutput = doTranspose ? m2.getNumColumns() : m2.getNumRows();
        int numColumnsOutput = m1.getNumColumns();
        if (ret == null) {
            ret = new MatrixBlock(numRowsOutput, numColumnsOutput, false, numRowsOutput * numColumnsOutput);
        } else if (ret.getNumColumns() != numColumnsOutput || ret.getNumRows() != numRowsOutput || !ret.isAllocated()) {
            ret.reset(numRowsOutput, numColumnsOutput, false, numRowsOutput * numColumnsOutput);
        }
        ret.allocateDenseBlock();
        return ret;
    }

    private static MatrixBlock leftMultByCompressedTransposedMatrix(CompressedMatrixBlock right, CompressedMatrixBlock left, MatrixBlock ret, int k) throws Exception {
        if (k > 1) {
            return CLALibLeftMultBy.leftMultByCompressedTransposedMatrixParallel(right, left, ret, k);
        }
        return CLALibLeftMultBy.leftMultByCompressedTransposedMatrixSingleThread(right, left, ret);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static MatrixBlock leftMultByCompressedTransposedMatrixParallel(CompressedMatrixBlock right, CompressedMatrixBlock left, MatrixBlock ret, int k) throws Exception {
        int sd = right.getNumRows();
        int cr = right.getNumColumns();
        int rl = left.getNumColumns();
        List<AColGroup> rightCG = right.getColGroups();
        List<AColGroup> leftCG = left.getColGroups();
        boolean containsRight = CLALibUtils.shouldPreFilter(rightCG);
        double[] cR = containsRight ? new double[cr] : null;
        List<AColGroup> fRight = CLALibUtils.filterGroups(rightCG, cR);
        boolean containsLeft = CLALibUtils.shouldPreFilter(leftCG);
        double[] cL = containsLeft ? new double[rl] : null;
        List<AColGroup> fLeft = CLALibUtils.filterGroups(leftCG, cL);
        ret.allocateDenseBlock();
        ret.setNonZeros((long)ret.getNumRows() * (long)ret.getNumColumns());
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            ArrayList<Future<MatrixBlock>> t = new ArrayList<Future<MatrixBlock>>();
            int j = 0;
            while (j < fLeft.size()) {
                int n = j++;
                t.add(pool.submit(() -> {
                    MatrixBlock retT = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(), false);
                    retT.allocateDenseBlock();
                    for (int i = 0; i < fRight.size(); ++i) {
                        ((AColGroup)fRight.get(i)).leftMultByAColGroup((AColGroup)fLeft.get(jj), retT, sd);
                    }
                    retT.examSparsity(true);
                    return retT;
                }));
            }
            if (containsLeft && containsRight) {
                CLALibLeftMultBy.outerProductWithScaling(cL, cR, sd, ret);
            }
            if (containsLeft) {
                for (Future<?> future : CLALibLeftMultBy.outerProductParallelTasks(cL, CLALibUtils.getColSum(fRight, cr, sd), ret, pool)) {
                    future.get();
                }
            }
            if (containsRight) {
                for (Future<?> future : CLALibLeftMultBy.outerProductParallelTasks(CLALibUtils.getColSum(fLeft, rl, sd), cR, ret, pool)) {
                    future.get();
                }
            }
            for (Future future : t) {
                MatrixBlock mb = (MatrixBlock)future.get();
                if (mb.isEmpty()) continue;
                if (mb.isInSparseFormat()) {
                    LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new BinaryOperator(Plus.getPlusFnObject()));
                    continue;
                }
                if (mb.getDenseBlock().isContiguous()) {
                    double[] retV = ret.getDenseBlockValues();
                    LibMatrixMult.vectAdd(mb.getDenseBlockValues(), retV, 0, 0, retV.length);
                    continue;
                }
                LibMatrixBincell.bincellOpInPlaceRight(ret, mb, new BinaryOperator(Plus.getPlusFnObject()));
            }
            ret.recomputeNonZeros(k);
        }
        finally {
            pool.shutdown();
        }
        return ret;
    }

    private static MatrixBlock leftMultByCompressedTransposedMatrixSingleThread(CompressedMatrixBlock right, CompressedMatrixBlock left, MatrixBlock ret) {
        int sd = right.getNumRows();
        int cr = right.getNumColumns();
        int rl = left.getNumColumns();
        List<AColGroup> rightCG = right.getColGroups();
        List<AColGroup> leftCG = left.getColGroups();
        boolean containsRight = CLALibUtils.shouldPreFilter(rightCG);
        double[] cR = containsRight ? new double[cr] : null;
        List<AColGroup> fRight = CLALibUtils.filterGroups(rightCG, cR);
        boolean containsLeft = CLALibUtils.shouldPreFilter(leftCG);
        double[] cL = containsLeft ? new double[rl] : null;
        List<AColGroup> fLeft = CLALibUtils.filterGroups(leftCG, cL);
        ret.setNonZeros((long)ret.getNumRows() * (long)ret.getNumColumns());
        ret.allocateDenseBlock();
        for (int j = 0; j < fLeft.size(); ++j) {
            for (int i = 0; i < fRight.size(); ++i) {
                fRight.get(i).leftMultByAColGroup(fLeft.get(j), ret, sd);
            }
        }
        if (containsLeft && containsRight) {
            CLALibLeftMultBy.outerProductWithScaling(cL, cR, sd, ret);
        }
        if (containsLeft) {
            CLALibLeftMultBy.outerProductSingleThread(cL, CLALibUtils.getColSum(fRight, cr, sd), ret);
        }
        if (containsRight) {
            CLALibLeftMultBy.outerProductSingleThread(CLALibUtils.getColSum(fLeft, rl, sd), cR, ret);
        }
        ret.recomputeNonZeros();
        return ret;
    }

    private static MatrixBlock LMM(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int k, boolean overlapping) throws Exception {
        int numColumnsOut = ret.getNumColumns();
        int lr = that.getNumRows();
        boolean shouldFilter = CLALibUtils.shouldPreFilter(colGroups);
        ArrayList<AColGroup> noPreAggGroups = new ArrayList<AColGroup>();
        ArrayList<APreAgg> preAggGroups = new ArrayList<APreAgg>();
        if (shouldFilter) {
            double[] rowSums;
            double[] constV = new double[numColumnsOut];
            CLALibUtils.filterGroupsAndSplitPreAgg(colGroups, constV, noPreAggGroups, preAggGroups);
            if (!noPreAggGroups.isEmpty() || !preAggGroups.isEmpty()) {
                int sizeSum = preAggGroups.size() + noPreAggGroups.size();
                rowSums = new double[lr];
                if (k == 1 || sizeSum == 1) {
                    CLALibLeftMultBy.LMMTaskExec(noPreAggGroups, preAggGroups, that, ret, 0, lr, rowSums);
                } else {
                    CLALibLeftMultBy.LMMParallel(noPreAggGroups, preAggGroups, that, ret, rowSums, overlapping, k);
                }
            } else {
                rowSums = that.rowSum(k).getDenseBlockValues();
            }
            if (rowSums != null) {
                if (ret.isEmpty()) {
                    ret.allocateDenseBlock();
                } else {
                    ret.sparseToDense();
                }
                CLALibLeftMultBy.outerProduct(rowSums, constV, ret, k);
            }
        } else {
            CLALibUtils.splitPreAgg(colGroups, noPreAggGroups, preAggGroups);
            if (k == 1 || colGroups.size() == 1) {
                CLALibLeftMultBy.LMMTaskExec(noPreAggGroups, preAggGroups, that, ret, 0, lr, null);
            } else {
                CLALibLeftMultBy.LMMParallel(noPreAggGroups, preAggGroups, that, ret, null, overlapping, k);
            }
        }
        ret.recomputeNonZeros(k);
        ret.examSparsity(k);
        return ret;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void LMMParallel(List<AColGroup> npa, List<APreAgg> pa, MatrixBlock that, MatrixBlock ret, double[] rowSums, boolean overlapping, int k) throws Exception {
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            boolean useTmp;
            int nG = npa.size() + pa.size();
            boolean bl = useTmp = overlapping && nG > 1 || nG * 2 < k && ret.getNumColumns() < 1000;
            if (!useTmp) {
                CLALibLeftMultBy.LMMParallelNoTempOut(npa, pa, that, ret, rowSums, overlapping, k, pool);
            } else {
                CLALibLeftMultBy.LMMParallelTempOut(npa, pa, that, ret, rowSums, overlapping, k, pool);
            }
        }
        finally {
            pool.shutdown();
        }
    }

    private static void LMMParallelNoTempOut(List<AColGroup> npa, List<APreAgg> pa, MatrixBlock that, MatrixBlock ret, double[] rowSums, boolean overlapping, int k, ExecutorService pool) throws Exception {
        int s = Math.min(pa.size(), k);
        int rt = that.getNumRows();
        int ct = that.getNumColumns();
        int rowBlockSize = Math.max(rt / k, 1);
        ArrayList tasks = new ArrayList();
        for (int blo = 0; blo < rt; blo += rowBlockSize) {
            int n = blo;
            int end = Math.min(blo + rowBlockSize, rt);
            CLALibLeftMultBy.LLMNoTempOutRowBlockTasks(npa, pa, that, ret, rowSums, pool, s, ct, tasks, n, end, k);
        }
        for (Future future : tasks) {
            future.get();
        }
    }

    private static void LLMNoTempOutRowBlockTasks(List<AColGroup> npa, List<APreAgg> pa, MatrixBlock that, MatrixBlock ret, double[] rowSums, ExecutorService pool, int s, int ct, ArrayList<Future<?>> tasks, int start, int end, int k) {
        for (AColGroup g : npa) {
            CLALibLeftMultBy.noTmpNoAggGroups(that, ret, pool, ct, tasks, start, end, g, k);
        }
        int off = 0;
        while (off < s) {
            int offT = off++;
            tasks.add(pool.submit(() -> CLALibLeftMultBy.LMMWithPreAgg(pa, that, ret, start, end, 0, ct, offT, s, null)));
        }
        if (rowSums != null) {
            tasks.add(pool.submit(() -> CLALibLeftMultBy.rowSum(that, rowSums, start, end, 0, ct)));
        }
    }

    private static void noTmpNoAggGroups(MatrixBlock that, MatrixBlock ret, ExecutorService pool, int ct, ArrayList<Future<?>> tasks, int start, int end, AColGroup g, int k) {
        ArrayList<Future<MatrixBlock>> npaSubTask = new ArrayList<Future<MatrixBlock>>();
        int retNRow = ret.getNumRows();
        int retNCol = ret.getNumColumns();
        if (retNCol < 1000000) {
            int colBlockSize = Math.max(ct / Math.max(k, 2), 64000);
            for (int bloC = 0; bloC < ct; bloC += colBlockSize) {
                int startC = bloC;
                int endC = Math.min(bloC + colBlockSize, ct);
                npaSubTask.add(pool.submit(() -> {
                    Timing t = new Timing();
                    double[] tmp = new double[retNRow * retNCol];
                    MatrixBlock tmpBlock = new MatrixBlock(retNRow, retNCol, tmp);
                    g.leftMultByMatrixNoPreAgg(that, tmpBlock, start, end, startC, endC);
                    LOG.debug((Object)("noPreAggTiming: " + t));
                    return tmpBlock;
                }));
            }
            tasks.add(pool.submit(() -> CLALibLeftMultBy.addInPlaceFuture(ret, npaSubTask)));
        } else {
            tasks.add(pool.submit(() -> g.leftMultByMatrixNoPreAgg(that, ret, start, end, 0, ct)));
        }
    }

    private static Object addInPlaceFuture(MatrixBlock ret, List<Future<MatrixBlock>> npaSubTask) throws Exception {
        for (Future<MatrixBlock> f : npaSubTask) {
            CLALibLeftMultBy.addInPlace(f.get(), ret);
        }
        return null;
    }

    private static void LMMParallelTempOut(List<AColGroup> npa, List<APreAgg> pa, MatrixBlock that, MatrixBlock ret, double[] rowSums, boolean overlapping, int k, ExecutorService pool) throws Exception {
        int rt = that.getNumRows();
        int ct = that.getNumColumns();
        int rowBlockSize = Math.max(rt / k, 1);
        int threadsUsedOnRows = (int)Math.ceil((double)rt / (double)rowBlockSize);
        k = Math.max(1, k / threadsUsedOnRows);
        int s = Math.min(npa.size() + pa.size(), k);
        k = Math.max(1, k / s);
        int colBlockSize = Math.max(ct / k, 64000);
        int threadsUsedOnColBlocks = (int)Math.ceil((double)ct / (double)colBlockSize);
        k /= threadsUsedOnColBlocks;
        ArrayList<Future<MatrixBlock>> tasks = new ArrayList<Future<MatrixBlock>>();
        int retCols = ret.getNumColumns();
        int retRows = ret.getNumRows();
        for (int blo = 0; blo < rt; blo += rowBlockSize) {
            int start = blo;
            int end = Math.min(blo + rowBlockSize, rt);
            for (AColGroup g : npa) {
                tasks.add(pool.submit(new LMMNoPreAggTask(g, that, retRows, retCols, start, end)));
            }
            for (int off = 0; off < s; ++off) {
                int offT = off;
                if (that.isInSparseFormat()) {
                    tasks.add(pool.submit(new LMMPreAggTask(pa, that, retRows, retCols, start, end, 0, ct, offT, s, null)));
                    continue;
                }
                for (int bloC = 0; bloC < ct; bloC += colBlockSize) {
                    int startC = bloC;
                    int endC = Math.min(startC + colBlockSize, ct);
                    tasks.add(pool.submit(new LMMPreAggTask(pa, that, retRows, retCols, start, end, startC, endC, offT, s, null)));
                }
            }
            if (rowSums == null) continue;
            tasks.add(pool.submit(new LMMRowSums(that, start, end, rowSums)));
        }
        CLALibLeftMultBy.addInPlaceFuture(ret, tasks);
    }

    private static Object addInPlace(MatrixBlock a, MatrixBlock out) throws Exception {
        if (a != null) {
            DenseBlock dba = a.getDenseBlock();
            DenseBlock dbb = out.getDenseBlock();
            int blocks = dba.numBlocks();
            for (int b = 0; b < blocks; ++b) {
                double[] av = dba.valuesAt(b);
                double[] bv = dbb.valuesAt(b);
                int len = av.length;
                for (int i = 0; i < len; ++i) {
                    int n = i;
                    bv[n] = bv[n] + av[i];
                }
            }
        }
        return null;
    }

    private static void LMMTaskExec(List<AColGroup> npa, List<APreAgg> pa, MatrixBlock that, MatrixBlock ret, int rl, int ru, double[] rowSums) throws Exception {
        int cu = that.getNumColumns();
        if (npa.isEmpty() && pa.isEmpty()) {
            CLALibLeftMultBy.rowSum(that, rowSums, rl, ru, 0, cu);
            return;
        }
        for (int r = rl; r < ru; r += 4) {
            int re = Math.min(r + 4, ru);
            for (int i = 0; i < npa.size(); ++i) {
                npa.get(i).leftMultByMatrixNoPreAgg(that, ret, r, re, 0, cu);
            }
            if (pa.size() <= 0) continue;
            CLALibLeftMultBy.LMMWithPreAgg(pa, that, ret, r, re, 0, cu, 0, 1, rowSums);
        }
    }

    private static void outerProduct(double[] leftRowSum, double[] rightColumnSum, MatrixBlock result, int k) throws InterruptedException, ExecutionException {
        if (k > 1) {
            CLALibLeftMultBy.outerProductParallel(leftRowSum, rightColumnSum, result, k);
        } else {
            CLALibLeftMultBy.outerProductSingleThread(leftRowSum, rightColumnSum, result);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void outerProductParallel(double[] leftRowSum, double[] rightColumnSum, MatrixBlock result, int k) throws InterruptedException, ExecutionException {
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            for (Future<?> t : CLALibLeftMultBy.outerProductParallelTasks(leftRowSum, rightColumnSum, result, pool)) {
                t.get();
            }
        }
        finally {
            pool.shutdown();
        }
    }

    private static void outerProductRange(double[] leftRowSum, double[] rightColumnSum, MatrixBlock result, int rl, int ru, int cl, int cu) {
        if (result.getDenseBlock().isContiguous()) {
            CLALibLeftMultBy.outerProductRangeContiguous(leftRowSum, rightColumnSum, result.getDenseBlockValues(), rl, ru, cl, cu);
        } else {
            CLALibLeftMultBy.outerProductRangeGeneric(leftRowSum, rightColumnSum, result.getDenseBlock(), rl, ru, cl, cu);
        }
    }

    private static void outerProductRangeContiguous(double[] leftRowSum, double[] rightColumnSum, double[] result, int rl, int ru, int cl, int cu) {
        for (int row = rl; row < ru; ++row) {
            int offOut = rightColumnSum.length * row;
            double vLeft = leftRowSum[row];
            if (vLeft == 0.0) continue;
            for (int col = cl; col < cu; ++col) {
                int n = offOut + col;
                result[n] = result[n] + vLeft * rightColumnSum[col];
            }
        }
    }

    private static void outerProductRangeGeneric(double[] leftRowSum, double[] rightColumnSum, DenseBlock res, int rl, int ru, int cl, int cu) {
        for (int row = rl; row < ru; ++row) {
            int offOut = res.pos(row);
            double[] result = res.values(row);
            double vLeft = leftRowSum[row];
            if (vLeft == 0.0) continue;
            for (int col = cl; col < cu; ++col) {
                int n = offOut + col;
                result[n] = result[n] + vLeft * rightColumnSum[col];
            }
        }
    }

    private static void outerProductSingleThread(double[] leftRowSum, double[] rightColumnSum, MatrixBlock result) {
        int blkz = 1024;
        for (int row = 0; row < leftRowSum.length; row += 1024) {
            int rl = row;
            int ru = Math.min(leftRowSum.length, row + 1024);
            int colBz = CLALibLeftMultBy.outerProdGetColBz(1024, row, rl, ru);
            for (int col = 0; col < rightColumnSum.length; col += colBz) {
                int cl = col;
                int cu = Math.min(rightColumnSum.length, col + colBz);
                CLALibLeftMultBy.outerProductRange(leftRowSum, rightColumnSum, result, rl, ru, cl, cu);
            }
        }
    }

    private static List<Future<?>> outerProductParallelTasks(double[] leftRowSum, double[] rightColumnSum, MatrixBlock result, ExecutorService pool) {
        int blkz = 1024;
        ArrayList tasks = new ArrayList();
        for (int row = 0; row < leftRowSum.length; row += 1024) {
            int rl = row;
            int ru = Math.min(leftRowSum.length, row + 1024);
            int colBz = CLALibLeftMultBy.outerProdGetColBz(1024, row, rl, ru);
            for (int col = 0; col < rightColumnSum.length; col += colBz) {
                int cl = col;
                int cu = Math.min(rightColumnSum.length, col + colBz);
                tasks.add(pool.submit(() -> CLALibLeftMultBy.outerProductRange(leftRowSum, rightColumnSum, result, rl, ru, cl, cu)));
            }
        }
        return tasks;
    }

    private static int outerProdGetColBz(int blkz, int row, int rl, int ru) {
        int colBz = ru < row + blkz ? 0x100000 - (ru - rl) * 1024 + 1024 : blkz;
        return colBz;
    }

    private static void outerProductWithScaling(double[] leftRowSum, double[] rightColumnSum, int scaling, MatrixBlock result) {
        if (result.getDenseBlock().isContiguous()) {
            CLALibLeftMultBy.outerProductWithScalingContiguous(leftRowSum, rightColumnSum, scaling, result.getDenseBlockValues());
        } else {
            CLALibLeftMultBy.outerProductWithScalingGeneric(leftRowSum, rightColumnSum, scaling, result.getDenseBlock());
        }
    }

    private static void outerProductWithScalingContiguous(double[] leftRowSum, double[] rightColumnSum, int scaling, double[] result) {
        for (int row = 0; row < leftRowSum.length; ++row) {
            int offOut = rightColumnSum.length * row;
            double vLeft = leftRowSum[row] * (double)scaling;
            for (int col = 0; col < rightColumnSum.length; ++col) {
                int n = offOut + col;
                result[n] = result[n] + vLeft * rightColumnSum[col];
            }
        }
    }

    private static void outerProductWithScalingGeneric(double[] leftRowSum, double[] rightColumnSum, int scaling, DenseBlock res) {
        for (int row = 0; row < leftRowSum.length; ++row) {
            int offOut = res.pos(row);
            double[] result = res.values(row);
            double vLeft = leftRowSum[row] * (double)scaling;
            for (int col = 0; col < rightColumnSum.length; ++col) {
                int n = offOut + col;
                result[n] = result[n] + vLeft * rightColumnSum[col];
            }
        }
    }

    private static void LMMWithPreAgg(List<APreAgg> preAggCGs, MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu, int off, int skip, double[] rowSums) {
        try {
            if (!that.isInSparseFormat()) {
                CLALibLeftMultBy.LMMWithPreAggDense(preAggCGs, that, ret, rl, ru, cl, cu, off, skip, rowSums);
            } else {
                CLALibLeftMultBy.LMMWithPreAggSparse(preAggCGs, that, ret, rl, ru, cl, cu, off, skip, rowSums);
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Failed LLM pre aggregate", e);
        }
    }

    private static void LMMWithPreAggSparse(List<APreAgg> preAggCGs, MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu, int off, int skip, double[] rowSum) throws Exception {
        MatrixBlock preA = new MatrixBlock();
        MatrixBlock fTmp = new MatrixBlock();
        SparseBlock sb = that.getSparseBlock();
        for (int j = off; j < preAggCGs.size(); j += skip) {
            int nCol = preAggCGs.get(j).getNumCols();
            int nVal = preAggCGs.get(j).getNumValues();
            APreAgg g = preAggCGs.get(j);
            for (int r = rl; r < ru; ++r) {
                CLALibLeftMultBy.preAggSparseRow(that, ret, cl, cu, preA, fTmp, sb, nCol, nVal, g, r);
            }
        }
        if (rowSum != null) {
            CLALibLeftMultBy.rowSumSparse(that.getSparseBlock(), rowSum, rl, ru, cl, cu);
        }
    }

    private static void preAggSparseRow(MatrixBlock that, MatrixBlock ret, int cl, int cu, MatrixBlock preA, MatrixBlock fTmp, SparseBlock sb, int nCol, int nVal, APreAgg g, int r) {
        if (sb.isEmpty(r)) {
            return;
        }
        int rcu = r + 1;
        if (!preA.isAllocated()) {
            preA.reset(1, nVal);
            preA.allocateDenseBlock();
        } else {
            preA.reset(1, nVal);
        }
        CLALibLeftMultBy.allocateOrResetTmpRes(ret, fTmp, 1);
        double[] preAV = preA.getDenseBlockValues();
        preA.setNonZeros(g.getPreAggregateSize());
        fTmp.setNonZeros(1L);
        g.preAggregateSparse(sb, preAV, r, rcu, cl, cu);
        g.mmWithDictionary(preA, fTmp, ret, 1, r, rcu);
    }

    private static void allocateOrResetTmpRes(MatrixBlock ret, MatrixBlock fTmp, int rows) {
        if (!fTmp.isAllocated()) {
            fTmp.reset(rows, ret.getNumColumns());
            fTmp.allocateDenseBlock();
        } else {
            fTmp.reset(rows, ret.getNumColumns());
        }
    }

    private static void LMMWithPreAggDense(List<APreAgg> preAggCGs, MatrixBlock that, MatrixBlock ret, int rl, int ru, int cl, int cu, int off, int skip, double[] rowSum) throws InterruptedException, ExecutionException {
        int colBZ = 2048;
        int rowBlockSize = 4;
        int colGroupBlocking = 4;
        int nColGroups = preAggCGs.size();
        double[][] preAgg = new double[4][];
        MatrixBlock tmpRes = new MatrixBlock();
        for (int rlt = rl; rlt < ru; rlt += 4) {
            int rut = Math.min(rlt + 4, ru);
            for (int gl = off; gl < nColGroups; gl += 4 * skip) {
                int gu = Math.min(gl + 4 * skip, nColGroups);
                int j = gl;
                int p = 0;
                while (j < gu) {
                    CLALibLeftMultBy.preAllocate(preAggCGs, j, rut, rlt, preAgg, p);
                    j += skip;
                    ++p;
                }
                for (int clt = cl; clt < cu; clt += 2048) {
                    int cut = Math.min(clt + 2048, cu);
                    int j2 = gl;
                    int p2 = 0;
                    while (j2 < gu) {
                        CLALibLeftMultBy.preAggregate(that, ret, preAggCGs, rut, rlt, clt, cut, j2, preAgg, p2);
                        j2 += skip;
                        ++p2;
                    }
                    if (gu != nColGroups) continue;
                    CLALibLeftMultBy.rowSum(that, rowSum, rlt, rut, clt, cut);
                }
                j = gl;
                p = 0;
                while (j < gu) {
                    APreAgg cg = preAggCGs.get(j);
                    if (!(cg.getDictionary() instanceof AIdentityDictionary)) {
                        CLALibLeftMultBy.allocateOrResetTmpRes(ret, tmpRes, 4);
                        CLALibLeftMultBy.postMultiply(ret, tmpRes, preAgg, p, cg, rut, rlt);
                    }
                    j += skip;
                    ++p;
                }
            }
        }
    }

    private static void preAllocate(List<APreAgg> preAggCGs, int j, int rut, int rlt, double[][] preAgg, int p) {
        APreAgg cg = preAggCGs.get(j);
        if (cg.getDictionary() instanceof AIdentityDictionary) {
            return;
        }
        int preAggNCol = cg.getPreAggregateSize();
        int len = (rut - rlt) * preAggNCol;
        if (preAgg[p] == null || preAgg[p].length < len) {
            preAgg[p] = new double[len];
        } else {
            Arrays.fill(preAgg[p], 0, (rut - rlt) * preAggNCol, 0.0);
        }
    }

    private static void preAggregate(MatrixBlock that, MatrixBlock ret, List<APreAgg> preAggCGs, int rut, int rlt, int clt, int cut, int j, double[][] preAgg, int p) {
        APreAgg cg = preAggCGs.get(j);
        if (cg.getDictionary() instanceof IdentityDictionary) {
            cg.leftMMIdentityPreAggregateDense(that, ret, rlt, rut, clt, cut);
        } else {
            cg.preAggregateDense(that, preAgg[p], rlt, rut, clt, cut);
        }
    }

    private static void postMultiply(MatrixBlock ret, MatrixBlock tmpRes, double[][] preAgg, int p, APreAgg cg, int rut, int rlt) {
        int preAggNCol = cg.getPreAggregateSize();
        MatrixBlock preAggThis = new MatrixBlock(rut - rlt, preAggNCol, preAgg[p]);
        cg.mmWithDictionary(preAggThis, tmpRes, ret, 1, rlt, rut);
    }

    public static double[] rowSum(MatrixBlock mb, int rl, int ru, int cl, int cu) {
        double[] ret = new double[ru];
        CLALibLeftMultBy.rowSum(mb, ret, rl, ru, cl, cu);
        return ret;
    }

    private static void rowSum(MatrixBlock mb, double[] rowSum, int rl, int ru, int cl, int cu) {
        if (mb.isEmpty()) {
            throw new DMLCompressionException("Invalid empty block to rowsum");
        }
        if (rowSum == null) {
            return;
        }
        if (mb.isInSparseFormat()) {
            CLALibLeftMultBy.rowSumSparse(mb.getSparseBlock(), rowSum, rl, ru, cl, cu);
        } else {
            CLALibLeftMultBy.rowSumDense(mb, rowSum, rl, ru, cl, cu);
        }
    }

    private static void rowSumSparse(SparseBlock sb, double[] rowSum, int rl, int ru, int cl, int cu) {
        for (int i = rl; i < ru; ++i) {
            CLALibLeftMultBy.rowSumSparseSingleRow(sb, rowSum, cl, cu, i);
        }
    }

    private static void rowSumSparseSingleRow(SparseBlock sb, double[] rowSum, int cl, int cu, int i) {
        int j;
        if (sb.isEmpty(i)) {
            return;
        }
        int apos = sb.pos(i);
        int alen = sb.size(i) + apos;
        double[] aval = sb.values(i);
        int[] aix = sb.indexes(i);
        for (j = apos; j < alen && aix[j] < cl; ++j) {
        }
        if (aix[alen - 1] < cu) {
            while (j < alen) {
                int n = i;
                rowSum[n] = rowSum[n] + aval[j++];
            }
        } else {
            while (j < alen && aix[j] < cu) {
                int n = i;
                rowSum[n] = rowSum[n] + aval[j++];
            }
        }
    }

    private static void rowSumDense(MatrixBlock that, double[] rowSum, int rl, int ru, int cl, int cu) {
        DenseBlock db = that.getDenseBlock();
        if (db.isContiguous()) {
            double[] thatV = db.values(0);
            for (int r = rl; r < ru; ++r) {
                CLALibLeftMultBy.rowSumDenseSingleRow(rowSum, cl, cu, db, thatV, r);
            }
        } else {
            for (int r = rl; r < ru; ++r) {
                double[] thatV = db.values(r);
                CLALibLeftMultBy.rowSumDenseSingleRow(rowSum, cl, cu, db, thatV, r);
            }
        }
    }

    private static void rowSumDenseSingleRow(double[] rowSum, int cl, int cu, DenseBlock db, double[] thatV, int r) {
        int rowOff = db.pos(r);
        double tmp = 0.0;
        for (int c = rowOff + cl; c < rowOff + cu; ++c) {
            tmp += thatV[c];
        }
        int n = r;
        rowSum[n] = rowSum[n] + tmp;
    }

    private static class LMMRowSums
    implements Callable<MatrixBlock> {
        private final MatrixBlock _that;
        private final int _rl;
        private final int _ru;
        private final double[] _rowSums;

        protected LMMRowSums(MatrixBlock that, int rl, int ru, double[] rowSums) {
            this._that = that;
            this._rl = rl;
            this._ru = ru;
            this._rowSums = rowSums;
        }

        @Override
        public MatrixBlock call() throws Exception {
            if (this._that.isInSparseFormat()) {
                CLALibLeftMultBy.rowSumSparse(this._that.getSparseBlock(), this._rowSums, this._rl, this._ru, 0, this._that.getNumColumns());
            } else {
                CLALibLeftMultBy.rowSumDense(this._that, this._rowSums, this._rl, this._ru, 0, this._that.getNumColumns());
            }
            return null;
        }
    }

    private static class LMMNoPreAggTask
    implements Callable<MatrixBlock> {
        private final AColGroup _cg;
        private final MatrixBlock _that;
        private final MatrixBlock _ret;
        private final int _rl;
        private final int _ru;

        protected LMMNoPreAggTask(AColGroup cg, MatrixBlock that, int retR, int retC, int rl, int ru) {
            this._cg = cg;
            this._that = that;
            this._ret = new MatrixBlock(retR, retC, false);
            this._ret.allocateDenseBlock();
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public MatrixBlock call() throws Exception {
            this._cg.leftMultByMatrixNoPreAgg(this._that, this._ret, this._rl, this._ru, 0, this._that.getNumColumns());
            return this._ret;
        }
    }

    private static class LMMPreAggTask
    implements Callable<MatrixBlock> {
        private final List<APreAgg> _pa;
        private final MatrixBlock _that;
        private final int _retR;
        private final int _retC;
        private final int _rl;
        private final int _ru;
        private final int _cl;
        private final int _cu;
        private final double[] _rowSums;
        private final int _off;
        private final int _skip;

        protected LMMPreAggTask(List<APreAgg> pa, MatrixBlock that, int retR, int retC, int rl, int ru, int cl, int cu, int off, int skip, double[] rowSums) {
            this._pa = pa;
            this._that = that;
            this._retR = retR;
            this._retC = retC;
            this._rl = rl;
            this._ru = ru;
            this._cl = cl;
            this._cu = cu;
            this._rowSums = rowSums;
            this._off = off;
            this._skip = skip;
        }

        @Override
        public MatrixBlock call() throws Exception {
            double[] tmpArr = new double[this._retR * this._retC];
            MatrixBlock _ret = new MatrixBlock(this._retR, this._retC, tmpArr);
            CLALibLeftMultBy.LMMWithPreAgg(this._pa, this._that, _ret, this._rl, this._ru, this._cl, this._cu, this._off, this._skip, this._rowSums);
            return _ret;
        }
    }
}

