aboutsummaryrefslogtreecommitdiffstats
path: root/Solvers/VIATRA-Solver/hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra/src/hu/bme/mit/inf/dslreasoner/viatrasolver/logic2viatra/cardinality/ExtendedLinearExpressionBuilderFactory.xtend
blob: 6054affe85b8c314c4653a034a9abcf06f233357 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package hu.bme.mit.inf.dslreasoner.viatrasolver.logic2viatra.cardinality

import com.google.common.collect.ImmutableList
import com.google.common.collect.ImmutableMap
import hu.bme.mit.inf.dslreasoner.logic.model.logiclanguage.Type
import java.util.ArrayList
import java.util.HashMap
import java.util.HashSet
import java.util.List
import java.util.Map
import java.util.Set
import org.eclipse.viatra.query.runtime.api.IPatternMatch
import org.eclipse.xtend.lib.annotations.Accessors
import org.eclipse.xtend.lib.annotations.FinalFieldsConstructor

interface BoundSaturationListener {
	def void boundsSaturated(Integer lower, Integer upper)
}

interface ExtendedLinearExpressionBuilderFactory {
	def ExtendedLinearExpressionBuilder createBuilder()

	def Dimension getDimension(IPatternMatch patternMatch)
}

interface ExtendedLinearExpressionBuilder extends LinearTypeExpressionBuilder {
	override ExtendedLinearExpressionBuilder add(int scale, Type type)

	def ExtendedLinearExpressionBuilder add(int scale, IPatternMatch patternMatch)

	def ExtendedLinearExpressionBuilder add(int scale, Dimension dimension)

	def LinearBoundedExpression build(BoundSaturationListener listener)
}

class ExtendedPolyhedronBuilder implements ExtendedLinearExpressionBuilderFactory {
	val Map<Type, LinearBoundedExpression> typeBounds
	val Map<Map<Dimension, Integer>, LinearBoundedExpression> expressionsCache

	val ImmutableList.Builder<Dimension> dimensions = ImmutableList.builder
	val Set<LinearConstraint> constraints = new HashSet
	val Set<LinearBoundedExpression> expressionsToSaturate = new HashSet
	val Map<IPatternMatch, Dimension> patternMatchCounts = new HashMap
	@Accessors(PUBLIC_GETTER) val List<Pair<LinearBoundedExpression, BoundSaturationListener>> saturationListeners = new ArrayList

	new(Polyhedron polyhedron, Map<Type, LinearBoundedExpression> typeBounds,
		Map<Map<Dimension, Integer>, LinearBoundedExpression> initialExpressionsCache) {
		this.typeBounds = typeBounds
		this.expressionsCache = new HashMap(initialExpressionsCache)
		dimensions.addAll(polyhedron.dimensions)
		constraints.addAll(polyhedron.constraints)
		expressionsToSaturate.addAll(polyhedron.expressionsToSaturate)
	}

	override createBuilder() {
		new ExtendedLinearExpressionBuilderImpl(this)
	}

	override getDimension(IPatternMatch patternMatch) {
		patternMatchCounts.computeIfAbsent(patternMatch) [ key |
			val dimension = new Dimension(key.toString, 0, null)
			dimensions.add(dimension)
			dimension
		]
	}

	def buildPolyhedron() {
		new Polyhedron(
			dimensions.build,
			ImmutableList.copyOf(constraints),
			ImmutableList.copyOf(expressionsToSaturate)
		)
	}

	@FinalFieldsConstructor
	private static class ExtendedLinearExpressionBuilderImpl implements ExtendedLinearExpressionBuilder {
		val ExtendedPolyhedronBuilder polyhedronBuilder

		val Map<Dimension, Integer> coefficients = new HashMap

		override add(int scale, Type type) {
			val expression = polyhedronBuilder.typeBounds.get(type)
			if (expression === null) {
				throw new IllegalArgumentException("Unknown Type: " + type)
			}
			add(scale, expression)
		}

		override add(int scale, IPatternMatch patternMatch) {
			val dimension = polyhedronBuilder.getDimension(patternMatch)
			add(scale, dimension)
		}

		private def add(int scale, LinearBoundedExpression expression) {
			switch (expression) {
				Dimension: add(scale, expression)
				LinearConstraint: add(scale, expression.coefficients)
				default: throw new IllegalArgumentException("Unknown LinearBoundedExpression: " + expression)
			}
		}

		private def add(int scale, Map<Dimension, Integer> coefficients) {
			for (pair : coefficients.entrySet) {
				add(scale * pair.value, pair.key)
			}
			this
		}

		override add(int scale, Dimension dimension) {
			coefficients.merge(dimension, scale)[a, b|a + b]
			this
		}

		override build() {
			val filteredCoefficients = ImmutableMap.copyOf(coefficients.filter [ _, coefficient |
				coefficient != 0
			])
			polyhedronBuilder.expressionsCache.computeIfAbsent(filteredCoefficients) [ map |
				if (map.size == 1) {
					val pair = map.entrySet.head
					if (pair.value == 1) {
						return pair.key
					}
				}
				val constraint = new LinearConstraint(map)
				polyhedronBuilder.constraints.add(constraint)
				constraint
			]
		}

		override build(BoundSaturationListener listener) {
			val expression = build()
			if (listener !== null) {
				polyhedronBuilder.expressionsToSaturate.add(expression)
				polyhedronBuilder.saturationListeners.add(expression -> listener)
			}
			expression
		}
	}
}