/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.combination;

import com.google.common.math.DoubleMath;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.lang3.Range;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

class ScoreCombinationUtil {
    @Generated
    private static final Logger log = LogManager.getLogger(ScoreCombinationUtil.class);
    private static final String PARAM_NAME_WEIGHTS = "weights";
    private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f;

    ScoreCombinationUtil() {
    }

    public List<Float> getWeights(Map<String, Object> params) {
        if (Objects.isNull(params) || params.isEmpty()) {
            return List.of();
        }
        List<Float> weightsList = params.getOrDefault(PARAM_NAME_WEIGHTS, List.of()).stream().map(Double::floatValue).collect(Collectors.toUnmodifiableList());
        this.validateWeights(weightsList);
        return weightsList;
    }

    public void validateParams(Map<String, Object> actualParams, Set<String> supportedParams) {
        if (Objects.isNull(actualParams) || actualParams.isEmpty()) {
            return;
        }
        Optional<String> optionalNotSupportedParam = actualParams.keySet().stream().filter(paramName -> !supportedParams.contains(paramName)).findFirst();
        if (optionalNotSupportedParam.isPresent()) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "provided parameter for combination technique is not supported. supported parameters are [%s]", supportedParams.stream().collect(Collectors.joining(","))));
        }
        if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase) && !(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS));
        }
    }

    public float getWeightForSubQuery(List<Float> weights, int indexOfSubQuery) {
        return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery).floatValue() : 1.0f;
    }

    protected void validateIfWeightsMatchScores(float[] scores, List<Float> weights) {
        if (weights.isEmpty()) {
            return;
        }
        if (scores.length != weights.size()) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "number of weights [%d] must match number of sub-queries [%d] in hybrid query", weights.size(), scores.length));
        }
    }

    private void validateWeights(List<Float> weightsList) {
        boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between((Comparable)Float.valueOf(0.0f), (Comparable)Float.valueOf(1.0f)).contains(weight));
        if (isOutOfRange) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "all weights must be in range [0.0 ... 1.0], submitted weights: %s", Arrays.toString((Object[])weightsList.toArray(new Float[0]))));
        }
        float sumOfWeights = weightsList.stream().reduce(Float.valueOf(0.0f), Float::sum).floatValue();
        if (!DoubleMath.fuzzyEquals((double)1.0, (double)sumOfWeights, (double)0.01f)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "sum of weights for combination must be equal to 1.0, submitted weights: %s", Arrays.toString((Object[])weightsList.toArray(new Float[0]))));
        }
    }
}

