aboutsummaryrefslogtreecommitdiffstats
path: root/Framework/hu.bme.mit.inf.dslreasoner.viatra2logic/src/hu/bme/mit/inf/dslreasoner/viatra2logic/Viatra2LogicTypeInferer.xtend
blob: b8a6b9c146e8479fac8800c95d9c39ece830f189 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
package hu.bme.mit.inf.dslreasoner.viatra2logic

import hu.bme.mit.inf.dslreasoner.ecore2logic.Ecore2Logic
import hu.bme.mit.inf.dslreasoner.ecore2logic.Ecore2Logic_Trace
import hu.bme.mit.inf.dslreasoner.logic.model.builder.LogicProblemBuilder
import hu.bme.mit.inf.dslreasoner.logic.model.builder.TracedOutput
import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.ComplexTypeReference
import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.IntTypeReference
import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.PrimitiveTypeReference
import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.RealTypeReference
import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.Type
import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.TypeReference
import hu.bme.mit.inf.dslreasoner.logic.model.logicproblem.LogicProblem
import java.util.HashMap
import java.util.HashSet
import java.util.List
import java.util.Set
import org.eclipse.emf.ecore.EEnum
import org.eclipse.emf.ecore.EcorePackage
import org.eclipse.viatra.query.patternlanguage.emf.EMFPatternLanguageStandaloneSetup
import org.eclipse.viatra.query.patternlanguage.emf.specification.XBaseEvaluator
import org.eclipse.viatra.query.runtime.emf.EMFQueryMetaContext
import org.eclipse.viatra.query.runtime.emf.types.EClassTransitiveInstancesKey
import org.eclipse.viatra.query.runtime.emf.types.EClassUnscopedTransitiveInstancesKey
import org.eclipse.viatra.query.runtime.emf.types.EDataTypeInSlotsKey
import org.eclipse.viatra.query.runtime.matchers.context.IInputKey
import org.eclipse.viatra.query.runtime.matchers.context.common.JavaTransitiveInstancesKey
import org.eclipse.viatra.query.runtime.matchers.planning.helpers.TypeHelper
import org.eclipse.viatra.query.runtime.matchers.psystem.IExpressionEvaluator
import org.eclipse.viatra.query.runtime.matchers.psystem.PBody
import org.eclipse.viatra.query.runtime.matchers.psystem.PVariable
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.AggregatorConstraint
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.ExpressionEvaluation
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.PatternCallBasedDeferred
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.PatternMatchCounter
import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PParameter
import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PQuery
import org.eclipse.xtext.xbase.XExpression
import org.eclipse.xtext.xbase.typesystem.IBatchTypeResolver
import org.eclipse.xtext.xbase.typesystem.references.UnknownTypeReference

import static extension hu.bme.mit.inf.dslreasoner.util.CollectionsUtil.*
import org.eclipse.xtext.xbase.typesystem.references.InnerTypeReference

class Viatra2LogicTypeInferer{
	val Ecore2Logic ecore2Logic
	val extension LogicProblemBuilder builder = new LogicProblemBuilder
	/**Typeresolver uses the same resolver as EMFPatternLanguageStandaloneSetup.*/
	val IBatchTypeResolver typeResolver = 
		(new EMFPatternLanguageStandaloneSetup).createInjector.getInstance(IBatchTypeResolver)
	val expressionExtractor = new XExpressionExtractor
	
	new(Ecore2Logic ecore2Logic) {
		this.ecore2Logic = ecore2Logic
	}
	
