/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.spark.bulkwriter;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Range;
import com.google.common.collect.RangeMap;
import com.google.common.collect.RangeSet;
import com.google.common.collect.TreeRangeMap;
import com.google.common.collect.TreeRangeSet;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.math.BigInteger;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.cassandra.spark.bulkwriter.BroadcastableTokenPartitioner;
import org.apache.cassandra.spark.bulkwriter.DecoratedKey;
import org.apache.cassandra.spark.bulkwriter.RingInstance;
import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
import org.apache.cassandra.spark.utils.RangeUtils;
import org.apache.spark.Partitioner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TokenPartitioner
extends Partitioner {
    private static final Logger LOGGER = LoggerFactory.getLogger(TokenPartitioner.class);
    private static final long serialVersionUID = -8787074052066841747L;
    private transient int nrPartitions;
    private transient RangeMap<BigInteger, Integer> partitionMap;
    private transient Map<Integer, Range<BigInteger>> reversePartitionMap;
    private final transient TokenRangeMapping<RingInstance> tokenRangeMapping;
    private final Integer numberSplits;

    public TokenPartitioner(TokenRangeMapping<RingInstance> tokenRangeMapping, Integer userSpecifiedNumberSplits, int defaultParallelism, Integer cores) {
        this(tokenRangeMapping, userSpecifiedNumberSplits, defaultParallelism, cores, true);
    }

    @VisibleForTesting
    public TokenPartitioner(TokenRangeMapping<RingInstance> tokenRangeMapping, Integer userSpecifiedNumberSplits, int defaultParallelism, Integer cores, boolean randomize) {
        this.tokenRangeMapping = tokenRangeMapping;
        this.numberSplits = this.calculateSplits(tokenRangeMapping, userSpecifiedNumberSplits, defaultParallelism, cores);
        this.setupTokenRangeMap(randomize);
        this.validate();
        this.logPartitionInfo();
    }

    private void logPartitionInfo() {
        LOGGER.info("Number of partitions: {}", (Object)this.nrPartitions);
        LOGGER.info("Partition map: {}", this.partitionMap);
        LOGGER.info("Reverse partition map: {}", this.reversePartitionMap);
    }

    public TokenPartitioner(BroadcastableTokenPartitioner broadcastable) {
        this.tokenRangeMapping = null;
        this.numberSplits = broadcastable.numSplits();
        this.partitionMap = TreeRangeMap.create();
        this.reversePartitionMap = new HashMap<Integer, Range<BigInteger>>();
        this.nrPartitions = 0;
        broadcastable.getPartitionEntries().forEach((range, partitionId) -> {
            this.partitionMap.put(range, partitionId);
            this.reversePartitionMap.put((Integer)partitionId, (Range<BigInteger>)range);
            if (partitionId >= this.nrPartitions) {
                this.nrPartitions = partitionId + 1;
            }
        });
        this.logPartitionInfo();
    }

    public int numPartitions() {
        return this.nrPartitions;
    }

    public int getPartition(Object key) {
        DecoratedKey decoratedKey = (DecoratedKey)key;
        Integer partition = (Integer)this.partitionMap.get((Comparable)decoratedKey.getToken());
        return partition == null ? 0 : partition;
    }

    public int numSplits() {
        return this.numberSplits;
    }

    public Range<BigInteger> getTokenRange(int partitionId) {
        return this.reversePartitionMap.get(partitionId);
    }

    private void setupTokenRangeMap(boolean randomize) {
        this.partitionMap = TreeRangeMap.create();
        this.reversePartitionMap = new HashMap<Integer, Range<BigInteger>>();
        AtomicInteger nextPartitionId = new AtomicInteger(0);
        List<Range> subRanges = this.tokenRangeMapping.getRangeMap().asMapOfRanges().keySet().stream().flatMap(tr -> RangeUtils.split((Range)tr, (int)this.numberSplits).stream()).collect(Collectors.toList());
        if (randomize) {
            Collections.shuffle(subRanges);
        }
        subRanges.forEach(tr -> {
            int partitionId = nextPartitionId.getAndIncrement();
            this.partitionMap.put(tr, (Object)partitionId);
            this.reversePartitionMap.put(partitionId, (Range<BigInteger>)tr);
        });
        this.nrPartitions = nextPartitionId.get();
    }

    private void validate() {
        this.validateMapSizes();
        this.validateCompleteRangeCoverage();
        this.validateRangesDoNotOverlap();
    }

    private void validateRangesDoNotOverlap() {
        List sortedRanges = this.partitionMap.asMapOfRanges().keySet().stream().sorted(Comparator.comparing(Range::lowerEndpoint)).collect(Collectors.toList());
        Range previous = null;
        for (Range current : sortedRanges) {
            if (previous != null) {
                Preconditions.checkState((!current.isConnected(previous) || current.intersection(previous).isEmpty() ? 1 : 0) != 0, (String)"Two ranges in partition map are overlapping %s %s", (Object[])new Object[]{previous, current});
            }
            previous = current;
        }
    }

    private void validateCompleteRangeCoverage() {
        TreeRangeSet missingRangeSet = TreeRangeSet.create();
        missingRangeSet.add(Range.closed((Comparable)this.tokenRangeMapping.partitioner().minToken(), (Comparable)this.tokenRangeMapping.partitioner().maxToken()));
        this.partitionMap.asMapOfRanges().keySet().forEach(arg_0 -> ((RangeSet)missingRangeSet).remove(arg_0));
        List missingRanges = missingRangeSet.asRanges().stream().filter(Range::isEmpty).collect(Collectors.toList());
        Preconditions.checkState((boolean)missingRanges.isEmpty(), (Object)("There should be no missing ranges, but found " + missingRanges.toString()));
    }

    private void validateMapSizes() {
        Preconditions.checkState((this.nrPartitions == this.partitionMap.asMapOfRanges().keySet().size() ? 1 : 0) != 0, (Object)String.format("Number of partitions %d not matching with partition map size %d", this.nrPartitions, this.partitionMap.asMapOfRanges().keySet().size()));
        Preconditions.checkState((this.nrPartitions == this.reversePartitionMap.keySet().size() ? 1 : 0) != 0, (Object)String.format("Number of partitions %d not matching with reverse partition map size %d", this.nrPartitions, this.reversePartitionMap.keySet().size()));
        Preconditions.checkState((this.nrPartitions >= this.tokenRangeMapping.getRangeMap().asMapOfRanges().keySet().size() ? 1 : 0) != 0, (Object)String.format("Number of partitions %d supposed to be more than number of token ranges %d", this.nrPartitions, this.tokenRangeMapping.getRangeMap().asMapOfRanges().keySet().size()));
        Preconditions.checkState((this.nrPartitions >= this.tokenRangeMapping.getTokenRanges().keySet().size() ? 1 : 0) != 0, (Object)String.format("Number of partitions %d supposed to be more than number of instances %d", this.nrPartitions, this.tokenRangeMapping.getTokenRanges().keySet().size()));
        Preconditions.checkState((this.partitionMap.asMapOfRanges().keySet().size() == this.reversePartitionMap.keySet().size() ? 1 : 0) != 0, (Object)String.format("You must be kidding me! Partition map %d and reverse map %d are not of same size", this.partitionMap.asMapOfRanges().keySet().size(), this.reversePartitionMap.keySet().size()));
    }

    public int calculateSplits(TokenRangeMapping<RingInstance> tokenRangeMapping, Integer numberSplits, int defaultParallelism, Integer cores) {
        if (numberSplits >= 0) {
            return numberSplits;
        }
        int tasksToRun = Math.max(cores, defaultParallelism);
        Map rangeListMap = tokenRangeMapping.getRangeMap().asMapOfRanges();
        LOGGER.info("Initial ranges: {}", (Object)rangeListMap);
        int ranges = rangeListMap.size();
        LOGGER.info("Number of ranges: {}", (Object)ranges);
        int calculatedSplits = this.divCeil(tasksToRun, ranges);
        LOGGER.info("Calculated number of splits as {}", (Object)calculatedSplits);
        return calculatedSplits;
    }

    int divCeil(int a, int b) {
        return (a + b - 1) / b;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
        Map partitionEntries = this.partitionMap.asMapOfRanges();
        out.writeInt(partitionEntries.size());
        for (Map.Entry entry : partitionEntries.entrySet()) {
            out.writeObject(entry.getKey());
            out.writeInt((Integer)entry.getValue());
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.partitionMap = TreeRangeMap.create();
        this.reversePartitionMap = new HashMap<Integer, Range<BigInteger>>();
        this.nrPartitions = 0;
        int size = in.readInt();
        for (int i = 0; i < size; ++i) {
            Range range = (Range)in.readObject();
            int partitionId = in.readInt();
            this.partitionMap.put(range, (Object)partitionId);
            this.reversePartitionMap.put(partitionId, (Range<BigInteger>)range);
            if (partitionId < this.nrPartitions) continue;
            this.nrPartitions = partitionId + 1;
        }
    }
}

