/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.analytics;

import com.datastax.driver.core.utils.UUIDs;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.sql.Date;
import java.sql.Timestamp;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.cassandra.analytics.SharedClusterSparkIntegrationTestBase;
import org.apache.cassandra.bridge.type.InternalDuration;
import org.apache.cassandra.sidecar.testing.QualifiedName;
import org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter;
import org.apache.cassandra.spark.utils.ByteBufferUtils;
import org.apache.cassandra.spark.utils.ScalaConversionUtils;
import org.apache.cassandra.spark.utils.SparkTypeUtils;
import org.apache.cassandra.testing.TestUtils;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import scala.collection.mutable.Seq;

class BulkWriteDataTypesTest
extends SharedClusterSparkIntegrationTestBase {
    static final Function<Integer, Object> INTEGER_MAPPER = recordNumber -> recordNumber;
    static final Function<Integer, Object> INTEGER_ARRAY_MAPPER = recordNumber -> Arrays.asList(recordNumber, recordNumber, recordNumber);
    static final Function<Integer, Object> BYTE_ARRAY_MAPPER = recordNumber -> Arrays.asList(("course" + recordNumber).getBytes(StandardCharsets.UTF_8), ("course" + recordNumber).getBytes(StandardCharsets.UTF_8));
    static final Function<Integer, Object> STRING_SET_MAPPER = recordNumber -> ImmutableSet.of((Object)String.format("course%06d", recordNumber), (Object)String.format("course%06d", recordNumber + 1));
    static final Function<Integer, Object> STRING_BINARY_MAP_MAPPER = recordNumber -> ImmutableMap.of((Object)String.format("course%06d", recordNumber), (Object)("course" + recordNumber).getBytes(StandardCharsets.UTF_8), (Object)String.format("course%06d", recordNumber + 1), (Object)("course" + (recordNumber + 1)).getBytes(StandardCharsets.UTF_8));
    static final Function<Integer, Object> STRING_ARRAY_INTEGER_MAP_MAPPER = recordNumber -> ImmutableMap.of((Object)String.format("course%06d", recordNumber), Collections.singletonList(recordNumber), (Object)String.format("course%06d", recordNumber + 1), Collections.singletonList(recordNumber + 1));
    static final Function<Integer, Object> STRING_ARRAY_BINARY_MAP_MAPPER = recordNumber -> ImmutableMap.of((Object)String.format("course%06d", recordNumber), Collections.singletonList(("course" + recordNumber).getBytes(StandardCharsets.UTF_8)), (Object)String.format("course%06d", recordNumber + 1), Collections.singletonList(("course" + (recordNumber + 1)).getBytes(StandardCharsets.UTF_8)));
    static final Function<Integer, Object> LONG_MAPPER = recordNumber -> (long)recordNumber.intValue();
    static final Function<Integer, Object> STRING_MAPPER = recordNumber -> "course" + recordNumber;
    static final Function<Integer, Object> BINARY_MAPPER = recordNumber -> ("course" + recordNumber).getBytes(StandardCharsets.UTF_8);
    static final Function<Integer, Object> TIME_UUID_MAPPER = recordNumber -> UUIDs.timeBased().toString();
    static final Function<Integer, Object> RANDOM_UUID_MAPPER = recordNumber -> UUID.randomUUID().toString();
    static final Function<Integer, Object> TIMESTAMP_MAPPER = recordNumber -> Timestamp.from(new java.util.Date(1731457509115L).toInstant().plus((long)recordNumber.intValue(), ChronoUnit.SECONDS));
    static final Function<Integer, Object> DATE_MAPPER = recordNumber -> Date.valueOf(((Timestamp)TIMESTAMP_MAPPER.apply((Integer)recordNumber)).toLocalDateTime().toLocalDate());
    static final Function<Integer, Object> DURATION_MAPPER = recordNumber -> SparkTypeUtils.convertDuration((InternalDuration)new InternalDuration(1, recordNumber.intValue(), (long)(recordNumber * 1000000000)));

    BulkWriteDataTypesTest() {
    }

    @ParameterizedTest(name="{index} => {0}")
    @MethodSource(value={"testArguments"})
    void testType(String tableName, TypeTestSetup typeTestSetup) {
        SparkSession spark = this.getOrCreateSparkSession();
        Dataset<Row> df = this.generateDataset(spark, typeTestSetup);
        QualifiedName table = new QualifiedName("spark_test", tableName);
        if (typeTestSetup.expectedFailureMessage != null) {
            Assertions.assertThatException().isThrownBy(() -> this.bulkWriterDataFrameWriter(df, table).save()).withMessageContaining(typeTestSetup.expectedFailureMessage);
        } else {
            this.bulkWriterDataFrameWriter(df, table).save();
            this.sparkTestUtils.validateWrites(df.collectAsList(), this.queryAllData(table), typeTestSetup.columnMapperValidation, typeTestSetup.rowMapperValidation);
        }
    }

    Dataset<Row> generateDataset(SparkSession spark, TypeTestSetup typeTestSetup) {
        StructType schema = new StructType();
        for (int i = 0; i < typeTestSetup.columns.size(); ++i) {
            schema = schema.add(typeTestSetup.columns.get(i), typeTestSetup.columnTypes.get(i), false);
        }
        List rows = IntStream.range(0, typeTestSetup.numRows).mapToObj(recordNum -> {
            ArrayList<Object> values = new ArrayList<Object>(typeTestSetup.columns.size());
            for (Function<Integer, Object> fn : typeTestSetup.valueFunction) {
                values.add(fn.apply(recordNum));
            }
            return RowFactory.create((Object[])values.toArray());
        }).collect(Collectors.toList());
        return spark.createDataFrame(rows, schema);
    }

    protected void initializeSchemaForTest() {
        this.createTestKeyspace("spark_test", TestUtils.DC1_RF1);
        BulkWriteDataTypesTest.typesToTest().forEach(typeTestSetup -> {
            QualifiedName tableName = new QualifiedName("spark_test", typeTestSetup.tableName);
            this.createTestTable(tableName, typeTestSetup.createTableSchema);
        });
    }

    static Stream<Arguments> testArguments() {
        return BulkWriteDataTypesTest.typesToTest().stream().map(typeTestSetup -> Arguments.of((Object[])new Object[]{typeTestSetup.tableName, typeTestSetup}));
    }

    static List<TypeTestSetup> typesToTest() {
        ArrayList<TypeTestSetup> types = new ArrayList<TypeTestSetup>();
        types.add(BulkWriteDataTypesTest.simpleBigIntSchemaSetup());
        types.add(BulkWriteDataTypesTest.simpleDateSchemaSetup());
        types.add(BulkWriteDataTypesTest.simpleDurationSchemaSetup());
        types.add(BulkWriteDataTypesTest.integersAndStringsSchemaSetup());
        types.add(BulkWriteDataTypesTest.timeUUIDSchemaSetup());
        types.add(BulkWriteDataTypesTest.randomUUIDFailureSchemaSetup());
        types.add(BulkWriteDataTypesTest.byteArrayColumnSetup());
        types.add(BulkWriteDataTypesTest.intListSchemaSetup());
        types.add(BulkWriteDataTypesTest.byteListSchemaSetup());
        types.add(BulkWriteDataTypesTest.stringSetSchemaSetup());
        types.add(BulkWriteDataTypesTest.mapByteSchemaSetup());
        types.add(BulkWriteDataTypesTest.mapListSchemaSetup());
        types.add(BulkWriteDataTypesTest.nestedDataTypesSchemaSetup());
        types.add(BulkWriteDataTypesTest.timeSchemaTimestampSetup());
        types.add(BulkWriteDataTypesTest.timeWithLongSourceSchemaSetup());
        types.add(BulkWriteDataTypesTest.customDateTypeSchemaSetup());
        return types;
    }

    static TypeTestSetup simpleBigIntSchemaSetup() {
        return new TypeTestSetup("bigint_schema", Arrays.asList("id", "marks"), Arrays.asList(DataTypes.IntegerType, DataTypes.LongType), Arrays.asList(INTEGER_MAPPER, LONG_MAPPER), "CREATE TABLE %s (id int, marks bigint, PRIMARY KEY (id))");
    }

    static TypeTestSetup simpleDateSchemaSetup() {
        TypeTestSetup setup = new TypeTestSetup("date_schema", Arrays.asList("id", "course"), Arrays.asList(DataTypes.IntegerType, DataTypes.DateType), Arrays.asList(INTEGER_MAPPER, DATE_MAPPER), "CREATE TABLE %s (id int, course date, PRIMARY KEY (id))");
        setup.rowMapperValidation = row -> String.format("%s:%s", row.get(0), SqlToCqlTypeConverter.DATE_CONVERTER.convertInternal(row.get(1)));
        return setup;
    }

    static TypeTestSetup simpleDurationSchemaSetup() {
        return new TypeTestSetup("duration_schema", Arrays.asList("id", "took"), Arrays.asList(DataTypes.IntegerType, DataTypes.CalendarIntervalType), Arrays.asList(INTEGER_MAPPER, DURATION_MAPPER), "CREATE TABLE %s (id int, took duration, PRIMARY KEY (id))", "Cannot save interval data type into external storage.");
    }

    static TypeTestSetup integersAndStringsSchemaSetup() {
        return new TypeTestSetup("simple_schema", Arrays.asList("id", "course", "marks"), Arrays.asList(DataTypes.IntegerType, DataTypes.StringType, DataTypes.IntegerType), Arrays.asList(INTEGER_MAPPER, STRING_MAPPER, INTEGER_MAPPER), "CREATE TABLE %s (id int, course text, marks int, PRIMARY KEY (id))");
    }

    static TypeTestSetup timeUUIDSchemaSetup() {
        return new TypeTestSetup("timeuuid_schema", Arrays.asList("id", "course"), Arrays.asList(DataTypes.IntegerType, DataTypes.StringType), Arrays.asList(INTEGER_MAPPER, TIME_UUID_MAPPER), "CREATE TABLE %s (id int, course timeuuid, PRIMARY KEY (id))");
    }

    static TypeTestSetup randomUUIDFailureSchemaSetup() {
        return new TypeTestSetup("timeuuid_schema_bad_uuid", Arrays.asList("id", "course"), Arrays.asList(DataTypes.IntegerType, DataTypes.StringType), Arrays.asList(INTEGER_MAPPER, RANDOM_UUID_MAPPER), "CREATE TABLE %s (id int, course timeuuid, PRIMARY KEY (id))", "Bulk Write to Cassandra has failed");
    }

    static TypeTestSetup byteArrayColumnSetup() {
        TypeTestSetup setup = new TypeTestSetup("bytearray_column", Arrays.asList("id", "binarydata"), Arrays.asList(DataTypes.IntegerType, DataTypes.BinaryType), Arrays.asList(INTEGER_MAPPER, BINARY_MAPPER), "CREATE TABLE %s (id int, binarydata blob, PRIMARY KEY (id))");
        setup.rowMapperValidation = row -> String.format("%s:%s", row.get(0), new String((byte[])row.get(1), StandardCharsets.UTF_8));
        setup.columnMapperValidation = columns -> {
            String col1 = new String(ByteBufferUtils.getArray((ByteBuffer)((ByteBuffer)columns[1])), StandardCharsets.UTF_8);
            return String.format("%s:%s", columns[0], col1);
        };
        return setup;
    }

    static TypeTestSetup intListSchemaSetup() {
        TypeTestSetup setup = new TypeTestSetup("list_column", Arrays.asList("id", "listdata"), Arrays.asList(DataTypes.IntegerType, DataTypes.createArrayType((DataType)DataTypes.IntegerType)), Arrays.asList(INTEGER_MAPPER, INTEGER_ARRAY_MAPPER), "CREATE TABLE %s (id int, listdata LIST<int>, PRIMARY KEY (id))");
        setup.rowMapperValidation = row -> String.format("%s:%s", row.get(0), row.getList(1));
        return setup;
    }

    static TypeTestSetup byteListSchemaSetup() {
        TypeTestSetup setup = new TypeTestSetup("byte_list_column", Arrays.asList("id", "listdata"), Arrays.asList(DataTypes.IntegerType, DataTypes.createArrayType((DataType)DataTypes.BinaryType)), Arrays.asList(INTEGER_MAPPER, BYTE_ARRAY_MAPPER), "CREATE TABLE %s (id int, listdata LIST<blob>, PRIMARY KEY (id))");
        setup.rowMapperValidation = row -> {
            List byteList = row.getList(1);
            return String.format("%s:%s", row.get(0), byteList.stream().map(b -> new String((byte[])b, StandardCharsets.UTF_8)).collect(Collectors.toList()));
        };
        setup.columnMapperValidation = columns -> {
            List byteBufferList = (List)columns[1];
            return String.format("%s:%s", columns[0], byteBufferList.stream().map(b -> new String(ByteBufferUtils.getArray((ByteBuffer)b), StandardCharsets.UTF_8)).collect(Collectors.toList()));
        };
        return setup;
    }

    static TypeTestSetup stringSetSchemaSetup() {
        TypeTestSetup setup = new TypeTestSetup("set_list_column", Arrays.asList("id", "setdata"), Arrays.asList(DataTypes.IntegerType, DataTypes.createArrayType((DataType)DataTypes.StringType)), Arrays.asList(INTEGER_MAPPER, STRING_SET_MAPPER), "CREATE TABLE %s (id int, setdata set<text>, PRIMARY KEY (id))");
        setup.rowMapperValidation = row -> String.format("%s:%s", row.get(0), row.getList(1));
        return setup;
    }

    static TypeTestSetup mapByteSchemaSetup() {
        TypeTestSetup setup = new TypeTestSetup("map_byte_column", Arrays.asList("id", "mapdata"), Arrays.asList(DataTypes.IntegerType, DataTypes.createMapType((DataType)DataTypes.StringType, (DataType)DataTypes.BinaryType)), Arrays.asList(INTEGER_MAPPER, STRING_BINARY_MAP_MAPPER), "CREATE TABLE %s (id int, mapdata frozen<map<text,blob>>, PRIMARY KEY (id, mapdata))");
        setup.rowMapperValidation = row -> {
            Map map = row.getJavaMap(1);
            String value = map.entrySet().stream().map(entry -> String.format("%s=%s", entry.getKey(), new String((byte[])entry.getValue(), StandardCharsets.UTF_8))).collect(Collectors.joining(", ", "[", "]"));
            return String.format("%s:%s", row.get(0), value);
        };
        setup.columnMapperValidation = columns -> {
            Map map = (Map)columns[1];
            String value = map.entrySet().stream().map(entry -> String.format("%s=%s", entry.getKey(), new String(ByteBufferUtils.getArray((ByteBuffer)((ByteBuffer)entry.getValue())), StandardCharsets.UTF_8))).collect(Collectors.joining(", ", "[", "]"));
            return String.format("%s:%s", columns[0], value);
        };
        return setup;
    }

    static TypeTestSetup mapListSchemaSetup() {
        TypeTestSetup setup = new TypeTestSetup("map_list_column", Arrays.asList("id", "mapdata"), Arrays.asList(DataTypes.IntegerType, DataTypes.createMapType((DataType)DataTypes.StringType, (DataType)DataTypes.createArrayType((DataType)DataTypes.IntegerType))), Arrays.asList(INTEGER_MAPPER, STRING_ARRAY_INTEGER_MAP_MAPPER), "CREATE TABLE %s (id int, mapdata map<text,frozen<list<int>>>, PRIMARY KEY (id))");
        setup.rowMapperValidation = row -> {
            Map map = row.getJavaMap(1);
            Map value = map.entrySet().stream().sorted(Comparator.comparing(e -> (String)e.getKey())).collect(Collectors.toMap(e -> (String)e.getKey(), e -> ScalaConversionUtils.mutableSeqAsJavaList((Seq)((Seq)e.getValue())), (x, y) -> y, LinkedHashMap::new));
            return String.format("%s:%s", row.get(0), value);
        };
        return setup;
    }

    static TypeTestSetup nestedDataTypesSchemaSetup() {
        TypeTestSetup setup = new TypeTestSetup("map_list_byte_column", Arrays.asList("id", "mapdata"), Arrays.asList(DataTypes.IntegerType, DataTypes.createMapType((DataType)DataTypes.StringType, (DataType)DataTypes.createArrayType((DataType)DataTypes.BinaryType))), Arrays.asList(INTEGER_MAPPER, STRING_ARRAY_BINARY_MAP_MAPPER), "CREATE TABLE %s (id int, mapdata map<text,frozen<list<blob>>>, PRIMARY KEY (id))");
        setup.rowMapperValidation = row -> {
            Map map = row.getJavaMap(1);
            Function<byte[], String> unwrapBytes = b -> new String((byte[])b, StandardCharsets.UTF_8);
            String value = map.entrySet().stream().map(entry -> String.format("%s=%s", entry.getKey(), ScalaConversionUtils.mutableSeqAsJavaList((Seq)((Seq)entry.getValue())).stream().map(unwrapBytes).collect(Collectors.toList()))).collect(Collectors.joining(", ", "[", "]"));
            return String.format("%s:%s", row.get(0), value);
        };
        setup.columnMapperValidation = columns -> {
            Map map = (Map)columns[1];
            Function<ByteBuffer, String> unwrapBytes = b -> new String(ByteBufferUtils.getArray((ByteBuffer)b), StandardCharsets.UTF_8);
            String value = map.entrySet().stream().map(entry -> String.format("%s=%s", entry.getKey(), ((List)entry.getValue()).stream().map(unwrapBytes).collect(Collectors.toList()))).collect(Collectors.joining(", ", "[", "]"));
            return String.format("%s:%s", columns[0], value);
        };
        return setup;
    }

    static TypeTestSetup timeSchemaTimestampSetup() {
        TypeTestSetup setup = new TypeTestSetup("time_schema_timestamp", Arrays.asList("id", "course"), Arrays.asList(DataTypes.IntegerType, DataTypes.TimestampType), Arrays.asList(INTEGER_MAPPER, TIMESTAMP_MAPPER), "CREATE TABLE %s (id int, course time, PRIMARY KEY (id))");
        setup.rowMapperValidation = row -> String.format("%s:%s", row.get(0), SqlToCqlTypeConverter.TIME_CONVERTER.convertInternal(row.get(1)));
        return setup;
    }

    static TypeTestSetup timeWithLongSourceSchemaSetup() {
        return new TypeTestSetup("time_schema_long", Arrays.asList("id", "course"), Arrays.asList(DataTypes.IntegerType, DataTypes.LongType), Arrays.asList(INTEGER_MAPPER, LONG_MAPPER), "CREATE TABLE %s (id int, course time, PRIMARY KEY (id))");
    }

    static TypeTestSetup customDateTypeSchemaSetup() {
        TypeTestSetup setup = new TypeTestSetup("c_12_ts_as_custom_schema_timestamp", Arrays.asList("id", "course"), Arrays.asList(DataTypes.IntegerType, DataTypes.LongType), Arrays.asList(INTEGER_MAPPER, LONG_MAPPER), "CREATE TABLE %s (id int, course 'org.apache.cassandra.db.marshal.DateType', PRIMARY KEY (id))");
        setup.rowMapperValidation = row -> String.format("%s:%s", row.get(0), SqlToCqlTypeConverter.TIMESTAMP_CONVERTER.convertInternal(row.get(1)));
        return setup;
    }

    static class TypeTestSetup {
        final String tableName;
        final List<String> columns;
        final List<DataType> columnTypes;
        final List<Function<Integer, Object>> valueFunction;
        final String createTableSchema;
        final String expectedFailureMessage;
        final int numRows = 10000;
        Function<Object[], String> columnMapperValidation = columns -> String.format(String.join((CharSequence)":", Collections.nCopies(((Object[])columns).length, "%s")), columns);
        Function<Row, String> rowMapperValidation = row -> {
            int size = row.size();
            Object[] data = new Object[size];
            for (int i = 0; i < size; ++i) {
                data[i] = row.get(i);
            }
            return String.format(String.join((CharSequence)":", Collections.nCopies(size, "%s")), data);
        };

        TypeTestSetup(String tableName, List<String> columns2, List<DataType> columnTypes, List<Function<Integer, Object>> valueFunction, String createTableSchema) {
            this.tableName = tableName;
            this.columns = columns2;
            this.columnTypes = columnTypes;
            this.valueFunction = valueFunction;
            this.createTableSchema = createTableSchema;
            this.expectedFailureMessage = null;
        }

        TypeTestSetup(String tableName, List<String> columns2, List<DataType> columnTypes, List<Function<Integer, Object>> valueFunction, String createTableSchema, String expectedFailureMessage) {
            this.tableName = tableName;
            this.columns = columns2;
            this.columnTypes = columnTypes;
            this.valueFunction = valueFunction;
            this.createTableSchema = createTableSchema;
            this.expectedFailureMessage = expectedFailureMessage;
        }
    }
}

