/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.deployment;

import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.IdsQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.ml.inference.deployment.AbstractPyTorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.ClearCacheControlMessagePytorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.InferencePyTorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
import org.elasticsearch.xpack.ml.inference.deployment.ThreadSettingsControlMessagePytorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

public class DeploymentManager {
    private static final Logger logger = LogManager.getLogger(DeploymentManager.class);
    private static final AtomicLong requestIdCounter = new AtomicLong(1L);
    public static final int NUM_RESTART_ATTEMPTS = 3;
    private final Client client;
    private final NamedXContentRegistry xContentRegistry;
    private final PyTorchProcessFactory pyTorchProcessFactory;
    private final ExecutorService executorServiceForDeployment;
    private final ExecutorService executorServiceForProcess;
    private final ThreadPool threadPool;
    private final InferenceAuditor inferenceAuditor;
    private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<Long, ProcessContext>();
    private final int maxProcesses;

    public DeploymentManager(Client client, NamedXContentRegistry xContentRegistry, ThreadPool threadPool, PyTorchProcessFactory pyTorchProcessFactory, int maxProcesses, InferenceAuditor inferenceAuditor) {
        this.client = Objects.requireNonNull(client);
        this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
        this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory);
        this.threadPool = Objects.requireNonNull(threadPool);
        this.inferenceAuditor = Objects.requireNonNull(inferenceAuditor);
        this.executorServiceForDeployment = threadPool.executor("ml_utility");
        this.executorServiceForProcess = threadPool.executor("ml_native_inference_comms");
        this.maxProcesses = maxProcesses;
    }

    public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
        return Optional.ofNullable((ProcessContext)this.processContextByAllocation.get(task.getId())).map(processContext -> {
            PyTorchResultProcessor.ResultStats stats = processContext.getResultProcessor().getResultStats();
            PyTorchResultProcessor.RecentStats recentStats = stats.recentStats();
            return new ModelStats(processContext.startTime, stats.timingStats().getCount(), stats.timingStats().getAverage(), stats.timingStatsExcludingCacheHits().getAverage(), stats.lastUsed(), processContext.priorityProcessWorker.queueSize() + stats.numberOfPendingResults(), stats.errorCount(), stats.cacheHitCount(), processContext.rejectedExecutionCount.intValue(), processContext.timeoutCount.intValue(), processContext.numThreadsPerAllocation, processContext.numAllocations, stats.peakThroughput(), recentStats.requestsProcessed(), recentStats.avgInferenceTime(), recentStats.cacheHitCount());
        });
    }

    ProcessContext addProcessContext(Long id, ProcessContext processContext) {
        return this.processContextByAllocation.putIfAbsent(id, processContext);
    }

    public void startDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
        this.startDeployment(task, null, finalListener);
    }

    public void startDeployment(TrainedModelDeploymentTask task, Integer startsCount, ActionListener<TrainedModelDeploymentTask> finalListener) {
        logger.info("[{}] Starting model deployment of model [{}]", (Object)task.getDeploymentId(), (Object)task.getModelId());
        if (this.processContextByAllocation.size() >= this.maxProcesses) {
            finalListener.onFailure((Exception)((Object)ExceptionsHelper.serverError((String)"[{}] Could not start inference process as the node reached the max number [{}] of processes", (Object[])new Object[]{task.getDeploymentId(), this.maxProcesses})));
            return;
        }
        final ProcessContext processContext = new ProcessContext(task, startsCount);
        if (this.addProcessContext(task.getId(), processContext) != null) {
            finalListener.onFailure((Exception)((Object)ExceptionsHelper.serverError((String)"[{}] Could not create inference process as one already exists", (Object[])new Object[]{task.getDeploymentId()})));
            return;
        }
        final ActionListener failedDeploymentListener = ActionListener.wrap(arg_0 -> finalListener.onResponse(arg_0), failure -> {
            ProcessContext failedContext = (ProcessContext)this.processContextByAllocation.remove(task.getId());
            if (failedContext != null) {
                failedContext.forcefullyStopProcess();
            }
            finalListener.onFailure(failure);
        });
        ActionListener modelLoadedListener = ActionListener.wrap(success -> {
            this.executorServiceForProcess.execute(() -> processContext.getResultProcessor().process((PyTorchProcess)processContext.process.get()));
            finalListener.onResponse((Object)task);
        }, arg_0 -> ((ActionListener)failedDeploymentListener).onFailure(arg_0));
        ActionListener getVerifiedModel = ActionListener.wrap(modelConfig -> {
            processContext.modelInput.set((Object)modelConfig.getInput());
            processContext.prefixes.set((Object)modelConfig.getPrefixStrings());
            InferenceConfig patt9227$temp = modelConfig.getInferenceConfig();
            if (patt9227$temp instanceof NlpConfig) {
                NlpConfig nlpConfig = (NlpConfig)patt9227$temp;
                task.init((InferenceConfig)nlpConfig);
                SearchRequest searchRequest = this.vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
                ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)TransportSearchAction.TYPE, (ActionRequest)searchRequest, (ActionListener)ActionListener.wrap(searchVocabResponse -> {
                    if (searchVocabResponse.getHits().getHits().length == 0) {
                        failedDeploymentListener.onFailure((Exception)((Object)new ResourceNotFoundException(Messages.getMessage((String)"Could not find vocabulary document [{1}] for trained model [{0}]", (Object[])new Object[]{modelConfig.getModelId(), VocabularyConfig.docId((String)modelConfig.getModelId())}), new Object[0])));
                        return;
                    }
                    Vocabulary vocabulary = this.parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0));
                    NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary);
                    NlpTask.Processor processor = nlpTask.createProcessor();
                    processContext.nlpTaskProcessor.set((Object)processor);
                    this.executorServiceForDeployment.execute((Runnable)new AbstractRunnable((TrainedModelConfig)modelConfig, modelLoadedListener){
                        final /* synthetic */ TrainedModelConfig val$modelConfig;
                        final /* synthetic */ ActionListener val$modelLoadedListener;
                        {
                            this.val$modelConfig = trainedModelConfig;
                            this.val$modelLoadedListener = actionListener2;
                        }

                        public void onFailure(Exception e) {
                            failedDeploymentListener.onFailure(e);
                        }

                        protected void doRun() {
                            processContext.startAndLoad(this.val$modelConfig.getLocation(), (ActionListener<Boolean>)this.val$modelLoadedListener);
                        }
                    });
                }, arg_0 -> ((ActionListener)failedDeploymentListener).onFailure(arg_0)));
            } else {
                failedDeploymentListener.onFailure((Exception)new IllegalArgumentException(Strings.format((String)"[%s] must be a pytorch model; found inference config of kind [%s]", (Object[])new Object[]{modelConfig.getModelId(), modelConfig.getInferenceConfig().getWriteableName()})));
            }
        }, arg_0 -> ((ActionListener)failedDeploymentListener).onFailure(arg_0));
        ActionListener verifyModelAndClusterArchitecturesListener = ActionListener.wrap(getModelResponse -> {
            assert (getModelResponse.getResources().results().size() == 1);
            TrainedModelConfig modelConfig = (TrainedModelConfig)getModelResponse.getResources().results().get(0);
            this.verifyMlNodesAndModelArchitectures(modelConfig, this.client, this.threadPool, (ActionListener<TrainedModelConfig>)getVerifiedModel);
        }, arg_0 -> ((ActionListener)failedDeploymentListener).onFailure(arg_0));
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)new GetTrainedModelsAction.Request(task.getParams().getModelId()), (ActionListener)verifyModelAndClusterArchitecturesListener);
    }

    void verifyMlNodesAndModelArchitectures(final TrainedModelConfig configToReturn, Client client, ThreadPool threadPool, final ActionListener<TrainedModelConfig> configToReturnListener) {
        ActionListener<TrainedModelConfig> verifyConfigListener = new ActionListener<TrainedModelConfig>(){

            public void onResponse(TrainedModelConfig config) {
                assert (Objects.equals(config, configToReturn));
                configToReturnListener.onResponse((Object)configToReturn);
            }

            public void onFailure(Exception e) {
                configToReturnListener.onFailure(e);
            }
        };
        this.callVerifyMlNodesAndModelArchitectures(configToReturn, verifyConfigListener, client, threadPool);
    }

    void callVerifyMlNodesAndModelArchitectures(TrainedModelConfig configToReturn, ActionListener<TrainedModelConfig> configToReturnListener, Client client, ThreadPool threadPool) {
        MlPlatformArchitecturesUtil.verifyMlNodesAndModelArchitectures(configToReturnListener, (Client)client, (ExecutorService)threadPool.executor("ml_utility"), (TrainedModelConfig)configToReturn);
    }

    private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig, String modelId) {
        return (SearchRequest)this.client.prepareSearch(new String[]{vocabularyConfig.getIndex()}).setQuery((QueryBuilder)new IdsQueryBuilder().addIds(new String[]{VocabularyConfig.docId((String)modelId)})).setSize(1).setTrackTotalHits(false).request();
    }

    Vocabulary parseVocabularyDocLeniently(SearchHit hit) throws IOException {
        Vocabulary vocabulary;
        block8: {
            XContentParser parser = XContentHelper.createParserNotCompressed((XContentParserConfiguration)LoggingDeprecationHandler.XCONTENT_PARSER_CONFIG.withRegistry(this.xContentRegistry), (BytesReference)hit.getSourceRef(), (XContentType)XContentType.JSON);
            try {
                vocabulary = (Vocabulary)Vocabulary.PARSER.apply(parser, null);
                if (parser == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (parser != null) {
                        try {
                            parser.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    logger.error(() -> "failed to parse trained model vocabulary [" + hit.getId() + "]", (Throwable)e);
                    throw e;
                }
            }
            parser.close();
        }
        return vocabulary;
    }

    public void stopDeployment(TrainedModelDeploymentTask task) {
        ProcessContext processContext = (ProcessContext)this.processContextByAllocation.remove(task.getId());
        if (processContext != null) {
            logger.info("[{}] Stopping deployment, reason [{}]", (Object)task.getDeploymentId(), (Object)task.stoppedReason().orElse("unknown"));
            processContext.forcefullyStopProcess();
        } else {
            logger.warn("[{}] No process context to stop", (Object)task.getDeploymentId());
        }
    }

    public void stopAfterCompletingPendingWork(TrainedModelDeploymentTask task) {
        ProcessContext processContext = (ProcessContext)this.processContextByAllocation.remove(task.getId());
        if (processContext != null) {
            logger.info("[{}] Stopping deployment after completing pending tasks, reason [{}]", (Object)task.getDeploymentId(), (Object)task.stoppedReason().orElse("unknown"));
            processContext.stopProcessAfterCompletingPendingWork();
        } else {
            logger.warn("[{}] No process context to stop gracefully", (Object)task.getDeploymentId());
        }
    }

    public void infer(TrainedModelDeploymentTask task, InferenceConfig config, NlpInferenceInput input, boolean skipQueue, TimeValue timeout, TrainedModelPrefixStrings.PrefixType prefixType, CancellableTask parentActionTask, boolean chunkResponse, ActionListener<InferenceResults> listener) {
        ProcessContext processContext = this.getProcessContext(task, arg_0 -> listener.onFailure(arg_0));
        if (processContext == null) {
            return;
        }
        long requestId = requestIdCounter.getAndIncrement();
        InferencePyTorchAction inferenceAction = new InferencePyTorchAction(task.getDeploymentId(), requestId, timeout, processContext, config, input, prefixType, this.threadPool, parentActionTask, chunkResponse, listener);
        PriorityProcessWorkerExecutorService.RequestPriority priority = skipQueue ? PriorityProcessWorkerExecutorService.RequestPriority.HIGH : PriorityProcessWorkerExecutorService.RequestPriority.NORMAL;
        this.executePyTorchAction(processContext, priority, inferenceAction);
    }

    public void updateNumAllocations(TrainedModelDeploymentTask task, int numAllocationThreads, TimeValue timeout, ActionListener<ThreadSettings> listener) {
        ProcessContext processContext = this.getProcessContext(task, arg_0 -> listener.onFailure(arg_0));
        if (processContext == null) {
            return;
        }
        long requestId = requestIdCounter.getAndIncrement();
        ThreadSettingsControlMessagePytorchAction controlMessageAction = new ThreadSettingsControlMessagePytorchAction(task.getDeploymentId(), requestId, numAllocationThreads, timeout, processContext, this.threadPool, listener);
        this.executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, controlMessageAction);
    }

    public void clearCache(TrainedModelDeploymentTask task, TimeValue timeout, ActionListener<AcknowledgedResponse> listener) {
        ProcessContext processContext = this.getProcessContext(task, arg_0 -> listener.onFailure(arg_0));
        if (processContext == null) {
            return;
        }
        long requestId = requestIdCounter.getAndIncrement();
        ClearCacheControlMessagePytorchAction controlMessageAction = new ClearCacheControlMessagePytorchAction(task.getDeploymentId(), requestId, timeout, processContext, this.threadPool, (ActionListener<Boolean>)listener.delegateFailureAndWrap((l, b) -> l.onResponse((Object)AcknowledgedResponse.TRUE)));
        this.executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, controlMessageAction);
    }

    void executePyTorchAction(ProcessContext processContext, PriorityProcessWorkerExecutorService.RequestPriority priority, AbstractPyTorchAction<?> action) {
        try {
            processContext.getPriorityProcessWorker().executeWithPriority(action, priority, action.getRequestId());
        }
        catch (EsRejectedExecutionException e) {
            processContext.getRejectedExecutionCount().incrementAndGet();
            action.onFailure((Exception)((Object)e));
        }
        catch (Exception e) {
            action.onFailure(e);
        }
    }

    private ProcessContext getProcessContext(TrainedModelDeploymentTask task, Consumer<Exception> errorConsumer) {
        if (task.isStopped()) {
            errorConsumer.accept((Exception)ExceptionsHelper.conflictStatusException((String)"[{}] is stopping or stopped due to [{}]", (Object[])new Object[]{task.getDeploymentId(), task.stoppedReason().orElse("")}));
            return null;
        }
        ProcessContext processContext = (ProcessContext)this.processContextByAllocation.get(task.getId());
        if (processContext == null) {
            errorConsumer.accept((Exception)ExceptionsHelper.conflictStatusException((String)"[{}] process context missing", (Object[])new Object[]{task.getDeploymentId()}));
            return null;
        }
        return processContext;
    }

    class ProcessContext {
        private static final String PROCESS_NAME = "inference process";
        private static final TimeValue COMPLETION_TIMEOUT = TimeValue.timeValueMinutes((long)3L);
        private final TrainedModelDeploymentTask task;
        private final SetOnce<PyTorchProcess> process = new SetOnce();
        private final SetOnce<NlpTask.Processor> nlpTaskProcessor = new SetOnce();
        private final SetOnce<TrainedModelInput> modelInput = new SetOnce();
        private final SetOnce<TrainedModelPrefixStrings> prefixes = new SetOnce();
        private final PyTorchResultProcessor resultProcessor;
        private final PyTorchStateStreamer stateStreamer;
        private final PriorityProcessWorkerExecutorService priorityProcessWorker;
        private final AtomicInteger rejectedExecutionCount = new AtomicInteger();
        private final AtomicInteger timeoutCount = new AtomicInteger();
        private final AtomicInteger startsCount = new AtomicInteger();
        private volatile Instant startTime;
        private volatile Integer numThreadsPerAllocation;
        private volatile Integer numAllocations;
        private volatile boolean isStopped;

        ProcessContext(TrainedModelDeploymentTask task, Integer startsCount) {
            this.task = Objects.requireNonNull(task);
            this.resultProcessor = new PyTorchResultProcessor(task.getDeploymentId(), threadSettings -> {
                this.numThreadsPerAllocation = threadSettings.numThreadsPerAllocation();
                this.numAllocations = threadSettings.numAllocations();
            });
            this.stateStreamer = new PyTorchStateStreamer(DeploymentManager.this.client, DeploymentManager.this.executorServiceForProcess, DeploymentManager.this.xContentRegistry);
            this.priorityProcessWorker = new PriorityProcessWorkerExecutorService(DeploymentManager.this.threadPool.getThreadContext(), PROCESS_NAME, task.getParams().getQueueCapacity());
            this.startsCount.set(startsCount == null ? 1 : startsCount);
        }

        PyTorchResultProcessor getResultProcessor() {
            return this.resultProcessor;
        }

        synchronized void startAndLoad(TrainedModelLocation modelLocation, ActionListener<Boolean> loadedListener) {
            assert (Thread.currentThread().getName().contains("ml_utility")) : Strings.format((String)"Must execute from [%s] but thread is [%s]", (Object[])new Object[]{"ml_utility", Thread.currentThread().getName()});
            if (this.isStopped) {
                logger.debug("[{}] model stopped before it is started", (Object)this.task.getDeploymentId());
                loadedListener.onFailure((Exception)new IllegalArgumentException("model stopped before it is started"));
                return;
            }
            logger.debug("[{}] start and load", (Object)this.task.getDeploymentId());
            this.process.set((Object)DeploymentManager.this.pyTorchProcessFactory.createProcess(this.task, DeploymentManager.this.executorServiceForProcess, () -> this.resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES), this.onProcessCrashHandleRestarts(this.startsCount, this.task.getDeploymentId())));
            this.startTime = Instant.now();
            logger.debug("[{}] process started", (Object)this.task.getDeploymentId());
            try {
                this.loadModel(modelLocation, (ActionListener<Boolean>)loadedListener.delegateFailureAndWrap((delegate, success) -> {
                    if (this.isStopped) {
                        logger.debug("[{}] model loaded but process is stopped", (Object)this.task.getDeploymentId());
                        this.killProcessIfPresent();
                        delegate.onFailure((Exception)new IllegalStateException("model loaded but process is stopped"));
                        return;
                    }
                    logger.debug("[{}] model loaded, starting priority process worker thread", (Object)this.task.getDeploymentId());
                    this.startPriorityProcessWorker();
                    delegate.onResponse(success);
                }));
            }
            catch (Exception e) {
                loadedListener.onFailure(e);
            }
        }

        private Consumer<String> onProcessCrashHandleRestarts(AtomicInteger startsCount, String deploymentId) {
            return reason -> {
                if (this.isThisProcessOlderThan1Day()) {
                    startsCount.set(1);
                    String logMessage = "[" + this.task.getDeploymentId() + "] inference process crashed due to reason [" + reason + "]. This process was started more than 24 hours ago; the starts count is reset to 1.";
                    logger.error(logMessage);
                } else {
                    logger.error("[{}] inference process crashed due to reason [{}]", (Object)this.task.getDeploymentId(), reason);
                }
                DeploymentManager.this.processContextByAllocation.remove(this.task.getId());
                this.isStopped = true;
                this.resultProcessor.stop();
                this.stateStreamer.cancel();
                if (startsCount.get() <= 3) {
                    String logAndAuditMessage = "Inference process [" + this.task.getDeploymentId() + "] failed due to [" + reason + "]. This is the [" + startsCount.get() + "] failure in 24 hours, and the process will be restarted.";
                    logger.info(logAndAuditMessage);
                    DeploymentManager.this.threadPool.executor("ml_utility").execute(() -> DeploymentManager.this.inferenceAuditor.warning(deploymentId, logAndAuditMessage));
                    this.priorityProcessWorker.shutdownNow();
                    ActionListener errorListener = ActionListener.wrap(trainedModelDeploymentTask -> logger.debug("Completed restart of inference process, the [{}] start", (Object)startsCount), e -> this.finishClosingProcess(startsCount, "Failed to restart inference process because of error [" + e.getMessage() + "]", deploymentId));
                    DeploymentManager.this.startDeployment(this.task, startsCount.incrementAndGet(), (ActionListener<TrainedModelDeploymentTask>)errorListener);
                } else {
                    this.finishClosingProcess(startsCount, (String)reason, deploymentId);
                }
            };
        }

        private boolean isThisProcessOlderThan1Day() {
            return this.startTime.isBefore(Instant.now().minus(Duration.ofDays(1L)));
        }

        private void finishClosingProcess(AtomicInteger startsCount, String reason, String deploymentId) {
            String logAndAuditMessage = "[" + this.task.getDeploymentId() + "] inference process failed after [" + startsCount.get() + "] starts in 24 hours, not restarting again.";
            logger.warn(logAndAuditMessage);
            DeploymentManager.this.threadPool.executor("ml_utility").execute(() -> DeploymentManager.this.inferenceAuditor.error(deploymentId, logAndAuditMessage));
            this.priorityProcessWorker.shutdownNowWithError(new IllegalStateException(reason));
            if (this.nlpTaskProcessor.get() != null) {
                ((NlpTask.Processor)this.nlpTaskProcessor.get()).close();
            }
            this.task.setFailed("inference process crashed due to reason [" + reason + "]");
        }

        void startPriorityProcessWorker() {
            DeploymentManager.this.executorServiceForProcess.submit(this.priorityProcessWorker::start);
        }

        synchronized void forcefullyStopProcess() {
            logger.debug(() -> Strings.format((String)"[%s] Forcefully stopping process", (Object[])new Object[]{this.task.getDeploymentId()}));
            this.prepareInternalStateForShutdown();
            this.priorityProcessWorker.shutdownNow();
            try {
                if (this.priorityProcessWorker.awaitTermination(10L, TimeUnit.SECONDS)) {
                    this.priorityProcessWorker.notifyQueueRunnables();
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                logger.info(org.elasticsearch.common.Strings.format((String)"[%s] Interrupted waiting for process worker after shutdownNow", (Object[])new Object[]{PROCESS_NAME}));
            }
            this.killProcessIfPresent();
            this.closeNlpTaskProcessor();
        }

        private void prepareInternalStateForShutdown() {
            this.isStopped = true;
            this.resultProcessor.stop();
            this.stateStreamer.cancel();
        }

        private void killProcessIfPresent() {
            try {
                if (this.process.get() == null) {
                    return;
                }
                ((PyTorchProcess)this.process.get()).kill(true);
            }
            catch (IOException e) {
                logger.error(() -> "[" + this.task.getDeploymentId() + "] Failed to kill process", (Throwable)e);
            }
        }

        private void closeNlpTaskProcessor() {
            if (this.nlpTaskProcessor.get() != null) {
                ((NlpTask.Processor)this.nlpTaskProcessor.get()).close();
            }
        }

        private synchronized void stopProcessAfterCompletingPendingWork() {
            logger.debug(() -> Strings.format((String)"[%s] Stopping process after completing its pending work", (Object[])new Object[]{this.task.getDeploymentId()}));
            this.prepareInternalStateForShutdown();
            this.signalAndWaitForWorkerTermination();
            this.stopProcessGracefully();
            this.closeNlpTaskProcessor();
        }

        private void signalAndWaitForWorkerTermination() {
            try {
                this.awaitTerminationAfterCompletingWork();
            }
            catch (TimeoutException e) {
                logger.warn(Strings.format((String)"[%s] Timed out waiting for process worker to complete, forcing a shutdown", (Object[])new Object[]{this.task.getDeploymentId()}), (Throwable)e);
                this.priorityProcessWorker.shutdown();
                this.priorityProcessWorker.notifyQueueRunnables();
            }
        }

        private void awaitTerminationAfterCompletingWork() throws TimeoutException {
            try {
                this.priorityProcessWorker.shutdown();
                if (!this.priorityProcessWorker.awaitTermination(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES)) {
                    throw new TimeoutException(org.elasticsearch.common.Strings.format((String)"Timed out waiting for process worker to complete for process %s", (Object[])new Object[]{PROCESS_NAME}));
                }
                this.priorityProcessWorker.notifyQueueRunnables();
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                logger.info(org.elasticsearch.common.Strings.format((String)"[%s] Interrupted waiting for process worker to complete", (Object[])new Object[]{PROCESS_NAME}));
            }
        }

        private void stopProcessGracefully() {
            try {
                this.closeProcessIfPresent();
                this.resultProcessor.awaitCompletion(COMPLETION_TIMEOUT.getMinutes(), TimeUnit.MINUTES);
            }
            catch (TimeoutException e) {
                logger.warn(Strings.format((String)"[%s] Timed out waiting for results processor to stop", (Object[])new Object[]{this.task.getDeploymentId()}), (Throwable)e);
            }
        }

        private void closeProcessIfPresent() {
            try {
                if (this.process.get() == null) {
                    return;
                }
                ((PyTorchProcess)this.process.get()).close();
            }
            catch (IOException e) {
                logger.error(Strings.format((String)"[%s] Failed to stop process gracefully, attempting to kill it", (Object[])new Object[]{this.task.getDeploymentId()}), (Throwable)e);
                this.killProcessIfPresent();
            }
        }

        void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
            if (this.isStopped) {
                listener.onFailure((Exception)new IllegalArgumentException("Process has stopped, model loading canceled"));
                return;
            }
            if (modelLocation instanceof IndexLocation) {
                IndexLocation indexLocation = (IndexLocation)modelLocation;
                ((PyTorchProcess)this.process.get()).loadModel(this.task.getParams().getModelId(), indexLocation.getIndexName(), this.stateStreamer, (ActionListener<Boolean>)ActionListener.wrap(r -> DeploymentManager.this.executorServiceForDeployment.submit(() -> listener.onResponse(r)), e -> DeploymentManager.this.executorServiceForDeployment.submit(() -> listener.onFailure(e))));
            } else {
                listener.onFailure((Exception)new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]"));
            }
        }

        AtomicInteger getTimeoutCount() {
            return this.timeoutCount;
        }

        PriorityProcessWorkerExecutorService getPriorityProcessWorker() {
            return this.priorityProcessWorker;
        }

        AtomicInteger getRejectedExecutionCount() {
            return this.rejectedExecutionCount;
        }

        SetOnce<TrainedModelInput> getModelInput() {
            return this.modelInput;
        }

        SetOnce<PyTorchProcess> getProcess() {
            return this.process;
        }

        SetOnce<NlpTask.Processor> getNlpTaskProcessor() {
            return this.nlpTaskProcessor;
        }

        SetOnce<TrainedModelPrefixStrings> getPrefixStrings() {
            return this.prefixes;
        }
    }
}

