aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/language/src/main/java/tools/refinery/language/typesystem/TypedModule.java
diff options
context:
space:
mode:
Diffstat (limited to 'subprojects/language/src/main/java/tools/refinery/language/typesystem/TypedModule.java')
-rw-r--r--subprojects/language/src/main/java/tools/refinery/language/typesystem/TypedModule.java568
1 files changed, 568 insertions, 0 deletions
diff --git a/subprojects/language/src/main/java/tools/refinery/language/typesystem/TypedModule.java b/subprojects/language/src/main/java/tools/refinery/language/typesystem/TypedModule.java
new file mode 100644
index 00000000..de923e0d
--- /dev/null
+++ b/subprojects/language/src/main/java/tools/refinery/language/typesystem/TypedModule.java
@@ -0,0 +1,568 @@
1/*
2 * SPDX-FileCopyrightText: 2024 The Refinery Authors <https://refinery.tools/>
3 *
4 * SPDX-License-Identifier: EPL-2.0
5 */
6package tools.refinery.language.typesystem;
7
8import com.google.inject.Inject;
9import org.eclipse.emf.common.util.Diagnostic;
10import org.eclipse.emf.ecore.EObject;
11import org.eclipse.emf.ecore.EStructuralFeature;
12import org.eclipse.xtext.validation.CheckType;
13import org.eclipse.xtext.validation.FeatureBasedDiagnostic;
14import tools.refinery.language.expressions.BuiltinTermInterpreter;
15import tools.refinery.language.expressions.TermInterpreter;
16import tools.refinery.language.model.problem.*;
17import tools.refinery.language.scoping.imports.ImportAdapterProvider;
18import tools.refinery.language.validation.ProblemValidator;
19
20import java.util.*;
21
22public class TypedModule {
23 private static final String OPERAND_TYPE_ERROR_MESSAGE = "Cannot determine operand type.";
24
25 @Inject
26 private SignatureProvider signatureProvider;
27
28 @Inject
29 private ImportAdapterProvider importAdapterProvider;
30
31 private TermInterpreter interpreter;
32 private final Map<Variable, List<AssignmentExpr>> assignments = new LinkedHashMap<>();
33 private final Map<Variable, FixedType> variableTypes = new HashMap<>();
34 private final Map<Expr, ExprType> expressionTypes = new HashMap<>();
35 private final Set<Variable> variablesToProcess = new LinkedHashSet<>();
36 private final List<FeatureBasedDiagnostic> diagnostics = new ArrayList<>();
37
38 void setProblem(Problem problem) {
39 interpreter = importAdapterProvider.getTermInterpreter(problem);
40 gatherAssignments(problem);
41 checkTypes(problem);
42 }
43
44 private void gatherAssignments(Problem problem) {
45 var iterator = problem.eAllContents();
46 while (iterator.hasNext()) {
47 var eObject = iterator.next();
48 if (!(eObject instanceof AssignmentExpr assignmentExpr)) {
49 continue;
50 }
51 if (assignmentExpr.getLeft() instanceof VariableOrNodeExpr variableOrNodeExpr &&
52 variableOrNodeExpr.getVariableOrNode() instanceof Variable variable) {
53 var assignmentList = assignments.computeIfAbsent(variable, ignored -> new ArrayList<>(1));
54 assignmentList.add(assignmentExpr);
55 }
56 iterator.prune();
57 }
58 }
59
60 private void checkTypes(Problem problem) {
61 for (var statement : problem.getStatements()) {
62 switch (statement) {
63 case PredicateDefinition predicateDefinition -> checkTypes(predicateDefinition);
64 case Assertion assertion -> checkTypes(assertion);
65 default -> {
66 // Nothing to type check.
67 }
68 }
69 }
70 }
71
72 private void checkTypes(PredicateDefinition predicateDefinition) {
73 for (var conjunction : predicateDefinition.getBodies()) {
74 for (var literal : conjunction.getLiterals()) {
75 coerceIntoLiteral(literal);
76 }
77 }
78 }
79
80 private void checkTypes(Assertion assertion) {
81 var relation = assertion.getRelation();
82 var value = assertion.getValue();
83 if (relation == null) {
84 return;
85 }
86 var type = signatureProvider.getSignature(relation).resultType();
87 if (type == ExprType.LITERAL) {
88 if (value == null) {
89 return;
90 }
91 expectType(value, BuiltinTermInterpreter.BOOLEAN_TYPE);
92 return;
93 }
94 if (value == null) {
95 var message = "Assertion value of type %s is required.".formatted(type);
96 error(message, assertion, ProblemPackage.Literals.ASSERTION__RELATION, 0, ProblemValidator.TYPE_ERROR);
97 }
98 expectType(value, type);
99 }
100
101 public List<FeatureBasedDiagnostic> getDiagnostics() {
102 return diagnostics;
103 }
104
105 public FixedType getVariableType(Variable variable) {
106 // We can't use computeIfAbsent here, because translating referenced queries calls this method in a reentrant
107 // way, which would cause a ConcurrentModificationException with computeIfAbsent.
108 @SuppressWarnings("squid:S3824")
109 var type = variableTypes.get(variable);
110 //noinspection Java8MapApi
111 if (type == null) {
112 type = computeVariableType(variable);
113 variableTypes.put(variable, type);
114 }
115 return type;
116 }
117
118 private FixedType computeVariableType(Variable variable) {
119 if (variable instanceof Parameter) {
120 return computeUnassignedVariableType(variable);
121 }
122 var assignmnentList = assignments.get(variable);
123 if (assignmnentList == null || assignmnentList.isEmpty()) {
124 return computeUnassignedVariableType(variable);
125 }
126 if (variablesToProcess.contains(variable)) {
127 throw new IllegalStateException("Circular reference to variable: " + variable.getName());
128 }
129 if (assignmnentList.size() > 1) {
130 var message = "Multiple assignments for variable '%s'.".formatted(variable.getName());
131 for (var assignment : assignmnentList) {
132 error(message, assignment, ProblemPackage.Literals.BINARY_EXPR__LEFT, 0,
133 ProblemValidator.INVALID_ASSIGNMENT_ISSUE);
134 }
135 return ExprType.INVALID;
136 }
137 var assignment = assignmnentList.getFirst();
138 variablesToProcess.add(variable);
139 try {
140 var assignedType = getExpressionType(assignment.getRight());
141 if (assignedType instanceof MutableType) {
142 var message = "Cannot determine type of variable '%s'.".formatted(variable.getName());
143 error(message, assignment, ProblemPackage.Literals.BINARY_EXPR__RIGHT, 0, ProblemValidator.TYPE_ERROR);
144 return ExprType.INVALID;
145 }
146 if (assignedType instanceof DataExprType dataExprType) {
147 return dataExprType;
148 }
149 if (assignedType != ExprType.INVALID) {
150 var message = "Expected data expression for variable '%s', got %s instead."
151 .formatted(variable.getName(), assignedType);
152 error(message, assignment, ProblemPackage.Literals.BINARY_EXPR__RIGHT, 0, ProblemValidator.TYPE_ERROR);
153 }
154 return ExprType.INVALID;
155 } finally {
156 variablesToProcess.remove(variable);
157 }
158 }
159
160 private FixedType computeUnassignedVariableType(Variable variable) {
161 if (variable instanceof Parameter parameter &&
162 parameter.getParameterType() instanceof DatatypeDeclaration datatypeDeclaration) {
163 return signatureProvider.getDataType(datatypeDeclaration);
164 }
165 // Parameters without an explicit datatype annotation are node variables.
166 return ExprType.NODE;
167 }
168
169 public ExprType getExpressionType(Expr expr) {
170 // We can't use computeIfAbsent here, because translating referenced queries calls this method in a reentrant
171 // way, which would cause a ConcurrentModificationException with computeIfAbsent.
172 @SuppressWarnings("squid:S3824")
173 var type = expressionTypes.get(expr);
174 //noinspection Java8MapApi
175 if (type == null) {
176 type = computeExpressionType(expr);
177 expressionTypes.put(expr, type);
178 }
179 return type.unwrapIfSet();
180 }
181
182 private ExprType computeExpressionType(Expr expr) {
183 return switch (expr) {
184 case LogicConstant logicConstant -> computeExpressionType(logicConstant);
185 case IntConstant ignored -> BuiltinTermInterpreter.INT_TYPE;
186 case RealConstant ignored -> BuiltinTermInterpreter.REAL_TYPE;
187 case StringConstant ignored -> BuiltinTermInterpreter.STRING_TYPE;
188 case InfiniteConstant ignored -> new MutableType();
189 case VariableOrNodeExpr variableOrNodeExpr -> computeExpressionType(variableOrNodeExpr);
190 case AssignmentExpr assignmentExpr -> computeExpressionType(assignmentExpr);
191 case Atom atom -> computeExpressionType(atom);
192 case NegationExpr negationExpr -> computeExpressionType(negationExpr);
193 case ArithmeticUnaryExpr arithmeticUnaryExpr -> computeExpressionType(arithmeticUnaryExpr);
194 case CountExpr countExpr -> computeExpressionType(countExpr);
195 case AggregationExpr aggregationExpr -> computeExpressionType(aggregationExpr);
196 case ComparisonExpr comparisonExpr -> computeExpressionType(comparisonExpr);
197 case LatticeBinaryExpr latticeBinaryExpr -> computeExpressionType(latticeBinaryExpr);
198 case RangeExpr rangeExpr -> computeExpressionType(rangeExpr);
199 case ArithmeticBinaryExpr arithmeticBinaryExpr -> computeExpressionType(arithmeticBinaryExpr);
200 case CastExpr castExpr -> computeExpressionType(castExpr);
201 default -> {
202 error("Unknown expression: " + expr.getClass().getSimpleName(), expr, null, 0,
203 ProblemValidator.UNKNOWN_EXPRESSION_ISSUE);
204 yield ExprType.INVALID;
205 }
206 };
207 }
208
209 private ExprType computeExpressionType(LogicConstant expr) {
210 return switch (expr.getLogicValue()) {
211 case TRUE, FALSE -> BuiltinTermInterpreter.BOOLEAN_TYPE;
212 case UNKNOWN, ERROR -> new MutableType();
213 case null -> ExprType.INVALID;
214 };
215 }
216
217 private ExprType computeExpressionType(VariableOrNodeExpr expr) {
218 var target = expr.getVariableOrNode();
219 if (target == null || target.eIsProxy()) {
220 return ExprType.INVALID;
221 }
222 return switch (target) {
223 case Node ignored -> ExprType.NODE;
224 case Variable variable -> {
225 if (variablesToProcess.contains(variable)) {
226 var message = "Circular reference to variable '%s'.".formatted(variable.getName());
227 error(message, expr, ProblemPackage.Literals.VARIABLE_OR_NODE_EXPR__VARIABLE_OR_NODE, 0,
228 ProblemValidator.INVALID_ASSIGNMENT_ISSUE);
229 yield ExprType.INVALID;
230 }
231 yield getVariableType(variable);
232 }
233 default -> {
234 error("Unknown variable: " + target.getName(), expr,
235 ProblemPackage.Literals.VARIABLE_OR_NODE_EXPR__VARIABLE_OR_NODE, 0,
236 ProblemValidator.UNKNOWN_EXPRESSION_ISSUE);
237 yield ExprType.INVALID;
238 }
239 };
240 }
241
242 private ExprType computeExpressionType(AssignmentExpr expr) {
243 // Force the left side to type check. Since the left side is a variable, it will force the right side to also
244 // type check in order to infer the variable type.
245 return getExpressionType(expr.getLeft()) == ExprType.INVALID ? ExprType.INVALID : ExprType.LITERAL;
246 }
247
248 private ExprType computeExpressionType(Atom atom) {
249 var relation = atom.getRelation();
250 if (relation == null || relation.eIsProxy()) {
251 return ExprType.INVALID;
252 }
253 if (relation instanceof DatatypeDeclaration) {
254 var message = "Invalid call to data type. Use 'as %s' for casting.".formatted(
255 relation.getName());
256 error(message, atom, ProblemPackage.Literals.ATOM__RELATION, 0, ProblemValidator.TYPE_ERROR);
257 }
258 var signature = signatureProvider.getSignature(relation);
259 var parameterTypes = signature.parameterTypes();
260 var arguments = atom.getArguments();
261 int size = Math.min(parameterTypes.size(), arguments.size());
262 boolean ok = parameterTypes.size() == arguments.size();
263 for (int i = 0; i < size; i++) {
264 var parameterType = parameterTypes.get(i);
265 var argument = arguments.get(i);
266 if (!expectType(argument, parameterType)) {
267 // Avoid short-circuiting to let us type check all arguments.
268 ok = false;
269 }
270 }
271 return ok ? signature.resultType() : ExprType.INVALID;
272 }
273
274 private ExprType computeExpressionType(NegationExpr negationExpr) {
275 var body = negationExpr.getBody();
276 if (body == null) {
277 return ExprType.INVALID;
278 }
279 var actualType = getExpressionType(body);
280 if (actualType == ExprType.LITERAL) {
281 // Negation of literals yields another (non-enumerable) literal.
282 return ExprType.LITERAL;
283 }
284 if (actualType == DataExprType.INVALID) {
285 return ExprType.INVALID;
286 }
287 if (actualType instanceof MutableType) {
288 error(OPERAND_TYPE_ERROR_MESSAGE, body, null, 0, ProblemValidator.TYPE_ERROR);
289 return ExprType.INVALID;
290 }
291 if (actualType instanceof DataExprType dataExprType) {
292 var result = interpreter.getNegationType(dataExprType);
293 if (result.isPresent()) {
294 return result.get();
295 }
296 }
297 var message = "Data type %s does not support negation.".formatted(actualType);
298 error(message, negationExpr, null, 0, ProblemValidator.TYPE_ERROR);
299 return ExprType.INVALID;
300 }
301
302 private ExprType computeExpressionType(ArithmeticUnaryExpr expr) {
303 var op = expr.getOp();
304 var body = expr.getBody();
305 if (op == null || body == null) {
306 return ExprType.INVALID;
307 }
308 var actualType = getExpressionType(body);
309 if (actualType == DataExprType.INVALID) {
310 return ExprType.INVALID;
311 }
312 if (actualType instanceof MutableType) {
313 error(OPERAND_TYPE_ERROR_MESSAGE, body, null, 0, ProblemValidator.TYPE_ERROR);
314 return ExprType.INVALID;
315 }
316 if (actualType instanceof DataExprType dataExprType) {
317 var result = interpreter.getUnaryOperationType(op, dataExprType);
318 if (result.isPresent()) {
319 return result.get();
320 }
321 }
322 var message = "Unsupported operator for data type %s.".formatted(actualType);
323 error(message, expr, null, 0, ProblemValidator.TYPE_ERROR);
324 return ExprType.INVALID;
325 }
326
327 private ExprType computeExpressionType(CountExpr countExpr) {
328 return coerceIntoLiteral(countExpr.getBody()) ? BuiltinTermInterpreter.INT_TYPE : ExprType.INVALID;
329 }
330
331 private ExprType computeExpressionType(AggregationExpr expr) {
332 var aggregator = expr.getAggregator();
333 if (aggregator == null || aggregator.eIsProxy()) {
334 return null;
335 }
336 // Avoid short-circuiting to let us type check both the value and the condition.
337 boolean ok = coerceIntoLiteral(expr.getCondition());
338 var value = expr.getValue();
339 var actualType = getExpressionType(value);
340 if (actualType == ExprType.INVALID) {
341 return ExprType.INVALID;
342 }
343 if (actualType instanceof MutableType) {
344 error(OPERAND_TYPE_ERROR_MESSAGE, value, null, 0, ProblemValidator.TYPE_ERROR);
345 return ExprType.INVALID;
346 }
347 if (actualType instanceof DataExprType dataExprType) {
348 var aggregatorName = signatureProvider.getAggregatorName(aggregator);
349 var result = interpreter.getAggregationType(aggregatorName, dataExprType);
350 if (result.isPresent()) {
351 return ok ? result.get() : ExprType.INVALID;
352 }
353 }
354 var message = "Unsupported aggregator for type %s.".formatted(actualType);
355 error(message, expr, ProblemPackage.Literals.AGGREGATION_EXPR__AGGREGATOR, 0, ProblemValidator.TYPE_ERROR);
356 return ExprType.INVALID;
357 }
358
359 private ExprType computeExpressionType(ComparisonExpr expr) {
360 var left = expr.getLeft();
361 var right = expr.getRight();
362 var op = expr.getOp();
363 if (op == ComparisonOp.NODE_EQ || op == ComparisonOp.NODE_NOT_EQ) {
364 // Avoid short-circuiting to let us type check both arguments.
365 boolean leftOk = expectType(left, ExprType.NODE);
366 boolean rightOk = expectType(right, ExprType.NODE);
367 return leftOk && rightOk ? ExprType.LITERAL : ExprType.INVALID;
368 }
369 if (!(getCommonDataType(expr) instanceof DataExprType commonType)) {
370 return ExprType.INVALID;
371 }
372 // Data equality and inequality are always supported for data types.
373 if (op != ComparisonOp.EQ && op != ComparisonOp.NOT_EQ && !interpreter.isComparisonSupported(commonType)) {
374 var message = "Data type %s does not support comparison.".formatted(commonType);
375 error(message, expr, null, 0, ProblemValidator.TYPE_ERROR);
376 return ExprType.INVALID;
377 }
378 return BuiltinTermInterpreter.BOOLEAN_TYPE;
379 }
380
381 private ExprType computeExpressionType(LatticeBinaryExpr expr) {
382 // Lattice operations are always supported for data types.
383 return getCommonDataType(expr);
384 }
385
386 private ExprType computeExpressionType(RangeExpr expr) {
387 var left = expr.getLeft();
388 var right = expr.getRight();
389 if (left instanceof InfiniteConstant && right instanceof InfiniteConstant) {
390 // `*..*` is equivalent to `unknown` if neither subexpression have been typed yet.
391 var mutableType = new MutableType();
392 if (expressionTypes.putIfAbsent(left, mutableType) == null &&
393 expressionTypes.put(right, mutableType) == null) {
394 return mutableType;
395 }
396 }
397 if (!(getCommonDataType(expr) instanceof DataExprType commonType)) {
398 return ExprType.INVALID;
399 }
400 if (!interpreter.isRangeSupported(commonType)) {
401 var message = "Data type %s does not support ranges.".formatted(commonType);
402 error(message, expr, null, 0, ProblemValidator.TYPE_ERROR);
403 return ExprType.INVALID;
404 }
405 return commonType;
406 }
407
408 private ExprType computeExpressionType(ArithmeticBinaryExpr expr) {
409 var op = expr.getOp();
410 var left = expr.getLeft();
411 var right = expr.getRight();
412 if (op == null || left == null || right == null) {
413 return ExprType.INVALID;
414 }
415 // Avoid short-circuiting to let us type check both arguments.
416 var leftType = getExpressionType(left);
417 var rightType = getExpressionType(right);
418 if (leftType == ExprType.INVALID || rightType == ExprType.INVALID) {
419 return ExprType.INVALID;
420 }
421 if (rightType instanceof MutableType rightMutableType) {
422 if (leftType instanceof DataExprType leftExprType) {
423 rightMutableType.setActualType(leftExprType);
424 rightType = leftExprType;
425 } else {
426 error(OPERAND_TYPE_ERROR_MESSAGE, right, null, 0, ProblemValidator.TYPE_ERROR);
427 return ExprType.INVALID;
428 }
429 }
430 if (leftType instanceof MutableType leftMutableType) {
431 if (rightType instanceof DataExprType rightExprType) {
432 leftMutableType.setActualType(rightExprType);
433 leftType = rightExprType;
434 } else {
435 error(OPERAND_TYPE_ERROR_MESSAGE, left, null, 0, ProblemValidator.TYPE_ERROR);
436 return ExprType.INVALID;
437 }
438 }
439 if (leftType instanceof DataExprType leftExprType && rightType instanceof DataExprType rightExprType) {
440 var result = interpreter.getBinaryOperationType(op, leftExprType, rightExprType);
441 if (result.isPresent()) {
442 return result.get();
443 }
444 }
445 var messageBuilder = new StringBuilder("Unsupported operator for ");
446 if (leftType.equals(rightType)) {
447 messageBuilder.append("data type ").append(leftType);
448 } else {
449 messageBuilder.append("data types ").append(leftType).append(" and ").append(rightType);
450 }
451 messageBuilder.append(".");
452 error(messageBuilder.toString(), expr, null, 0, ProblemValidator.TYPE_ERROR);
453 return ExprType.INVALID;
454 }
455
456 private ExprType computeExpressionType(CastExpr expr) {
457 var body = expr.getBody();
458 var targetRelation = expr.getTargetType();
459 if (body == null || !(targetRelation instanceof DatatypeDeclaration targetDeclaration)) {
460 return null;
461 }
462 var actualType = getExpressionType(body);
463 if (actualType == ExprType.INVALID) {
464 return ExprType.INVALID;
465 }
466 var targetType = signatureProvider.getDataType(targetDeclaration);
467 if (actualType instanceof MutableType mutableType) {
468 // Type ascription for polymorphic literal (e.g., `unknown as int` for the set of all integers).
469 mutableType.setActualType(targetType);
470 return targetType;
471 }
472 if (actualType.equals(targetType)) {
473 return targetType;
474 }
475 if (actualType instanceof DataExprType dataExprType && interpreter.isCastSupported(dataExprType, targetType)) {
476 return targetType;
477 }
478 var message = "Casting from %s to %s is not supported.".formatted(actualType, targetType);
479 error(message, expr, null, 0, ProblemValidator.TYPE_ERROR);
480 return ExprType.INVALID;
481 }
482
483 private FixedType getCommonDataType(BinaryExpr expr) {
484 var commonType = getCommonType(expr);
485 if (!(commonType instanceof DataExprType) && commonType != ExprType.INVALID) {
486 var message = "Expected data expression, got %s instead.".formatted(commonType);
487 error(message, expr, null, 0, ProblemValidator.TYPE_ERROR);
488 return ExprType.INVALID;
489 }
490 return commonType;
491 }
492
493 private FixedType getCommonType(BinaryExpr expr) {
494 var left = expr.getLeft();
495 var right = expr.getRight();
496 if (left == null || right == null) {
497 return ExprType.INVALID;
498 }
499 var leftType = getExpressionType(left);
500 if (leftType instanceof FixedType fixedLeftType) {
501 return expectType(right, fixedLeftType) ? fixedLeftType : ExprType.INVALID;
502 } else {
503 var rightType = getExpressionType(right);
504 if (rightType instanceof FixedType fixedRightType) {
505 return expectType(left, leftType, fixedRightType) ? fixedRightType : ExprType.INVALID;
506 } else {
507 error(OPERAND_TYPE_ERROR_MESSAGE, left, null, 0, ProblemValidator.TYPE_ERROR);
508 error(OPERAND_TYPE_ERROR_MESSAGE, right, null, 0, ProblemValidator.TYPE_ERROR);
509 return ExprType.INVALID;
510 }
511 }
512 }
513
514 private boolean coerceIntoLiteral(Expr expr) {
515 if (expr == null) {
516 return false;
517 }
518 var actualType = getExpressionType(expr);
519 if (actualType == ExprType.LITERAL) {
520 return true;
521 }
522 return expectType(expr, actualType, BuiltinTermInterpreter.BOOLEAN_TYPE);
523 }
524
525 private boolean expectType(Expr expr, FixedType expectedType) {
526 if (expr == null) {
527 return false;
528 }
529 var actualType = getExpressionType(expr);
530 return expectType(expr, actualType, expectedType);
531 }
532
533 private boolean expectType(Expr expr, ExprType actualType, FixedType expectedType) {
534 if (expectedType == ExprType.INVALID) {
535 // Silence any further errors is the expected type failed to compute.
536 return false;
537 }
538 if (actualType.equals(expectedType)) {
539 return true;
540 }
541 if (actualType == ExprType.INVALID) {
542 // We have already emitted an error previously.
543 return false;
544 }
545 if (actualType instanceof MutableType mutableType && expectedType instanceof DataExprType dataExprType) {
546 mutableType.setActualType(dataExprType);
547 return true;
548 }
549 var builder = new StringBuilder()
550 .append("Expected ")
551 .append(expectedType)
552 .append(" expression");
553 if (!(actualType instanceof MutableType)) {
554 builder.append(", got ")
555 .append(actualType)
556 .append(" instead");
557 }
558 builder.append(".");
559 error(builder.toString(), expr, null, 0, ProblemValidator.TYPE_ERROR);
560 return false;
561 }
562
563 private void error(String message, EObject object, EStructuralFeature feature, int index, String code,
564 String... issueData) {
565 diagnostics.add(new FeatureBasedDiagnostic(Diagnostic.ERROR, message, object, feature, index,
566 CheckType.NORMAL, code, issueData));
567 }
568}