package com.gengoai.apollo.ml;

import com.gengoai.Validation;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.function.SerializableFunction;
import com.gengoai.function.Unchecked;
import com.gengoai.io.Resources;
import com.gengoai.io.resource.Resource;
import com.gengoai.json.Json;
import com.gengoai.sql.NamedPreparedStatement;
import com.gengoai.sql.SQL;
import com.gengoai.sql.SQLContext;
import com.gengoai.sql.SQLElement;
import com.gengoai.sql.SQLiteDialect;
import com.gengoai.sql.object.Column;
import com.gengoai.sql.object.SQLDMLOperation;
import com.gengoai.sql.object.Table;
import com.gengoai.sql.object.Trigger;
import com.gengoai.sql.object.TriggerTime;
import com.gengoai.sql.sqlite.SQLiteConnectionRegistry;
import com.gengoai.sql.statement.InsertType;
import com.gengoai.sql.statement.Select;
import com.gengoai.sql.statement.UpdateStatement;
import com.gengoai.stream.MStream;
import com.gengoai.stream.StreamingContext;
import com.gengoai.string.Strings;
import com.gengoai.tuple.Tuple2;
import com.gengoai.tuple.Tuples;
import java.lang.invoke.SerializedLambda;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/SQLiteDataSet.class */
public class SQLiteDataSet extends DataSet {
    private static final long serialVersionUID = 1;
    private static final String SIZE_NAME = "__size__";
    private static final Column json = new Column("json", "JSON");
    private static final Column name = new Column("name", "TEXT").primaryKey();
    private static final Column value = new Column("value", "BLOB");
    private static final Table dataTable = new Table("data", (SQLElement) null, List.of(json), Collections.emptyList());
    private static final Table metadataTable = new Table("metadata", (SQLElement) null, List.of(name, value), Collections.emptyList());
    private final SQLContext executor;
    private boolean isShuffled;

    public SQLiteDataSet(@NonNull Stream<Datum> stream) {
        this(Resources.temporaryFile().deleteOnExit(), stream);
        if (stream == null) {
            throw new NullPointerException("stream is marked non-null but is null");
        }
    }

    public SQLiteDataSet(@NonNull Resource resource, @NonNull Stream<Datum> stream) {
        this(resource);
        if (resource == null) {
            throw new NullPointerException("location is marked non-null but is null");
        }
        if (stream == null) {
            throw new NullPointerException("stream is marked non-null but is null");
        }
        dataTable.insert().batch(this.executor, stream, (datum, namedPreparedStatement) -> {
            namedPreparedStatement.setString(json.getName(), Json.dumps(datum));
        }, 1000);
    }

