/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.engine.faiss;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.engine.AbstractMethodResolver;
import org.opensearch.knn.index.engine.Encoder;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.MethodComponent;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.engine.ResolvedMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.faiss.FaissHNSWMethod;
import org.opensearch.knn.index.engine.faiss.FaissIVFMethod;
import org.opensearch.knn.index.mapper.CompressionLevel;
import org.opensearch.knn.index.mapper.Mode;

public class FaissMethodResolver
extends AbstractMethodResolver {
    private static final Set<CompressionLevel> SUPPORTED_COMPRESSION_LEVELS = Set.of(CompressionLevel.x1, CompressionLevel.x2, CompressionLevel.x8, CompressionLevel.x16, CompressionLevel.x32);

    @Override
    public ResolvedMethodContext resolveMethod(KNNMethodContext knnMethodContext, KNNMethodConfigContext knnMethodConfigContext, boolean shouldRequireTraining, SpaceType spaceType) {
        this.validateConfig(knnMethodConfigContext);
        KNNMethodContext resolvedKNNMethodContext = this.initResolvedKNNMethodContext(knnMethodContext, KNNEngine.FAISS, spaceType, shouldRequireTraining ? "ivf" : "hnsw");
        MethodComponent method = !"hnsw".equals(resolvedKNNMethodContext.getMethodComponentContext().getName()) ? FaissIVFMethod.IVF_COMPONENT : FaissHNSWMethod.HNSW_COMPONENT;
        Map<String, Encoder> encoderMap = method == FaissHNSWMethod.HNSW_COMPONENT ? FaissHNSWMethod.SUPPORTED_ENCODERS : FaissIVFMethod.SUPPORTED_ENCODERS;
        this.resolveEncoder(resolvedKNNMethodContext, knnMethodConfigContext, encoderMap);
        CompressionLevel resolvedCompressionLevel = this.resolveCompressionLevelFromMethodContext(resolvedKNNMethodContext, knnMethodConfigContext, encoderMap);
        this.validateEncoderConfig(resolvedKNNMethodContext, knnMethodConfigContext, encoderMap);
        this.validateCompressionConflicts(knnMethodConfigContext.getCompressionLevel(), resolvedCompressionLevel);
        knnMethodConfigContext.setCompressionLevel(resolvedCompressionLevel);
        this.resolveMethodParams(resolvedKNNMethodContext.getMethodComponentContext(), knnMethodConfigContext, method);
        return ResolvedMethodContext.builder().knnMethodContext(resolvedKNNMethodContext).compressionLevel(resolvedCompressionLevel).build();
    }

    private void resolveEncoder(KNNMethodContext resolvedKNNMethodContext, KNNMethodConfigContext knnMethodConfigContext, Map<String, Encoder> encoderMap) {
        if (!this.shouldEncoderBeResolved(resolvedKNNMethodContext, knnMethodConfigContext)) {
            return;
        }
        CompressionLevel resolvedCompressionLevel = this.getDefaultCompressionLevel(knnMethodConfigContext);
        if (resolvedCompressionLevel == CompressionLevel.x1) {
            return;
        }
        MethodComponentContext encoderComponentContext = new MethodComponentContext("flat", new HashMap<String, Object>());
        Encoder encoder = encoderMap.get("flat");
        if (CompressionLevel.x2 == resolvedCompressionLevel) {
            encoderComponentContext = new MethodComponentContext("sq", new HashMap<String, Object>());
            encoder = encoderMap.get("sq");
            encoderComponentContext.getParameters().put("type", "fp16");
        }
        if (CompressionLevel.x8 == resolvedCompressionLevel) {
            encoderComponentContext = new MethodComponentContext("binary", new HashMap<String, Object>());
            encoder = encoderMap.get("binary");
            encoderComponentContext.getParameters().put("bits", CompressionLevel.x8.numBitsForFloat32());
        }
        if (CompressionLevel.x16 == resolvedCompressionLevel) {
            encoderComponentContext = new MethodComponentContext("binary", new HashMap<String, Object>());
            encoder = encoderMap.get("binary");
            encoderComponentContext.getParameters().put("bits", CompressionLevel.x16.numBitsForFloat32());
        }
        if (CompressionLevel.x32 == resolvedCompressionLevel) {
            encoderComponentContext = new MethodComponentContext("binary", new HashMap<String, Object>());
            encoder = encoderMap.get("binary");
            encoderComponentContext.getParameters().put("bits", CompressionLevel.x32.numBitsForFloat32());
        }
        Map<String, Object> resolvedParams = MethodComponent.getParameterMapWithDefaultsAdded(encoderComponentContext, encoder.getMethodComponent(), knnMethodConfigContext);
        encoderComponentContext.getParameters().putAll(resolvedParams);
        resolvedKNNMethodContext.getMethodComponentContext().getParameters().put("encoder", encoderComponentContext);
    }

    private void validateConfig(KNNMethodConfigContext knnMethodConfigContext) {
        CompressionLevel compressionLevel = knnMethodConfigContext.getCompressionLevel();
        ValidationException validationException = this.validateCompressionSupported(compressionLevel, SUPPORTED_COMPRESSION_LEVELS, KNNEngine.FAISS, null);
        if (validationException != null) {
            throw validationException;
        }
    }

    protected void validateEncoderConfig(KNNMethodContext resolvedKnnMethodContext, KNNMethodConfigContext knnMethodConfigContext, Map<String, Encoder> encoderMap) {
        if (!this.isEncoderSpecified(resolvedKnnMethodContext)) {
            return;
        }
        Encoder encoder = encoderMap.get(this.getEncoderName(resolvedKnnMethodContext));
        if (encoder == null) {
            return;
        }
        TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();
        TrainingConfigValidationOutput validationOutput = encoder.validateEncoderConfig(inputBuilder.knnMethodContext(resolvedKnnMethodContext).knnMethodConfigContext(knnMethodConfigContext).build());
        if (validationOutput.getValid() != null && !validationOutput.getValid().booleanValue()) {
            ValidationException validationException = new ValidationException();
            validationException.addValidationError(validationOutput.getErrorMessage());
            throw validationException;
        }
    }

    private CompressionLevel getDefaultCompressionLevel(KNNMethodConfigContext knnMethodConfigContext) {
        if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel())) {
            return knnMethodConfigContext.getCompressionLevel();
        }
        if (knnMethodConfigContext.getMode() == Mode.ON_DISK) {
            return CompressionLevel.x32;
        }
        return CompressionLevel.x1;
    }
}

