aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/store-query-viatra/src/main/java/tools/refinery/store/query/viatra/internal/rete/RefineryReteEngine.java
blob: c088219b2059166c7c7aebc57738a51e323c48ca (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/*
 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/>
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package tools.refinery.store.query.viatra.internal.rete;

import org.apache.log4j.Logger;
import org.eclipse.viatra.query.runtime.matchers.backend.IQueryBackendFactory;
import org.eclipse.viatra.query.runtime.matchers.context.IQueryBackendContext;
import org.eclipse.viatra.query.runtime.rete.matcher.ReteEngine;
import org.eclipse.viatra.query.runtime.rete.matcher.TimelyConfiguration;
import org.eclipse.viatra.query.runtime.rete.network.Network;
import org.eclipse.viatra.query.runtime.rete.network.NodeProvisioner;
import org.eclipse.viatra.query.runtime.rete.network.ReteContainer;

import java.io.IOException;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;

public class RefineryReteEngine extends ReteEngine {
	private static final MethodHandle REFINERY_NODE_FACTORY_CONSTRUCTOR;
	private static final MethodHandle REFINERY_CONNECTION_FACTORY_CONSTRUCTOR;
	private static final MethodHandle NETWORK_NODE_FACTORY_SETTER;
	private static final MethodHandle RETE_CONTAINER_CONNECTION_FACTORY_SETTER;
	private static final MethodHandle NODE_PROVISIONER_NODE_FACTORY_SETTER;
	private static final MethodHandle NODE_PROVISIONER_CONNECTION_FACTORY_SETTER;

	static {
		MethodHandles.Lookup lookup;
		try {
			lookup = MethodHandles.privateLookupIn(Network.class, MethodHandles.lookup());
		} catch (IllegalAccessException e) {
			throw new IllegalStateException("Cannot create private lookup", e);
		}
		var refineryNodeFactoryClass = defineClassFromFile(lookup, "RefineryNodeFactory");
		var refinaryConnectionFactoryClass = defineClassFromFile(lookup, "RefineryConnectionFactory");
		try {
			REFINERY_NODE_FACTORY_CONSTRUCTOR = lookup.findConstructor(refineryNodeFactoryClass,
					MethodType.methodType(Void.TYPE, Logger.class));
			REFINERY_CONNECTION_FACTORY_CONSTRUCTOR = lookup.findConstructor(refinaryConnectionFactoryClass,
					MethodType.methodType(Void.TYPE, ReteContainer.class));
		} catch (NoSuchMethodException | IllegalAccessException e) {
			throw new IllegalStateException("Cannot get constructor", e);
		}
		var nodeFactoryClass = refineryNodeFactoryClass.getSuperclass();
		var connectionFactoryClass = refinaryConnectionFactoryClass.getSuperclass();
		try {
			NETWORK_NODE_FACTORY_SETTER = lookup.findSetter(Network.class, "nodeFactory", nodeFactoryClass);
			RETE_CONTAINER_CONNECTION_FACTORY_SETTER = lookup.findSetter(ReteContainer.class, "connectionFactory",
					connectionFactoryClass);
			NODE_PROVISIONER_NODE_FACTORY_SETTER = lookup.findSetter(NodeProvisioner.class, "nodeFactory",
					nodeFactoryClass);
			NODE_PROVISIONER_CONNECTION_FACTORY_SETTER = lookup.findSetter(NodeProvisioner.class, "connectionFactory",
					connectionFactoryClass);
		} catch (NoSuchFieldException | IllegalAccessException e) {
			throw new IllegalStateException("Cannot get field setter", e);
		}
	}

	private static Class<?> defineClassFromFile(MethodHandles.Lookup lookup, String name) {
		byte[] classBytes;
		try (var resource = Network.class.getResourceAsStream(name + ".class")) {
			if (resource == null) {
				throw new IllegalStateException("Cannot find %s class file".formatted(name));
			}
			classBytes = resource.readAllBytes();
		} catch (IOException e) {
			throw new IllegalStateException("Cannot read %s class file".formatted(name), e);
		}
		Class<?> clazz;
		try {
			clazz = lookup.defineClass(classBytes);
		} catch (IllegalAccessException e) {
			throw new IllegalStateException("Cannot define %s class".formatted(name), e);
		}
		return clazz;
	}

	public RefineryReteEngine(IQueryBackendContext context, int reteThreads, boolean deleteAndReDeriveEvaluation,
							  TimelyConfiguration timelyConfiguration) {
		super(context, reteThreads, deleteAndReDeriveEvaluation, timelyConfiguration);
		installFactories();
	}

	private void installFactories() {
		var logger = getLogger();
		Object nodeFactory;
		try {
			nodeFactory = REFINERY_NODE_FACTORY_CONSTRUCTOR.invoke(logger);
		} catch (Error e) {
			// Fatal JVM errors should not be wrapped.
			throw e;
		} catch (Throwable e) {
			throw new IllegalStateException("Cannot construct node factory", e);
		}
		try {
			NETWORK_NODE_FACTORY_SETTER.invoke(reteNet, nodeFactory);
		} catch (Error e) {
			// Fatal JVM errors should not be wrapped.
			throw e;
		} catch (Throwable e) {
			throw new IllegalStateException("Cannot set factory", e);
		}
		for (var container : reteNet.getContainers()) {
			Object connectionFactory;
			try {
				connectionFactory = REFINERY_CONNECTION_FACTORY_CONSTRUCTOR.invoke(container);
			} catch (Error e) {
				// Fatal JVM errors should not be wrapped.
				throw e;
			} catch (Throwable e) {
				throw new IllegalStateException("Cannot construct connection factory", e);
			}
			var provisioner = container.getProvisioner();
			try {
				RETE_CONTAINER_CONNECTION_FACTORY_SETTER.invoke(container, connectionFactory);
				NODE_PROVISIONER_NODE_FACTORY_SETTER.invoke(provisioner, nodeFactory);
				NODE_PROVISIONER_CONNECTION_FACTORY_SETTER.invoke(provisioner, connectionFactory);
			} catch (Error e) {
				// Fatal JVM errors should not be wrapped.
				throw e;
			} catch (Throwable e) {
				throw new IllegalStateException("Cannot set factory", e);
			}
		}
	}

	@Override
	public IQueryBackendFactory getFactory() {
		return RefineryReteBackendFactory.INSTANCE;
	}
}