From 8a62e425b77c43a541d58d4615a13cd182d12a81 Mon Sep 17 00:00:00 2001 From: Kristóf Marussy Date: Fri, 31 May 2024 19:22:46 +0200 Subject: feat: generate multiple solutions Switch to partial interpretation based neighborhood calculation when multiple models are request to avoid returning isomorphic models. --- .../generator/cli/commands/GenerateCommand.java | 52 +++++++++++++++++++--- 1 file changed, 45 insertions(+), 7 deletions(-) (limited to 'subprojects/generator-cli') diff --git a/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java b/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java index b1c28df0..0fb97316 100644 --- a/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java +++ b/subprojects/generator-cli/src/main/java/tools/refinery/generator/cli/commands/GenerateCommand.java @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: 2023 The Refinery Authors + * SPDX-FileCopyrightText: 2023-2024 The Refinery Authors * * SPDX-License-Identifier: EPL-2.0 */ @@ -17,9 +17,12 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.regex.Pattern; @Parameters(commandDescription = "Generate a model from a partial model") public class GenerateCommand { + private static final Pattern EXTENSION_REGEX = Pattern.compile("(.+)\\.([^./\\\\]+)"); + @Inject private ProblemLoader loader; @@ -31,6 +34,7 @@ public class GenerateCommand { private List scopes = new ArrayList<>(); private List overrideScopes = new ArrayList<>(); private long randomSeed = 1; + private int count = 1; @Parameter(description = "input path", required = true) public void setInputPath(String inputPath) { @@ -57,21 +61,47 @@ public class GenerateCommand { this.randomSeed = randomSeed; } + @Parameter(names = {"-solution-number", "-n"}, description = "Maximum number of solutions") + public void setCount(int count) { + if (count <= 0) { + throw new IllegalArgumentException("Count must be positive"); + } + this.count = count; + } + public void run() throws IOException { + if (count > 1 && isStandardStream(outputPath)) { + throw new IllegalArgumentException("Must provide output path if count is larger than 1"); + } loader.extraPath(System.getProperty("user.dir")); var problem = isStandardStream(inputPath) ? loader.loadStream(System.in) : loader.loadFile(inputPath); problem = loader.loadScopeConstraints(problem, scopes, overrideScopes); + generatorFactory.partialInterpretationBasedNeighbourhoods(count >= 2); var generator = generatorFactory.createGenerator(problem); generator.setRandomSeed(randomSeed); + generator.setMaxNumberOfSolutions(count); generator.generate(); - var solution = generator.serializeSolution(); - var solutionResource = solution.eResource(); var saveOptions = Map.of(); - if (isStandardStream(outputPath)) { - printSolution(solutionResource, saveOptions); + if (count == 1) { + var solution = generator.serializeSolution(); + var solutionResource = solution.eResource(); + if (isStandardStream(outputPath)) { + printSolution(solutionResource, saveOptions); + } else { + try (var outputStream = new FileOutputStream(outputPath)) { + solutionResource.save(outputStream, saveOptions); + } + } } else { - try (var outputStream = new FileOutputStream(outputPath)) { - solutionResource.save(outputStream, saveOptions); + int solutionCount = generator.getSolutionCount(); + for (int i = 0; i < solutionCount; i++) { + generator.loadSolution(i); + var solution = generator.serializeSolution(); + var solutionResource = solution.eResource(); + var pathWithIndex = getFileNameWithIndex(outputPath, i + 1); + try (var outputStream = new FileOutputStream(pathWithIndex)) { + solutionResource.save(outputStream, saveOptions); + } } } } @@ -85,4 +115,12 @@ public class GenerateCommand { private void printSolution(Resource solutionResource, Map saveOptions) throws IOException { solutionResource.save(System.out, saveOptions); } + + private String getFileNameWithIndex(String simpleName, int index) { + var match = EXTENSION_REGEX.matcher(simpleName); + if (match.matches()) { + return "%s_%03d.%s".formatted(match.group(1), index, match.group(2)); + } + return "%s_%03d".formatted(simpleName, index); + } } -- cgit v1.2.3-70-g09d2