/*
 * Decompiled with CFR 0.152.
 */
package org.apache.jena.sparql.engine.join;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.apache.jena.atlas.io.IndentedWriter;
import org.apache.jena.atlas.io.Printable;
import org.apache.jena.atlas.iterator.Iter;
import org.apache.jena.sparql.core.Var;
import org.apache.jena.sparql.engine.binding.Binding;
import org.apache.jena.sparql.engine.binding.BindingFactory;
import org.apache.jena.sparql.engine.join.BitSetMapper;
import org.apache.jena.sparql.engine.join.HashProbeTable;
import org.apache.jena.sparql.engine.join.JoinKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class JoinIndex
implements Iterable<Binding>,
Printable {
    private static final Logger logger = LoggerFactory.getLogger(JoinIndex.class);
    private JoinKey superJoinKey;
    private BitSet mainJoinKeyBitSet;
    private HashProbeTable mainTable;
    private Map<BitSet, HashProbeTable> skewTables;

    public JoinIndex(JoinKey superJoinKey, BitSet mainJoinKeyBitSet, JoinKey mainJoinKey) {
        this.superJoinKey = Objects.requireNonNull(superJoinKey);
        this.mainJoinKeyBitSet = mainJoinKeyBitSet;
        if (mainJoinKey == null) {
            mainJoinKey = JoinKey.create(BitSetMapper.toList(superJoinKey, mainJoinKeyBitSet));
        }
        this.mainTable = new HashProbeTable(mainJoinKey);
    }

    public JoinKey getSuperJoinKey() {
        return this.superJoinKey;
    }

    public BitSet getMainJoinKeyBitSet() {
        return this.mainJoinKeyBitSet;
    }

    public HashProbeTable getMainTable() {
        return this.mainTable;
    }

    public Set<BitSet> getSkewKeys() {
        return this.skewTables == null ? Collections.emptySet() : this.skewTables.keySet();
    }

    public HashProbeTable getSkewTable(BitSet key) {
        return this.skewTables == null ? null : this.skewTables.get(key);
    }

    public Map<BitSet, HashProbeTable> getSkewTables() {
        return this.skewTables == null ? Collections.emptyMap() : this.skewTables;
    }

    public Map<JoinKey, HashProbeTable> getSkewTablesByJoinKey() {
        return this.getSkewTables().entrySet().stream().collect(Collectors.toMap(e2 -> JoinKey.create(BitSetMapper.toList(this.mainTable.getJoinKey(), (BitSet)e2.getKey())), Map.Entry::getValue));
    }

    public HashProbeTable getOrCreateSkewTable(BitSet tableKey) {
        if (this.skewTables == null) {
            this.skewTables = new LinkedHashMap<BitSet, HashProbeTable>();
        }
        HashProbeTable result = this.skewTables.computeIfAbsent(tableKey, key -> {
            List<Var> vars = BitSetMapper.toList(this.superJoinKey, key);
            return new HashProbeTable(JoinKey.create(vars));
        });
        return result;
    }

    public void put(Binding row) {
        BitSet rawRowKey = BitSetMapper.toBitSet(this.superJoinKey, row);
        BitSet effectiveRowKey = (BitSet)this.mainJoinKeyBitSet.clone();
        effectiveRowKey.and(rawRowKey);
        boolean isSameKey = effectiveRowKey.equals(this.mainJoinKeyBitSet);
        if (effectiveRowKey.isEmpty()) {
            this.mainTable.putNoKey(row);
        } else if (isSameKey) {
            this.mainTable.put(row);
        } else {
            HashProbeTable skewTable = this.getOrCreateSkewTable(effectiveRowKey);
            skewTable.put(row);
        }
    }

    public Iterator<Binding> getCandidates(Binding row) {
        if (logger.isTraceEnabled()) {
            BitSet joinKeyBitSet = BitSetMapper.toBitSet(this.superJoinKey, row);
            logger.trace("Lookup with " + String.valueOf(BitSetMapper.toList(this.superJoinKey, joinKeyBitSet)));
        }
        Iterator<Binding> it = this.getMainTable().getCandidates(row);
        if (this.skewTables != null) {
            for (Map.Entry<BitSet, HashProbeTable> entry : this.skewTables.entrySet()) {
                HashProbeTable skewTable = entry.getValue();
                Iterator<Binding> subIt = skewTable.getCandidates(row, false);
                if (logger.isTraceEnabled()) {
                    subIt = JoinIndex.printIteratorItems(subIt, "sub-iterator", logger::trace);
                }
                it = Iter.concat(it, subIt);
            }
        }
        if (logger.isTraceEnabled()) {
            it = JoinIndex.printIteratorItems(it, "Lookup result for " + String.valueOf(row), logger::trace);
        }
        return it;
    }

    @Override
    public Iterator<Binding> iterator() {
        return this.getCandidates(BindingFactory.empty());
    }

    public void clear() {
        this.mainTable.clear();
        if (this.skewTables != null) {
            this.skewTables.clear();
        }
    }

    public String toString() {
        return Printable.toString(this);
    }

    @Override
    public void output(IndentedWriter out) {
        out.ensureStartOfLine();
        out.println("JoinIndex " + String.valueOf(this.mainTable.getJoinKey()));
        out.incIndent();
        out.println("Main table: " + String.valueOf(this.mainTable));
        Map<BitSet, HashProbeTable> skewTables = this.getSkewTables();
        if (skewTables.isEmpty()) {
            out.println("Skew tables: none");
        } else {
            out.println("Skew tables");
            skewTables.values().forEach(table -> {
                out.incIndent();
                out.println("|- " + String.valueOf(table));
                out.decIndent();
            });
        }
        out.decIndent();
    }

    private static <T> Iterator<T> printIteratorItems(Iterator<T> it, String label, Consumer<String> logger) {
        ArrayList list = new ArrayList();
        it.forEachRemaining(list::add);
        if (label != null) {
            logger.accept(label + ": " + list.size() + " items");
        }
        for (Object item : list) {
            logger.accept("- " + String.valueOf(item));
        }
        return list.iterator();
    }
}

