package org.apache.rya.mongodb.aggregation;

import com.google.common.collect.Sets;
import com.mongodb.MongoNamespace;
import com.mongodb.client.MongoCollection;
import java.util.Arrays;
import java.util.List;
import org.bson.Document;
import org.eclipse.rdf4j.model.IRI;
import org.eclipse.rdf4j.model.ValueFactory;
import org.eclipse.rdf4j.model.impl.SimpleValueFactory;
import org.eclipse.rdf4j.model.vocabulary.OWL;
import org.eclipse.rdf4j.model.vocabulary.RDF;
import org.eclipse.rdf4j.query.algebra.Extension;
import org.eclipse.rdf4j.query.algebra.ExtensionElem;
import org.eclipse.rdf4j.query.algebra.Join;
import org.eclipse.rdf4j.query.algebra.MultiProjection;
import org.eclipse.rdf4j.query.algebra.Not;
import org.eclipse.rdf4j.query.algebra.Projection;
import org.eclipse.rdf4j.query.algebra.ProjectionElem;
import org.eclipse.rdf4j.query.algebra.ProjectionElemList;
import org.eclipse.rdf4j.query.algebra.QueryRoot;
import org.eclipse.rdf4j.query.algebra.StatementPattern;
import org.eclipse.rdf4j.query.algebra.ValueConstant;
import org.eclipse.rdf4j.query.algebra.Var;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/rya/mongodb/aggregation/SparqlToPipelineTransformVisitorTest.class */
public class SparqlToPipelineTransformVisitorTest {
    private static final ValueFactory VF = SimpleValueFactory.getInstance();
    private static final String LUBM = "urn:lubm";
    private static final IRI UNDERGRAD = VF.createIRI(LUBM, "UndergraduateStudent");
    private static final IRI PROFESSOR = VF.createIRI(LUBM, "Professor");
    private static final IRI COURSE = VF.createIRI(LUBM, "Course");
    private static final IRI TAKES = VF.createIRI(LUBM, "takesCourse");
    private static final IRI TEACHES = VF.createIRI(LUBM, "teachesCourse");
    MongoCollection<Document> collection;

    private static Var constant(IRI iri) {
        return new Var(iri.stringValue(), iri);
    }

    @Before
    public void setUp() {
        this.collection = (MongoCollection) Mockito.mock(MongoCollection.class);
        Mockito.when(this.collection.getNamespace()).thenReturn(new MongoNamespace("db", "collection"));
    }

    @Test
    public void testStatementPattern() throws Exception {
        QueryRoot queryRoot = new QueryRoot(new StatementPattern(new Var("x"), constant(RDF.TYPE), constant(UNDERGRAD)));
        queryRoot.visit(new SparqlToPipelineTransformVisitor(this.collection));
        Assert.assertTrue(queryRoot.getArg() instanceof AggregationPipelineQueryNode);
        Assert.assertEquals(Sets.newHashSet(new String[]{"x"}), queryRoot.getArg().getAssuredBindingNames());
    }

    @Test
    public void testJoin() throws Exception {
        QueryRoot queryRoot = new QueryRoot(new Join(new StatementPattern(new Var("x"), constant(RDF.TYPE), constant(UNDERGRAD)), new StatementPattern(new Var("x"), constant(TAKES), new Var("course"))));
        queryRoot.visit(new SparqlToPipelineTransformVisitor(this.collection));
        Assert.assertTrue(queryRoot.getArg() instanceof AggregationPipelineQueryNode);
        Assert.assertEquals(Sets.newHashSet(new String[]{"x", "course"}), queryRoot.getArg().getAssuredBindingNames());
    }

    @Test
    public void testNestedJoins() throws Exception {
        StatementPattern statementPattern = new StatementPattern(new Var("x"), constant(RDF.TYPE), constant(UNDERGRAD));
        QueryRoot queryRoot = new QueryRoot(new Join(new StatementPattern(new Var("y"), constant(RDF.TYPE), constant(PROFESSOR)), new Join(new Join(statementPattern, new StatementPattern(new Var("x"), constant(TAKES), new Var("c"))), new StatementPattern(new Var("y"), constant(TEACHES), new Var("c")))));
        queryRoot.visit(new SparqlToPipelineTransformVisitor(this.collection));
        Assert.assertTrue(queryRoot.getArg() instanceof AggregationPipelineQueryNode);
        Assert.assertEquals(Sets.newHashSet(new String[]{"x", "y", "c"}), queryRoot.getArg().getAssuredBindingNames());
    }

    @Test
    public void testComplexJoin() throws Exception {
        QueryRoot queryRoot = new QueryRoot(new Join(new Join(new StatementPattern(new Var("x"), constant(RDF.TYPE), constant(UNDERGRAD)), new StatementPattern(new Var("x"), constant(TAKES), new Var("c"))), new Join(new StatementPattern(new Var("y"), constant(RDF.TYPE), constant(PROFESSOR)), new StatementPattern(new Var("y"), constant(TEACHES), new Var("c")))));
        queryRoot.visit(new SparqlToPipelineTransformVisitor(this.collection));
        Assert.assertTrue(queryRoot.getArg() instanceof Join);
        Join arg = queryRoot.getArg();
        Assert.assertTrue(arg.getLeftArg() instanceof AggregationPipelineQueryNode);
        Assert.assertTrue(arg.getRightArg() instanceof AggregationPipelineQueryNode);
        AggregationPipelineQueryNode leftArg = arg.getLeftArg();
        AggregationPipelineQueryNode rightArg = arg.getRightArg();
        Assert.assertEquals(Sets.newHashSet(new String[]{"x", "c"}), leftArg.getAssuredBindingNames());
        Assert.assertEquals(Sets.newHashSet(new String[]{"y", "c"}), rightArg.getAssuredBindingNames());
    }

