/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.geospatial;

import com.google.common.collect.ImmutableMap;
import com.google.common.io.Resources;
import io.trino.Session;
import io.trino.jmh.Benchmarks;
import io.trino.metadata.Metadata;
import io.trino.metadata.QualifiedObjectName;
import io.trino.metadata.TableHandle;
import io.trino.plugin.geospatial.GeoPlugin;
import io.trino.plugin.memory.MemoryConnectorFactory;
import io.trino.spi.Plugin;
import io.trino.spi.connector.ConnectorFactory;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.MaterializedResult;
import io.trino.testing.TestingSession;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.testng.Assert;
import org.testng.annotations.Test;

@State(value=Scope.Thread)
@OutputTimeUnit(value=TimeUnit.MILLISECONDS)
@BenchmarkMode(value={Mode.AverageTime})
@Fork(value=3)
@Warmup(iterations=10)
@Measurement(iterations=10)
public class BenchmarkSpatialJoin {
    @Benchmark
    public MaterializedResult benchmarkJoin(Context context) {
        return context.getQueryRunner().execute("SELECT count(*) FROM points, polygons WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(latitude, longitude))");
    }

    @Benchmark
    public MaterializedResult benchmarkUserOptimizedJoin(Context context) {
        return context.getQueryRunner().execute("SELECT count(*) FROM (SELECT ST_Point(latitude, longitude) as point FROM points) t1, (SELECT ST_GeometryFromText(wkt) as geometry FROM polygons) t2 WHERE ST_Contains(geometry, point)");
    }

    @Test
    public void verify() throws Exception {
        Context context = new Context();
        try {
            context.setUp();
            context.createPointsTable();
            BenchmarkSpatialJoin benchmark = new BenchmarkSpatialJoin();
            benchmark.benchmarkJoin(context);
            benchmark.benchmarkUserOptimizedJoin(context);
        }
        finally {
            context.queryRunner.close();
        }
    }

    public static void main(String[] args) throws Exception {
        new BenchmarkSpatialJoin().verify();
        Benchmarks.benchmark(BenchmarkSpatialJoin.class).run();
    }

    @State(value=Scope.Thread)
    public static class Context {
        private LocalQueryRunner queryRunner;
        @Param(value={"10", "100", "1000", "10000"})
        private int pointCount;

        public LocalQueryRunner getQueryRunner() {
            return this.queryRunner;
        }

        @Setup
        public void setUp() throws Exception {
            String polygonValues;
            this.queryRunner = LocalQueryRunner.create((Session)TestingSession.testSessionBuilder().setCatalog("memory").setSchema("default").build());
            this.queryRunner.installPlugin((Plugin)new GeoPlugin());
            this.queryRunner.createCatalog("memory", (ConnectorFactory)new MemoryConnectorFactory(), (Map)ImmutableMap.of());
            Path path = new File(Resources.getResource((String)"us-states.tsv").toURI()).toPath();
            try (Stream<String> lines = Files.lines(path);){
                polygonValues = lines.map(line -> line.split("\t")).map(parts -> String.format("('%s', '%s')", parts[0], parts[1])).collect(Collectors.joining(","));
            }
            this.queryRunner.execute(String.format("CREATE TABLE memory.default.polygons AS SELECT * FROM (VALUES %s) as t (name, wkt)", polygonValues));
        }

        @Setup(value=Level.Invocation)
        public void createPointsTable() {
            this.queryRunner.execute(String.format("CREATE TABLE memory.default.points AS SELECT 'p' || cast(elem AS VARCHAR) as name, xMin + (xMax - xMin) * random() as longitude, yMin + (yMax - yMin) * random() as latitude FROM (SELECT -124 AS xMin, -65 AS xMax, 27 AS yMin, 49 AS yMax) CROSS JOIN UNNEST(sequence(1, %s)) AS t(elem)", this.pointCount));
        }

        @TearDown(value=Level.Invocation)
        public void dropPointsTable() {
            this.queryRunner.inTransaction(this.queryRunner.getDefaultSession(), transactionSession -> {
                Metadata metadata = this.queryRunner.getMetadata();
                Optional tableHandle = metadata.getTableHandle(transactionSession, QualifiedObjectName.valueOf((String)"memory.default.points"));
                Assert.assertTrue((boolean)tableHandle.isPresent(), (String)"Table memory.default.points does not exist");
                metadata.dropTable(transactionSession, (TableHandle)tableHandle.get());
                return null;
            });
        }

        @TearDown
        public void tearDown() {
            this.queryRunner.close();
            this.queryRunner = null;
        }
    }
}

