/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.auth;

import com.google.common.annotations.VisibleForTesting;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import org.apache.cassandra.auth.CIDRGroupsMappingTable;
import org.apache.cassandra.cql3.CIDR;

public class CIDRGroupsMappingIntervalTree<V>
implements CIDRGroupsMappingTable<V> {
    private final IPIntervalTree<V> tree;

    public CIDRGroupsMappingIntervalTree(boolean isIPv6, Map<CIDR, Set<V>> cidrMappings) {
        for (CIDR cidr : cidrMappings.keySet()) {
            if (isIPv6 == cidr.isIPv6()) continue;
            throw new IllegalArgumentException("Invalid CIDR format, expecting " + this.getIPTypeString(isIPv6) + ", received " + this.getIPTypeString(cidr.isIPv6()));
        }
        this.tree = IPIntervalTree.build(new ArrayList(cidrMappings.entrySet().stream().collect(Collectors.groupingBy(p -> ((CIDR)p.getKey()).getNetMask(), TreeMap::new, Collectors.toList())).descendingMap().values()));
    }

    @Override
    public Set<V> lookupLongestMatchForIP(InetAddress ip) {
        if (this.tree == null) {
            return Collections.emptySet();
        }
        return this.tree.query(ip);
    }

    static class IPIntervalTree<V> {
        private final IPIntervalNode<V>[] level0;
        private final int depth;

        private IPIntervalTree(IPIntervalNode<V>[] nodes, int depth) {
            this.level0 = nodes;
            this.depth = depth;
        }

        @VisibleForTesting
        int getDepth() {
            return this.depth;
        }

        private static <V> void optimizeLevels(List<Map.Entry<CIDR, V>> upperLevel, List<Map.Entry<CIDR, V>> lowerLevel) {
            ArrayList<Map.Entry<CIDR, V>> newUpper = new ArrayList<Map.Entry<CIDR, V>>(upperLevel.size() + lowerLevel.size());
            newUpper.addAll(upperLevel);
            ArrayList<Map.Entry<CIDR, V>> newLower = new ArrayList<Map.Entry<CIDR, V>>(lowerLevel.size());
            for (int i = 0; i < lowerLevel.size(); ++i) {
                boolean noOverlap = true;
                for (int j = 0; j < upperLevel.size(); ++j) {
                    if (!CIDR.overlaps(lowerLevel.get(i).getKey(), upperLevel.get(j).getKey())) continue;
                    newLower.add(lowerLevel.get(i));
                    noOverlap = false;
                    break;
                }
                if (!noOverlap) continue;
                newUpper.add(lowerLevel.get(i));
            }
            upperLevel.clear();
            lowerLevel.clear();
            upperLevel.addAll(newUpper);
            lowerLevel.addAll(newLower);
        }

        private static <V> void optimizeAllLevels(List<List<Map.Entry<CIDR, Set<V>>>> cidrsGroupedByNetMasks) {
            for (int i = 0; i < cidrsGroupedByNetMasks.size(); ++i) {
                List<Map.Entry<CIDR, V>> current = cidrsGroupedByNetMasks.get(0);
                for (int j = i + 1; j < cidrsGroupedByNetMasks.size(); ++j) {
                    List<Map.Entry<CIDR, V>> lower = cidrsGroupedByNetMasks.get(j);
                    IPIntervalTree.optimizeLevels(current, lower);
                }
            }
        }

        private static <V> void linkNodes(List<List<Map.Entry<CIDR, Set<V>>>> cidrMappings, IPIntervalNode<V>[][] result, int startIndex) {
            List<Map.Entry<CIDR, Set<V>>> cidrsAtLevel = cidrMappings.get(startIndex);
            int next = startIndex + 1;
            IPIntervalNode[] lowerLevel = next == result.length ? null : result[next];
            result[startIndex] = (IPIntervalNode[])cidrsAtLevel.stream().map(pair -> {
                CIDR cidr = (CIDR)pair.getKey();
                Set value = (Set)pair.getValue();
                IPIntervalNode node = new IPIntervalNode(cidr, value, lowerLevel);
                if (next + 1 < result.length && (node.left == null || node.right == null)) {
                    for (int i = next + 1; i < result.length; ++i) {
                        node.updateLeftIfNull(result[i]);
                        node.updateRightIfNull(result[i]);
                        if (node.left != null && node.right != null) break;
                    }
                }
                return node;
            }).sorted(Comparator.comparing(n -> n.cidr.getStartIpAddress(), CIDR::compareIPs)).toArray(IPIntervalNode[]::new);
        }

        public static <V> IPIntervalTree<V> build(List<List<Map.Entry<CIDR, Set<V>>>> cidrsGroupedByNetMasks) {
            if (cidrsGroupedByNetMasks.isEmpty()) {
                return null;
            }
            IPIntervalTree.optimizeAllLevels(cidrsGroupedByNetMasks);
            cidrsGroupedByNetMasks.removeIf(List::isEmpty);
            IPIntervalNode[][] result = new IPIntervalNode[cidrsGroupedByNetMasks.size()][];
            for (int i = cidrsGroupedByNetMasks.size() - 1; i >= 0; --i) {
                IPIntervalTree.linkNodes(cidrsGroupedByNetMasks, result, i);
            }
            return new IPIntervalTree<V>(result[0], cidrsGroupedByNetMasks.size());
        }

        public Set<V> query(InetAddress ip) {
            IPIntervalNode<V> closest = IPIntervalNode.binarySearchNodes(this.level0, ip);
            return IPIntervalNode.query(closest, ip);
        }
    }

    static class IPIntervalNode<V> {
        private final CIDR cidr;
        private final Set<V> values = new HashSet<V>();
        private IPIntervalNode<V>[] left;
        private IPIntervalNode<V>[] right;

        public IPIntervalNode(CIDR cidr, Set<V> values, IPIntervalNode<V>[] children) {
            this.cidr = cidr;
            if (values != null) {
                this.values.addAll(values);
            }
            this.updateChildren(children, true, true);
        }

        @VisibleForTesting
        CIDR cidr() {
            return this.cidr;
        }

        @VisibleForTesting
        IPIntervalNode<V>[] left() {
            return this.left;
        }

        @VisibleForTesting
        IPIntervalNode<V>[] right() {
            return this.right;
        }

        private void updateLeft(IPIntervalNode<V>[] newValue, boolean shouldUpdate) {
            if (shouldUpdate) {
                this.left = newValue;
            }
        }

        private void updateRight(IPIntervalNode<V>[] newValue, boolean shouldUpdate) {
            if (shouldUpdate) {
                this.right = newValue;
            }
        }

        private void updateChildren(IPIntervalNode<V>[] children, boolean updateLeft, boolean updateRight) {
            if (children == null) {
                this.updateLeft(null, updateLeft);
                this.updateRight(null, updateRight);
                return;
            }
            int index = IPIntervalNode.binarySearchNodesIndex(children, this.cidr.getStartIpAddress());
            IPIntervalNode<V> closest = children[index];
            if (index == 0 && CIDR.compareIPs(this.cidr.getEndIpAddress(), closest.cidr.getStartIpAddress()) < 0) {
                this.updateLeft(null, updateLeft);
                this.updateRight(children, updateRight);
            } else if (index == children.length - 1 && CIDR.compareIPs(this.cidr.getStartIpAddress(), closest.cidr.getEndIpAddress()) > 0) {
                this.updateLeft(children, updateLeft);
                this.updateRight(null, updateRight);
            } else if (CIDR.compareIPs(this.cidr.getStartIpAddress(), closest.cidr.getEndIpAddress()) > 0) {
                this.updateLeft(Arrays.copyOfRange(children, 0, index + 1), updateLeft);
                this.updateRight(Arrays.copyOfRange(children, index + 1, children.length), updateRight);
            } else {
                this.updateLeft(Arrays.copyOfRange(children, 0, index + 1), updateLeft);
                this.updateRight(Arrays.copyOfRange(children, index, children.length), updateRight);
            }
        }

        private void updateLeftIfNull(IPIntervalNode<V>[] children) {
            if (this.left != null) {
                return;
            }
            this.updateChildren(children, true, false);
        }

        private void updateRightIfNull(IPIntervalNode<V>[] children) {
            if (this.right != null) {
                return;
            }
            this.updateChildren(children, false, true);
        }

        static <V> int binarySearchNodesIndex(IPIntervalNode<V>[] nodes, InetAddress ip) {
            int start = 0;
            int end = nodes.length;
            while (start < end) {
                int mid = start + (end - start) / 2;
                IPIntervalNode<V> midNode = nodes[mid];
                int cmp = CIDR.compareIPs(ip, midNode.cidr.getStartIpAddress());
                if (cmp == 0) {
                    return mid;
                }
                if (cmp < 0) {
                    end = mid;
                    continue;
                }
                int compEnd = CIDR.compareIPs(ip, midNode.cidr.getEndIpAddress());
                if (compEnd <= 0) {
                    return mid;
                }
                start = mid + 1;
            }
            return Math.max(end - 1, 0);
        }

        static <V> IPIntervalNode<V> binarySearchNodes(IPIntervalNode<V>[] nodes, InetAddress ip) {
            int index = IPIntervalNode.binarySearchNodesIndex(nodes, ip);
            return nodes[index];
        }

        static <V> Set<V> query(IPIntervalNode<V> root, InetAddress ip) {
            IPIntervalNode<V> current = root;
            while (true) {
                IPIntervalNode<V>[] candidates;
                boolean lessThanEnd;
                boolean largerThanStart = CIDR.compareIPs(ip, current.cidr.getStartIpAddress()) >= 0;
                boolean bl = lessThanEnd = CIDR.compareIPs(ip, current.cidr.getEndIpAddress()) <= 0;
                if (largerThanStart && lessThanEnd) {
                    return current.values;
                }
                IPIntervalNode<V>[] iPIntervalNodeArray = candidates = largerThanStart ? current.right : current.left;
                if (candidates == null) {
                    return null;
                }
                current = IPIntervalNode.binarySearchNodes(candidates, ip);
            }
        }
    }
}