	def Viatra2LogicTypeResult inferTypes(List<PQuery> pQueries, TracedOutput<LogicProblem, Ecore2Logic_Trace> ecore2LogicTrace, Viatra2LogicTrace viatra2LogicTrace) {
		val Viatra2LogicTypeResult result = new Viatra2LogicTypeResult(new HashMap,new HashMap);
		for(query : pQueries) {
			for(body: query.lookup(viatra2LogicTrace.query2Disjunction).bodies) {
				for(variable : body.uniqueVariables) {
					getOrMakeTypeDecision(result,variable,body,ecore2LogicTrace,viatra2LogicTrace,emptySet)
				}
			}
			for(parameter: query.parameters) {
				getOrMakeTypeDecision(result,query,parameter,ecore2LogicTrace,viatra2LogicTrace,emptySet)
			}
		}
		return result
	}
	
	private def TypeReference getOrMakeTypeDecision(
		Viatra2LogicTypeResult result,
		PVariable variable, PBody body,
		TracedOutput<LogicProblem,  Ecore2Logic_Trace> ecore2LogicTrace,
		Viatra2LogicTrace viatra2LogicTrace,
		Set<? extends PParameter> checkedInDecisionMaking)
	{
		if(result.containsSolution(body,variable)) {
			return result.getType(body,variable)
		} else {
			val inferredTypesByViatra = TypeHelper::inferUnaryTypesFor(body.uniqueVariables, body.constraints, EMFQueryMetaContext.DEFAULT)
			val constraintsForVariable = variable.lookup(inferredTypesByViatra)
			
			val typeConstraintsDerivedByTypeHelper = constraintsForVariable.map[transformTypeReference(ecore2LogicTrace)]
			val typesFromEval = variable.getTypesFromEval(typeResolver)
			val typesFromAggregatorResult = variable.getTypeFromPassivePatternCallConstraintResult(
				result,
				ecore2LogicTrace,
				viatra2LogicTrace,
				checkedInDecisionMaking)
				 
			val typesFromPositiveReasoning = (typeConstraintsDerivedByTypeHelper + typesFromEval + typesFromAggregatorResult).filterNull
			
			val types = if(!typesFromPositiveReasoning.empty) {
				typesFromPositiveReasoning
			} else {
				variable.getTypeFromPassivePatternCallConstraints(
					result,
					ecore2LogicTrace,
					viatra2LogicTrace,
					checkedInDecisionMaking)
			}
			
			val commonSubtype = this.calculateCommonSubtype(types)
			
			result.addType(body,variable,commonSubtype)
			return commonSubtype
		}
	}

	private def TypeReference getOrMakeTypeDecision(
		Viatra2LogicTypeResult result,
		PQuery query,
		PParameter parameter,
		TracedOutput<LogicProblem,  Ecore2Logic_Trace> ecore2LogicTrace,
		Viatra2LogicTrace viatra2LogicTrace,
		Set<? extends PParameter> checkedInDecisionMaking)
	{
		if(checkedInDecisionMaking.contains(parameter)) {
			return null
		}
		if(result.containsSolution(parameter)) {
			return result.getType(query, parameter)
		}
		
		var TypeReference typeReference;
		
		if(parameter.declaredUnaryType !== null) {
			 val key = parameter.declaredUnaryType
			 typeReference = key.transformTypeReference(ecore2LogicTrace)
		} else {
			val bodies = query.lookup(viatra2LogicTrace.query2Disjunction).bodies
			val newChecked = new HashSet(checkedInDecisionMaking) => [add(parameter)]
			val Iterable<TypeReference> variableTypes = bodies.map[body|
				val symbolicParameter = body.symbolicParameters.filter[patternParameter === parameter].head
				val variable = symbolicParameter.parameterVariable
				getOrMakeTypeDecision(result,variable,body,ecore2LogicTrace,viatra2LogicTrace,newChecked)
			]
			typeReference = calculateCommonSupertype(variableTypes)
		}
		result.addType(query,parameter,typeReference)
		return typeReference
	}
	
	private def Iterable<? extends TypeReference> getTypesFromEval(PVariable v, IBatchTypeResolver typeResolver) {
		val constraints = v.getReferringConstraintsOfType(
			typeof(ExpressionEvaluation)
		).filter[
			it.outputVariable === v
		]
		val res = constraints.map[getTypeFromEval]
		return res
	}
	
