/*
* SPDX-FileCopyrightText: 2023 The Refinery Authors
*
* SPDX-License-Identifier: EPL-2.0
*/
package tools.refinery.store.query.interpreter;
import org.junit.jupiter.api.Test;
import tools.refinery.store.model.Model;
import tools.refinery.store.model.ModelStore;
import tools.refinery.store.query.ModelQueryAdapter;
import tools.refinery.logic.dnf.Query;
import tools.refinery.logic.term.StatefulAggregate;
import tools.refinery.logic.term.StatefulAggregator;
import tools.refinery.logic.term.Variable;
import tools.refinery.store.query.view.AnySymbolView;
import tools.refinery.store.query.view.FunctionView;
import tools.refinery.store.query.view.KeyOnlyView;
import tools.refinery.store.representation.Symbol;
import tools.refinery.store.tuple.Tuple;
import java.util.Map;
import java.util.Optional;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static tools.refinery.store.query.interpreter.tests.QueryAssertions.assertNullableResults;
class AggregatorBatchingTest {
private static final Symbol person = Symbol.of("Person", 1);
private static final Symbol values = Symbol.of("values", 2, Integer.class, null);
private static final AnySymbolView personView = new KeyOnlyView<>(person);
private static final FunctionView valuesView = new FunctionView<>(values);
private final Query query = Query.of(Integer.class, (builder, p1, output) -> builder
.clause(
personView.call(p1),
output.assign(valuesView.aggregate(new InstrumentedAggregator(), p1, Variable.of()))
));
private int extractCount = 0;
@Test
void batchTest() {
var model = createModel();
var personInterpretation = model.getInterpretation(person);
var valuesInterpretation = model.getInterpretation(values);
var queryEngine = model.getAdapter(ModelQueryAdapter.class);
var resultSet = queryEngine.getResultSet(query);
assertThat(extractCount, is(1));
personInterpretation.put(Tuple.of(0), true);
personInterpretation.put(Tuple.of(1), true);
valuesInterpretation.put(Tuple.of(0, 0), 1);
valuesInterpretation.put(Tuple.of(0, 1), 2);
valuesInterpretation.put(Tuple.of(0, 2), 3);
valuesInterpretation.put(Tuple.of(1, 0), 1);
valuesInterpretation.put(Tuple.of(1, 1), -1);
queryEngine.flushChanges();
assertThat(extractCount, is(5));
assertNullableResults(Map.of(
Tuple.of(0), Optional.of(6),
Tuple.of(1), Optional.of(0),
Tuple.of(2), Optional.empty()
), resultSet);
}
@Test
void separateTest() {
var model = createModel();
var personInterpretation = model.getInterpretation(person);
var valuesInterpretation = model.getInterpretation(values);
var queryEngine = model.getAdapter(ModelQueryAdapter.class);
var resultSet = queryEngine.getResultSet(query);
assertThat(extractCount, is(1));
personInterpretation.put(Tuple.of(0), true);
personInterpretation.put(Tuple.of(1), true);
queryEngine.flushChanges();
assertThat(extractCount, is(3));
valuesInterpretation.put(Tuple.of(0, 0), 1);
valuesInterpretation.put(Tuple.of(1, 0), 1);
queryEngine.flushChanges();
assertThat(extractCount, is(5));
assertNullableResults(Map.of(
Tuple.of(0), Optional.of(1),
Tuple.of(1), Optional.of(1),
Tuple.of(2), Optional.empty()
), resultSet);
valuesInterpretation.put(Tuple.of(0, 1), 2);
valuesInterpretation.put(Tuple.of(1, 1), -1);
queryEngine.flushChanges();
assertThat(extractCount, is(9));
assertNullableResults(Map.of(
Tuple.of(0), Optional.of(3),
Tuple.of(1), Optional.of(0),
Tuple.of(2), Optional.empty()
), resultSet);
valuesInterpretation.put(Tuple.of(0, 2), 3);
queryEngine.flushChanges();
assertThat(extractCount, is(11));
assertNullableResults(Map.of(
Tuple.of(0), Optional.of(6),
Tuple.of(1), Optional.of(0),
Tuple.of(2), Optional.empty()
), resultSet);
}
private Model createModel() {
var store = ModelStore.builder()
.symbols(person, values)
.with(QueryInterpreterAdapter.builder()
.query(query))
.build();
return store.createEmptyModel();
}
class InstrumentedAggregator implements StatefulAggregator {
@Override
public Class getResultType() {
return Integer.class;
}
@Override
public Class getInputType() {
return Integer.class;
}
@Override
public StatefulAggregate createEmptyAggregate() {
return new InstrumentedAggregate();
}
}
class InstrumentedAggregate implements StatefulAggregate {
private int sum;
public InstrumentedAggregate() {
this(0);
}
private InstrumentedAggregate(int sum) {
this.sum = sum;
}
@Override
public void add(Integer value) {
sum += value;
}
@Override
public void remove(Integer value) {
sum -= value;
}
@Override
public Integer getResult() {
extractCount++;
return sum;
}
@Override
public boolean isEmpty() {
return sum == 0;
}
@Override
public StatefulAggregate deepCopy() {
return new InstrumentedAggregate(sum);
}
}
}