/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
import org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.SmallestPriorityQueue;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.utils.Hash;

public class KMVSketch
extends CountDistinctSketch {
    private static final Log LOG = LogFactory.getLog((String)KMVSketch.class.getName());

    public KMVSketch(Operator op) {
        super(op);
    }

    @Override
    public MatrixBlock getValue(MatrixBlock blkIn) {
        if (this.op.getDirection().isRowCol()) {
            long res;
            long D = blkIn.getNonZeros() + 1L;
            long tmp = D * D;
            int M = tmp > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int)tmp;
            int k = D > 64L ? 64 : (int)D;
            SmallestPriorityQueue spq = this.getKSmallestHashes(blkIn, k, M);
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("M not forced to int size: " + tmp));
                LOG.debug((Object)("M: " + M));
                LOG.debug((Object)("M: " + M));
                LOG.debug((Object)("kth smallest hash:" + spq.peek()));
                LOG.debug((Object)("spq: " + spq));
            }
            if ((res = this.countDistinctValuesKMV(spq, k, M, D)) <= 0L) {
                throw new DMLRuntimeException("Impossible estimate of distinct values");
            }
            return new MatrixBlock(res);
        }
        if (this.op.getDirection().isRow()) {
            long D = (long)Math.floor((double)blkIn.getNonZeros() / (double)blkIn.getNumRows()) + 1L;
            long tmp = D * D;
            int M = tmp > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int)tmp;
            int k = D > 64L ? 64 : (int)D;
            MatrixBlock resultMatrix = new MatrixBlock(blkIn.getNumRows(), 1, false, blkIn.getNumRows());
            resultMatrix.allocateBlock();
            SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
            for (int i = 0; i < blkIn.getNumRows(); ++i) {
                for (int j = 0; j < blkIn.getNumColumns(); ++j) {
                    spq.add(blkIn.get(i, j));
                }
                long res = this.countDistinctValuesKMV(spq, k, M, D);
                resultMatrix.set(i, 0, res);
                spq.clear();
            }
            return resultMatrix;
        }
        long D = (long)Math.floor((double)blkIn.getNonZeros() / (double)blkIn.getNumColumns()) + 1L;
        long tmp = D * D;
        int M = tmp > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int)tmp;
        int k = D > 64L ? 64 : (int)D;
        MatrixBlock resultMatrix = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumColumns());
        resultMatrix.allocateBlock();
        SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
        for (int j = 0; j < blkIn.getNumColumns(); ++j) {
            for (int i = 0; i < blkIn.getNumRows(); ++i) {
                spq.add(blkIn.get(i, j));
            }
            long res = this.countDistinctValuesKMV(spq, k, M, D);
            resultMatrix.set(0, j, res);
            spq.clear();
        }
        return resultMatrix;
    }

    private SmallestPriorityQueue getKSmallestHashes(MatrixBlock in, int k, int M) {
        SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
        this.countDistinctValuesKMV(in, this.op.getHashType(), k, spq, M);
        return spq;
    }

    private void countDistinctValuesKMV(MatrixBlock in, Hash.HashType hashType, int k, SmallestPriorityQueue spq, int m) {
        if (in.isEmpty()) {
            spq.add(0.0);
        } else {
            if (in instanceof CompressedMatrixBlock) {
                throw new NotImplementedException("Cannot approximate distinct count for compressed matrices");
            }
            if (in.getSparseBlock() != null) {
                SparseBlock sb = in.getSparseBlock();
                if (sb.isContiguous()) {
                    double[] data = sb.values(0);
                    this.countDistinctValuesKMV(data, hashType, k, spq, m);
                } else {
                    for (int i = 0; i < in.getNumRows(); ++i) {
                        if (sb.isEmpty(i)) continue;
                        double[] data = sb.values(i);
                        this.countDistinctValuesKMV(data, hashType, k, spq, m);
                    }
                }
            } else {
                DenseBlock db = in.getDenseBlock();
                int bil = db.index(0);
                int biu = db.index(in.getNumRows());
                for (int i = bil; i <= biu; ++i) {
                    double[] data = db.valuesAt(i);
                    this.countDistinctValuesKMV(data, hashType, k, spq, m);
                }
            }
        }
    }

    private void countDistinctValuesKMV(double[] data, Hash.HashType hashType, int k, SmallestPriorityQueue spq, int m) {
        for (double fullValue : data) {
            int hash = Hash.hash(fullValue, hashType);
            int v = Math.abs(hash) % (m - 1) + 1;
            spq.add(v);
        }
    }

    private long countDistinctValuesKMV(SmallestPriorityQueue spq, int k, int M, long D) {
        long res;
        if (spq.size() < k) {
            res = spq.size();
        } else {
            double kthSmallestHash = spq.poll();
            double U_k = kthSmallestHash / (double)M;
            double estimate = (double)(k - 1) / U_k;
            double ceilEstimate = Math.min(estimate, (double)D);
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("U_k : " + U_k));
                LOG.debug((Object)("Estimate: " + estimate));
                LOG.debug((Object)("Ceil worst case: " + D));
            }
            res = Math.round(ceilEstimate);
        }
        return res;
    }

    @Override
    public MatrixBlock getValueFromSketch(CorrMatrixBlock arg0) {
        MatrixBlock blkIn = arg0.getValue();
        if (this.op.getDirection().isRow()) {
            MatrixBlock blkOut = new MatrixBlock(blkIn.getNumRows(), 1, false, blkIn.getNumRows());
            blkOut.allocateBlock();
            for (int i = 0; i < blkIn.getNumRows(); ++i) {
                this.getDistinctCountFromSketchByIndex(arg0, i, blkOut);
            }
            return blkOut;
        }
        if (this.op.getDirection().isCol()) {
            MatrixBlock blkOut = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumColumns());
            blkOut.allocateBlock();
            for (int j = 0; j < blkIn.getNumColumns(); ++j) {
                this.getDistinctCountFromSketchByIndex(arg0, j, blkOut);
            }
            return blkOut;
        }
        MatrixBlock blkOut = new MatrixBlock(1, 1, false, 1L);
        blkOut.allocateBlock();
        this.getDistinctCountFromSketchByIndex(arg0, 0, blkOut);
        return blkOut;
    }

    private void getDistinctCountFromSketchByIndex(CorrMatrixBlock arg0, int idx, MatrixBlock blkOut) {
        double ceilEstimate;
        double M;
        MatrixBlock blkIn = arg0.getValue();
        MatrixBlock blkInCorr = arg0.getCorrection();
        if (this.op.getOperatorType() != CountDistinctOperatorTypes.KMV) {
            throw new IllegalArgumentException(this.getClass().getSimpleName() + " cannot use " + this.op.getOperatorType());
        }
        double kthSmallestHash = this.op.getDirection().isRow() || this.op.getDirection().isRowCol() ? blkIn.get(idx, 0) : blkIn.get(0, idx);
        double nHashes = blkInCorr.get(idx, 0);
        double k = blkInCorr.get(idx, 1);
        double D = blkInCorr.get(idx, 2);
        double D2 = D * D;
        double d = M = D2 > 2.147483647E9 ? 2.147483647E9 : D2;
        if (nHashes != 0.0 && nHashes < k) {
            ceilEstimate = nHashes;
        } else if (nHashes == 0.0) {
            ceilEstimate = 1.0;
        } else {
            double U_k = kthSmallestHash / M;
            double estimate = (k - 1.0) / U_k;
            ceilEstimate = Math.min(estimate, D);
        }
        if (this.op.getDirection().isRow() || this.op.getDirection().isRowCol()) {
            blkOut.set(idx, 0, ceilEstimate);
        } else {
            blkOut.set(0, idx, ceilEstimate);
        }
    }

    @Override
    public CorrMatrixBlock create(MatrixBlock blkIn) {
        if (this.op.getDirection().isRowCol()) {
            MatrixBlock blkOut = new MatrixBlock(blkIn);
            MatrixBlock blkOutCorr = new MatrixBlock(1, 3, false);
            this.createSketchByIndex(blkIn, blkOutCorr, 0, blkOut);
            return new CorrMatrixBlock(blkOut, blkOutCorr);
        }
        if (this.op.getDirection().isRow()) {
            MatrixBlock blkOut = blkIn;
            MatrixBlock blkOutCorr = new MatrixBlock(blkIn.getNumRows(), 3, false);
            for (int i = 0; i < blkIn.getNumRows(); ++i) {
                this.createSketchByIndex(blkOut, blkOutCorr, i);
            }
            return new CorrMatrixBlock(blkOut, blkOutCorr);
        }
        if (this.op.getDirection().isCol()) {
            MatrixBlock blkOut = blkIn;
            MatrixBlock blkOutCorr = new MatrixBlock(blkIn.getNumColumns(), 3, false);
            for (int j = 0; j < blkIn.getNumColumns(); ++j) {
                this.createSketchByIndex(blkOut, blkOutCorr, j);
            }
            return new CorrMatrixBlock(blkOut, blkOutCorr);
        }
        throw new DMLRuntimeException(String.format("Unexpected direction: %s", new Object[]{this.op.getDirection()}));
    }

    private MatrixBlock sliceMatrixBlockByIndexDirection(MatrixBlock blkIn, int idx) {
        MatrixBlock blkInSlice = this.op.getDirection().isRow() ? blkIn.slice(idx, idx) : (this.op.getDirection().isCol() ? blkIn.slice(0, blkIn.getNumRows() - 1, idx, idx) : blkIn);
        return blkInSlice;
    }

    private void createSketchByIndex(MatrixBlock blkIn, MatrixBlock sketchMetaMB, int idx) {
        this.createSketchByIndex(blkIn, sketchMetaMB, idx, null);
    }

    private void createSketchByIndex(MatrixBlock blkIn, MatrixBlock sketchMetaMB, int idx, MatrixBlock blkOut) {
        int k;
        MatrixBlock sketchMB = blkOut == null ? blkIn : blkOut;
        MatrixBlock blkInSlice = this.sliceMatrixBlockByIndexDirection(blkIn, idx);
        long D = blkInSlice.getNonZeros() + 1L;
        long D2 = D * D;
        int M = D2 > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int)D2;
        int n = k = D > 64L ? 64 : (int)D;
        if (blkOut != null) {
            sketchMB.reset(1, k);
        }
        if (blkInSlice.getLength() == 1L || blkInSlice.isEmpty()) {
            sketchMetaMB.set(idx, 0, 0.0);
            sketchMetaMB.set(idx, 1, k);
            sketchMetaMB.set(idx, 2, D);
            return;
        }
        SmallestPriorityQueue spq = this.getKSmallestHashes(blkInSlice, k, M);
        int nHashes = spq.size();
        assert (nHashes > 0);
        int i = 0;
        while (!spq.isEmpty()) {
            double toInsert = spq.poll();
            if (this.op.getDirection().isRow()) {
                sketchMB.set(idx, i, toInsert);
            } else if (this.op.getDirection().isCol()) {
                sketchMB.set(i, idx, toInsert);
            } else {
                sketchMB.set(idx, i, toInsert);
            }
            ++i;
        }
        sketchMetaMB.set(idx, 0, nHashes);
        sketchMetaMB.set(idx, 1, k);
        sketchMetaMB.set(idx, 2, D);
    }

    @Override
    public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock arg1) {
        MatrixBlock matrix0 = arg0.getValue();
        MatrixBlock matrix1 = arg1.getValue();
        if (this.op.getDirection().isRow()) {
            MatrixBlock combined = matrix0.getNumColumns() > matrix1.getNumColumns() ? matrix0 : matrix1;
            MatrixBlock combinedCorr = new MatrixBlock(matrix0.getNumRows(), 3, false);
            CorrMatrixBlock blkout = new CorrMatrixBlock(combined, combinedCorr);
            for (int i = 0; i < matrix0.getNumRows(); ++i) {
                this.unionSketchByIndex(arg0, arg1, i, blkout);
            }
            return blkout;
        }
        if (this.op.getDirection().isCol()) {
            MatrixBlock combined = matrix0.getNumRows() > matrix1.getNumRows() ? matrix0 : matrix1;
            MatrixBlock combinedCorr = new MatrixBlock(matrix0.getNumColumns(), 3, false);
            CorrMatrixBlock blkOut = new CorrMatrixBlock(combined, combinedCorr);
            for (int j = 0; j < matrix0.getNumColumns(); ++j) {
                this.unionSketchByIndex(arg0, arg1, j, blkOut);
            }
            return blkOut;
        }
        MatrixBlock combined = matrix0.getNumColumns() > matrix1.getNumColumns() ? matrix0 : matrix1;
        MatrixBlock combinedCorr = new MatrixBlock(1, 3, false);
        CorrMatrixBlock blkOut = new CorrMatrixBlock(combined, combinedCorr);
        this.unionSketchByIndex(arg0, arg1, 0, blkOut);
        return blkOut;
    }

    public void unionSketchByIndex(CorrMatrixBlock arg0, CorrMatrixBlock arg1, int idx, CorrMatrixBlock blkOut) {
        double val;
        MatrixBlock corr0 = arg0.getCorrection();
        MatrixBlock corr1 = arg1.getCorrection();
        this.validateSketchMetadata(corr0);
        this.validateSketchMetadata(corr1);
        MatrixBlock matrix0 = arg0.getValue();
        MatrixBlock matrix1 = arg1.getValue();
        if (this.op.getDirection().isRow() && matrix0.getNumRows() != matrix1.getNumRows() || this.op.getDirection().isCol() && matrix0.getNumColumns() != matrix1.getNumColumns()) {
            throw new DMLRuntimeException("Cannot take the union of sketches: rows/columns are not aligned");
        }
        MatrixBlock combined = blkOut.getValue();
        MatrixBlock combinedCorr = blkOut.getCorrection();
        double nHashes0 = corr0.get(idx, 0);
        double k0 = corr0.get(idx, 1);
        double D0 = corr0.get(idx, 2);
        double nHashes1 = corr1.get(idx, 0);
        double k1 = corr1.get(idx, 1);
        double D1 = corr1.get(idx, 2);
        double nHashes = Math.max(nHashes0, nHashes1);
        double k = Math.max(k0, k1);
        double D = D0 + D1 - 1.0;
        SmallestPriorityQueue hashUnion = new SmallestPriorityQueue((int)nHashes);
        int i = 0;
        while ((double)i < nHashes0) {
            val = this.op.getDirection().isRow() || this.op.getDirection().isRowCol() ? matrix0.get(idx, i) : matrix0.get(i, idx);
            hashUnion.add(val);
            ++i;
        }
        i = 0;
        while ((double)i < nHashes1) {
            val = this.op.getDirection().isRow() || this.op.getDirection().isRowCol() ? matrix1.get(idx, i) : matrix1.get(i, idx);
            hashUnion.add(val);
            ++i;
        }
        i = 0;
        while (!hashUnion.isEmpty()) {
            val = hashUnion.poll();
            if (this.op.getDirection().isRow() || this.op.getDirection().isRowCol()) {
                combined.set(idx, i, val);
            } else {
                combined.set(i, idx, val);
            }
            ++i;
        }
        combinedCorr.set(idx, 0, nHashes);
        combinedCorr.set(idx, 1, k);
        combinedCorr.set(idx, 2, D);
    }

    @Override
    public CorrMatrixBlock intersection(CorrMatrixBlock arg0, CorrMatrixBlock arg1) {
        throw new NotImplementedException(String.format("%s intersection has not been implemented yet", KMVSketch.class.getSimpleName()));
    }
}