	def TypeReference getTypeFromEval(ExpressionEvaluation evaluation) {
		val XExpression expression = expressionExtractor.extractExpression(evaluation.evaluator)
		val returnType = typeResolver.resolveTypes(expression).getReturnType(expression);
		if(returnType === null || returnType instanceof UnknownTypeReference) {
			return null
		} else {
			val javaIdentifier = returnType.wrapperTypeIfPrimitive.javaIdentifier
			if(javaIdentifier == Boolean.name) {
				return LogicBool
			} else if(javaIdentifier == Integer.name || javaIdentifier == Short.name) {
				return LogicInt
			} else if(javaIdentifier == Double.name || javaIdentifier == Float.name){
				return LogicReal
			} else if(javaIdentifier == String.name) {
				return LogicString
			} else {
				throw new UnsupportedOperationException('''Unsupported eval type: "«javaIdentifier»"!''')
			}
		}
	}
	
	private def getTypeFromPassivePatternCallConstraintResult(
		PVariable v,
		Viatra2LogicTypeResult result,
		TracedOutput<LogicProblem,  Ecore2Logic_Trace> ecore2LogicTrace,
		Viatra2LogicTrace viatra2LogicTrace,
		Set<? extends PParameter> checkedInDecisionMaking
	) {
		val referringConstraints = v.referringConstraints
		
		val referringCountMatcherTargeting = referringConstraints
			.filter(PatternMatchCounter)
			.filter[it.resultVariable === v]
			.map[builder.LogicInt]
		val referringAggregatorConstraintsTargeting = referringConstraints
			.filter(AggregatorConstraint)
			.filter[it.resultVariable === v]
			.map[ // get the type of the referred column
				getOrMakeTypeDecision(
					result,
					it.referredQuery,
					it.referredQuery.parameters.get(aggregatedColumn),
					ecore2LogicTrace,
					viatra2LogicTrace,
					checkedInDecisionMaking)]
					
		return referringCountMatcherTargeting + referringAggregatorConstraintsTargeting
	}
	
	private def getTypeFromPassivePatternCallConstraints(
		PVariable v,
		Viatra2LogicTypeResult result,
		TracedOutput<LogicProblem,  Ecore2Logic_Trace> ecore2LogicTrace,
		Viatra2LogicTrace viatra2LogicTrace,
		Set<? extends PParameter> checkedInDecisionMaking
	) {
		val referringConstraints = v.referringConstraints
		if(referringConstraints.size === 1) {
			val onlyConstraint = referringConstraints.head
			
			if(onlyConstraint instanceof PatternCallBasedDeferred) {
				val indexOfVariable = v.lookup(onlyConstraint.actualParametersTuple.invertIndex)
				val parameter = onlyConstraint.referredQuery.parameters.get(indexOfVariable)
				val res = getOrMakeTypeDecision(result, onlyConstraint.referredQuery, parameter, ecore2LogicTrace,viatra2LogicTrace,checkedInDecisionMaking)
				return #[res]
			} else {
				throw new IllegalArgumentException('''A non-PatternCallBasedDeferred type constraint is referring to the variable "«v.name»"!''')
			}
		} else {
			throw new IllegalArgumentException('''Multiple («referringConstraints.size», «FOR c:referringConstraints SEPARATOR ", "»«c»«ENDFOR») constraints are referring to variable "«v.name»", but no type is inferred!''')
		}
	}
	
