aboutsummaryrefslogtreecommitdiffstats
path: root/subprojects/generator-cli/src/main/java
diff options
context:
space:
mode:
authorLibravatar Kristóf Marussy <kristof@marussy.com>2024-05-31 19:22:46 +0200
committerLibravatar Kristóf Marussy <kristof@marussy.com>2024-06-01 20:17:47 +0200
commit8a62e425b77c43a541d58d4615a13cd182d12a81 (patch)
tree2f84b7d5e3f3dbc4dbda7da06ee93732418024ce /subprojects/generator-cli/src/main/java
parentfix(reasoning): candidate rounding mode (diff)
downloadrefinery-8a62e425b77c43a541d58d4615a13cd182d12a81.tar.gz
refinery-8a62e425b77c43a541d58d4615a13cd182d12a81.tar.zst
refinery-8a62e425b77c43a541d58d4615a13cd182d12a81.zip
feat: generate multiple solutions
Switch to partial interpretation based neighborhood calculation when multiple models are request to avoid returning isomorphic models.
Diffstat (limited to 'subprojects/generator-cli/src/main/java')
-rw-r--r--subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java52
1 files changed, 45 insertions, 7 deletions
diff --git a/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java b/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java
index b1c28df0..0fb97316 100644
--- a/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java
+++ b/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java
@@ -1,5 +1,5 @@
1/* 1/*
2 * SPDX-FileCopyrightText: 2023 The Refinery Authors <https://refinery.tools/> 2 * SPDX-FileCopyrightText: 2023-2024 The Refinery Authors <https://refinery.tools/>
3 * 3 *
4 * SPDX-License-Identifier: EPL-2.0 4 * SPDX-License-Identifier: EPL-2.0
5 */ 5 */
@@ -17,9 +17,12 @@ import java.io.IOException;
17import java.util.ArrayList; 17import java.util.ArrayList;
18import java.util.List; 18import java.util.List;
19import java.util.Map; 19import java.util.Map;
20import java.util.regex.Pattern;
20 21
21@Parameters(commandDescription = "Generate a model from a partial model") 22@Parameters(commandDescription = "Generate a model from a partial model")
22public class GenerateCommand { 23public class GenerateCommand {
24 private static final Pattern EXTENSION_REGEX = Pattern.compile("(.+)\\.([^./\\\\]+)");
25
23 @Inject 26 @Inject
24 private ProblemLoader loader; 27 private ProblemLoader loader;
25 28
@@ -31,6 +34,7 @@ public class GenerateCommand {
31 private List<String> scopes = new ArrayList<>(); 34 private List<String> scopes = new ArrayList<>();
32 private List<String> overrideScopes = new ArrayList<>(); 35 private List<String> overrideScopes = new ArrayList<>();
33 private long randomSeed = 1; 36 private long randomSeed = 1;
37 private int count = 1;
34 38
35 @Parameter(description = "input path", required = true) 39 @Parameter(description = "input path", required = true)
36 public void setInputPath(String inputPath) { 40 public void setInputPath(String inputPath) {
@@ -57,21 +61,47 @@ public class GenerateCommand {
57 this.randomSeed = randomSeed; 61 this.randomSeed = randomSeed;
58 } 62 }
59 63
64 @Parameter(names = {"-solution-number", "-n"}, description = "Maximum number of solutions")
65 public void setCount(int count) {
66 if (count <= 0) {
67 throw new IllegalArgumentException("Count must be positive");
68 }
69 this.count = count;
70 }
71
60 public void run() throws IOException { 72 public void run() throws IOException {
73 if (count > 1 && isStandardStream(outputPath)) {
74 throw new IllegalArgumentException("Must provide output path if count is larger than 1");
75 }
61 loader.extraPath(System.getProperty("user.dir")); 76 loader.extraPath(System.getProperty("user.dir"));
62 var problem = isStandardStream(inputPath) ? loader.loadStream(System.in) : loader.loadFile(inputPath); 77 var problem = isStandardStream(inputPath) ? loader.loadStream(System.in) : loader.loadFile(inputPath);
63 problem = loader.loadScopeConstraints(problem, scopes, overrideScopes); 78 problem = loader.loadScopeConstraints(problem, scopes, overrideScopes);
79 generatorFactory.partialInterpretationBasedNeighbourhoods(count >= 2);
64 var generator = generatorFactory.createGenerator(problem); 80 var generator = generatorFactory.createGenerator(problem);
65 generator.setRandomSeed(randomSeed); 81 generator.setRandomSeed(randomSeed);
82 generator.setMaxNumberOfSolutions(count);
66 generator.generate(); 83 generator.generate();
67 var solution = generator.serializeSolution();
68 var solutionResource = solution.eResource();
69 var saveOptions = Map.of(); 84 var saveOptions = Map.of();
70 if (isStandardStream(outputPath)) { 85 if (count == 1) {
71 printSolution(solutionResource, saveOptions); 86 var solution = generator.serializeSolution();
87 var solutionResource = solution.eResource();
88 if (isStandardStream(outputPath)) {
89 printSolution(solutionResource, saveOptions);
90 } else {
91 try (var outputStream = new FileOutputStream(outputPath)) {
92 solutionResource.save(outputStream, saveOptions);
93 }
94 }
72 } else { 95 } else {
73 try (var outputStream = new FileOutputStream(outputPath)) { 96 int solutionCount = generator.getSolutionCount();
74 solutionResource.save(outputStream, saveOptions); 97 for (int i = 0; i < solutionCount; i++) {
98 generator.loadSolution(i);
99 var solution = generator.serializeSolution();
100 var solutionResource = solution.eResource();
101 var pathWithIndex = getFileNameWithIndex(outputPath, i + 1);
102 try (var outputStream = new FileOutputStream(pathWithIndex)) {
103 solutionResource.save(outputStream, saveOptions);
104 }
75 } 105 }
76 } 106 }
77 } 107 }
@@ -85,4 +115,12 @@ public class GenerateCommand {
85 private void printSolution(Resource solutionResource, Map<?, ?> saveOptions) throws IOException { 115 private void printSolution(Resource solutionResource, Map<?, ?> saveOptions) throws IOException {
86 solutionResource.save(System.out, saveOptions); 116 solutionResource.save(System.out, saveOptions);
87 } 117 }
118
119 private String getFileNameWithIndex(String simpleName, int index) {
120 var match = EXTENSION_REGEX.matcher(simpleName);
121 if (match.matches()) {
122 return "%s_%03d.%s".formatted(match.group(1), index, match.group(2));
123 }
124 return "%s_%03d".formatted(simpleName, index);
125 }
88} 126}