aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/pquery/DNF2PQuery.java
blob: e3c586a07d13c2955912a67be186da84f666e099 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package tools.refinery.store.query.viatra.internal.pquery;

import org.eclipse.viatra.query.runtime.matchers.psystem.PBody;
import org.eclipse.viatra.query.runtime.matchers.psystem.PVariable;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.Equality;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.ExportedParameter;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.Inequality;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.NegativePatternCall;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.BinaryTransitiveClosure;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.PositivePatternCall;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.TypeConstraint;
import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PParameter;
import org.eclipse.viatra.query.runtime.matchers.tuple.Tuples;
import tools.refinery.store.query.*;
import tools.refinery.store.query.atom.DNFAtom;
import tools.refinery.store.query.atom.EquivalenceAtom;
import tools.refinery.store.query.atom.DNFCallAtom;
import tools.refinery.store.query.atom.RelationViewAtom;
import tools.refinery.store.query.view.RelationView;

import java.util.*;
import java.util.stream.Collectors;

public class DNF2PQuery {
	private final Set<DNF> translating = new LinkedHashSet<>();

	private final Map<DNF, SimplePQuery> dnf2PQueryMap = new HashMap<>();

	private final Map<RelationView<?>, RelationViewWrapper> view2WrapperMap = new HashMap<>();

	public SimplePQuery translate(DNF dnfQuery) {
		if (translating.contains(dnfQuery)) {
			var path = translating.stream().map(DNF::getName).collect(Collectors.joining(" -> "));
			throw new IllegalStateException("Circular reference %s -> %s detected".formatted(path,
					dnfQuery.getName()));
		}
		// We can't use computeIfAbsent here, because translating referenced queries calls this method in a reentrant
		// way, which would cause a ConcurrentModificationException with computeIfAbsent.
		var pQuery = dnf2PQueryMap.get(dnfQuery);
		if (pQuery == null) {
			translating.add(dnfQuery);
			try {
				pQuery = doTranslate(dnfQuery);
				dnf2PQueryMap.put(dnfQuery, pQuery);
			} finally {
				translating.remove(dnfQuery);
			}
		}
		return pQuery;
	}

	private SimplePQuery doTranslate(DNF dnfQuery) {
		var pQuery = new SimplePQuery(dnfQuery.getUniqueName());

		Map<Variable, PParameter> parameters = new HashMap<>();
		for (Variable variable : dnfQuery.getParameters()) {
			parameters.put(variable, new PParameter(variable.getUniqueName()));
		}

		List<PParameter> parameterList = new ArrayList<>();
		for (var param : dnfQuery.getParameters()) {
			parameterList.add(parameters.get(param));
		}
		pQuery.setParameters(parameterList);

		for (DNFAnd clause : dnfQuery.getClauses()) {
			PBody body = new PBody(pQuery);
			List<ExportedParameter> symbolicParameters = new ArrayList<>();
			for (var param : dnfQuery.getParameters()) {
				PVariable pVar = body.getOrCreateVariableByName(param.getUniqueName());
				symbolicParameters.add(new ExportedParameter(body, pVar, parameters.get(param)));
			}
			body.setSymbolicParameters(symbolicParameters);
			pQuery.addBody(body);
			for (DNFAtom constraint : clause.constraints()) {
				translateDNFAtom(constraint, body);
			}
		}

		return pQuery;
	}

	private void translateDNFAtom(DNFAtom constraint, PBody body) {
		if (constraint instanceof EquivalenceAtom equivalenceAtom) {
			translateEquivalenceAtom(equivalenceAtom, body);
		} else if (constraint instanceof RelationViewAtom relationViewAtom) {
			translateRelationViewAtom(relationViewAtom, body);
		} else if (constraint instanceof DNFCallAtom dnfCallAtom) {
			translateDNFCallAtom(dnfCallAtom, body);
		} else {
			throw new IllegalArgumentException("Unknown constraint: " + constraint.toString());
		}
	}

	private void translateEquivalenceAtom(EquivalenceAtom equivalence, PBody body) {
		PVariable varSource = body.getOrCreateVariableByName(equivalence.left().getUniqueName());
		PVariable varTarget = body.getOrCreateVariableByName(equivalence.right().getUniqueName());
		if (equivalence.positive()) {
			new Equality(body, varSource, varTarget);
		} else {
			new Inequality(body, varSource, varTarget);
		}
	}

	private void translateRelationViewAtom(RelationViewAtom relationViewAtom, PBody body) {
		int arity = relationViewAtom.getSubstitution().size();
		Object[] variables = new Object[arity];
		for (int i = 0; i < arity; i++) {
			var variable = relationViewAtom.getSubstitution().get(i);
			variables[i] = body.getOrCreateVariableByName(variable.getUniqueName());
		}
		new TypeConstraint(body, Tuples.flatTupleOf(variables), wrapView(relationViewAtom.getTarget()));
	}

	private RelationViewWrapper wrapView(RelationView<?> relationView) {
		return view2WrapperMap.computeIfAbsent(relationView, RelationViewWrapper::new);
	}

	private void translateDNFCallAtom(DNFCallAtom queryCallAtom, PBody body) {
		int arity = queryCallAtom.getSubstitution().size();
		Object[] variables = new Object[arity];
		for (int i = 0; i < arity; i++) {
			var variable = queryCallAtom.getSubstitution().get(i);
			variables[i] = body.getOrCreateVariableByName(variable.getUniqueName());
		}
		var variablesTuple = Tuples.flatTupleOf(variables);
		var translatedReferred = translate(queryCallAtom.getTarget());
		switch (queryCallAtom.getKind()) {
		case POSITIVE -> new PositivePatternCall(body, variablesTuple, translatedReferred);
		case TRANSITIVE -> new BinaryTransitiveClosure(body, variablesTuple, translatedReferred);
		case NEGATIVE -> new NegativePatternCall(body, variablesTuple, translatedReferred);
		}
	}
}