	def TypeReference calculateCommonSubtype(Iterable<TypeReference> types) {
		val primitiveTypeReferences = types.filter(PrimitiveTypeReference)
		val complexTypeReferences = types.filter(ComplexTypeReference)
		if(complexTypeReferences.isEmpty) {
			// If there is an int type, ...
			if(primitiveTypeReferences.exists[it instanceof IntTypeReference]) {
				// ... and all types are either real or int, then return int!
				if(primitiveTypeReferences.forall[it instanceof RealTypeReference || it instanceof IntTypeReference]) {
					return primitiveTypeReferences.filter(IntTypeReference).head
				}
				// Otherwise, the types are inconsistent, because they mixing numeric and non-numeric types.
				else throw new IllegalArgumentException('''Inconsistent types: «primitiveTypeReferences.map[eClass.name].toSet.toList»''')
			}
			// If there is no Real, then the types should be homogenious
			val head = primitiveTypeReferences.head
			if(primitiveTypeReferences.exists[it.eClass !== head.eClass]) {
				throw new IllegalArgumentException('''Inconsistent types: «primitiveTypeReferences.map[eClass.name].toSet.toList»''')
			}
			return head
		} else if(primitiveTypeReferences.isEmpty) {
			val complexTypes = complexTypeReferences.map[it.referred].toSet
			if(complexTypes.size === 1) {
				return builder.toTypeReference(complexTypes.head)
			}
			// Collect possible subtypes
			val subtypeSets = complexTypes.map[it.transitiveClosureStar[it.subtypes].toSet]
			val commonTypeSet = new HashSet(subtypeSets.head)
			val otherSets = subtypeSets.tail
			for(otherSet : otherSets) {
				commonTypeSet.retainAll(otherSet)
			}
			if(commonTypeSet.empty) {
				throw new IllegalArgumentException('''Inconsistent types: «complexTypes.map[name].toList»''')
			}
			
			
			return calculateCommonComplexSupertype(commonTypeSet)
			
		} else {
			throw new IllegalArgumentException('''
			Inconsistent types, mixing primitive and complex types:
				«primitiveTypeReferences.map[eClass.name].toSet.toList»
					and
				«complexTypeReferences.map[it.referred].toSet.map[name].toList»''')
			
		}
	}
	
	
	def TypeReference calculateCommonSupertype(Iterable<TypeReference> types) {
		val primitiveTypeReferences = types.filter(PrimitiveTypeReference)
		val complexTypeReferences = types.filter(ComplexTypeReference)
		if(complexTypeReferences.isEmpty) {
			// If there is a real type, ...
			if(primitiveTypeReferences.exists[it instanceof RealTypeReference]) {
				// ... and all types are either real or int, then return real!
				if(primitiveTypeReferences.forall[it instanceof RealTypeReference || it instanceof IntTypeReference]) {
					return primitiveTypeReferences.filter(RealTypeReference).head
				}
				// Otherwise, the types are inconsistent, because they mixing numeric and non-numeric types.
				else throw new IllegalArgumentException('''Inconsistent types: «primitiveTypeReferences.map[eClass.name].toSet.toList»''')
			}
			// If there is no Real, then the types should be homogenious
			val head = primitiveTypeReferences.head
			if(primitiveTypeReferences.exists[it.eClass !== head.eClass]) {
				throw new IllegalArgumentException('''Inconsistent types: «primitiveTypeReferences.map[eClass.name].toSet.toList»''')
			}
			return head
		} else if(primitiveTypeReferences.isEmpty) {
			val complexTypes = complexTypeReferences.map[it.referred].toSet
			return calculateCommonComplexSupertype(complexTypes)
			
		} else {
			throw new IllegalArgumentException('''
			Inconsistent types, mixing primitive and complex types:
				«primitiveTypeReferences.map[eClass.name].toSet.toList»
					and
				«complexTypeReferences.map[it.referred].toSet.map[name].toList»''')
			
		}
	}
	def TypeReference calculateCommonComplexSupertype(Set<Type> complexTypes) {
		if(complexTypes.size === 1) {
			return builder.toTypeReference(complexTypes.head)
		}
		// Collect possible supertypes
		val supertypeSets = complexTypes.map[it.transitiveClosureStar[it.supertypes].toSet]
		val commonTypeSet = new HashSet(supertypeSets.head)
		val otherSets = supertypeSets.tail
		for(otherSet : otherSets) {
			commonTypeSet.retainAll(otherSet)
		}
		if(commonTypeSet.empty) {
			throw new IllegalArgumentException('''Inconsistent types: «complexTypes.map[name].toList»''')
		}
		// Remove type that already have covered
		val coveredTypes = commonTypeSet.map[it.supertypes].flatten
		commonTypeSet.removeAll(coveredTypes)
		return builder.toTypeReference(commonTypeSet.head)
	}
	
