/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.common.network.sasl.registration;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.protobuf.ByteString;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeoutException;
import org.apache.celeborn.common.exception.CelebornException;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientBootstrap;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.sasl.CelebornSaslClient;
import org.apache.celeborn.common.network.sasl.SaslClientBootstrap;
import org.apache.celeborn.common.network.sasl.SaslCredentials;
import org.apache.celeborn.common.network.sasl.SaslTimeoutException;
import org.apache.celeborn.common.network.sasl.registration.RegistrationInfo;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbAuthType;
import org.apache.celeborn.common.protocol.PbAuthenticationInitiationRequest;
import org.apache.celeborn.common.protocol.PbAuthenticationInitiationResponse;
import org.apache.celeborn.common.protocol.PbRegisterApplicationRequest;
import org.apache.celeborn.common.protocol.PbRegisterApplicationResponse;
import org.apache.celeborn.common.protocol.PbSaslMechanism;
import org.apache.celeborn.common.protocol.PbSaslRequest;
import org.apache.celeborn.common.util.JavaUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RegistrationClientBootstrap
implements TransportClientBootstrap {
    private static final Logger LOG = LoggerFactory.getLogger(RegistrationClientBootstrap.class);
    private static final String VERSION = "1.0";
    private static final List<PbSaslMechanism> SASL_MECHANISMS = Lists.newArrayList((Object[])new PbSaslMechanism[]{PbSaslMechanism.newBuilder().setMechanism("ANONYMOUS").addAuthTypes(PbAuthType.CLIENT_AUTH).build(), PbSaslMechanism.newBuilder().setMechanism("DIGEST-MD5").addAuthTypes(PbAuthType.CONNECTION_AUTH).build()});
    private final TransportConf conf;
    private final String appId;
    private final SaslCredentials saslCredentials;
    private final RegistrationInfo registrationInfo;

    public RegistrationClientBootstrap(TransportConf conf, String appId, SaslCredentials saslCredentials, RegistrationInfo registrationInfo) {
        this.conf = (TransportConf)Preconditions.checkNotNull((Object)conf, (Object)"conf");
        this.appId = (String)Preconditions.checkNotNull((Object)appId, (Object)"appId");
        this.saslCredentials = (SaslCredentials)Preconditions.checkNotNull((Object)saslCredentials, (Object)"saslCredentials");
        this.registrationInfo = (RegistrationInfo)Preconditions.checkNotNull((Object)registrationInfo, (Object)"registrationInfo");
    }

    @Override
    public void doBootstrap(TransportClient client) throws RuntimeException {
        if (this.registrationInfo.getRegistrationState() == RegistrationInfo.RegistrationState.REGISTERED) {
            LOG.info("client has already registered, skip register.");
            this.doSaslBootstrap(client);
            return;
        }
        try {
            LOG.info("authentication initiation started for {}", (Object)this.appId);
            this.doAuthInitiation(client);
            LOG.info("authentication initiation successful for {}", (Object)this.appId);
            this.doClientAuthentication(client);
            LOG.info("client authenticated for {}", (Object)this.appId);
            this.register(client);
            LOG.info("Registration for {}", (Object)this.appId);
            this.registrationInfo.setRegistrationState(RegistrationInfo.RegistrationState.REGISTERED);
            client.setClientId(this.appId);
        }
        catch (IOException | CelebornException e) {
            throw new RuntimeException(e);
        }
        finally {
            if (this.registrationInfo.getRegistrationState() != RegistrationInfo.RegistrationState.REGISTERED) {
                this.registrationInfo.setRegistrationState(RegistrationInfo.RegistrationState.FAILED);
            }
        }
    }

    private void doAuthInitiation(TransportClient client) throws IOException, CelebornException {
        ByteBuffer authInitResponseBuffer;
        PbAuthenticationInitiationRequest authInitRequest = PbAuthenticationInitiationRequest.newBuilder().setVersion(VERSION).setAuthEnabled(true).addAllSaslMechanisms(SASL_MECHANISMS).build();
        TransportMessage msg = new TransportMessage(MessageType.AUTHENTICATION_INITIATION_REQUEST, authInitRequest.toByteArray());
        try {
            authInitResponseBuffer = client.sendRpcSync(msg.toByteBuffer(), this.conf.saslTimeoutMs());
        }
        catch (RuntimeException ex) {
            if (ex.getCause() instanceof TimeoutException) {
                throw new SaslTimeoutException(ex.getCause());
            }
            throw ex;
        }
        PbAuthenticationInitiationResponse authInitResponse = (PbAuthenticationInitiationResponse)TransportMessage.fromByteBuffer(authInitResponseBuffer).getParsedPayload();
        if (!this.validateServerResponse(authInitResponse)) {
            String exMsg = "Registration failed due to incompatibility with the server. InitRequest: " + authInitRequest + " InitResponse: " + authInitResponse;
            throw new CelebornException(exMsg);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void doClientAuthentication(TransportClient client) throws IOException {
        CelebornSaslClient saslClient = new CelebornSaslClient("ANONYMOUS", null, null);
        try {
            byte[] payload = saslClient.firstToken();
            while (!saslClient.isComplete()) {
                ByteBuffer response;
                TransportMessage msg = new TransportMessage(MessageType.SASL_REQUEST, PbSaslRequest.newBuilder().setMethod("ANONYMOUS").setAuthType(PbAuthType.CLIENT_AUTH).setPayload(ByteString.copyFrom((byte[])payload)).build().toByteArray());
                try {
                    LOG.info("Sending SASL message for client authentication");
                    response = client.sendRpcSync(msg.toByteBuffer(), this.conf.saslTimeoutMs());
                }
                catch (RuntimeException ex) {
                    if (ex.getCause() instanceof TimeoutException) {
                        throw new SaslTimeoutException(ex.getCause());
                    }
                    throw ex;
                }
                payload = saslClient.response(JavaUtils.bufferToArray(response));
            }
        }
        finally {
            try {
                saslClient.dispose();
            }
            catch (RuntimeException e) {
                LOG.warn("Error while disposing SASL client", (Throwable)e);
            }
        }
    }

    private void register(TransportClient client) throws IOException, CelebornException {
        ByteBuffer response;
        TransportMessage msg = new TransportMessage(MessageType.REGISTER_APPLICATION_REQUEST, PbRegisterApplicationRequest.newBuilder().setId(this.appId).setSecret(this.saslCredentials.getPassword()).build().toByteArray());
        try {
            response = client.sendRpcSync(msg.toByteBuffer(), this.conf.saslTimeoutMs());
        }
        catch (RuntimeException ex) {
            if (ex.getCause() instanceof TimeoutException) {
                throw new SaslTimeoutException(ex.getCause());
            }
            throw ex;
        }
        PbRegisterApplicationResponse registerApplicationResponse = (PbRegisterApplicationResponse)TransportMessage.fromByteBuffer(response).getParsedPayload();
        if (!registerApplicationResponse.getStatus()) {
            throw new CelebornException("Application registration failed. AppId = " + this.appId);
        }
    }

    private void doSaslBootstrap(TransportClient client) {
        SaslClientBootstrap bootstrap = new SaslClientBootstrap(this.conf, this.appId, this.saslCredentials);
        bootstrap.doBootstrap(client);
    }

    private boolean validateServerResponse(PbAuthenticationInitiationResponse authInitResponse) {
        if (!authInitResponse.getVersion().equals(VERSION)) {
            return false;
        }
        Map<PbAuthType, Set<String>> serverSupportedMechs = RegistrationClientBootstrap.findSupportedSaslMechs(authInitResponse.getSaslMechanismsList());
        Set<String> clientAuthMechs = serverSupportedMechs.get((Object)PbAuthType.CLIENT_AUTH);
        if (clientAuthMechs == null) {
            return false;
        }
        if (!clientAuthMechs.contains("ANONYMOUS")) {
            return false;
        }
        Set<String> connectionAuthMechs = serverSupportedMechs.get((Object)PbAuthType.CONNECTION_AUTH);
        if (connectionAuthMechs == null) {
            return false;
        }
        return connectionAuthMechs.contains("DIGEST-MD5");
    }

    private static Map<PbAuthType, Set<String>> findSupportedSaslMechs(List<PbSaslMechanism> serverSupportedMechs) {
        HashMap<PbAuthType, Set<String>> supportedMechs = new HashMap<PbAuthType, Set<String>>();
        for (PbSaslMechanism mech : serverSupportedMechs) {
            for (PbAuthType authType : mech.getAuthTypesList()) {
                Set mechanisms = supportedMechs.computeIfAbsent(authType, k -> Sets.newHashSet());
                mechanisms.add(mech.getMechanism());
            }
        }
        return supportedMechs;
    }
}

