/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.ltr.model;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.ModelException;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;

public class MultipleAdditiveTreesModel
extends LTRScoringModel {
    private final HashMap<String, Integer> fname2index = new HashMap();
    private List<RegressionTree> trees;

    private RegressionTree createRegressionTree(Map<String, Object> map) {
        RegressionTree rt = new RegressionTree();
        if (map != null) {
            SolrPluginUtils.invokeSetters((Object)rt, map.entrySet());
        }
        return rt;
    }

    private RegressionTreeNode createRegressionTreeNode(Map<String, Object> map) {
        RegressionTreeNode rtn = new RegressionTreeNode();
        if (map != null) {
            SolrPluginUtils.invokeSetters((Object)rtn, map.entrySet());
        }
        return rtn;
    }

    public void setTrees(Object trees) {
        this.trees = new ArrayList<RegressionTree>();
        for (Object o : (List)trees) {
            RegressionTree rt = this.createRegressionTree((Map)o);
            this.trees.add(rt);
        }
    }

    public MultipleAdditiveTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, List<Feature> allFeatures, Map<String, Object> params) {
        super(name, features, norms, featureStoreName, allFeatures, params);
        for (int i = 0; i < features.size(); ++i) {
            String key = features.get(i).getName();
            this.fname2index.put(key, i);
        }
    }

    @Override
    protected void validate() throws ModelException {
        super.validate();
        if (this.trees == null) {
            throw new ModelException("no trees declared for model " + this.name);
        }
        for (RegressionTree tree : this.trees) {
            tree.validate();
        }
    }

    @Override
    public float score(float[] modelFeatureValuesNormalized) {
        float score = 0.0f;
        for (RegressionTree t : this.trees) {
            score += t.score(modelFeatureValuesNormalized);
        }
        return score;
    }

    private static float scoreNode(float[] featureVector, RegressionTreeNode regressionTreeNode) {
        while (!regressionTreeNode.isLeaf()) {
            if (regressionTreeNode.featureIndex < 0 || regressionTreeNode.featureIndex >= featureVector.length) {
                return 0.0f;
            }
            if (featureVector[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold.floatValue()) {
                regressionTreeNode = regressionTreeNode.left;
                continue;
            }
            regressionTreeNode = regressionTreeNode.right;
        }
        return regressionTreeNode.value;
    }

    private static void validateNode(RegressionTreeNode regressionTreeNode) throws ModelException {
        Stack<RegressionTreeNode> stack = new Stack<RegressionTreeNode>();
        stack.push(regressionTreeNode);
        while (!stack.empty()) {
            RegressionTreeNode topStackNode = (RegressionTreeNode)stack.pop();
            if (topStackNode.isLeaf()) {
                if (topStackNode.left == null && topStackNode.right == null) continue;
                throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=" + topStackNode.left + " and right=" + topStackNode.right);
            }
            if (null == topStackNode.threshold) {
                throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
            }
            if (null == topStackNode.left) {
                throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
            }
            stack.push(topStackNode.left);
            if (null == topStackNode.right) {
                throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
            }
            stack.push(topStackNode.right);
        }
    }

    private static String explainNode(float[] featureVector, RegressionTreeNode regressionTreeNode) {
        StringBuilder returnValueBuilder = new StringBuilder();
        while (true) {
            if (regressionTreeNode.isLeaf()) {
                returnValueBuilder.append("val: " + regressionTreeNode.value);
                return returnValueBuilder.toString();
            }
            if (regressionTreeNode.featureIndex < 0 || regressionTreeNode.featureIndex >= featureVector.length) {
                returnValueBuilder.append("'" + regressionTreeNode.feature + "' does not exist in FV, Return Zero");
                return returnValueBuilder.toString();
            }
            if (featureVector[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold.floatValue()) {
                returnValueBuilder.append("'" + regressionTreeNode.feature + "':" + featureVector[regressionTreeNode.featureIndex] + " <= " + regressionTreeNode.threshold + ", Go Left | ");
                regressionTreeNode = regressionTreeNode.left;
                continue;
            }
            returnValueBuilder.append("'" + regressionTreeNode.feature + "':" + featureVector[regressionTreeNode.featureIndex] + " > " + regressionTreeNode.threshold + ", Go Right | ");
            regressionTreeNode = regressionTreeNode.right;
        }
    }

    @Override
    public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
        float[] fv = new float[featureExplanations.size()];
        int index = 0;
        for (Explanation featureExplain : featureExplanations) {
            fv[index] = featureExplain.getValue().floatValue();
            ++index;
        }
        ArrayList<Explanation> details = new ArrayList<Explanation>();
        index = 0;
        for (RegressionTree t : this.trees) {
            float score = t.score(fv);
            Explanation p = Explanation.match((Number)Float.valueOf(score), (String)("tree " + index + " | " + t.explain(fv)), (Explanation[])new Explanation[0]);
            details.add(p);
            ++index;
        }
        return Explanation.match((Number)Float.valueOf(finalScore), (String)(this.toString() + " model applied to features, sum of:"), details);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder(this.getClass().getSimpleName());
        sb.append("(name=").append(this.getName());
        sb.append(",trees=[");
        for (int ii = 0; ii < this.trees.size(); ++ii) {
            if (ii > 0) {
                sb.append(',');
            }
            sb.append(this.trees.get(ii));
        }
        sb.append("])");
        return sb.toString();
    }

    public class RegressionTree {
        private Float weight;
        private RegressionTreeNode root;

        public void setWeight(float weight) {
            this.weight = Float.valueOf(weight);
        }

        public void setWeight(String weight) {
            this.weight = Float.valueOf(weight);
        }

        public void setRoot(Object root) {
            this.root = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)root);
        }

        public float score(float[] featureVector) {
            return this.weight.floatValue() * MultipleAdditiveTreesModel.scoreNode(featureVector, this.root);
        }

        public String explain(float[] featureVector) {
            return MultipleAdditiveTreesModel.explainNode(featureVector, this.root);
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("(weight=").append(this.weight);
            sb.append(",root=").append(this.root);
            sb.append(")");
            return sb.toString();
        }

        public void validate() throws ModelException {
            if (this.weight == null) {
                throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight");
            }
            if (this.root == null) {
                throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree");
            }
            MultipleAdditiveTreesModel.validateNode(this.root);
        }
    }

    public class RegressionTreeNode {
        private static final float NODE_SPLIT_SLACK = 1.0E-6f;
        private float value = 0.0f;
        private String feature;
        private int featureIndex = -1;
        private Float threshold;
        private RegressionTreeNode left;
        private RegressionTreeNode right;

        public void setValue(float value) {
            this.value = value;
        }

        public void setValue(String value) {
            this.value = Float.parseFloat(value);
        }

        public void setFeature(String feature) {
            this.feature = feature;
            Integer idx = (Integer)MultipleAdditiveTreesModel.this.fname2index.get(this.feature);
            this.featureIndex = idx == null ? -1 : idx;
        }

        public void setThreshold(float threshold) {
            this.threshold = Float.valueOf(threshold + 1.0E-6f);
        }

        public void setThreshold(String threshold) {
            this.threshold = Float.valueOf(Float.parseFloat(threshold) + 1.0E-6f);
        }

        public void setLeft(Object left) {
            this.left = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)left);
        }

        public void setRight(Object right) {
            this.right = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)right);
        }

        public boolean isLeaf() {
            return this.feature == null;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            if (this.isLeaf()) {
                sb.append(this.value);
            } else {
                sb.append("(feature=").append(this.feature);
                sb.append(",threshold=").append(this.threshold.floatValue() - 1.0E-6f);
                sb.append(",left=").append(this.left);
                sb.append(",right=").append(this.right);
                sb.append(')');
            }
            return sb.toString();
        }
    }
}

