diff --git a/commons/src/main/java/org/apache/causeway/commons/graph/GraphUtils.java b/commons/src/main/java/org/apache/causeway/commons/graph/GraphUtils.java index b701c2ef29a..9a8ccf22adc 100644 --- a/commons/src/main/java/org/apache/causeway/commons/graph/GraphUtils.java +++ b/commons/src/main/java/org/apache/causeway/commons/graph/GraphUtils.java @@ -26,10 +26,10 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.TreeMap; import java.util.function.BiPredicate; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -40,6 +40,7 @@ import org.apache.causeway.commons.functional.IndexedConsumer; import org.apache.causeway.commons.graph.GraphUtils.GraphKernel.GraphCharacteristic; import org.apache.causeway.commons.internal.assertions._Assert; +import org.apache.causeway.commons.internal.base._Casts; import org.apache.causeway.commons.internal.collections._PrimitiveCollections.IntList; import org.apache.causeway.commons.internal.primitives._Longs; @@ -361,7 +362,43 @@ public void visitNeighborsIndexed(final int nodeIndex, * Returns an isomorphic graph with this graph's nodes replaced by given mapping function. */ public Graph map(final Function nodeMapper) { - return new Graph(kernel, nodes.map(nodeMapper), edgeAttributeByPackedEdgeIndex()); + var graph = new Graph(kernel, nodes.map(nodeMapper), edgeAttributeByPackedEdgeIndex()); + _Assert.assertEquals(kernel.nodeCount(), graph.nodes().size()); + return graph; + } + + /** + * Returns a sub-graph with any nodes removed from this graph, that do not pass the filter. + */ + public Graph filter(final Predicate filter) { + if(nodes.isEmpty()) return this; + + var nodeType = _Casts.>uncheckedCast(nodes.getFirst().get().getClass()); + var builder = new GraphBuilder(nodeType, kernel().characteristics); + var isUndirected = kernel().isUndirected(); + + nodes.forEach(IndexedConsumer.zeroBased((nodeIndex, node)->{ + if(filter.test(node)) { + builder.addNode(node); + Graph.this.visitNeighborsIndexed(nodeIndex, (neighborIndex, neighbor)->{ + if(isUndirected + && neighborIndex { private final Class nodeType; private final ImmutableEnumSet characteristics; private final boolean isUndirected; + private final Map indexByNode; private final List nodeList; private final IntList fromNode = new IntList(4); // best guess initial edge capacity private final IntList toNode = new IntList(4); // best guess initial edge capacity @@ -484,8 +522,9 @@ public static GraphBuilder undirected(final Class nodeType) { } /** - * Adds a new node to the graph. - * @apiNote nodes are not required to be unique with respect to {@link Objects#equals}. + * Adds a new node to the graph, respecting node equality, that is, + * no duplicates are added. + * @apiNote duplicates with respect to {@link Objects#equals} are not added */ public GraphBuilder addNode(final @NonNull T node) { addNodeHonoringIndexMap(node); @@ -559,9 +598,8 @@ private GraphBuilder(final Class nodeType, final ImmutableEnumSet(); - //XXX map implementation is not required to be ordered, could use a HashMap here as well. - // This is purely a performance question! - this.edgeAttributeByPackedEdgeIndex = new TreeMap<>(); + this.indexByNode = new HashMap<>(); + this.edgeAttributeByPackedEdgeIndex = new HashMap<>(); } public Graph build() { @@ -604,7 +642,12 @@ private int indexOfWithAdd(final T node) { } private int addNodeHonoringIndexMap(final T node) { + final Integer nodeIndex = indexByNode.get(node); + // skip adding if the node is a duplicate + if(nodeIndex!=null) return nodeIndex; + final int nextIndex = nodeList.size(); + indexByNode.put(node, nextIndex); nodeList.add(node); if(nodeIndexByNode!=null) { nodeIndexByNode.put(node, nextIndex); diff --git a/commons/src/test/java/org/apache/causeway/commons/graph/GraphUtilsTest.java b/commons/src/test/java/org/apache/causeway/commons/graph/GraphUtilsTest.java index 4806bba54f3..68ab62f0724 100644 --- a/commons/src/test/java/org/apache/causeway/commons/graph/GraphUtilsTest.java +++ b/commons/src/test/java/org/apache/causeway/commons/graph/GraphUtilsTest.java @@ -24,6 +24,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertLinesMatch; import org.springframework.util.StringUtils; @@ -89,6 +90,38 @@ void builderUndirected() { TextUtils.readLines(textForm).filter(StringUtils::hasLength)); } + @Test + void nodeEqualityDirected() { + var graph = GraphUtils.GraphBuilder.directed(Customer.class) + .addNode(new Customer("A")) + .addNode(new Customer("B")) + .addNode(new Customer("A")) + .addEdge(0, 1) + .build(); + + //debug + //System.err.println(graph.toString(Customer::getName)); + + assertEquals(2, + graph.nodes().size()); + } + + @Test + void nodeEqualityUndirected() { + var graph = GraphUtils.GraphBuilder.undirected(Customer.class) + .addNode(new Customer("A")) + .addNode(new Customer("B")) + .addNode(new Customer("A")) + .addEdge(0, 1) + .build(); + + //debug + //System.err.println(graph.toString(Customer::getName)); + + assertEquals(2, + graph.nodes().size()); + } + @Test void builderWithEdgeAttributes() { var gBuilder = GraphUtils.GraphBuilder.undirected(Customer.class); @@ -172,4 +205,54 @@ void kernelSubgraph() { assertEquals(2, subgraph3.edgeCount()); } + @Test + void filterDirectedGraph() { + var graph = GraphUtils.GraphBuilder.directed(Customer.class) + .addNode(new Customer("A")) + .addNode(new Customer("B")) + .addNode(new Customer("C")) + .addNode(new Customer("D")) + .addEdge(0, 1, 0.1) // A -> B (weight=0.1) + .addEdge(1, 2) // B -> C + .addEdge(2, 0, 0.7) // C -> A (weight=0.7) + .build() + .filter(node->!node.getName().equals("C")); // now remove C + + var textForm = graph.toString(Customer::getName); + + //debug + //System.err.println(textForm); + + assertLinesMatch( + Can.of("A -> B (0.1)", "B", "D").toList(), + TextUtils.readLines(textForm).filter(StringUtils::hasLength).toList()); + } + + @Test + void filterUndirectedGraph() { + var a = new Customer("A"); + var b = new Customer("B"); + var c = new Customer("C"); + var d = new Customer("D"); + + var graph = GraphUtils.GraphBuilder.undirected(Customer.class) + .addEdge(a, b, 0.1) // A - B (weight=0.1) + .addEdge(c, a) // A - C + .addEdge(c, b, 0.7) // B - C (weight=0.7) + .addNode(d) + .build() + .filter(node->!node.getName().equals("C")) // now remove C + ; + + var textForm = graph.toString( + NodeFormatter.of(Customer::getName), + edgeAttr->String.format("(weight=%.1f)", (double)edgeAttr)); + //debug + //System.err.println(textForm); + + assertLinesMatch( + Can.of("A - B (weight=0.1)", "D").toList(), + TextUtils.readLines(textForm).filter(StringUtils::hasLength).toList()); + } + }