Reduce likelihood of multiple coordinators on concurrent startup

Closes #41290

Signed-off-by: Alexander Schwartz <aschwart@redhat.com>
Signed-off-by: Alexander Schwartz <alexander.schwartz@gmx.net>
Co-authored-by: Pedro Ruivo <pruivo@users.noreply.github.com>
This commit is contained in:
Alexander Schwartz 2025-08-04 13:41:46 +02:00 committed by GitHub
parent 4e699e10da
commit 1b5e05c8f5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 447 additions and 3 deletions

View File

@ -106,6 +106,11 @@
<artifactId>microprofile-metrics-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
@ -142,6 +147,14 @@
</annotationProcessorPaths>
</configuration>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<systemPropertyVariables>
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
</systemPropertyVariables>
</configuration>
</plugin>
</plugins>
</build>
</project>

View File

@ -17,13 +17,25 @@
package org.keycloak.jgroups.protocol;
import org.jgroups.Address;
import org.jgroups.Event;
import org.jgroups.PhysicalAddress;
import org.jgroups.View;
import org.jgroups.protocols.JDBC_PING2;
import org.jgroups.protocols.PingData;
import org.jgroups.stack.Protocol;
import org.jgroups.util.ExtendedUUID;
import org.jgroups.util.NameCache;
import org.jgroups.util.Responses;
import org.jgroups.util.UUID;
import org.keycloak.connections.jpa.JpaConnectionProviderFactory;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Enhanced JDBC_PING2 to handle entries transactionally.
@ -34,6 +46,108 @@ public class KEYCLOAK_JDBC_PING2 extends JDBC_PING2 {
private JpaConnectionProviderFactory factory;
@Override
protected void handleView(View new_view, View old_view, boolean coord_changed) {
// If we are the coordinator, it is good to learn about new entries that have been added before we delete them.
// If we are not the coordinator, it is good to learn the new entries added by the coordinator.
// This avoids a "JGRP000032: %s: no physical address for %s, dropping message" that leads to split clusters at concurrent startup.
learnExistingAddresses();
// This is an updated logic where we do not call removeAll but instead remove those obsolete entries.
// This avoids the short moment where the table is empty and a new node might not see any other node.
if (is_coord) {
if (remove_old_coords_on_view_change) {
Address old_coord = old_view != null ? old_view.getCreator() : null;
if (old_coord != null)
remove(cluster_name, old_coord);
}
Address[] left = View.diff(old_view, new_view)[1];
if (coord_changed || update_store_on_view_change || left.length > 0) {
writeAll(left);
if (remove_all_data_on_view_change) {
removeAllNotInCurrentView();
}
if (remove_all_data_on_view_change || remove_old_coords_on_view_change) {
startInfoWriter();
}
}
} else if (coord_changed && !remove_all_data_on_view_change) {
// I'm no longer the coordinator, usually due to a merge.
// The new coordinator will update my status to non-coordinator, and remove me fully
// if 'remove_all_data_on_view_change' is enabled and I'm no longer part of the view.
// Maybe this branch even be removed completely, but for JDBC_PING 'remove_all_data_on_view_change' is always set to true.
PhysicalAddress physical_addr = (PhysicalAddress) down(new Event(Event.GET_PHYSICAL_ADDRESS, local_addr));
PingData coord_data = new PingData(local_addr, true, NameCache.get(local_addr), physical_addr).coord(is_coord);
write(Collections.singletonList(coord_data), cluster_name);
}
}
@Override
protected void removeAll(String clustername) {
// This is unsafe as even if we would fill the table a moment later, a new node might see an empty table and become a coordinator
throw new RuntimeException("Not implemented as it is unsafe");
}
private void removeAllNotInCurrentView() {
try {
List<PingData> list = readFromDB(getClusterName());
for (PingData data : list) {
Address addr = data.getAddress();
if (view != null && !view.containsMember(addr)) {
addDiscoveryResponseToCaches(addr, data.getLogicalName(), data.getPhysicalAddr());
remove(cluster_name, addr);
}
}
} catch (Exception e) {
log.error(String.format("%s: failed reading from the DB", local_addr), e);
}
}
protected void learnExistingAddresses() {
try {
List<PingData> list = readFromDB(getClusterName());
for (PingData data : list) {
Address addr = data.getAddress();
if (local_addr != null && !local_addr.equals(addr)) {
addDiscoveryResponseToCaches(addr, data.getLogicalName(), data.getPhysicalAddr());
}
}
} catch (Exception e) {
log.error(String.format("%s: failed reading from the DB", local_addr), e);
}
}
@Override
public synchronized boolean isInfoWriterRunning() {
// Do not rely on the InfoWriter, instead always write the missing information on find if it is missing. Find is also triggered by MERGE.
return false;
}
@Override
public void findMembers(List<Address> members, boolean initial_discovery, Responses responses) {
if (initial_discovery) {
try {
List<PingData> pingData = readFromDB(cluster_name);
PhysicalAddress physical_addr = (PhysicalAddress) down(new Event(Event.GET_PHYSICAL_ADDRESS, local_addr));
PingData coord_data = new PingData(local_addr, true, NameCache.get(local_addr), physical_addr).coord(is_coord);
write(Collections.singletonList(coord_data), cluster_name);
while (pingData.stream().noneMatch(PingData::isCoord)) {
// Do a quick check if more nodes have arrived, to have a more complete list of nodes to start with.
List<PingData> newPingData = readFromDB(cluster_name);
if (newPingData.stream().map(PingData::getAddress).collect(Collectors.toSet()).equals(pingData.stream().map(PingData::getAddress).collect(Collectors.toSet()))
|| pingData.stream().anyMatch(PingData::isCoord)) {
break;
}
pingData = newPingData;
}
} catch (Exception e) {
log.error(String.format("%s: failed reading from the DB", local_addr), e);
}
}
super.findMembers(members, initial_discovery, responses);
}
@Override
protected void writeToDB(PingData data, String clustername) throws SQLException {
lock.lock();
@ -70,6 +184,51 @@ public class KEYCLOAK_JDBC_PING2 extends JDBC_PING2 {
}
/* START: JDBC_PING2 does not handle ExtendedUUID yet, see
https://github.com/belaban/JGroups/pull/901 - until this is backported, we convert all of them.
*/
@Override
public <T extends Protocol> T addr(Address addr) {
addr = toUUID(addr);
return super.addr(addr);
}
@Override
public <T extends Protocol> T setAddress(Address addr) {
addr = toUUID(addr);
return super.setAddress(addr);
}
@Override
protected void delete(Connection conn, String clustername, Address addressToDelete) throws SQLException {
super.delete(conn, clustername, toUUID(addressToDelete));
}
@Override
protected void delete(String clustername, Address addressToDelete) throws SQLException {
super.delete(clustername, toUUID(addressToDelete));
}
@Override
protected void insert(Connection connection, PingData data, String clustername) throws SQLException {
if (data.getAddress() instanceof ExtendedUUID) {
data = new PingData(toUUID(data.getAddress()), data.isServer(), data.getLogicalName(), data.getPhysicalAddr()).coord(data.isCoord());
}
super.insert(connection, data, clustername);
}
private static Address toUUID(Address addr) {
if (addr instanceof ExtendedUUID eUUID) {
addr = new UUID(eUUID.getMostSignificantBits(), eUUID.getLeastSignificantBits());
}
return addr;
}
/* END: JDBC_PING2 does not handle ExtendedUUID yet, see
https://github.com/belaban/JGroups/pull/901 - until this is backported, we convert all of them.
*/
@Override
protected void loadDriver() {
//no-op, using JpaConnectionProviderFactory

View File

@ -17,6 +17,9 @@
package org.keycloak.spi.infinispan.impl.embedded;
import static org.infinispan.configuration.global.TransportConfiguration.CLUSTER_NAME;
import static org.infinispan.configuration.global.TransportConfiguration.STACK;
import java.lang.invoke.MethodHandles;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
@ -25,12 +28,15 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import jakarta.persistence.Query;
import org.infinispan.commons.configuration.attributes.Attribute;
import org.infinispan.configuration.global.TransportConfigurationBuilder;
import org.infinispan.configuration.parsing.ConfigurationBuilderHolder;
import org.infinispan.remoting.transport.jgroups.EmbeddedJGroupsChannelConfigurator;
import org.infinispan.remoting.transport.jgroups.JGroupsTransport;
import org.jboss.logging.Logger;
import org.jgroups.Address;
import org.jgroups.JChannel;
import org.jgroups.conf.ClassConfigurator;
import org.jgroups.conf.ProtocolConfiguration;
import org.jgroups.protocols.TCP;
@ -38,8 +44,11 @@ import org.jgroups.protocols.TCP_NIO2;
import org.jgroups.protocols.UDP;
import org.jgroups.stack.Protocol;
import org.jgroups.util.DefaultSocketFactory;
import org.jgroups.util.ExtendedUUID;
import org.jgroups.util.SocketFactory;
import org.jgroups.util.UUID;
import org.keycloak.Config;
import org.keycloak.common.util.Retry;
import org.keycloak.connections.infinispan.InfinispanConnectionProvider;
import org.keycloak.connections.infinispan.InfinispanConnectionSpi;
import org.keycloak.connections.jpa.JpaConnectionProvider;
@ -48,6 +57,7 @@ import org.keycloak.connections.jpa.util.JpaUtils;
import org.keycloak.infinispan.util.InfinispanUtils;
import org.keycloak.jgroups.protocol.KEYCLOAK_JDBC_PING2;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.spi.infinispan.JGroupsCertificateProvider;
import javax.net.ssl.KeyManager;
@ -56,7 +66,7 @@ import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.TrustManager;
import static org.infinispan.configuration.global.TransportConfiguration.STACK;
import org.keycloak.storage.configuration.ServerConfigStorageProvider;
/**
* Utility class to configure JGroups based on the Keycloak configuration.
@ -66,6 +76,7 @@ public final class JGroupsConfigurator {
private static final Logger logger = Logger.getLogger(MethodHandles.lookup().lookupClass());
private static final String TLS_PROTOCOL_VERSION = "TLSv1.3";
private static final String TLS_PROTOCOL = "TLS";
public static final String JGROUPS_ADDRESS_SEQUENCE = "JGROUPS_ADDRESS_SEQUENCE";
private JGroupsConfigurator() {
}
@ -176,12 +187,49 @@ public final class JGroupsConfigurator {
var tableName = JpaUtils.getTableNameForNativeQuery("JGROUPS_PING", em);
var stack = getProtocolConfigurations(tableName, isUdp);
var connectionFactory = (JpaConnectionProviderFactory) session.getKeycloakSessionFactory().getProviderFactory(JpaConnectionProvider.class);
holder.addJGroupsStack(new JpaFactoryAwareJGroupsChannelConfigurator(stackName, stack, connectionFactory, isUdp), null);
String clusterName = transportOf(holder).attributes().attribute(CLUSTER_NAME).get();
Address address = Retry.call(ignored -> KeycloakModelUtils.runJobInTransactionWithResult(session.getKeycloakSessionFactory(),
s -> prepareJGroupsAddress(s, clusterName)),
50, 10);
holder.addJGroupsStack(new JpaFactoryAwareJGroupsChannelConfigurator(stackName, stack, connectionFactory, isUdp, address), null);
transportOf(holder).stack(stackName);
JGroupsConfigurator.logger.info("JGroups JDBC_PING discovery enabled.");
}
/**
* Generate the next sequence of the address, and place it into the JGROUPS_PING table so other nodes can see it.
* If we are the first = smallest entry, the other nodes will wait for us to become a coordinator
* for max_join_attempts x all_clients_retry_timeout = 10 x 100 ms = 1 second. Otherwise, we will wait for that
* one second. This prevents a split-brain scenario on a concurrent startup.
*/
private static Address prepareJGroupsAddress(KeycloakSession session, String clusterName) {
var storage = session.getProvider(ServerConfigStorageProvider.class);
String seq = storage.loadOrCreate(JGROUPS_ADDRESS_SEQUENCE, () -> "0");
long value = Long.parseLong(seq) + 1;
String newSeq = Long.toString(value);
storage.replace(JGROUPS_ADDRESS_SEQUENCE, seq, newSeq);
var cp = session.getProvider(JpaConnectionProvider.class);
var tableName = JpaUtils.getTableNameForNativeQuery("JGROUPS_PING", cp.getEntityManager());
String statement = String.format("INSERT INTO %s values (?, ?, ?, ?, ?)", tableName);
ExtendedUUID address = new ExtendedUUID(0, value);
Query s = cp.getEntityManager().createNativeQuery(statement);
s.setParameter(1, org.jgroups.util.Util.addressToString(new UUID(address.getMostSignificantBits(), address.getLeastSignificantBits()))); // address
s.setParameter(2, "(starting)"); // name
s.setParameter(3, clusterName); // cluster name
s.setParameter(4, "127.0.0.1:0"); // ip = new IpAddress("localhost", 0).toString()
s.setParameter(5, false); // coord
s.executeUpdate();
return address;
}
private static List<ProtocolConfiguration> getProtocolConfigurations(String tableName, boolean udp) {
var list = new ArrayList<ProtocolConfiguration>(udp ? 1 : 2);
list.add(new ProtocolConfiguration(KEYCLOAK_JDBC_PING2.class.getName(),
@ -258,10 +306,18 @@ public final class JGroupsConfigurator {
private static class JpaFactoryAwareJGroupsChannelConfigurator extends EmbeddedJGroupsChannelConfigurator {
private final JpaConnectionProviderFactory factory;
private final Address address;
public JpaFactoryAwareJGroupsChannelConfigurator(String name, List<ProtocolConfiguration> stack, JpaConnectionProviderFactory factory, boolean isUdp) {
public JpaFactoryAwareJGroupsChannelConfigurator(String name, List<ProtocolConfiguration> stack, JpaConnectionProviderFactory factory, boolean isUdp, Address address) {
super(name, stack, null, isUdp ? "udp" : "tcp");
this.factory = Objects.requireNonNull(factory);
this.address = address;
}
@Override
protected JChannel amendChannel(JChannel channel) {
channel.addAddressGenerator(() -> address);
return super.amendChannel(channel);
}
@Override

View File

@ -0,0 +1,99 @@
package org.keycloak.jgroups.protocol;
import org.jboss.logging.Logger;
import org.jgroups.JChannel;
import org.jgroups.conf.ClassConfigurator;
import org.jgroups.util.ThreadFactory;
import org.jgroups.util.Util;
import org.junit.Ignore;
import org.junit.Test;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Misc tests for {@link KEYCLOAK_JDBC_PING2}, running against H2
* @author Bela Ban
* @author Alexander Schwartz
*/
public class JdbcPing2Test {
protected static Logger log = Logger.getLogger(JdbcPing2Test.class);
protected static final String CLUSTER="jdbc-test";
protected static final int NUM_NODES=8;
public static final String PROTOCOL_STACK = "jdbc-h2.xml";
static {
ClassConfigurator.addProtocol((short) 1026, KEYCLOAK_JDBC_PING2_FOR_TESTING.class);
}
/**
* 100 iterations would run approx 8 minutes and should complete successfully,
* with an average of 3.3 seconds in converging.
*/
@Test
@Ignore
public void testConcurrentStartupMultipleTimes() throws Exception {
int count = 100;
long sum = 0;
for (int j = 0; j < 100; j++) {
sum += runSingleTest();
}
log.info("Average time to form the cluster: " + Duration.ofNanos(sum / count));
}
private static long runSingleTest() throws Exception {
JChannel[] channels = new JChannel[NUM_NODES];
List<Thread> threads = new ArrayList<>();
try {
for (int i = 0; i < channels.length; i++) {
channels[i] = createChannel(PROTOCOL_STACK, String.valueOf(i + 1));
}
CountDownLatch latch = new CountDownLatch(1);
int index = 1;
for (JChannel ch : channels) {
ThreadFactory thread_factory = ch.stack().getTransport().getThreadFactory();
Connector connector = new Connector(latch, ch);
Thread thread = thread_factory.newThread(connector, "connector-" + index++);
threads.add(thread);
thread.start();
}
latch.countDown();
long start = System.nanoTime();
Util.waitUntilAllChannelsHaveSameView(40000, 100, channels);
long time = System.nanoTime() - start;
log.infof("-- cluster of %d formed in %s:\n%s\n", NUM_NODES, Duration.ofNanos(time),
Stream.of(channels).map(ch -> String.format("%s: %s", ch.address(), ch.view()))
.collect(Collectors.joining("\n")));
return time;
} finally {
for (Thread thread : threads) {
thread.join();
}
Arrays.stream(channels).filter(ch -> ch.view().getCoord() != ch.getAddress()).forEach(JChannel::close);
Arrays.stream(channels).filter(ch -> !ch.isClosed()).forEach(JChannel::close);
log.infof("Closed");
}
}
protected static JChannel createChannel(String cfg, String name) throws Exception {
return new JChannel(cfg).name(name);
}
protected record Connector(CountDownLatch latch, JChannel ch) implements Runnable {
@Override
public void run() {
try {
latch.await();
ch.connect(CLUSTER);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}

View File

@ -0,0 +1,37 @@
/*
* Copyright 2025 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.jgroups.protocol;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
/**
* This overwrites the get connection method again to avoid the JPA style connection handling
*/
public class KEYCLOAK_JDBC_PING2_FOR_TESTING extends KEYCLOAK_JDBC_PING2 {
@Override
protected Connection getConnection() {
try {
return dataSource != null? dataSource.getConnection() :
DriverManager.getConnection(connection_url, connection_username, connection_password);
} catch (SQLException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -0,0 +1,48 @@
<!--
JDBC_PING2 for Postgresql
-->
<config xmlns="urn:org:jgroups"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="urn:org:jgroups http://www.jgroups.org/schema/jgroups.xsd">
<TCP
bind_addr="localhost"
bind_port="0"
recv_buf_size="150000"
send_buf_size="640000"
sock_conn_timeout="300ms"
/>
<org.keycloak.jgroups.protocol.KEYCLOAK_JDBC_PING2_FOR_TESTING
connection_driver="org.h2.Driver"
connection_url="jdbc:h2:mem:test;DB_CLOSE_DELAY=-1"
connection_username=""
connection_password=""
remove_all_data_on_view_change="true"
register_shutdown_hook="true"
return_entire_cache="false"
write_data_on_find="true"
/>
<!-- very aggressive merging to speed up the test -->
<MERGE3 min_interval="2000"
max_interval="4000"/>
<FD_ALL3 timeout="40s" interval="5s" />
<VERIFY_SUSPECT2 />
<pbcast.NAKACK2
use_mcast_xmit="false"
xmit_interval="500ms"/>
<UNICAST3
xmit_interval="500ms"/>
<pbcast.STABLE
desired_avg_gossip="5s"
max_bytes="1000000"/>
<pbcast.GMS
print_local_addr="false"
join_timeout="3s"
max_join_attempts="5"/>
<UFC max_credits="2M"
min_threshold="0.40"/>
<MFC max_credits="2M"
min_threshold="0.4"/>
<FRAG3 frag_size="60000" />
<pbcast.STATE_TRANSFER/>
</config>

View File

@ -0,0 +1,32 @@
#
# Copyright 2023 Red Hat, Inc. and/or its affiliates
# and other contributors as indicated by the @author tags.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# https://github.com/jboss-logging/jboss-logmanager
logger.level=INFO
logger.handlers=CONSOLE
handler.CONSOLE=org.jboss.logmanager.handlers.ConsoleHandler
handler.CONSOLE.properties=autoFlush
handler.CONSOLE.level=DEBUG
handler.CONSOLE.autoFlush=true
handler.CONSOLE.formatter=PATTERN
# The log format pattern for both logs
formatter.PATTERN=org.jboss.logmanager.formatters.PatternFormatter
formatter.PATTERN.properties=pattern
formatter.PATTERN.pattern=%d{HH:mm:ss,SSS} %-5p %t [%c] %m%n