/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCZeros;
import org.apache.sysds.runtime.compress.colgroup.FORUtil;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToBit;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

public class ColGroupSDC
extends AMorphingMMColGroup {
    private static final long serialVersionUID = 769993538831949086L;
    protected AOffset _indexes;
    protected AMapToData _data;
    protected double[] _defaultTuple;

    protected ColGroupSDC(int numRows) {
        super(numRows);
    }

    private ColGroupSDC(int[] colIndices, int numRows, ADictionary dict, double[] defaultTuple, AOffset offsets, AMapToData data, int[] cachedCounts) {
        super(colIndices, numRows, dict, cachedCounts);
        if (data.getUnique() != dict.getNumberOfValues(colIndices.length)) {
            if (data.getUnique() != data.getMax()) {
                throw new DMLCompressionException("Invalid unique count compared to actual: " + data.getUnique() + " " + data.getMax());
            }
            throw new DMLCompressionException("Invalid construction of SDC group: number uniques: " + data.getUnique() + " vs." + dict.getNumberOfValues(colIndices.length));
        }
        this._indexes = offsets;
        this._data = data;
        this._zeros = false;
        this._defaultTuple = defaultTuple;
        if (data instanceof MapToBit && ((MapToBit)data).isEmpty()) {
            throw new DMLCompressionException("Error in SDC construction should have been SDCSingle");
        }
    }

    protected static AColGroup create(int[] colIndices, int numRows, ADictionary dict, double[] defaultTuple, AOffset offsets, AMapToData data, int[] cachedCounts) {
        boolean allZero = FORUtil.allZero(defaultTuple);
        if (dict == null && allZero) {
            return new ColGroupEmpty(colIndices);
        }
        if (dict == null) {
            return ColGroupSDCSingle.create(colIndices, numRows, null, defaultTuple, offsets, null);
        }
        if (allZero) {
            return ColGroupSDCZeros.create(colIndices, numRows, dict, offsets, data, cachedCounts);
        }
        return new ColGroupSDC(colIndices, numRows, dict, defaultTuple, offsets, data, cachedCounts);
    }

    @Override
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.SDC;
    }

    @Override
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.SDC;
    }

    @Override
    public double getIdx(int r, int colIdx) {
        AIterator it = this._indexes.getIterator(r);
        if (it == null || it.value() != r) {
            return this._defaultTuple[colIdx];
        }
        int rowOff = this._data.getIndex(it.getDataIndex());
        int nCol = this._colIndexes.length;
        return this._dict.getValue(rowOff * nCol + colIdx);
    }

    @Override
    public ADictionary getDictionary() {
        throw new NotImplementedException("Not implemented getting the dictionary out, and i think we should consider removing the option");
    }

    @Override
    protected double[] preAggSumRows() {
        return this._dict.sumAllRowsToDoubleWithDefault(this._defaultTuple);
    }

    @Override
    protected double[] preAggSumSqRows() {
        return this._dict.sumAllRowsToDoubleSqWithDefault(this._defaultTuple);
    }

    @Override
    protected double[] preAggProductRows() {
        throw new NotImplementedException("Should implement preAgg with extra cell");
    }

    @Override
    protected double[] preAggBuiltinRows(Builtin builtin) {
        return this._dict.aggregateRowsWithDefault(builtin, this._defaultTuple);
    }

    @Override
    protected double computeMxx(double c, Builtin builtin) {
        double ret = this._dict.aggregate(c, builtin);
        for (int i = 0; i < this._defaultTuple.length; ++i) {
            ret = builtin.execute(ret, this._defaultTuple[i]);
        }
        return ret;
    }

    @Override
    protected void computeColMxx(double[] c, Builtin builtin) {
        this._dict.aggregateCols(c, builtin, this._colIndexes);
        for (int x = 0; x < this._colIndexes.length; ++x) {
            c[this._colIndexes[x]] = builtin.execute(c[this._colIndexes[x]], this._defaultTuple[x]);
        }
    }

    @Override
    protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) {
        ColGroupSDC.computeRowSums(c, rl, ru, preAgg, this._data, this._indexes, this._numRows);
    }

    protected static final void computeRowSums(double[] c, int rl, int ru, double[] preAgg, AMapToData data, AOffset indexes, int nRows) {
        int r;
        double def;
        block10: {
            AIterator it = indexes.getIterator(rl);
            def = preAgg[preAgg.length - 1];
            if (it != null && it.value() > ru) {
                indexes.cacheIterator(it, ru);
            } else {
                if (it != null && ru >= indexes.getOffsetToLast()) {
                    int maxId = data.size() - 1;
                    while (true) {
                        if (it.value() == r) {
                            int n = r++;
                            c[n] = c[n] + preAgg[data.getIndex(it.getDataIndex())];
                            if (it.getDataIndex() >= maxId) break block10;
                            it.next();
                        } else {
                            int n = r;
                            c[n] = c[n] + def;
                        }
                        ++r;
                    }
                }
                if (it != null) {
                    for (r = rl; r < ru; ++r) {
                        if (it.value() == r) {
                            int n = r;
                            c[n] = c[n] + preAgg[data.getIndex(it.getDataIndex())];
                            it.next();
                            continue;
                        }
                        int n = r;
                        c[n] = c[n] + def;
                    }
                    indexes.cacheIterator(it, ru);
                }
            }
        }
        while (r < ru) {
            int n = r++;
            c[n] = c[n] + def;
        }
    }

    @Override
    protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg) {
        ColGroupSDC.computeRowMxx(c, builtin, rl, ru, preAgg, this._data, this._indexes, this._numRows, preAgg[preAgg.length - 1]);
    }

    /*
     * Enabled aggressive block sorting
     */
    protected static final void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg, AMapToData data, AOffset indexes, int nRows, double def) {
        int r;
        block7: {
            AIterator it;
            block10: {
                block9: {
                    block8: {
                        it = indexes.getIterator(rl);
                        if (it == null || it.value() <= ru) break block8;
                        indexes.cacheIterator(it, ru);
                        break block7;
                    }
                    if (it != null && ru >= indexes.getOffsetToLast()) break block9;
                    if (it == null) break block7;
                    break block10;
                }
                int maxId = data.size() - 1;
                while (true) {
                    block12: {
                        block11: {
                            if (it.value() != r) break block11;
                            c[r] = builtin.execute(c[r], preAgg[data.getIndex(it.getDataIndex())]);
                            if (it.getDataIndex() < maxId) {
                                it.next();
                                break block12;
                            } else {
                                ++r;
                                break block7;
                            }
                        }
                        c[r] = builtin.execute(c[r], def);
                    }
                    ++r;
                }
            }
            for (r = rl; r < ru; ++r) {
                if (it.value() == r) {
                    c[r] = builtin.execute(c[r], preAgg[data.getIndex(it.getDataIndex())]);
                    it.next();
                    continue;
                }
                c[r] = builtin.execute(c[r], def);
            }
            indexes.cacheIterator(it, ru);
        }
        while (r < ru) {
            c[r] = builtin.execute(c[r], def);
            ++r;
        }
        return;
    }

    @Override
    protected void computeSum(double[] c, int nRows) {
        super.computeSum(c, nRows);
        int count = this._numRows - this._data.size();
        for (int x = 0; x < this._defaultTuple.length; ++x) {
            c[0] = c[0] + this._defaultTuple[x] * (double)count;
        }
    }

    @Override
    public void computeColSums(double[] c, int nRows) {
        super.computeColSums(c, nRows);
        int count = this._numRows - this._data.size();
        for (int x = 0; x < this._colIndexes.length; ++x) {
            int n = this._colIndexes[x];
            c[n] = c[n] + this._defaultTuple[x] * (double)count;
        }
    }

    @Override
    protected void computeSumSq(double[] c, int nRows) {
        super.computeSumSq(c, nRows);
        int count = this._numRows - this._data.size();
        for (int x = 0; x < this._colIndexes.length; ++x) {
            c[0] = c[0] + this._defaultTuple[x] * this._defaultTuple[x] * (double)count;
        }
    }

    @Override
    protected void computeColSumsSq(double[] c, int nRows) {
        super.computeColSumsSq(c, nRows);
        int count = this._numRows - this._data.size();
        for (int x = 0; x < this._colIndexes.length; ++x) {
            int n = this._colIndexes[x];
            c[n] = c[n] + this._defaultTuple[x] * this._defaultTuple[x] * (double)count;
        }
    }

    @Override
    protected void computeProduct(double[] c, int nRows) {
        int count = this._numRows - this._data.size();
        this._dict.productWithDefault(c, this.getCounts(), this._defaultTuple, count);
    }

    @Override
    protected void computeColProduct(double[] c, int nRows) {
        super.computeColProduct(c, nRows);
        for (int x = 0; x < this._colIndexes.length; ++x) {
            int n = this._colIndexes[x];
            c[n] = c[n] * this._defaultTuple[x];
        }
    }

    @Override
    protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) {
        throw new NotImplementedException();
    }

    @Override
    public int[] getCounts(int[] counts) {
        return this._data.getCounts(counts);
    }

    @Override
    public long getNumberNonZeros(int nRows) {
        long c = super.getNumberNonZeros(nRows);
        int count = this._numRows - this._data.size();
        for (int x = 0; x < this._colIndexes.length; ++x) {
            c += this._defaultTuple[x] != 0.0 ? (long)count : 0L;
        }
        return c;
    }

    @Override
    public long estimateInMemorySize() {
        long size = super.estimateInMemorySize();
        size += this._indexes.getInMemorySize();
        size += this._data.getInMemorySize();
        return size += (long)(8 * this._colIndexes.length);
    }

    @Override
    public AColGroup scalarOperation(ScalarOperator op) {
        double[] newDefaultTuple = new double[this._defaultTuple.length];
        for (int i = 0; i < this._defaultTuple.length; ++i) {
            newDefaultTuple[i] = op.executeScalar(this._defaultTuple[i]);
        }
        ADictionary nDict = this._dict.applyScalarOp(op);
        return ColGroupSDC.create(this._colIndexes, this._numRows, nDict, newDefaultTuple, this._indexes, this._data, this.getCachedCounts());
    }

    @Override
    public AColGroup unaryOperation(UnaryOperator op) {
        double[] newDefaultTuple = new double[this._defaultTuple.length];
        for (int i = 0; i < this._defaultTuple.length; ++i) {
            newDefaultTuple[i] = op.fn.execute(this._defaultTuple[i]);
        }
        ADictionary nDict = this._dict.applyUnaryOp(op);
        return ColGroupSDC.create(this._colIndexes, this._numRows, nDict, newDefaultTuple, this._indexes, this._data, this.getCachedCounts());
    }

    @Override
    public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) {
        double[] newDefaultTuple = new double[this._defaultTuple.length];
        for (int i = 0; i < this._defaultTuple.length; ++i) {
            newDefaultTuple[i] = op.fn.execute(v[this._colIndexes[i]], this._defaultTuple[i]);
        }
        ADictionary newDict = this._dict.binOpLeft(op, v, this._colIndexes);
        return ColGroupSDC.create(this._colIndexes, this._numRows, newDict, newDefaultTuple, this._indexes, this._data, this.getCachedCounts());
    }

    @Override
    public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) {
        double[] newDefaultTuple = new double[this._defaultTuple.length];
        for (int i = 0; i < this._defaultTuple.length; ++i) {
            newDefaultTuple[i] = op.fn.execute(this._defaultTuple[i], v[this._colIndexes[i]]);
        }
        ADictionary newDict = this._dict.binOpRight(op, v, this._colIndexes);
        return ColGroupSDC.create(this._colIndexes, this._numRows, newDict, newDefaultTuple, this._indexes, this._data, this.getCachedCounts());
    }

    @Override
    public void write(DataOutput out) throws IOException {
        super.write(out);
        this._indexes.write(out);
        this._data.write(out);
        for (double d : this._defaultTuple) {
            out.writeDouble(d);
        }
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        super.readFields(in);
        this._indexes = OffsetFactory.readIn(in);
        this._data = MapToFactory.readIn(in);
        this._defaultTuple = new double[this._colIndexes.length];
        for (int i = 0; i < this._colIndexes.length; ++i) {
            this._defaultTuple[i] = in.readDouble();
        }
    }

    @Override
    public long getExactSizeOnDisk() {
        long ret = super.getExactSizeOnDisk();
        ret += this._data.getExactSizeOnDisk();
        ret += this._indexes.getExactSizeOnDisk();
        return ret += (long)(8 * this._colIndexes.length);
    }

    @Override
    public AColGroup replace(double pattern, double replace) {
        ADictionary replaced = this._dict.replace(pattern, replace, this._colIndexes.length);
        double[] newDefaultTuple = new double[this._defaultTuple.length];
        for (int i = 0; i < this._defaultTuple.length; ++i) {
            newDefaultTuple[i] = this._defaultTuple[i] == pattern ? replace : this._defaultTuple[i];
        }
        return ColGroupSDC.create(this._colIndexes, this._numRows, replaced, newDefaultTuple, this._indexes, this._data, this.getCachedCounts());
    }

    @Override
    public AColGroup extractCommon(double[] constV) {
        for (int i = 0; i < this._colIndexes.length; ++i) {
            int n = this._colIndexes[i];
            constV[n] = constV[n] + this._defaultTuple[i];
        }
        ADictionary subtractedDict = this._dict.subtractTuple(this._defaultTuple);
        return ColGroupSDCZeros.create(this._colIndexes, this._numRows, subtractedDict, this._indexes, this._data, this.getCounts());
    }

    public AColGroup subtractDefaultTuple() {
        ADictionary subtractedDict = this._dict.subtractTuple(this._defaultTuple);
        return ColGroupSDCZeros.create(this._colIndexes, this._numRows, subtractedDict, this._indexes, this._data, this.getCounts());
    }

    @Override
    public CM_COV_Object centralMoment(CMOperator op, int nRows) {
        CM_COV_Object ret = super.centralMoment(op, nRows);
        int count = this._numRows - this._data.size();
        op.fn.execute(ret, this._defaultTuple[0], count);
        return ret;
    }

    @Override
    public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
        ADictionary d = this._dict.rexpandCols(max, ignore, cast, this._colIndexes.length);
        return ColGroupSDC.rexpandCols(max, ignore, cast, nRows, d, this._indexes, this._data, this.getCachedCounts(), this._defaultTuple[0]);
    }

    protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows, ADictionary d, AOffset indexes, AMapToData data, int[] counts, double def) {
        if (d == null) {
            if (def <= 0.0 || def > (double)max) {
                return ColGroupEmpty.create(max);
            }
            double[] retDef = new double[max];
            retDef[(int)def - 1] = 1.0;
            return ColGroupSDCSingle.create(Util.genColsIndices(max), nRows, new Dictionary(new double[max]), retDef, indexes, null);
        }
        if (def <= 0.0) {
            if (ignore) {
                return ColGroupSDCZeros.create(Util.genColsIndices(max), nRows, d, indexes, data, counts);
            }
            throw new DMLRuntimeException("Invalid content of zero in rexpand");
        }
        if (def > (double)max) {
            return ColGroupSDCZeros.create(Util.genColsIndices(max), nRows, d, indexes, data, counts);
        }
        double[] retDef = new double[max];
        retDef[(int)def - 1] = 1.0;
        return ColGroupSDC.create(Util.genColsIndices(max), nRows, d, retDef, indexes, data, counts);
    }

    @Override
    public double getCost(ComputationCostEstimator e, int nRows) {
        int nVals = this.getNumValues();
        int nCols = this.getNumCols();
        int nRowsScanned = this._data.size();
        return e.getCost(nRows, nRowsScanned, nCols, nVals, this._dict.getSparsity());
    }

    @Override
    protected AColGroup sliceMultiColumns(int idStart, int idEnd, int[] outputCols) {
        ColGroupSDC ret = (ColGroupSDC)super.sliceMultiColumns(idStart, idEnd, outputCols);
        ret._defaultTuple = new double[idEnd - idStart];
        int i = idStart;
        int j = 0;
        while (i < idEnd) {
            ret._defaultTuple[j] = this._defaultTuple[i];
            ++i;
            ++j;
        }
        return ret;
    }

    @Override
    protected AColGroup sliceSingleColumn(int idx) {
        ColGroupSDC ret = (ColGroupSDC)super.sliceSingleColumn(idx);
        ret._defaultTuple = new double[1];
        ret._defaultTuple[0] = this._defaultTuple[idx];
        return ret;
    }

    @Override
    public boolean containsValue(double pattern) {
        if (pattern == 0.0 && this._zeros) {
            return true;
        }
        boolean ret = this._dict.containsValue(pattern);
        if (ret) {
            return ret;
        }
        for (double v : this._defaultTuple) {
            if (v != pattern) continue;
            return true;
        }
        return false;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s", "Default: "));
        sb.append(Arrays.toString(this._defaultTuple));
        sb.append(String.format("\n%15s", "Indexes: "));
        sb.append(this._indexes.toString());
        sb.append(String.format("\n%15s", "Data: "));
        sb.append(this._data.toString());
        return sb.toString();
    }
}

