/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.Tensor;

public class ParameterAveraging
implements StochasticGradientOptimiser {
    @Config(mandatory=true, description="Inner optimiser to average parameters across.")
    private StochasticGradientOptimiser optimiser;
    private int iterations = 0;
    private Tensor[] weights;
    private Parameters parameters;

    public ParameterAveraging(StochasticGradientOptimiser optimiser) {
        this.optimiser = optimiser;
    }

    private ParameterAveraging() {
    }

    @Override
    public void initialise(Parameters parameters) {
        this.optimiser.initialise(parameters);
        this.weights = parameters.getEmptyCopy();
        this.parameters = parameters;
    }

    @Override
    public Tensor[] step(Tensor[] updates, double weight) {
        ++this.iterations;
        Tensor[] output = this.optimiser.step(updates, weight);
        for (int i = 0; i < output.length; ++i) {
            this.weights[i].intersectAndAddInPlace(output[i], a -> a * (double)this.iterations);
        }
        return output;
    }

    @Override
    public void finalise() {
        Tensor[] tmp = this.parameters.get();
        for (int i = 0; i < tmp.length; ++i) {
            tmp[i].intersectAndAddInPlace(this.weights[i], a -> -a / (double)this.iterations);
        }
    }

    public String toString() {
        return "ParameterAveraging(optimiser=" + this.optimiser.toString() + ")";
    }

    @Override
    public void reset() {
        this.optimiser.reset();
        this.iterations = 0;
        this.weights = null;
    }

    @Override
    public ParameterAveraging copy() {
        return new ParameterAveraging(this.optimiser.copy());
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "StochasticGradientOptimiser");
    }
}

