/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.matrix.data.DnnParameters;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNNHelper;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

public class LibMatrixDNNPooling {
    protected static final Log LOG = LogFactory.getLog((String)LibMatrixDNNPooling.class.getName());

    public static ArrayList<Callable<Long>> getPoolingWorkers(DnnParameters params, LibMatrixDNN.PoolingType poolType) {
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k / 2.0);
        int i = 0;
        while (i * taskSize < params.N) {
            if (params.input1.isInSparseFormat()) {
                ret.add(new SparsePooling(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, poolType));
            } else {
                ret.add(new DensePooling(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, poolType));
            }
            ++i;
        }
        return ret;
    }

    public static ArrayList<Callable<Long>> getPoolingBackwardWorkers(DnnParameters params, boolean performReluBackward, LibMatrixDNN.PoolingType poolType) {
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k / 2.0);
        if (poolType == LibMatrixDNN.PoolingType.MAX) {
            boolean sparse1 = params.input1.isInSparseFormat();
            boolean sparse2 = params.input2.isInSparseFormat();
            int i = 0;
            while (i * taskSize < params.N) {
                if (!sparse1 && !sparse2) {
                    ret.add(new PoolingBackwardDenseDense(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, performReluBackward));
                } else if (!sparse1 && sparse2) {
                    ret.add(new PoolingBackwardDenseSparse(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, performReluBackward));
                } else if (sparse1 && !sparse2) {
                    ret.add(new PoolingBackwardSparseDense(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, performReluBackward));
                } else if (sparse1 && sparse2) {
                    ret.add(new PoolingBackwardSparseSparse(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, performReluBackward));
                }
                ++i;
            }
        } else {
            boolean sparse = params.input2.isInSparseFormat();
            int i = 0;
            while (i * taskSize < params.N) {
                if (!sparse) {
                    ret.add(new AvgPoolingBackwardDense(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
                } else {
                    ret.add(new AvgPoolingBackwardSparse(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
                }
                ++i;
            }
        }
        return ret;
    }

    public static void poolingDenseStride1Pad0(LibMatrixDNN.PoolingType pType, double minVal, double pFact, double[] in, double[] out, int rl, int ru, int ii, int oi, int C, int P2, int Q, int R, int S, int H, int W) {
        boolean max = pType == LibMatrixDNN.PoolingType.MAX;
        int CHW = C * H * W;
        if (P2 == 1 && Q == 1 && W == 1) {
            int lenh = Math.min(R, H);
            int i = rl;
            while (i < ru) {
                int c = 0;
                int off = ii + (i - rl) * CHW;
                while (c < C) {
                    out[oi + c] = max ? LibMatrixDNNPooling.max(minVal, in, off, lenh) : LibMatrixDNNPooling.avg(minVal, in, off, lenh, pFact);
                    ++c;
                    off += H;
                }
                ++i;
                oi += C;
            }
        } else {
            int CPQ = C * P2 * Q;
            int HW = H * W;
            Arrays.fill(out, rl * CPQ, ru * CPQ, minVal);
            for (int i = rl; i < ru; ++i) {
                int c = 0;
                int off = ii + (i - rl) * CHW;
                int oix = oi + (i - rl) * CPQ;
                while (c < C) {
                    int p = 0;
                    while (p < P2) {
                        for (int h = p; h < Math.min(p + R, H); ++h) {
                            int off2 = off + h * W;
                            for (int q = 0; q < Q; ++q) {
                                out[oix + q] = max ? LibMatrixDNNPooling.max(out[oix + q], in, off2 + q, Math.min(S, W - q)) : LibMatrixDNNPooling.avg(out[oix + q], in, off2 + q, Math.min(S, W - q), pFact);
                            }
                        }
                        ++p;
                        oix += Q;
                    }
                    ++c;
                    off += HW;
                }
            }
        }
    }

    private static double avg(double aval, double[] b, int bi, int len, double poolingMultiplier) {
        return LibSpoofPrimitives.vectSum(b, bi, len) * poolingMultiplier + aval;
    }

    private static double max(double aval, double[] b, int bi, int len) {
        double ret = aval;
        for (int i = bi; i < bi + len; ++i) {
            ret = Math.max(ret, b[i]);
        }
        return ret;
    }

    private static int getMaxIndex(int p, int q, int inputOffset, double[] inputArray, DnnParameters params, boolean performReluBackward) {
        int start_index_h = params.start_indexes_h[p];
        int end_index_h = params.end_indexes_h[p];
        int start_index_w = params.start_indexes_w[q];
        int end_index_w = params.end_indexes_w[q];
        int maxIndex = -1;
        double maxVal = performReluBackward ? 0.0 : Double.NEGATIVE_INFINITY;
        for (int h = start_index_h; h < end_index_h; ++h) {
            for (int w = start_index_w; w < end_index_w; ++w) {
                int idx = inputOffset + h * params.W + w;
                double currDoutVal = inputArray[idx];
                if (!(maxVal < currDoutVal)) continue;
                maxIndex = idx;
                maxVal = currDoutVal;
            }
        }
        return maxVal == 0.0 && performReluBackward ? -1 : maxIndex;
    }

    private static void add(SparseRow row, int[] i, double[] v, int size) {
        LibMatrixDNNPooling.sort(i, v, size);
        for (int x = 0; x < size; ++x) {
            row.append(i[x], v[x]);
        }
    }

    private static void sort(int[] i, double[] v, int size) {
        if (size > 32) {
            LOG.warn((Object)("Not a optimal size for small array sort " + size));
        }
        switch (size) {
            case 1: {
                break;
            }
            case 2: {
                LibMatrixDNNPooling.comp(i, v, 0, 1);
                break;
            }
            case 3: {
                LibMatrixDNNPooling.sort3(i, v);
                break;
            }
            case 4: {
                LibMatrixDNNPooling.sort4(i, v);
                break;
            }
            case 5: {
                LibMatrixDNNPooling.sort5(i, v);
                break;
            }
            case 6: {
                LibMatrixDNNPooling.sort6(i, v);
                break;
            }
            case 7: {
                LibMatrixDNNPooling.sort7(i, v);
                break;
            }
            default: {
                LibMatrixDNNPooling.insertSort(i, v, size);
            }
        }
    }

    private static void sort3(int[] i, double[] v) {
        LibMatrixDNNPooling.comp(i, v, 0, 2);
        LibMatrixDNNPooling.comp(i, v, 0, 1);
        LibMatrixDNNPooling.comp(i, v, 1, 2);
    }

    private static void sort4(int[] i, double[] v) {
        LibMatrixDNNPooling.comp(i, v, 0, 2);
        LibMatrixDNNPooling.comp(i, v, 1, 3);
        LibMatrixDNNPooling.comp(i, v, 0, 1);
        LibMatrixDNNPooling.comp(i, v, 2, 3);
        LibMatrixDNNPooling.comp(i, v, 1, 2);
    }

    private static void sort5(int[] i, double[] v) {
        LibMatrixDNNPooling.comp(i, v, 0, 1);
        LibMatrixDNNPooling.comp(i, v, 2, 3);
        LibMatrixDNNPooling.comp(i, v, 1, 3);
        LibMatrixDNNPooling.comp(i, v, 2, 4);
        LibMatrixDNNPooling.comp(i, v, 1, 4);
        LibMatrixDNNPooling.comp(i, v, 0, 2);
        LibMatrixDNNPooling.comp(i, v, 1, 2);
        LibMatrixDNNPooling.comp(i, v, 3, 4);
        LibMatrixDNNPooling.comp(i, v, 2, 3);
    }

    private static void sort6(int[] i, double[] v) {
        LibMatrixDNNPooling.comp(i, v, 0, 1);
        LibMatrixDNNPooling.comp(i, v, 2, 3);
        LibMatrixDNNPooling.comp(i, v, 4, 5);
        LibMatrixDNNPooling.comp(i, v, 1, 3);
        LibMatrixDNNPooling.comp(i, v, 0, 4);
        LibMatrixDNNPooling.comp(i, v, 1, 3);
        LibMatrixDNNPooling.comp(i, v, 1, 5);
        LibMatrixDNNPooling.comp(i, v, 2, 4);
        LibMatrixDNNPooling.comp(i, v, 1, 2);
        LibMatrixDNNPooling.comp(i, v, 3, 5);
        LibMatrixDNNPooling.comp(i, v, 3, 4);
        LibMatrixDNNPooling.comp(i, v, 2, 3);
    }

    private static void sort7(int[] i, double[] v) {
        LibMatrixDNNPooling.comp(i, v, 0, 1);
        LibMatrixDNNPooling.comp(i, v, 2, 3);
        LibMatrixDNNPooling.comp(i, v, 4, 5);
        LibMatrixDNNPooling.comp(i, v, 0, 6);
        LibMatrixDNNPooling.comp(i, v, 2, 4);
        LibMatrixDNNPooling.comp(i, v, 0, 2);
        LibMatrixDNNPooling.comp(i, v, 1, 3);
        LibMatrixDNNPooling.comp(i, v, 5, 6);
        LibMatrixDNNPooling.comp(i, v, 1, 4);
        LibMatrixDNNPooling.comp(i, v, 2, 5);
        LibMatrixDNNPooling.comp(i, v, 1, 2);
        LibMatrixDNNPooling.comp(i, v, 4, 5);
        LibMatrixDNNPooling.comp(i, v, 2, 4);
        LibMatrixDNNPooling.comp(i, v, 3, 6);
        LibMatrixDNNPooling.comp(i, v, 3, 5);
        LibMatrixDNNPooling.comp(i, v, 3, 4);
    }

    private static void insertSort(int[] i, double[] v, int size) {
        for (int p = 1; p < size; ++p) {
            int k = i[p];
            double t = v[p];
            for (int j = p - 1; j >= 0 && i[j] > k; --j) {
                i[j + 1] = i[j];
                v[j + 1] = v[j];
            }
            i[j + 1] = k;
            v[j + 1] = t;
        }
    }

    private static void comp(int[] i, double[] v, int f, int t) {
        if (i[f] > i[t]) {
            LibMatrixDNNPooling.swap(i, v, f, t);
        }
    }

    private static void swap(int[] i, double[] v, int f, int t) {
        int tmpI = i[f];
        double tmpV = v[f];
        i[f] = i[t];
        v[f] = v[t];
        i[t] = tmpI;
        v[t] = tmpV;
    }

    private static class PoolingBackwardSparseSparse
    extends PoolingBackwardSparseDense {
        public PoolingBackwardSparseSparse(int rl, int ru, DnnParameters params, boolean relu) {
            super(rl, ru, params, relu, params.input2, params.output);
            if (!params.input1.isInSparseFormat() || !params.input2.isInSparseFormat()) {
                throw new RuntimeException("Incorrect usage: Call optimized versions");
            }
        }

        @Override
        protected void maxpoolingBackwardDense(int[] maxIx, int outOffset, int n, int c, int C, int Q, int PQ, int CPQ) {
            SparseBlock sblock = this.doutput.getSparseBlock();
            double[] out = this.output.getDenseBlockValues();
            if (sblock.isEmpty(n)) {
                return;
            }
            int apos = sblock.pos(n);
            int alen = sblock.size(n);
            int[] aix = sblock.indexes(n);
            double[] avals = sblock.values(n);
            int cpos = c == 0 ? 0 : sblock.posFIndexGTE(n, c * PQ);
            int cpos2 = c + 1 == C ? alen : sblock.posFIndexGTE(n, (c + 1) * PQ);
            cpos = cpos >= 0 ? cpos : alen;
            cpos2 = cpos2 >= 0 ? cpos2 : alen;
            for (int j = apos + cpos; j < apos + cpos2; ++j) {
                int p = aix[j] % PQ / Q;
                int q = aix[j] % Q;
                int pq = p * Q + q;
                int n2 = outOffset + maxIx[pq];
                out[n2] = out[n2] + avals[j];
            }
        }

        @Override
        protected void maxpoolingBackwardSparse(int[] maxIx, int offset, int n, int c, int C, int Q, int P2, int CPQ) {
            SparseBlock sblock = this.doutput.getSparseBlock();
            if (sblock.isEmpty(n)) {
                return;
            }
            int PQ = P2 * Q;
            SparseBlock out = this.output.getSparseBlock();
            out.allocate(n, PQ);
            SparseRow row = out.get(n);
            int apos = sblock.pos(n);
            int alen = sblock.size(n);
            int[] aix = sblock.indexes(n);
            double[] avals = sblock.values(n);
            int cpos = c == 0 ? 0 : sblock.posFIndexGTE(n, c * PQ);
            int cpos2 = c + 1 == C ? alen : sblock.posFIndexGTE(n, (c + 1) * PQ);
            cpos = cpos >= 0 ? cpos : alen;
            cpos2 = cpos2 >= 0 ? cpos2 : alen;
            for (int j = apos + cpos; j < apos + cpos2; ++j) {
                int p = aix[j] % PQ / Q;
                int q = aix[j] % Q;
                int pq = p * Q + q;
                row.add(maxIx[pq] + offset, avals[j]);
            }
        }
    }

    private static class PoolingBackwardSparseDense
    implements Callable<Long> {
        private final int _rl;
        private final int _ru;
        private final DnnParameters _params;
        private final boolean reluBack;
        protected final MatrixBlock doutput;
        protected final MatrixBlock output;

        protected PoolingBackwardSparseDense(int rl, int ru, DnnParameters params, boolean relu, MatrixBlock dout, MatrixBlock out) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this.reluBack = relu;
            this.doutput = dout;
            this.output = out;
        }

        public PoolingBackwardSparseDense(int rl, int ru, DnnParameters params, boolean relu) {
            this(rl, ru, params, relu, params.input2, params.output);
            if (this.doutput.getDenseBlock() == null) {
                throw new RuntimeException("Incorrect usage: empty inputs");
            }
            if (!params.input1.isInSparseFormat()) {
                throw new RuntimeException("Incorrect usage: sparse input1 expected");
            }
        }

        @Override
        public Long call() throws Exception {
            int P2 = this._params.P;
            int Q = this._params.Q;
            int W = this._params.W;
            int C = this._params.C;
            int R = this._params.R;
            int S = this._params.S;
            int padh = this._params.pad_h;
            int padw = this._params.pad_w;
            int strideh = this._params.stride_h;
            int stridew = this._params.stride_w;
            int PQ = this._params.P * this._params.Q;
            int CPQ = this._params.C * this._params.P * this._params.Q;
            int HW = this._params.H * this._params.W;
            int CHW = this._params.C * this._params.H * this._params.W;
            double[] maxVal = new double[PQ];
            int[] maxIx = new int[PQ];
            for (int n = this._rl; n < this._ru; ++n) {
                for (int c = 0; c < C; ++c) {
                    boolean empty = this.maxpoolingForward(maxVal, maxIx, n, c, padh, padw, strideh, stridew, C, P2, Q, R, S, HW, W);
                    if (empty) continue;
                    if (this.output.isInSparseFormat()) {
                        this.maxpoolingBackwardSparse(maxIx, c * HW, n, c, C, Q, P2, CPQ);
                        continue;
                    }
                    this.maxpoolingBackwardDense(maxIx, n * CHW + c * HW, n, c, C, Q, PQ, CPQ);
                }
            }
            return (long)(P2 * Q * C) * (long)(this._ru - this._rl);
        }

        protected boolean maxpoolingForward(double[] maxVal, int[] maxIx, int n, int c, int padh, int padw, int strideh, int stridew, int C, int P2, int Q, int R, int S, int HW, int W) {
            SparseBlock sblock = this._params.input1.getSparseBlock();
            if (!sblock.isEmpty(n)) {
                Arrays.fill(maxVal, -1.7976931348623157E308);
                int apos = sblock.pos(n);
                int alen = sblock.size(n);
                int[] aix = sblock.indexes(n);
                double[] avals = sblock.values(n);
                int cpos = c == 0 ? 0 : sblock.posFIndexGTE(n, c * HW);
                int cpos2 = c + 1 == C ? alen : sblock.posFIndexGTE(n, (c + 1) * HW);
                cpos = cpos >= 0 ? cpos : alen;
                cpos2 = cpos2 >= 0 ? cpos2 : alen;
                int lastix = c * HW - 1;
                for (int j = apos + cpos; j < apos + cpos2; ++j) {
                    PoolingBackwardSparseDense.update0(lastix + 1, aix[j], maxVal, maxIx, padh, padw, strideh, stridew, P2, Q, R, S, HW, W);
                    int h = aix[j] % HW / W;
                    int w = aix[j] % W;
                    double val = this.reluBack && avals[j] < 0.0 ? 0.0 : avals[j];
                    PoolingBackwardSparseDense.update(val, maxVal, maxIx, h, w, padh, padw, strideh, stridew, P2, Q, R, S, W);
                    lastix = aix[j];
                }
                PoolingBackwardSparseDense.update0(lastix + 1, (c + 1) * HW, maxVal, maxIx, padh, padw, strideh, stridew, P2, Q, R, S, HW, W);
                return false;
            }
            return true;
        }

        protected void maxpoolingBackwardDense(int[] maxIx, int outOffset, int n, int c, int C, int Q, int PQ, int CPQ) {
            double[] dout = this.doutput.getDenseBlockValues();
            double[] out = this.output.getDenseBlockValues();
            int doutOffset = n * CPQ + c * PQ;
            for (int pq = 0; pq < PQ; ++pq) {
                int n2 = outOffset + maxIx[pq];
                out[n2] = out[n2] + dout[doutOffset + pq];
            }
        }

        protected void maxpoolingBackwardSparse(int[] maxIx, int offset, int n, int c, int C, int Q, int P2, int CPQ) {
            double[] dout = this.doutput.getDenseBlockValues();
            SparseBlock out = this.output.getSparseBlock();
            out.allocate(n, P2 * Q);
            SparseRow row = out.get(n);
            int doutOffset = n * CPQ + c * P2 * Q;
            int pq = 0;
            for (int p = 0; p < P2; ++p) {
                for (int q = 0; q < Q; ++q) {
                    row.add(maxIx[pq] + offset, dout[doutOffset + pq]);
                    ++pq;
                }
            }
        }

        private static void update0(int lix, int uix, double[] maxVal, int[] maxIx, int padh, int padw, int strideh, int stridew, int P2, int Q, int R, int S, int HW, int W) {
            for (int i = lix; i < uix; ++i) {
                PoolingBackwardSparseDense.update(0.0, maxVal, maxIx, i % HW / W, i % W, padh, padw, strideh, stridew, P2, Q, R, S, W);
            }
        }

        private static void update(double val, double[] maxVal, int[] maxIx, int h, int w, int padh, int padw, int strideh, int stridew, int P2, int Q, int R, int S, int W) {
            int lp = Math.max((h + padh - R + strideh) / strideh, 0);
            int up = Math.min((h + padh + strideh) / strideh, P2);
            int lq = Math.max((w + padw - S + stridew) / stridew, 0);
            int uq = Math.min((w + padw + stridew) / stridew, Q);
            int maxIndex = h * W + w;
            for (int p = lp; p < up; ++p) {
                for (int q = lq; q < uq; ++q) {
                    int ix = p * Q + q;
                    if (!(maxVal[ix] < val)) continue;
                    maxVal[ix] = val;
                    maxIx[ix] = maxIndex;
                }
            }
        }
    }

    private static class AvgPoolingBackwardSparse
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final DnnParameters _params;
        MatrixBlock output;
        MatrixBlock dout;
        int CHW;
        int P;
        int Q;
        int HW;
        final double _poolingMultiplier;

        public AvgPoolingBackwardSparse(int rl, int ru, DnnParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this.dout = params.input2;
            this.output = params.output;
            this.CHW = params.C * params.H * params.W;
            this.HW = params.H * params.W;
            this.P = params.P;
            this.Q = params.Q;
            this._poolingMultiplier = Math.pow(params.R * params.S, -1.0);
            if (this.output.getDenseBlock() == null) {
                throw new RuntimeException("Incorrect usage: empty inputs");
            }
        }

        @Override
        public Long call() throws Exception {
            LibMatrixDNNHelper.CellIndex3 ix = new LibMatrixDNNHelper.CellIndex3();
            double[] out = this.output.getDenseBlockValues();
            SparseBlock sblock = this.dout.sparseBlock;
            for (int n = this._rl; n < this._ru; ++n) {
                if (sblock.isEmpty(n)) continue;
                int apos = sblock.pos(n);
                int alen = sblock.size(n);
                int[] aix = sblock.indexes(n);
                double[] avals = sblock.values(n);
                for (int j = apos; j < apos + alen; ++j) {
                    ix = LibMatrixDNNHelper.computeTensorIndexes(aix[j], this.P, this.Q, ix);
                    int c = ix.ix1;
                    int p = ix.ix2;
                    int q = ix.ix3;
                    int inputOffset = n * this.CHW + c * this.HW;
                    int start_index_h = this._params.start_indexes_h[p];
                    int end_index_h = this._params.end_indexes_h[p];
                    int start_index_w = this._params.start_indexes_w[q];
                    int end_index_w = this._params.end_indexes_w[q];
                    for (int h = start_index_h; h < end_index_h; ++h) {
                        for (int w = start_index_w; w < end_index_w; ++w) {
                            int n2 = inputOffset + h * this._params.W + w;
                            out[n2] = out[n2] + this._poolingMultiplier * avals[j];
                        }
                    }
                }
            }
            return this.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }

    private static class PoolingBackwardDenseSparse
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final DnnParameters _params;
        MatrixBlock output;
        boolean performReluBackward;
        double[] inputArray;
        MatrixBlock dout;
        final int CHW;
        final int P;
        final int Q;
        final int HW;
        final int C;

        public PoolingBackwardDenseSparse(int rl, int ru, DnnParameters params, boolean performReluBackward) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this.performReluBackward = performReluBackward;
            this.inputArray = params.input1.getDenseBlockValues();
            this.dout = params.input2;
            this.output = params.output;
            this.C = params.C;
            this.CHW = params.C * params.H * params.W;
            this.HW = params.H * params.W;
            this.P = params.P;
            this.Q = params.Q;
            if (this.inputArray == null) {
                throw new RuntimeException("Incorrect usage: empty inputs");
            }
            if (!params.input2.isInSparseFormat()) {
                throw new RuntimeException("Incorrect usage: Call optimized versions");
            }
        }

        @Override
        public Long call() throws Exception {
            SparseBlock sblock = this.dout.sparseBlock;
            if (this.output.isInSparseFormat()) {
                SparseBlock out = this.output.getSparseBlock();
                int[] i = new int[this.Q];
                double[] v = new double[this.Q];
                for (int n = this._rl; n < this._ru; ++n) {
                    if (sblock.isEmpty(n)) continue;
                    out.allocate(n, this.P * this.Q * this.C);
                    SparseRow elm = out.get(n);
                    int apos = sblock.pos(n);
                    int alen = sblock.size(n);
                    int[] aix = sblock.indexes(n);
                    double[] avals = sblock.values(n);
                    int oldP = 0;
                    int pointer = 0;
                    int nCHW = n * this.CHW;
                    for (int j = apos; j < apos + alen; ++j) {
                        int maxIndex;
                        int tmp = aix[j] / this.Q;
                        int inputOffset = nCHW + tmp / this.P * this.HW;
                        int p = tmp % this.P;
                        int q = aix[j] % this.Q;
                        if (p != oldP) {
                            LibMatrixDNNPooling.add(elm, i, v, pointer);
                            oldP = p;
                            pointer = 0;
                        }
                        if ((maxIndex = LibMatrixDNNPooling.getMaxIndex(p, q, inputOffset, this.inputArray, this._params, this.performReluBackward)) == -1) continue;
                        i[pointer] = maxIndex - nCHW;
                        v[pointer] = avals[j];
                        ++pointer;
                    }
                    LibMatrixDNNPooling.add(elm, i, v, pointer);
                }
            } else {
                LibMatrixDNNHelper.CellIndex3 ix = new LibMatrixDNNHelper.CellIndex3();
                double[] out = this.output.getDenseBlockValues();
                for (int n = this._rl; n < this._ru; ++n) {
                    if (sblock.isEmpty(n)) continue;
                    int apos = sblock.pos(n);
                    int alen = sblock.size(n);
                    int[] aix = sblock.indexes(n);
                    double[] avals = sblock.values(n);
                    for (int j = apos; j < apos + alen; ++j) {
                        ix = LibMatrixDNNHelper.computeTensorIndexes(aix[j], this.P, this.Q, ix);
                        int inputOffset = n * this.CHW + ix.ix1 * this.HW;
                        int maxIndex = LibMatrixDNNPooling.getMaxIndex(ix.ix2, ix.ix3, inputOffset, this.inputArray, this._params, this.performReluBackward);
                        if (maxIndex == -1) continue;
                        int n2 = maxIndex;
                        out[n2] = out[n2] + avals[j];
                    }
                }
            }
            return (long)(this.P * this.Q * this.C) * (long)(this._ru - this._rl);
        }
    }

    private static class PoolingBackwardDenseDense
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final DnnParameters _params;
        boolean performReluBackward;
        double[] inputArray;
        double[] doutArray;
        MatrixBlock output;
        int C;
        int CHW;
        int P;
        int Q;
        int HW;
        int CPQ;
        int PQ;

        public PoolingBackwardDenseDense(int rl, int ru, DnnParameters params, boolean performReluBackward) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this.performReluBackward = performReluBackward;
            this.inputArray = params.input1.getDenseBlockValues();
            this.doutArray = params.input2.getDenseBlockValues();
            this.output = params.output;
            this.C = params.C;
            this.CHW = params.C * params.H * params.W;
            this.HW = params.H * params.W;
            this.P = params.P;
            this.Q = params.Q;
            this.CPQ = params.C * params.P * params.Q;
            this.PQ = params.P * params.Q;
            if (this.inputArray == null || this.doutArray == null) {
                throw new RuntimeException("Incorrect usage: empty inputs");
            }
        }

        @Override
        public Long call() throws Exception {
            if (this.output.isInSparseFormat()) {
                SparseBlock out = this.output.getSparseBlock();
                int[] i = new int[this.Q];
                double[] v = new double[this.Q];
                for (int n = this._rl; n < this._ru; ++n) {
                    out.allocate(n, this.P * this.Q * this.C);
                    SparseRow elm = out.get(n);
                    int nCHW = n * this.CHW;
                    for (int c = 0; c < this.C; ++c) {
                        int inputOffset = nCHW + c * this.HW;
                        int outputOffset = n * this.CPQ + c * this.PQ;
                        for (int p = 0; p < this.P; ++p) {
                            int pointer = 0;
                            for (int q = 0; q < this.Q; ++q) {
                                int maxIndex = LibMatrixDNNPooling.getMaxIndex(p, q, inputOffset, this.inputArray, this._params, this.performReluBackward);
                                if (maxIndex == -1) continue;
                                i[pointer] = maxIndex - nCHW;
                                v[pointer] = this.doutArray[outputOffset + p * this.Q + q];
                                ++pointer;
                            }
                            LibMatrixDNNPooling.add(elm, i, v, pointer);
                        }
                    }
                }
            } else {
                double[] out = this.output.getDenseBlockValues();
                for (int n = this._rl; n < this._ru; ++n) {
                    for (int c = 0; c < this.C; ++c) {
                        int inputOffset = n * this.CHW + c * this.HW;
                        int outputOffset = n * this.CPQ + c * this.PQ;
                        for (int p = 0; p < this.P; ++p) {
                            for (int q = 0; q < this.Q; ++q) {
                                int maxIndex = LibMatrixDNNPooling.getMaxIndex(p, q, inputOffset, this.inputArray, this._params, this.performReluBackward);
                                if (maxIndex == -1) continue;
                                int n2 = maxIndex;
                                out[n2] = out[n2] + this.doutArray[outputOffset + p * this.Q + q];
                            }
                        }
                    }
                }
            }
            return (long)(this.P * this.Q * this.C) * (long)(this._ru - this._rl);
        }
    }

    private static class AvgPoolingBackwardDense
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final DnnParameters _params;
        double[] doutArray;
        MatrixBlock output;
        final int C;
        final int CHW;
        final int P;
        final int Q;
        final int HW;
        final int CPQ;
        final int PQ;
        final double _poolingMultiplier;

        public AvgPoolingBackwardDense(int rl, int ru, DnnParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this.doutArray = params.input2.getDenseBlockValues();
            this.output = params.output;
            this.C = params.C;
            this.CHW = params.C * params.H * params.W;
            this.HW = params.H * params.W;
            this.P = params.P;
            this.Q = params.Q;
            this.CPQ = params.C * params.P * params.Q;
            this.PQ = params.P * params.Q;
            this._poolingMultiplier = Math.pow(params.R * params.S, -1.0);
            if (this.doutArray == null || this.output.getDenseBlock() == null) {
                throw new RuntimeException("Incorrect usage: empty inputs");
            }
        }

        @Override
        public Long call() throws Exception {
            double[] out = this.output.getDenseBlockValues();
            for (int n = this._rl; n < this._ru; ++n) {
                for (int c = 0; c < this.C; ++c) {
                    int inputOffset = n * this.CHW + c * this.HW;
                    int outputOffset = n * this.CPQ + c * this.PQ;
                    for (int p = 0; p < this.P; ++p) {
                        for (int q = 0; q < this.Q; ++q) {
                            int start_index_h = this._params.start_indexes_h[p];
                            int end_index_h = this._params.end_indexes_h[p];
                            int start_index_w = this._params.start_indexes_w[q];
                            int end_index_w = this._params.end_indexes_w[q];
                            for (int h = start_index_h; h < end_index_h; ++h) {
                                for (int w = start_index_w; w < end_index_w; ++w) {
                                    int n2 = inputOffset + h * this._params.W + w;
                                    out[n2] = out[n2] + this._poolingMultiplier * this.doutArray[outputOffset + p * this.Q + q];
                                }
                            }
                        }
                    }
                }
            }
            return this.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }

    private static class SparsePooling
    implements Callable<Long> {
        private final int _rl;
        private final int _ru;
        private final DnnParameters _params;
        private double[] outputArray;
        private final int C;
        private final int P;
        private final int Q;
        private final int W;
        private final int H;
        private final int CPQ;
        private final int PQ;
        private final LibMatrixDNN.PoolingType _poolingType;
        private final double _poolingMultiplier;

        public SparsePooling(int rl, int ru, DnnParameters params, LibMatrixDNN.PoolingType poolingType) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this.outputArray = params.output.getDenseBlockValues();
            this.C = params.C;
            this.P = params.P;
            this.Q = params.Q;
            this.H = params.H;
            this.W = params.W;
            this.CPQ = this.C * this.P * this.Q;
            this.PQ = this.P * this.Q;
            this._poolingType = poolingType;
            this._poolingMultiplier = Math.pow(params.R * params.S, -1.0);
        }

        @Override
        public Long call() throws Exception {
            if (this._poolingType == LibMatrixDNN.PoolingType.MAX) {
                Arrays.fill(this.outputArray, this._rl * this.CPQ, this._ru * this.CPQ, this._params.minValForMaxPoolOperations);
            }
            for (int n = this._rl; n < this._ru; ++n) {
                if (!this._params.input1.sparseBlock.isEmpty(n)) {
                    int apos = this._params.input1.sparseBlock.pos(n);
                    int alen = this._params.input1.sparseBlock.size(n);
                    int[] aix = this._params.input1.sparseBlock.indexes(n);
                    double[] avals = this._params.input1.sparseBlock.values(n);
                    int chw = 0;
                    int index = apos;
                    for (int c = 0; c < this.C; ++c) {
                        int outOffset = n * this.CPQ + c * this.PQ;
                        for (int h = 0; h < this.H; ++h) {
                            int w = 0;
                            while (w < this.W) {
                                int q;
                                int outOffsetWithp;
                                int p;
                                double nchwVal = 0.0;
                                if (aix[index] == chw) {
                                    nchwVal = avals[index++];
                                    if (index >= apos + alen) {
                                        --index;
                                    }
                                }
                                if (this._poolingType == LibMatrixDNN.PoolingType.MAX) {
                                    for (p = 0; p < this.P; ++p) {
                                        if (h < this._params.start_indexes_h[p] || h >= this._params.end_indexes_h[p]) continue;
                                        outOffsetWithp = outOffset + p * this.Q;
                                        for (q = 0; q < this.Q; ++q) {
                                            if (w < this._params.start_indexes_w[q] || w >= this._params.end_indexes_w[q]) continue;
                                            this.outputArray[outOffsetWithp + q] = Math.max(this.outputArray[outOffsetWithp + q], nchwVal);
                                        }
                                    }
                                } else {
                                    for (p = 0; p < this.P; ++p) {
                                        if (h < this._params.start_indexes_h[p] || h >= this._params.end_indexes_h[p]) continue;
                                        outOffsetWithp = outOffset + p * this.Q;
                                        for (q = 0; q < this.Q; ++q) {
                                            if (w < this._params.start_indexes_w[q] || w >= this._params.end_indexes_w[q]) continue;
                                            int n2 = outOffsetWithp + q;
                                            this.outputArray[n2] = this.outputArray[n2] + this._poolingMultiplier * nchwVal;
                                        }
                                    }
                                }
                                ++w;
                                ++chw;
                            }
                        }
                    }
                    continue;
                }
                Arrays.fill(this.outputArray, n * this.CPQ, (n + 1) * this.CPQ, 0.0);
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }

    private static class DensePooling
    implements Callable<Long> {
        private final int _rl;
        private final int _ru;
        private final DnnParameters _params;
        private final LibMatrixDNN.PoolingType _poolingType;
        private final double _poolingMultiplier;

        public DensePooling(int rl, int ru, DnnParameters params, LibMatrixDNN.PoolingType poolingType) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this._poolingType = poolingType;
            this._poolingMultiplier = 1.0 / (double)(params.R * params.S);
        }

        @Override
        public Long call() throws Exception {
            boolean max;
            int C = this._params.C;
            int P2 = this._params.P;
            int Q = this._params.Q;
            int R = this._params.R;
            int S = this._params.S;
            int H = this._params.H;
            int W = this._params.W;
            int HW = this._params.H * this._params.W;
            int CHW = this._params.C * this._params.H * this._params.W;
            int CPQ = C * P2 * Q;
            double[] in = this._params.input1.getDenseBlockValues();
            double[] out = this._params.output.getDenseBlockValues();
            double minValForMaxPoolOperations = this._poolingType == LibMatrixDNN.PoolingType.AVG ? 0.0 : this._params.minValForMaxPoolOperations;
            boolean bl = max = this._poolingType == LibMatrixDNN.PoolingType.MAX;
            if (this._params.isStride1Pad0()) {
                LibMatrixDNNPooling.poolingDenseStride1Pad0(this._poolingType, minValForMaxPoolOperations, this._poolingMultiplier, in, out, this._rl, this._ru, this._rl * CHW, this._rl * CPQ, C, P2, Q, R, S, H, W);
            } else {
                Arrays.fill(out, this._rl * CPQ, this._ru * CPQ, minValForMaxPoolOperations);
                int[] hl = this._params.start_indexes_h;
                int[] hu = this._params.end_indexes_h;
                int[] wl = this._params.start_indexes_w;
                int[] wu = this._params.end_indexes_w;
                for (int i = this._rl; i < this._ru; ++i) {
                    int c = 0;
                    int off = i * CHW;
                    int oix = i * CPQ;
                    while (c < C) {
                        int p = 0;
                        while (p < P2) {
                            for (int h = hl[p]; h < hu[p]; ++h) {
                                int off2 = off + h * W;
                                for (int q = 0; q < Q; ++q) {
                                    out[oix + q] = max ? LibMatrixDNNPooling.max(out[oix + q], in, off2 + wl[q], wu[q] - wl[q]) : LibMatrixDNNPooling.avg(out[oix + q], in, off2 + wl[q], wu[q] - wl[q], this._poolingMultiplier);
                                }
                            }
                            ++p;
                            oix += Q;
                        }
                        ++c;
                        off += HW;
                    }
                }
            }
            return this._params.output.recomputeNonZeros(this._rl, this._ru - 1);
        }
    }
}