	/**
	 * Transforms a Viatra type reference to a logic type.
	 */
	def dispatch TypeReference transformTypeReference(EDataTypeInSlotsKey k,TracedOutput<LogicProblem, Ecore2Logic_Trace> ecore2LogicTrace) {
		val w = k.wrappedKey 
		if(w == EcorePackage.Literals.EINT || w == EcorePackage.Literals.ESHORT || w == EcorePackage.Literals.ELONG) {
			return builder.LogicInt
		} else if(w == EcorePackage.Literals.EDOUBLE || w == EcorePackage.Literals.EFLOAT) {
			return builder.LogicReal
		} else if(w == EcorePackage.Literals.EBOOLEAN) {
			return builder.LogicBool
		} else if(w == EcorePackage.Literals.ESTRING) {
			return builder.LogicString
		} else if(w instanceof EEnum) {
			val c = this.ecore2Logic.TypeofEEnum(ecore2LogicTrace.trace,w)
			return builder.toTypeReference(c);
		} else throw new UnsupportedOperationException('''Unknown reference type «w.class.name»''')
	}
	def dispatch TypeReference transformTypeReference(JavaTransitiveInstancesKey k,TracedOutput<LogicProblem, Ecore2Logic_Trace> ecore2LogicTrace) {
		val c = k.wrapperInstanceClass
		if(c == Integer || c == Long || c == Short) {
			return LogicInt
		} else if(c == Float || c == Double) {
			return LogicReal
		} else if(c == Boolean) {
			return LogicBool
		} else if(c == String) {
			return LogicString
		} else if(c.superclass == Enum){
			val enums = ecore2Logic.allEnumsInScope(ecore2LogicTrace.trace)
			for(enum : enums) {
				if(c == enum.instanceClass) {
					return builder.toTypeReference(ecore2Logic.TypeofEEnum(ecore2LogicTrace.trace,enum))
				}
			}
			throw new IllegalArgumentException('''Enum type «c.simpleName» is not mapped to logic!''')
		} else {
			return null
		}
	}
	def dispatch TypeReference transformTypeReference(EClassTransitiveInstancesKey k,TracedOutput<LogicProblem, Ecore2Logic_Trace> ecore2LogicTrace) {
		val c = k.wrappedKey
				
		if(this.ecore2Logic.allClassesInScope(ecore2LogicTrace.trace).toList.contains(c)) {
			return builder.toTypeReference(this.ecore2Logic.TypeofEClass(ecore2LogicTrace.trace,k.wrappedKey))
		} else {
			return null
		}
	}
	def dispatch TypeReference transformTypeReference(EClassUnscopedTransitiveInstancesKey k, TracedOutput<LogicProblem, Ecore2Logic_Trace> ecore2LogicTrace) {
		val c = k.wrappedKey
		
		if(this.ecore2Logic.allClassesInScope(ecore2LogicTrace.trace).toList.contains(c)) {
			return builder.toTypeReference(this.ecore2Logic.TypeofEClass(ecore2LogicTrace.trace,k.wrappedKey))
		} else {
			return null
		}
	}
	
	def dispatch TypeReference transformTypeReference(IInputKey k,TracedOutput<LogicProblem, Ecore2Logic_Trace> ecore2LogicTrace) {
		throw new IllegalArgumentException('''Unsupported type: «k.class.simpleName»''')
	}
}