aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-reasoning/src/main/java/tools/refinery/store/reasoning/internal/PartialClauseRewriter.java
blob: bc379c642b8230cf500928f73f685205b3c33abe (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
/*
 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/>
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package tools.refinery.store.reasoning.internal;

import tools.refinery.store.query.dnf.Dnf;
import tools.refinery.store.query.dnf.DnfClause;
import tools.refinery.store.query.literal.AbstractCallLiteral;
import tools.refinery.store.query.literal.Literal;
import tools.refinery.store.query.term.Variable;
import tools.refinery.store.reasoning.literal.Concreteness;
import tools.refinery.store.reasoning.literal.ModalConstraint;
import tools.refinery.store.reasoning.literal.Modality;
import tools.refinery.store.reasoning.representation.PartialRelation;

import java.util.*;

class PartialClauseRewriter {
	private final PartialQueryRewriter rewriter;
	private final List<Literal> completedLiterals = new ArrayList<>();
	private final Deque<Literal> workList = new ArrayDeque<>();
	private final Set<Variable> positiveVariables = new LinkedHashSet<>();
	private final Set<Variable> unmodifiablePositiveVariables = Collections.unmodifiableSet(positiveVariables);

	public PartialClauseRewriter(PartialQueryRewriter rewriter) {
		this.rewriter = rewriter;
	}

	public List<Literal> rewriteClause(DnfClause clause) {
		workList.addAll(clause.literals());
		while (!workList.isEmpty()) {
			var literal = workList.removeFirst();
			rewrite(literal);
		}
		return completedLiterals;
	}

	private void rewrite(Literal literal) {
		if (!(literal instanceof AbstractCallLiteral callLiteral)) {
			markAsDone(literal);
			return;
		}
		var target = callLiteral.getTarget();
		if (target instanceof Dnf dnf) {
			rewriteRecursively(callLiteral, dnf);
		} else if (target instanceof ModalConstraint modalConstraint) {
			var modality = modalConstraint.modality();
			var concreteness = modalConstraint.concreteness();
			var constraint = modalConstraint.constraint();
			if (constraint instanceof Dnf dnf) {
				rewriteRecursively(callLiteral, modality, concreteness, dnf);
			} else if (constraint instanceof PartialRelation partialRelation) {
				rewrite(callLiteral, modality, concreteness, partialRelation);
			} else {
				throw new IllegalArgumentException("Cannot interpret modal constraint: " + modalConstraint);
			}
		} else {
			markAsDone(literal);
		}
	}

	private void markAsDone(Literal literal) {
		completedLiterals.add(literal);
		positiveVariables.addAll(literal.getOutputVariables());
	}

	private void rewriteRecursively(AbstractCallLiteral callLiteral, Modality modality, Concreteness concreteness,
									Dnf dnf) {
		var liftedDnf = rewriter.getLifter().lift(modality, concreteness, dnf);
		rewriteRecursively(callLiteral, liftedDnf);
	}

	private void rewriteRecursively(AbstractCallLiteral callLiteral, Dnf dnf) {
		var rewrittenDnf = rewriter.rewrite(dnf);
		var rewrittenLiteral = callLiteral.withTarget(rewrittenDnf);
		completedLiterals.add(rewrittenLiteral);
	}

	private void rewrite(AbstractCallLiteral callLiteral, Modality modality, Concreteness concreteness,
						 PartialRelation partialRelation) {
		var relationRewriter = rewriter.getRelationRewriter(partialRelation);
		var literals = relationRewriter.rewriteLiteral(
				unmodifiablePositiveVariables, callLiteral, modality, concreteness);
		int length = literals.size();
		for (int i = length - 1; i >= 0; i--) {
			workList.addFirst(literals.get(i));
		}
	}
}