/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.optimization;

import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.NumericOps;
import breeze.linalg.norm$;
import breeze.linalg.operators.HasOps$;
import breeze.storage.Zero$;
import java.io.Serializable;
import java.util.Map;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.LogEntry;
import org.apache.spark.internal.LogEntry$;
import org.apache.spark.internal.LogKey;
import org.apache.spark.internal.LogKeys;
import org.apache.spark.internal.Logging;
import org.apache.spark.internal.MDC;
import org.apache.spark.internal.MessageWithContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.Gradient;
import org.apache.spark.mllib.optimization.Updater;
import org.apache.spark.rdd.RDD;
import org.slf4j.Logger;
import org.slf4j.event.Level;
import scala.Function0;
import scala.Function2;
import scala.MatchError;
import scala.None$;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.IterableOnceOps;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Seq;
import scala.collection.mutable.ArrayBuffer;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ModuleSerializationProxy;
import scala.runtime.ScalaRunTime$;

public final class GradientDescent$
implements Logging,
Serializable {
    public static final GradientDescent$ MODULE$ = new GradientDescent$();
    private static transient Logger org$apache$spark$internal$Logging$$log_;

    static {
        Logging.$init$((Logging)MODULE$);
    }

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public Logging.LogStringContext LogStringContext(StringContext sc) {
        return Logging.LogStringContext$((Logging)this, (StringContext)sc);
    }

    public void withLogContext(Map<String, String> context, Function0<BoxedUnit> body) {
        Logging.withLogContext$((Logging)this, context, body);
    }

    public MDC MDC(LogKey key, Object value) {
        return Logging.MDC$((Logging)this, (LogKey)key, (Object)value);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logInfo(LogEntry entry) {
        Logging.logInfo$((Logging)this, (LogEntry)entry);
    }

    public void logInfo(LogEntry entry, Throwable throwable) {
        Logging.logInfo$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logDebug(LogEntry entry) {
        Logging.logDebug$((Logging)this, (LogEntry)entry);
    }

    public void logDebug(LogEntry entry, Throwable throwable) {
        Logging.logDebug$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logTrace(LogEntry entry) {
        Logging.logTrace$((Logging)this, (LogEntry)entry);
    }

    public void logTrace(LogEntry entry, Throwable throwable) {
        Logging.logTrace$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logWarning(LogEntry entry) {
        Logging.logWarning$((Logging)this, (LogEntry)entry);
    }

    public void logWarning(LogEntry entry, Throwable throwable) {
        Logging.logWarning$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logError(LogEntry entry) {
        Logging.logError$((Logging)this, (LogEntry)entry);
    }

    public void logError(LogEntry entry, Throwable throwable) {
        Logging.logError$((Logging)this, (LogEntry)entry, (Throwable)throwable);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void logBasedOnLevel(Level level, Function0<MessageWithContext> f) {
        Logging.logBasedOnLevel$((Logging)this, (Level)level, f);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        org$apache$spark$internal$Logging$$log_ = x$1;
    }

    public Tuple2<Vector, double[]> runMiniBatchSGD(RDD<Tuple2<Object, Vector>> data, Gradient gradient, Updater updater, double stepSize, int numIterations, double regParam, double miniBatchFraction, Vector initialWeights, double convergenceTol) {
        if (miniBatchFraction < 1.0 && convergenceTol > 0.0) {
            this.logWarning((Function0<String>)(Function0 & Serializable)() -> "Testing against a convergenceTol when using miniBatchFraction < 1.0 can be unstable because of the stochasticity in sampling.");
        }
        if ((double)numIterations * miniBatchFraction < 1.0) {
            this.logWarning(LogEntry$.MODULE$.from((Function0 & Serializable)() -> MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"Not all examples will be used if numIterations * miniBatchFraction < 1.0: "}))).log((Seq)Nil$.MODULE$).$plus(MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"numIterations=", " and "}))).log((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new MDC[]{MODULE$.MDC((LogKey)LogKeys.NUM_ITERATIONS, BoxesRunTime.boxToInteger((int)numIterations))}))).$plus(MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"miniBatchFraction=", ""}))).log((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new MDC[]{MODULE$.MDC((LogKey)LogKeys.MINI_BATCH_FRACTION, BoxesRunTime.boxToDouble((double)miniBatchFraction))})))));
        }
        ArrayBuffer stochasticLossHistory = new ArrayBuffer(numIterations + 1);
        None$ previousWeights = None$.MODULE$;
        None$ currentWeights = None$.MODULE$;
        long numExamples = data.count();
        if (numExamples == 0L) {
            this.logWarning((Function0<String>)(Function0 & Serializable)() -> "GradientDescent.runMiniBatchSGD returning initial weights, no data found");
            return new Tuple2((Object)initialWeights, stochasticLossHistory.toArray((ClassTag)ClassTag$.MODULE$.Double()));
        }
        if ((double)numExamples * miniBatchFraction < 1.0) {
            this.logWarning((Function0<String>)(Function0 & Serializable)() -> "The miniBatchFraction is too small");
        }
        Vector weights = Vectors$.MODULE$.dense(initialWeights.toArray());
        int n = weights.size();
        double regVal = updater.compute(weights, Vectors$.MODULE$.zeros(weights.size()), 0.0, 1, regParam)._2$mcD$sp();
        boolean converged = false;
        IntRef i = IntRef.create((int)1);
        while (!converged && i.elem <= numIterations + 1) {
            Broadcast bcWeights = data.context().broadcast((Object)weights, ClassTag$.MODULE$.apply(Vector.class));
            Tuple3 tuple3 = (Tuple3)data.sample(false, miniBatchFraction, (long)(42 + i.elem)).treeAggregate((Object)new Tuple3(null, (Object)BoxesRunTime.boxToDouble((double)0.0), (Object)BoxesRunTime.boxToLong((long)0L)), (Function2 & Serializable)(c, v) -> {
                DenseVector vec = c._1() == null ? DenseVector$.MODULE$.zeros$mDc$sp(n, (ClassTag)ClassTag$.MODULE$.Double(), Zero$.MODULE$.DoubleZero()) : (DenseVector)c._1();
                double l = gradient.compute((Vector)v._2(), v._1$mcD$sp(), (Vector)bcWeights.value(), Vectors$.MODULE$.fromBreeze((breeze.linalg.Vector<Object>)vec));
                return new Tuple3((Object)vec, (Object)BoxesRunTime.boxToDouble((double)(BoxesRunTime.unboxToDouble((Object)c._2()) + l)), (Object)BoxesRunTime.boxToLong((long)(BoxesRunTime.unboxToLong((Object)c._3()) + 1L)));
            }, (Function2 & Serializable)(c1, c2) -> {
                DenseVector denseVector;
                if (c1._1() == null) {
                    denseVector = (DenseVector)c2._1();
                } else if (c2._1() == null) {
                    denseVector = (DenseVector)c1._1();
                } else {
                    ((NumericOps)c1._1()).$plus$eq(c2._1(), HasOps$.MODULE$.impl_OpAdd_InPlace_DV_DV_Double());
                    denseVector = (DenseVector)c1._1();
                }
                DenseVector vec = denseVector;
                return new Tuple3((Object)vec, (Object)BoxesRunTime.boxToDouble((double)(BoxesRunTime.unboxToDouble((Object)c1._2()) + BoxesRunTime.unboxToDouble((Object)c2._2()))), (Object)BoxesRunTime.boxToLong((long)(BoxesRunTime.unboxToLong((Object)c1._3()) + BoxesRunTime.unboxToLong((Object)c2._3()))));
            }, 2, true, ClassTag$.MODULE$.apply(Tuple3.class));
            if (tuple3 == null) {
                throw new MatchError((Object)tuple3);
            }
            DenseVector gradientSum = (DenseVector)tuple3._1();
            double lossSum = BoxesRunTime.unboxToDouble((Object)tuple3._2());
            long miniBatchSize = BoxesRunTime.unboxToLong((Object)tuple3._3());
            Tuple3 tuple32 = new Tuple3((Object)gradientSum, (Object)BoxesRunTime.boxToDouble((double)lossSum), (Object)BoxesRunTime.boxToLong((long)miniBatchSize));
            DenseVector gradientSum2 = (DenseVector)tuple32._1();
            double lossSum2 = BoxesRunTime.unboxToDouble((Object)tuple32._2());
            long miniBatchSize2 = BoxesRunTime.unboxToLong((Object)tuple32._3());
            bcWeights.destroy();
            if (miniBatchSize2 > 0L) {
                stochasticLossHistory.$plus$eq((Object)BoxesRunTime.boxToDouble((double)(lossSum2 / (double)miniBatchSize2 + regVal)));
                if (i.elem != numIterations + 1) {
                    Tuple2<Vector, Object> update = updater.compute(weights, Vectors$.MODULE$.fromBreeze((breeze.linalg.Vector<Object>)((breeze.linalg.Vector)gradientSum2.$div((Object)BoxesRunTime.boxToDouble((double)miniBatchSize2), HasOps$.MODULE$.impl_Op_DV_S_eq_DV_Double_OpDiv()))), stepSize, i.elem, regParam);
                    weights = (Vector)update._1();
                    regVal = update._2$mcD$sp();
                    previousWeights = currentWeights;
                    currentWeights = new Some((Object)weights);
                    None$ none$ = previousWeights;
                    None$ none$2 = None$.MODULE$;
                    if (none$ == null ? none$2 != null : !none$.equals(none$2)) {
                        None$ none$3 = currentWeights;
                        None$ none$4 = None$.MODULE$;
                        if (none$3 == null ? none$4 != null : !none$3.equals(none$4)) {
                            converged = this.isConverged((Vector)previousWeights.get(), (Vector)currentWeights.get(), convergenceTol);
                        }
                    }
                }
            } else {
                this.logWarning(LogEntry$.MODULE$.from((Function0 & Serializable)() -> MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"Iteration "}))).log((Seq)Nil$.MODULE$).$plus(MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"(", "/", "). "}))).log((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new MDC[]{MODULE$.MDC((LogKey)LogKeys.INDEX, BoxesRunTime.boxToInteger((int)i$1.elem)), MODULE$.MDC((LogKey)LogKeys.NUM_ITERATIONS, BoxesRunTime.boxToInteger((int)numIterations))}))).$plus(MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"The size of sampled batch is zero"}))).log((Seq)Nil$.MODULE$))));
            }
            ++i.elem;
        }
        this.logInfo(LogEntry$.MODULE$.from((Function0 & Serializable)() -> MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses "}))).log((Seq)Nil$.MODULE$).$plus(MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"", ""}))).log((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new MDC[]{MODULE$.MDC((LogKey)LogKeys.LOSSES, ((IterableOnceOps)stochasticLossHistory.takeRight(10)).mkString(", "))})))));
        return new Tuple2((Object)weights, stochasticLossHistory.toArray((ClassTag)ClassTag$.MODULE$.Double()));
    }

    public Tuple2<Vector, double[]> runMiniBatchSGD(RDD<Tuple2<Object, Vector>> data, Gradient gradient, Updater updater, double stepSize, int numIterations, double regParam, double miniBatchFraction, Vector initialWeights) {
        return this.runMiniBatchSGD(data, gradient, updater, stepSize, numIterations, regParam, miniBatchFraction, initialWeights, 0.001);
    }

    private boolean isConverged(Vector previousWeights, Vector currentWeights, double convergenceTol) {
        DenseVector currentBDV;
        DenseVector previousBDV = previousWeights.asBreeze().toDenseVector$mcD$sp((ClassTag)ClassTag$.MODULE$.Double());
        double solutionVecDiff = BoxesRunTime.unboxToDouble((Object)norm$.MODULE$.apply(previousBDV.$minus((Object)(currentBDV = currentWeights.asBreeze().toDenseVector$mcD$sp((ClassTag)ClassTag$.MODULE$.Double())), HasOps$.MODULE$.impl_OpSub_DV_DV_eq_DV_Double()), norm$.MODULE$.normDoubleToNormalNorm(norm$.MODULE$.canNorm(HasOps$.MODULE$.DV_canIterateValues(), norm$.MODULE$.scalarNorm_Double()))));
        return solutionVecDiff < convergenceTol * Math.max(BoxesRunTime.unboxToDouble((Object)norm$.MODULE$.apply((Object)currentBDV, norm$.MODULE$.normDoubleToNormalNorm(norm$.MODULE$.canNorm(HasOps$.MODULE$.DV_canIterateValues(), norm$.MODULE$.scalarNorm_Double())))), 1.0);
    }

    private Object writeReplace() {
        return new ModuleSerializationProxy(GradientDescent$.class);
    }

    private GradientDescent$() {
    }
}

