/*
 * Decompiled with CFR 0.152.
 */
package biz.k11i.xgboost.gbm;

import biz.k11i.xgboost.gbm.GBBase;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.util.FVec;
import biz.k11i.xgboost.util.ModelReader;
import java.io.IOException;
import java.io.Serializable;

public class GBTree
extends GBBase {
    ModelParam mparam;
    private RegTree[] trees;
    private int[] tree_info;
    RegTree[][] _groupTrees;

    GBTree() {
    }

    @Override
    public void loadModel(ModelReader reader, boolean with_pbuffer) throws IOException {
        int i;
        this.mparam = new ModelParam(reader);
        this.trees = new RegTree[this.mparam.num_trees];
        for (i = 0; i < this.mparam.num_trees; ++i) {
            this.trees[i] = new RegTree();
            this.trees[i].loadModel(reader);
        }
        if (this.mparam.num_trees != 0) {
            this.tree_info = reader.readIntArray(this.mparam.num_trees);
        }
        if (this.mparam.num_pbuffer != 0L && with_pbuffer) {
            reader.skip(4L * this.mparam.predBufferSize());
            reader.skip(4L * this.mparam.predBufferSize());
        }
        this._groupTrees = new RegTree[this.mparam.num_output_group][];
        for (i = 0; i < this.mparam.num_output_group; ++i) {
            int j;
            int treeCount = 0;
            for (j = 0; j < this.tree_info.length; ++j) {
                if (this.tree_info[j] != i) continue;
                ++treeCount;
            }
            this._groupTrees[i] = new RegTree[treeCount];
            treeCount = 0;
            for (j = 0; j < this.tree_info.length; ++j) {
                if (this.tree_info[j] != i) continue;
                this._groupTrees[i][treeCount++] = this.trees[j];
            }
        }
    }

    @Override
    public double[] predict(FVec feat, int ntree_limit) {
        double[] preds = new double[this.mparam.num_output_group];
        for (int gid = 0; gid < this.mparam.num_output_group; ++gid) {
            preds[gid] = this.pred(feat, gid, 0, ntree_limit);
        }
        return preds;
    }

    @Override
    public double predictSingle(FVec feat, int ntree_limit) {
        if (this.mparam.num_output_group != 1) {
            throw new IllegalStateException("Can't invoke predictSingle() because this model outputs multiple values: " + this.mparam.num_output_group);
        }
        return this.pred(feat, 0, 0, ntree_limit);
    }

    double pred(FVec feat, int bst_group, int root_index, int ntree_limit) {
        RegTree[] trees = this._groupTrees[bst_group];
        int treeleft = ntree_limit == 0 ? trees.length : ntree_limit;
        double psum = 0.0;
        for (int i = 0; i < treeleft; ++i) {
            psum += trees[i].getLeafValue(feat, root_index);
        }
        return psum;
    }

    @Override
    public int[] predictLeaf(FVec feat, int ntree_limit) {
        return this.predPath(feat, 0, ntree_limit);
    }

    int[] predPath(FVec feat, int root_index, int ntree_limit) {
        int treeleft = ntree_limit == 0 ? this.trees.length : ntree_limit;
        int[] leafIndex = new int[treeleft];
        for (int i = 0; i < treeleft; ++i) {
            leafIndex[i] = this.trees[i].getLeafIndex(feat, root_index);
        }
        return leafIndex;
    }

    static class ModelParam
    implements Serializable {
        final int num_trees;
        final int num_roots;
        final int num_feature;
        final long num_pbuffer;
        final int num_output_group;
        final int size_leaf_vector;
        final int[] reserved;

        ModelParam(ModelReader reader) throws IOException {
            this.num_trees = reader.readInt();
            this.num_roots = reader.readInt();
            this.num_feature = reader.readInt();
            reader.readInt();
            this.num_pbuffer = reader.readLong();
            this.num_output_group = reader.readInt();
            this.size_leaf_vector = reader.readInt();
            this.reserved = reader.readIntArray(31);
            reader.readInt();
        }

        long predBufferSize() {
            return (long)this.num_output_group * this.num_pbuffer * (long)(this.size_leaf_vector + 1);
        }
    }
}

