/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.utils.stats.InfrastructureAnalyzer;

public class CLALibReshape {
    protected static final Log LOG = LogFactory.getLog((String)CLALibReshape.class.getName());
    public static int COMPRESSED_RESHAPE_THRESHOLD = 1000;
    final CompressedMatrixBlock in;
    final int clen;
    final int rlen;
    final int rows;
    final int cols;
    final boolean rowwise;
    final ExecutorService pool;

    private CLALibReshape(CompressedMatrixBlock in, int rows, int cols, boolean rowwise, int k) {
        this.in = in;
        this.rlen = in.getNumRows();
        this.clen = in.getNumColumns();
        this.rows = rows;
        this.cols = cols;
        this.rowwise = rowwise;
        this.pool = k > 1 ? CommonThreadPool.get(k) : null;
    }

    public static MatrixBlock reshape(CompressedMatrixBlock in, int rows, int cols, boolean rowwise) {
        return new CLALibReshape(in, rows, cols, rowwise, InfrastructureAnalyzer.getLocalParallelism()).apply();
    }

    public static MatrixBlock reshape(CompressedMatrixBlock in, int rows, int cols, boolean rowwise, int k) {
        return new CLALibReshape(in, rows, cols, rowwise, k).apply();
    }

    private MatrixBlock apply() {
        try {
            this.checkValidity();
            if (this.shouldItBeCompressedOutputs()) {
                MatrixBlock matrixBlock = this.applyCompressed();
                return matrixBlock;
            }
            MatrixBlock matrixBlock = this.in.decompress().reshape(this.rows, this.cols, this.rowwise);
            return matrixBlock;
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed reshaping of compressed matrix", e);
        }
        finally {
            if (this.pool != null) {
                this.pool.shutdown();
            }
        }
    }

    private MatrixBlock applyCompressed() throws Exception {
        int multiplier = this.rlen / this.rows;
        List<AColGroup> retGroups = this.pool == null ? this.applySingleThread(multiplier) : (this.in.getColGroups().size() == 1 ? this.applyParallelPushDown(multiplier) : this.applyParallel(multiplier));
        CompressedMatrixBlock ret = new CompressedMatrixBlock(this.rows, this.cols);
        ret.allocateColGroupList(retGroups);
        ret.setNonZeros(this.in.getNonZeros());
        return ret;
    }

    private List<AColGroup> applySingleThread(int multiplier) {
        List<AColGroup> groups = this.in.getColGroups();
        ArrayList<AColGroup> retGroups = new ArrayList<AColGroup>(groups.size() * multiplier);
        for (AColGroup g : groups) {
            AColGroup[] tg = g.splitReshape(multiplier, this.rlen, this.clen);
            for (int i = 0; i < tg.length; ++i) {
                retGroups.add(tg[i]);
            }
        }
        return retGroups;
    }

    private List<AColGroup> applyParallelPushDown(int multiplier) throws Exception {
        List<AColGroup> groups = this.in.getColGroups();
        ArrayList<AColGroup> retGroups = new ArrayList<AColGroup>(groups.size() * multiplier);
        for (AColGroup g : groups) {
            AColGroup[] tg = g.splitReshapePushDown(multiplier, this.rlen, this.clen, this.pool);
            for (int i = 0; i < tg.length; ++i) {
                retGroups.add(tg[i]);
            }
        }
        return retGroups;
    }

    private List<AColGroup> applyParallel(int multiplier) throws Exception {
        List<AColGroup> groups = this.in.getColGroups();
        ArrayList<Future<AColGroup[]>> tasks = new ArrayList<Future<AColGroup[]>>(groups.size());
        for (AColGroup g : groups) {
            tasks.add(this.pool.submit(() -> g.splitReshape(multiplier, this.rlen, this.clen)));
        }
        ArrayList<AColGroup> retGroups = new ArrayList<AColGroup>(groups.size() * multiplier);
        for (Future future : tasks) {
            AColGroup[] tg = (AColGroup[])future.get();
            for (int i = 0; i < tg.length; ++i) {
                retGroups.add(tg[i]);
            }
        }
        return retGroups;
    }

    private void checkValidity() {
        if ((long)this.rlen * (long)this.clen != (long)this.rows * (long)this.cols) {
            throw new DMLRuntimeException("Reshape matrix requires consistent numbers of input/output cells (" + this.rlen + ":" + this.clen + ", " + this.rows + ":" + this.cols + ").");
        }
    }

    private boolean shouldItBeCompressedOutputs() {
        return this.rlen > COMPRESSED_RESHAPE_THRESHOLD && this.rowwise && (double)this.rlen / (double)this.rows % 1.0 == 0.0;
    }
}