    public SQLiteDataSet(@NonNull Resource resource) {
        this.isShuffled = false;
        if (resource == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        this.executor = SQLContext.create(SQLiteConnectionRegistry.getConnection("jdbc:sqlite:" + Strings.prependIfNotPresent(resource.path(), "/")), new SQLiteDialect());
        if (!dataTable.exists(this.executor)) {
            dataTable.createIfNotExists(this.executor);
        }
        if (metadataTable.exists(this.executor)) {
            metadataTable.select(new SQLElement[]{name, value}).where(name.neq(SQL.L(SIZE_NAME))).query(this.executor, resultSet -> {
                return Map.of(resultSet.getString(name.getName()), resultSet.getObject(value.getName()));
            }).forEach(map -> {
                map.forEach(Unchecked.biConsumer((str, obj) -> {
                    if (str.equals("ndArrayFactory")) {
                        this.ndArrayFactory = NDArrayFactory.valueOf(obj.toString());
                    } else {
                        this.metadata.put(str, (ObservationMetadata) Json.parse(obj.toString(), ObservationMetadata.class));
                    }
                }));
            });
        } else {
            this.executor.batch(new UpdateStatement[]{metadataTable.create(), metadataTable.insert(InsertType.INSERT_OR_REPLACE).values(new SQLElement[]{SQL.L(SIZE_NAME), SQL.N(0)}), Trigger.builder().name("data_insert_size_inc").table(dataTable).operation(SQLDMLOperation.INSERT).when(TriggerTime.AFTER).updateStatement(metadataTable.update().set(value, SQL.sql(new String[]{"cast(value as INTEGER)+1"})).where(name.eq(SQL.L(SIZE_NAME)))).build().create(), Trigger.builder().name("data_delete_size_dec").table(dataTable).operation(SQLDMLOperation.DELETE).when(TriggerTime.AFTER).updateStatement(metadataTable.update().set(value, SQL.sql(new String[]{"cast(value as INTEGER)-1"})).where(name.eq(SQL.L(SIZE_NAME)))).build().create()});
        }
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public Iterator<DataSet> batchIterator(final int i) {
        Validation.checkArgument(i > 0);
        return new Iterator<DataSet>() { // from class: com.gengoai.apollo.ml.SQLiteDataSet.1
            private final Iterator<Datum> itr;

            {
                this.itr = SQLiteDataSet.this.iterator();
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.itr.hasNext();
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public DataSet next() {
                if (!this.itr.hasNext()) {
                    throw new NoSuchElementException();
                }
                ArrayList arrayList = new ArrayList();
                while (this.itr.hasNext() && arrayList.size() < i) {
                    arrayList.add(this.itr.next());
                }
                return new InMemoryDataSet(arrayList);
            }
        };
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet cache() {
        return this;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSetType getType() {
        return DataSetType.OnDisk;
    }

    @Override // java.lang.Iterable
    public Iterator<Datum> iterator() {
        return stream().iterator();
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet map(@NonNull SerializableFunction<? super Datum, ? extends Datum> serializableFunction) {
        if (serializableFunction == null) {
            throw new NullPointerException("function is marked non-null but is null");
        }
        Connection connection = this.executor.getConnection();
        try {
            NamedPreparedStatement namedPreparedStatement = new NamedPreparedStatement(connection, this.executor.render(dataTable.update().set(json, SQL.namedArgument("json")).where(SQL.C("rowid").eq(SQL.namedArgument("rowid")))));
            try {
                AtomicLong atomicLong = new AtomicLong(0L);
                boolean autoCommit = connection.getAutoCommit();
                connection.setAutoCommit(false);
                parallelIdStream().forEach(Unchecked.consumer(tuple2 -> {
                    long longValue = ((Long) tuple2.v1).longValue();
                    Datum datum = (Datum) serializableFunction.apply(tuple2.v2);
                    synchronized (this) {
                        namedPreparedStatement.setObject("json", Json.dumps(datum));
                        namedPreparedStatement.setLong("rowid", longValue);
                        namedPreparedStatement.addBatch();
                        if (atomicLong.incrementAndGet() % SQLContext.DEFAULT_BATCH_SIZE == 0) {
                            namedPreparedStatement.executeBatch();
                            connection.commit();
                            atomicLong.set(0L);
                        }
                    }
                }));
                if (atomicLong.get() > 0) {
                    namedPreparedStatement.executeBatch();
                    connection.commit();
                }
                connection.setAutoCommit(autoCommit);
                namedPreparedStatement.close();
                return this;
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public MStream<Tuple2<Long, Datum>> parallelIdStream() {
        return StreamingContext.local().stream(dataTable.select(new String[]{"rowid", json.getName()}).queryParallel(this.executor, resultSet -> {
            long j = resultSet.getLong("rowid");
            return Tuples.$(Long.valueOf(j), (Datum) Json.parse(resultSet.getString(json.getName()), Datum.class));
        }, "rowid"));
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public MStream<Datum> parallelStream() {
        return StreamingContext.local().stream(select().queryParallel(this.executor, resultSet -> {
            return (Datum) Json.parse(resultSet.getString(1), Datum.class);
        }, "rowid"));
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet persist(@NonNull Resource resource) {
        if (resource == null) {
            throw new NullPointerException("copy is marked non-null but is null");
        }
        SQL.update(String.format("VACUUM main INTO '%s'", resource.path())).update(this.executor);
        return new SQLiteDataSet(resource);
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet putAllMetadata(@NonNull Map<String, ObservationMetadata> map) {
        if (map == null) {
            throw new NullPointerException("metadata is marked non-null but is null");
        }
        super.putAllMetadata(map);
        metadataTable.insert(InsertType.INSERT_OR_REPLACE).batch(this.executor, map.entrySet(), (entry, namedPreparedStatement) -> {
            namedPreparedStatement.setString(name.getName(), (String) entry.getKey());
            namedPreparedStatement.setObject(value.getName(), Json.dumps(entry.getValue()));
        });
        return this;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet removeMetadata(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("source is marked non-null but is null");
        }
        super.removeMetadata(str);
        metadataTable.delete(this.executor, name.eq(SQL.L(str)));
        return this;
    }

    private Select select() {
        return this.isShuffled ? dataTable.selectAll().orderBy(new SQLElement[]{SQL.F.random()}) : dataTable.selectAll();
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet setNDArrayFactory(@NonNull NDArrayFactory nDArrayFactory) {
        if (nDArrayFactory == null) {
            throw new NullPointerException("ndArrayFactory is marked non-null but is null");
        }
        super.setNDArrayFactory(nDArrayFactory);
        metadataTable.insert(InsertType.INSERT_OR_REPLACE).values(new SQLElement[]{SQL.L("ndArrayFactory"), SQL.L(nDArrayFactory.name())}).update(this.executor);
        return this;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet shuffle(Random random) {
        this.isShuffled = true;
        return this;
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public long size() {
        return metadataTable.select(new SQLElement[]{value}).where(name.eq(SQL.L(SIZE_NAME))).queryScalarLong(this.executor).longValue();
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public MStream<Datum> stream() {
        return StreamingContext.local().stream(select().query(this.executor, resultSet -> {
            return (Datum) Json.parse(resultSet.getString(1), Datum.class);
        }));
    }

    @Override // com.gengoai.apollo.ml.DataSet
    public DataSet updateMetadata(@NonNull String str, @NonNull Consumer<ObservationMetadata> consumer) {
        if (str == null) {
            throw new NullPointerException("source is marked non-null but is null");
        }
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
        super.updateMetadata(str, consumer);
        metadataTable.insert(InsertType.INSERT_OR_REPLACE).update(this.executor, Map.of(name.getName(), str, value.getName(), Json.dumps(getMetadata(str))));
        return this;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1452170807:
                if (implMethodName.equals("lambda$putAllMetadata$8f7f58b8$1")) {
                    z = 3;
                    break;
                }
                break;
            case -462043813:
                if (implMethodName.equals("lambda$stream$738f0283$1")) {
                    z = 6;
                    break;
                }
                break;
            case -11962916:
                if (implMethodName.equals("lambda$new$c28202cf$1")) {
                    z = true;
                    break;
                }
                break;
            case 662788894:
                if (implMethodName.equals("lambda$map$556e4c4$1")) {
                    z = 7;
                    break;
                }
                break;
            case 815903508:
                if (implMethodName.equals("lambda$parallelStream$738f0283$1")) {
                    z = 2;
                    break;
                }
                break;
            case 1463945740:
                if (implMethodName.equals("lambda$new$dd347367$1")) {
                    z = 5;
                    break;
                }
                break;
            case 1984618359:
                if (implMethodName.equals("lambda$new$6b3078fe$1")) {
                    z = 4;
                    break;
                }
                break;
            case 1986514436:
                if (implMethodName.equals("lambda$parallelIdStream$a6df1e07$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/sql/ResultSetMapper") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/sql/ResultSet;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/SQLiteDataSet") && serializedLambda.getImplMethodSignature().equals("(Ljava/sql/ResultSet;)Lcom/gengoai/tuple/Tuple2;")) {
                    return resultSet -> {
                        long j = resultSet.getLong("rowid");
                        return Tuples.$(Long.valueOf(j), (Datum) Json.parse(resultSet.getString(json.getName()), Datum.class));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/sql/ResultSetMapper") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/sql/ResultSet;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/SQLiteDataSet") && serializedLambda.getImplMethodSignature().equals("(Ljava/sql/ResultSet;)Ljava/util/Map;")) {
                    return resultSet2 -> {
                        return Map.of(resultSet2.getString(name.getName()), resultSet2.getObject(value.getName()));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/sql/ResultSetMapper") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/sql/ResultSet;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/SQLiteDataSet") && serializedLambda.getImplMethodSignature().equals("(Ljava/sql/ResultSet;)Lcom/gengoai/apollo/ml/Datum;")) {
                    return resultSet3 -> {
                        return (Datum) Json.parse(resultSet3.getString(1), Datum.class);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/CheckedBiConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/SQLiteDataSet") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Map$Entry;Lcom/gengoai/sql/NamedPreparedStatement;)V")) {
                    return (entry, namedPreparedStatement) -> {
                        namedPreparedStatement.setString(name.getName(), (String) entry.getKey());
                        namedPreparedStatement.setObject(value.getName(), Json.dumps(entry.getValue()));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/CheckedBiConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/SQLiteDataSet") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;Ljava/lang/Object;)V")) {
                    SQLiteDataSet sQLiteDataSet = (SQLiteDataSet) serializedLambda.getCapturedArg(0);
                    return (str, obj) -> {
                        if (str.equals("ndArrayFactory")) {
                            this.ndArrayFactory = NDArrayFactory.valueOf(obj.toString());
                        } else {
                            this.metadata.put(str, (ObservationMetadata) Json.parse(obj.toString(), ObservationMetadata.class));
                        }
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/CheckedBiConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/SQLiteDataSet") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;Lcom/gengoai/sql/NamedPreparedStatement;)V")) {
                    return (datum, namedPreparedStatement2) -> {
                        namedPreparedStatement2.setString(json.getName(), Json.dumps(datum));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/sql/ResultSetMapper") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/sql/ResultSet;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/SQLiteDataSet") && serializedLambda.getImplMethodSignature().equals("(Ljava/sql/ResultSet;)Lcom/gengoai/apollo/ml/Datum;")) {
                    return resultSet4 -> {
                        return (Datum) Json.parse(resultSet4.getString(1), Datum.class);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/CheckedConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/SQLiteDataSet") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/function/SerializableFunction;Lcom/gengoai/sql/NamedPreparedStatement;Ljava/util/concurrent/atomic/AtomicLong;Ljava/sql/Connection;Lcom/gengoai/tuple/Tuple2;)V")) {
                    SQLiteDataSet sQLiteDataSet2 = (SQLiteDataSet) serializedLambda.getCapturedArg(0);
                    SerializableFunction serializableFunction = (SerializableFunction) serializedLambda.getCapturedArg(1);
                    NamedPreparedStatement namedPreparedStatement3 = (NamedPreparedStatement) serializedLambda.getCapturedArg(2);
                    AtomicLong atomicLong = (AtomicLong) serializedLambda.getCapturedArg(3);
                    Connection connection = (Connection) serializedLambda.getCapturedArg(4);
                    return tuple2 -> {
                        long longValue = ((Long) tuple2.v1).longValue();
                        Datum datum2 = (Datum) serializableFunction.apply(tuple2.v2);
                        synchronized (this) {
                            namedPreparedStatement3.setObject("json", Json.dumps(datum2));
                            namedPreparedStatement3.setLong("rowid", longValue);
                            namedPreparedStatement3.addBatch();
                            if (atomicLong.incrementAndGet() % SQLContext.DEFAULT_BATCH_SIZE == 0) {
                                namedPreparedStatement3.executeBatch();
                                connection.commit();
                                atomicLong.set(0L);
                            }
                        }
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