    @Test
    public void testProjection() throws Exception {
        QueryRoot queryRoot = new QueryRoot(new Projection(new Join(new Join(new StatementPattern(new Var("course"), constant(RDF.TYPE), constant(COURSE)), new StatementPattern(new Var("x"), new Var("p"), new Var("course"))), new StatementPattern(new Var("x"), constant(RDF.TYPE), constant(UNDERGRAD))), new ProjectionElemList(new ProjectionElem[]{new ProjectionElem("p", "relation"), new ProjectionElem("course")})));
        queryRoot.visit(new SparqlToPipelineTransformVisitor(this.collection));
        Assert.assertTrue(queryRoot.getArg() instanceof AggregationPipelineQueryNode);
        Assert.assertEquals(Sets.newHashSet(new String[]{"relation", "course"}), queryRoot.getArg().getAssuredBindingNames());
    }

    @Test
    public void testEmptyProjection() throws Exception {
        QueryRoot queryRoot = new QueryRoot(new Projection(new StatementPattern(constant(UNDERGRAD), constant(RDF.TYPE), constant(OWL.CLASS)), new ProjectionElemList()));
        queryRoot.visit(new SparqlToPipelineTransformVisitor(this.collection));
        Assert.assertTrue(queryRoot.getArg() instanceof Projection);
        Projection arg = queryRoot.getArg();
        Assert.assertTrue(arg.getArg() instanceof AggregationPipelineQueryNode);
        Assert.assertEquals(Sets.newHashSet(), arg.getArg().getAssuredBindingNames());
    }

    @Test
    public void testMultiProjection() throws Exception {
        QueryRoot queryRoot = new QueryRoot(new MultiProjection(new Join(new Join(new StatementPattern(new Var("course"), constant(RDF.TYPE), constant(COURSE)), new StatementPattern(new Var("x"), new Var("p"), new Var("course"))), new StatementPattern(new Var("x"), constant(RDF.TYPE), constant(UNDERGRAD))), Arrays.asList(new ProjectionElemList(new ProjectionElem[]{new ProjectionElem("p", "relation"), new ProjectionElem("course")}), new ProjectionElemList(new ProjectionElem[]{new ProjectionElem("p", "relation"), new ProjectionElem("x", "student")}))));
        queryRoot.visit(new SparqlToPipelineTransformVisitor(this.collection));
        Assert.assertTrue(queryRoot.getArg() instanceof AggregationPipelineQueryNode);
        AggregationPipelineQueryNode arg = queryRoot.getArg();
        Assert.assertEquals(Sets.newHashSet(new String[]{"relation"}), arg.getAssuredBindingNames());
        Assert.assertEquals(Sets.newHashSet(new String[]{"relation", "course", "student"}), arg.getBindingNames());
    }

    @Test
    public void testExtension() throws Exception {
        QueryRoot queryRoot = new QueryRoot(new Extension(new StatementPattern(new Var("x"), constant(TAKES), new Var("c")), new ExtensionElem[]{new ExtensionElem(new Var("x"), "renamed"), new ExtensionElem(new ValueConstant(TAKES), "constant")}));
        queryRoot.visit(new SparqlToPipelineTransformVisitor(this.collection));
        Assert.assertTrue(queryRoot.getArg() instanceof AggregationPipelineQueryNode);
        Assert.assertEquals(Sets.newHashSet(new String[]{"x", "c", "renamed", "constant"}), queryRoot.getArg().getAssuredBindingNames());
    }

    @Test
    public void testUnsupportedExtension() throws Exception {
        StatementPattern statementPattern = new StatementPattern(new Var("x"), constant(TAKES), new Var("c"));
        List asList = Arrays.asList(new ExtensionElem(new Var("x"), "renamed"), new ExtensionElem(new Not(new ValueConstant(VF.createLiteral(true))), "notTrue"), new ExtensionElem(new ValueConstant(TAKES), "constant"));
        QueryRoot queryRoot = new QueryRoot(new Extension(statementPattern, asList));
        queryRoot.visit(new SparqlToPipelineTransformVisitor(this.collection));
        Assert.assertTrue(queryRoot.getArg() instanceof Extension);
        Assert.assertEquals(asList, queryRoot.getArg().getElements());
        AggregationPipelineQueryNode arg = queryRoot.getArg().getArg();
        Assert.assertTrue(arg instanceof AggregationPipelineQueryNode);
        Assert.assertEquals(Sets.newHashSet(new String[]{"x", "c"}), arg.getAssuredBindingNames());
    }
}
