diff options
author | Sorabh Hamirwasia <shamirwasia@maprtech.com> | 2018-03-01 15:08:10 -0800 |
---|---|---|
committer | Parth Chandra <parthc@apache.org> | 2018-03-11 11:33:36 +0530 |
commit | 920a12a6f01b99a677dd26223d9b1228a35fe3f5 (patch) | |
tree | 77bd35e606ca974e600b1ae11c3826fc8dce9642 /exec/java-exec/src/main/java/org/apache/drill/exec/rpc | |
parent | 035010c1781b480ee2ee2d33209103e902eb0ccf (diff) |
DRILL-6187: Exception in RPC communication between DataClient/ControlClient and respective servers when bit-to-bit security is on
This closes #1145
Diffstat (limited to 'exec/java-exec/src/main/java/org/apache/drill/exec/rpc')
8 files changed, 272 insertions, 841 deletions
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/BitRpcUtility.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/BitRpcUtility.java new file mode 100644 index 000000000..c71363dc7 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/BitRpcUtility.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.rpc; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Internal.EnumLite; +import com.google.protobuf.MessageLite; +import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint; +import org.apache.drill.exec.rpc.security.AuthenticatorFactory; +import org.apache.drill.exec.rpc.security.SaslProperties; +import org.apache.hadoop.security.UserGroupInformation; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +/** + * Utility class providing common methods shared between {@link org.apache.drill.exec.rpc.data.DataClient} and + * {@link org.apache.drill.exec.rpc.control.ControlClient} + */ +public final class BitRpcUtility { + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(BitRpcUtility.class); + + /** + * Method to do validation on the handshake message received from server side. Only used by BitClients NOT UserClient. + * Verify if rpc version of handshake message matches the supported RpcVersion and also validates the + * security configuration between client and server + * @param handshakeRpcVersion - rpc version received in handshake message + * @param remoteAuthMechs - authentication mechanisms supported by server + * @param rpcVersion - supported rpc version on client + * @param connection - client connection + * @param config - client connectin config + * @param client - data client or control client + * @return - Immutable list of authentication mechanisms supported by server or null + * @throws RpcException - exception is thrown if rpc version or authentication configuration mismatch is found + */ + public static List<String> validateHandshake(int handshakeRpcVersion, List<String> remoteAuthMechs, int rpcVersion, + ClientConnection connection, BitConnectionConfig config, + BasicClient client) throws RpcException { + + if (handshakeRpcVersion != rpcVersion) { + throw new RpcException(String.format("Invalid rpc version. Expected %d, actual %d.", + handshakeRpcVersion, rpcVersion)); + } + + if (remoteAuthMechs.size() != 0) { // remote requires authentication + client.setAuthComplete(false); + return ImmutableList.copyOf(remoteAuthMechs); + } else { + if (config.getAuthMechanismToUse() != null) { // local requires authentication + throw new RpcException(String.format("Remote Drillbit does not require auth, but auth is enabled in " + + "local Drillbit configuration. [Details: connection: (%s) and LocalAuthMechanism: (%s). Please check " + + "security configuration for bit-to-bit.", connection.getName(), config.getAuthMechanismToUse())); + } + } + return null; + } + + /** + * Creates various instances needed to start the SASL handshake. This is called from + * {@link BasicClient#prepareSaslHandshake(RpcConnectionHandler, List)} only for + * {@link org.apache.drill.exec.rpc.data.DataClient} and {@link org.apache.drill.exec.rpc.control.ControlClient} + * + * @param connectionHandler - Connection handler used by client's to know about success/failure conditions. + * @param serverAuthMechanisms - List of auth mechanisms configured on server side + * @param connection - ClientConnection used for authentication + * @param config - ClientConnection config + * @param endpoint - Remote DrillbitEndpoint + * @param client - Either of DataClient/ControlClient instance + * @param saslRpcType - SASL_MESSAGE RpcType for Data and Control channel + */ + public static <T extends EnumLite, CC extends ClientConnection, HS extends MessageLite, HR extends MessageLite> + void prepareSaslHandshake(final RpcConnectionHandler<CC> connectionHandler, List<String> serverAuthMechanisms, + CC connection, BitConnectionConfig config, DrillbitEndpoint endpoint, + final BasicClient<T, CC, HS, HR> client, T saslRpcType) { + try { + final Map<String, String> saslProperties = SaslProperties.getSaslProperties(connection.isEncryptionEnabled(), + connection.getMaxWrappedSize()); + final UserGroupInformation ugi = UserGroupInformation.getLoginUser(); + final AuthenticatorFactory factory = config.getAuthFactory(serverAuthMechanisms); + client.startSaslHandshake(connectionHandler, config.getSaslClientProperties(endpoint, saslProperties), + ugi, factory, saslRpcType); + } catch (final IOException e) { + logger.error("Failed while doing setup for starting sasl handshake for connection", connection.getName()); + final Exception ex = new RpcException(String.format("Failed to initiate authentication to %s", + endpoint.getAddress()), e); + connectionHandler.connectionFailed(RpcConnectionHandler.FailureType.AUTHENTICATION, ex); + } + } + + // Suppress default constructor + private BitRpcUtility() { + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlClient.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlClient.java index 1e0313a51..1df5ff1b5 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlClient.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlClient.java @@ -17,36 +17,26 @@ */ package org.apache.drill.exec.rpc.control; -import com.google.common.util.concurrent.SettableFuture; +import com.google.common.collect.ImmutableList; import com.google.protobuf.MessageLite; - import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelFuture; import io.netty.channel.socket.SocketChannel; import io.netty.util.concurrent.GenericFutureListener; - import org.apache.drill.exec.memory.BufferAllocator; import org.apache.drill.exec.proto.BitControl.BitControlHandshake; import org.apache.drill.exec.proto.BitControl.RpcType; import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint; import org.apache.drill.exec.rpc.BasicClient; -import org.apache.drill.exec.rpc.security.AuthenticationOutcomeListener; +import org.apache.drill.exec.rpc.BitRpcUtility; +import org.apache.drill.exec.rpc.FailingRequestHandler; import org.apache.drill.exec.rpc.OutOfMemoryHandler; import org.apache.drill.exec.rpc.ProtobufLengthDecoder; import org.apache.drill.exec.rpc.ResponseSender; -import org.apache.drill.exec.rpc.RpcCommand; +import org.apache.drill.exec.rpc.RpcConnectionHandler; import org.apache.drill.exec.rpc.RpcException; -import org.apache.drill.exec.rpc.RpcOutcomeListener; -import org.apache.drill.exec.rpc.FailingRequestHandler; -import org.apache.drill.exec.rpc.security.SaslProperties; - -import org.apache.hadoop.security.UserGroupInformation; -import javax.security.sasl.SaslClient; -import javax.security.sasl.SaslException; -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ExecutionException; +import java.util.List; public class ControlClient extends BasicClient<RpcType, ControlConnection, BitControlHandshake, BitControlHandshake> { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ControlClient.class); @@ -104,119 +94,21 @@ public class ControlClient extends BasicClient<RpcType, ControlConnection, BitCo } @Override - protected void validateHandshake(BitControlHandshake handshake) throws RpcException { - if (handshake.getRpcVersion() != ControlRpcConfig.RPC_VERSION) { - throw new RpcException(String.format("Invalid rpc version. Expected %d, actual %d.", - handshake.getRpcVersion(), ControlRpcConfig.RPC_VERSION)); - } - - if (handshake.getAuthenticationMechanismsCount() != 0) { // remote requires authentication - final SaslClient saslClient; - try { - final Map<String, String> saslProperties = SaslProperties.getSaslProperties(connection.isEncryptionEnabled(), - connection.getMaxWrappedSize()); - - saslClient = config.getAuthFactory(handshake.getAuthenticationMechanismsList()) - .createSaslClient(UserGroupInformation.getLoginUser(), - config.getSaslClientProperties(remoteEndpoint, saslProperties)); - } catch (final IOException e) { - throw new RpcException(String.format("Failed to initiate authenticate to %s", remoteEndpoint.getAddress()), e); - } - if (saslClient == null) { - throw new RpcException("Unexpected failure. Could not initiate SASL exchange."); - } - connection.setSaslClient(saslClient); - } else { - if (config.getAuthMechanismToUse() != null) { // local requires authentication - throw new RpcException(String.format("Drillbit (%s) does not require auth, but auth is enabled.", - remoteEndpoint.getAddress())); - } - } + protected void prepareSaslHandshake(final RpcConnectionHandler<ControlConnection> connectionHandler, + List<String> serverAuthMechanisms) { + BitRpcUtility.prepareSaslHandshake(connectionHandler, serverAuthMechanisms, connection, config, remoteEndpoint, + this, RpcType.SASL_MESSAGE); } @Override - protected void finalizeConnection(BitControlHandshake handshake, ControlConnection connection) { - connection.setEndpoint(handshake.getEndpoint()); + protected List<String> validateHandshake(BitControlHandshake handshake) throws RpcException { + return BitRpcUtility.validateHandshake(handshake.getRpcVersion(), handshake.getAuthenticationMechanismsList(), + ControlRpcConfig.RPC_VERSION, connection, config, this); } @Override - protected <M extends MessageLite> RpcCommand<M, ControlConnection> - getInitialCommand(final RpcCommand<M, ControlConnection> command) { - final RpcCommand<M, ControlConnection> initialCommand = super.getInitialCommand(command); - if (config.getAuthMechanismToUse() == null) { - return initialCommand; - } else { - return new AuthenticationCommand<>(initialCommand); - } - } - - private class AuthenticationCommand<M extends MessageLite> implements RpcCommand<M, ControlConnection> { - - private final RpcCommand<M, ControlConnection> command; - - AuthenticationCommand(RpcCommand<M, ControlConnection> command) { - this.command = command; - } - - @Override - public void connectionAvailable(ControlConnection connection) { - command.connectionFailed(FailureType.AUTHENTICATION, new SaslException("Should not reach here.")); - } - - @Override - public void connectionSucceeded(final ControlConnection connection) { - final UserGroupInformation loginUser; - try { - loginUser = UserGroupInformation.getLoginUser(); - } catch (final IOException e) { - logger.debug("Unexpected failure trying to login.", e); - command.connectionFailed(FailureType.AUTHENTICATION, e); - return; - } - - final SettableFuture<Void> future = SettableFuture.create(); - new AuthenticationOutcomeListener<>(ControlClient.this, connection, RpcType.SASL_MESSAGE, - loginUser, - new RpcOutcomeListener<Void>() { - @Override - public void failed(RpcException ex) { - logger.debug("Authentication failed.", ex); - future.setException(ex); - } - - @Override - public void success(Void value, ByteBuf buffer) { - connection.changeHandlerTo(config.getMessageHandler()); - future.set(null); - } - - @Override - public void interrupted(InterruptedException e) { - logger.debug("Authentication failed.", e); - future.setException(e); - } - }).initiate(config.getAuthMechanismToUse()); - - - try { - logger.trace("Waiting until authentication completes.."); - future.get(); - command.connectionSucceeded(connection); - } catch (InterruptedException e) { - command.connectionFailed(FailureType.AUTHENTICATION, e); - // Preserve evidence that the interruption occurred so that code higher up on the call stack can learn of the - // interruption and respond to it if it wants to. - Thread.currentThread().interrupt(); - } catch (ExecutionException e) { - command.connectionFailed(FailureType.AUTHENTICATION, e); - } - } - - @Override - public void connectionFailed(FailureType type, Throwable t) { - logger.debug("Authentication failed.", t); - command.connectionFailed(FailureType.AUTHENTICATION, t); - } + protected void finalizeConnection(BitControlHandshake handshake, ControlConnection connection) { + connection.setEndpoint(handshake.getEndpoint()); } @Override diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlConnection.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlConnection.java index 70189d78a..c7d4d8ec0 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlConnection.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlConnection.java @@ -78,7 +78,7 @@ public class ControlConnection extends AbstractServerConnection<ControlConnectio @Override public boolean isActive() { - return active; + return active && super.isActive(); } @Override diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/data/DataClient.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/data/DataClient.java index cba323e96..267b483d7 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/data/DataClient.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/data/DataClient.java @@ -17,7 +17,7 @@ */ package org.apache.drill.exec.rpc.data; -import com.google.common.util.concurrent.SettableFuture; +import com.google.common.collect.ImmutableList; import com.google.protobuf.MessageLite; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelFuture; @@ -29,21 +29,14 @@ import org.apache.drill.exec.proto.BitData.BitServerHandshake; import org.apache.drill.exec.proto.BitData.RpcType; import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint; import org.apache.drill.exec.rpc.BasicClient; +import org.apache.drill.exec.rpc.BitRpcUtility; import org.apache.drill.exec.rpc.OutOfMemoryHandler; import org.apache.drill.exec.rpc.ProtobufLengthDecoder; import org.apache.drill.exec.rpc.ResponseSender; -import org.apache.drill.exec.rpc.RpcCommand; +import org.apache.drill.exec.rpc.RpcConnectionHandler; import org.apache.drill.exec.rpc.RpcException; -import org.apache.drill.exec.rpc.RpcOutcomeListener; -import org.apache.drill.exec.rpc.security.AuthenticationOutcomeListener; -import org.apache.drill.exec.rpc.security.SaslProperties; -import org.apache.hadoop.security.UserGroupInformation; -import javax.security.sasl.SaslClient; -import javax.security.sasl.SaslException; -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ExecutionException; +import java.util.List; public class DataClient extends BasicClient<RpcType, DataClientConnection, BitClientHandshake, BitServerHandshake> { private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DataClient.class); @@ -103,114 +96,17 @@ public class DataClient extends BasicClient<RpcType, DataClientConnection, BitCl } @Override - protected void validateHandshake(BitServerHandshake handshake) throws RpcException { - if (handshake.getRpcVersion() != DataRpcConfig.RPC_VERSION) { - throw new RpcException(String.format("Invalid rpc version. Expected %d, actual %d.", - handshake.getRpcVersion(), DataRpcConfig.RPC_VERSION)); - } - - if (handshake.getAuthenticationMechanismsCount() != 0) { // remote requires authentication - final SaslClient saslClient; - try { - - final Map<String, String> saslProperties = SaslProperties.getSaslProperties(connection.isEncryptionEnabled(), - connection.getMaxWrappedSize()); - - saslClient = config.getAuthFactory(handshake.getAuthenticationMechanismsList()) - .createSaslClient(UserGroupInformation.getLoginUser(), - config.getSaslClientProperties(remoteEndpoint, saslProperties)); - } catch (final IOException e) { - throw new RpcException(String.format("Failed to initiate authenticate to %s", remoteEndpoint.getAddress()), e); - } - if (saslClient == null) { - throw new RpcException("Unexpected failure. Could not initiate SASL exchange."); - } - connection.setSaslClient(saslClient); - } else { - if (config.getAuthMechanismToUse() != null) { - throw new RpcException(String.format("Drillbit (%s) does not require auth, but auth is enabled.", - remoteEndpoint.getAddress())); - } - } - } - - protected <M extends MessageLite> RpcCommand<M, DataClientConnection> - getInitialCommand(final RpcCommand<M, DataClientConnection> command) { - final RpcCommand<M, DataClientConnection> initialCommand = super.getInitialCommand(command); - if (config.getAuthMechanismToUse() == null) { - return initialCommand; - } else { - return new AuthenticationCommand<>(initialCommand); - } + protected void prepareSaslHandshake(final RpcConnectionHandler<DataClientConnection> connectionHandler, List<String> serverAuthMechanisms) { + BitRpcUtility.prepareSaslHandshake(connectionHandler, serverAuthMechanisms, connection, config, remoteEndpoint, + this, RpcType.SASL_MESSAGE); } - private class AuthenticationCommand<M extends MessageLite> implements RpcCommand<M, DataClientConnection> { - - private final RpcCommand<M, DataClientConnection> command; - - AuthenticationCommand(RpcCommand<M, DataClientConnection> command) { - this.command = command; - } - @Override - public void connectionAvailable(DataClientConnection connection) { - command.connectionFailed(FailureType.AUTHENTICATION, new SaslException("Should not reach here.")); + protected List<String> validateHandshake(BitServerHandshake handshake) throws RpcException { + return BitRpcUtility.validateHandshake(handshake.getRpcVersion(), handshake.getAuthenticationMechanismsList(), + DataRpcConfig.RPC_VERSION, connection, config, this); } - @Override - public void connectionSucceeded(final DataClientConnection connection) { - final UserGroupInformation loginUser; - try { - loginUser = UserGroupInformation.getLoginUser(); - } catch (final IOException e) { - logger.debug("Unexpected failure trying to login.", e); - command.connectionFailed(FailureType.AUTHENTICATION, e); - return; - } - - final SettableFuture<Void> future = SettableFuture.create(); - new AuthenticationOutcomeListener<>(DataClient.this, connection, RpcType.SASL_MESSAGE, - loginUser, - new RpcOutcomeListener<Void>() { - @Override - public void failed(RpcException ex) { - logger.debug("Authentication failed.", ex); - future.setException(ex); - } - - @Override - public void success(Void value, ByteBuf buffer) { - future.set(null); - } - - @Override - public void interrupted(InterruptedException e) { - logger.debug("Authentication failed.", e); - future.setException(e); - } - }).initiate(config.getAuthMechanismToUse()); - - try { - logger.trace("Waiting until authentication completes.."); - future.get(); - command.connectionSucceeded(connection); - } catch (InterruptedException e) { - command.connectionFailed(FailureType.AUTHENTICATION, e); - // Preserve evidence that the interruption occurred so that code higher up on the call stack can learn of the - // interruption and respond to it if it wants to. - Thread.currentThread().interrupt(); - } catch (ExecutionException e) { - command.connectionFailed(FailureType.AUTHENTICATION, e); - } - } - - @Override - public void connectionFailed(FailureType type, Throwable t) { - logger.debug("Authentication failed.", t); - command.connectionFailed(FailureType.AUTHENTICATION, t); - } - } - @Override public ProtobufLengthDecoder getDecoder(BufferAllocator allocator) { return new DataProtobufLengthDecoder.Client(allocator, OutOfMemoryHandler.DEFAULT_INSTANCE); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/AuthenticationOutcomeListener.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/AuthenticationOutcomeListener.java deleted file mode 100644 index 5c34d012c..000000000 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/AuthenticationOutcomeListener.java +++ /dev/null @@ -1,300 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.drill.exec.rpc.security; - -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Maps; -import com.google.protobuf.ByteString; -import com.google.protobuf.Internal.EnumLite; -import com.google.protobuf.MessageLite; -import io.netty.buffer.ByteBuf; -import org.apache.drill.exec.proto.UserBitShared.SaslMessage; -import org.apache.drill.exec.proto.UserBitShared.SaslStatus; -import org.apache.drill.exec.rpc.BasicClient; -import org.apache.drill.exec.rpc.ClientConnection; -import org.apache.drill.exec.rpc.RpcException; -import org.apache.drill.exec.rpc.RpcOutcomeListener; -import org.apache.hadoop.security.UserGroupInformation; - -import javax.security.sasl.Sasl; -import javax.security.sasl.SaslClient; -import javax.security.sasl.SaslException; -import java.io.IOException; -import java.lang.reflect.UndeclaredThrowableException; -import java.security.PrivilegedExceptionAction; -import java.util.EnumMap; -import java.util.Map; - -import static com.google.common.base.Preconditions.checkNotNull; - -/** - * Handles SASL exchange, on the client-side. - * - * @param <T> handshake rpc type - * @param <C> Client connection type - * @param <HS> Handshake send type - * @param <HR> Handshake receive type - */ -public class AuthenticationOutcomeListener<T extends EnumLite, C extends ClientConnection, - HS extends MessageLite, HR extends MessageLite> - implements RpcOutcomeListener<SaslMessage> { - private static final org.slf4j.Logger logger = - org.slf4j.LoggerFactory.getLogger(AuthenticationOutcomeListener.class); - - private static final ImmutableMap<SaslStatus, SaslChallengeProcessor> - CHALLENGE_PROCESSORS; - static { - final Map<SaslStatus, SaslChallengeProcessor> map = new EnumMap<>(SaslStatus.class); - map.put(SaslStatus.SASL_IN_PROGRESS, new SaslInProgressProcessor()); - map.put(SaslStatus.SASL_SUCCESS, new SaslSuccessProcessor()); - map.put(SaslStatus.SASL_FAILED, new SaslFailedProcessor()); - CHALLENGE_PROCESSORS = Maps.immutableEnumMap(map); - } - - private final BasicClient<T, C, HS, HR> client; - private final C connection; - private final T saslRpcType; - private final UserGroupInformation ugi; - private final RpcOutcomeListener<?> completionListener; - - public AuthenticationOutcomeListener(BasicClient<T, C, HS, HR> client, - C connection, T saslRpcType, UserGroupInformation ugi, - RpcOutcomeListener<?> completionListener) { - this.client = client; - this.connection = connection; - this.saslRpcType = saslRpcType; - this.ugi = ugi; - this.completionListener = completionListener; - } - - public void initiate(final String mechanismName) { - logger.trace("Initiating SASL exchange."); - try { - final ByteString responseData; - final SaslClient saslClient = connection.getSaslClient(); - if (saslClient.hasInitialResponse()) { - responseData = ByteString.copyFrom(evaluateChallenge(ugi, saslClient, new byte[0])); - } else { - responseData = ByteString.EMPTY; - } - client.send(new AuthenticationOutcomeListener<>(client, connection, saslRpcType, ugi, completionListener), - connection, - saslRpcType, - SaslMessage.newBuilder() - .setMechanism(mechanismName) - .setStatus(SaslStatus.SASL_START) - .setData(responseData) - .build(), - SaslMessage.class, - true /* the connection will not be backed up at this point */); - logger.trace("Initiated SASL exchange."); - } catch (final Exception e) { - completionListener.failed(RpcException.mapException(e)); - } - } - - @Override - public void failed(RpcException ex) { - completionListener.failed(RpcException.mapException(ex)); - } - - @Override - public void success(SaslMessage value, ByteBuf buffer) { - logger.trace("Server responded with message of type: {}", value.getStatus()); - final SaslChallengeProcessor processor = CHALLENGE_PROCESSORS.get(value.getStatus()); - if (processor == null) { - completionListener.failed(RpcException.mapException( - new SaslException("Server sent a corrupt message."))); - } else { - // SaslSuccessProcessor.process disposes saslClient so get mechanism here to use later in logging - final String mechanism = connection.getSaslClient().getMechanismName(); - try { - final SaslChallengeContext<C> context = new SaslChallengeContext<>(value, ugi, connection); - final SaslMessage saslResponse = processor.process(context); - - if (saslResponse != null) { - client.send(new AuthenticationOutcomeListener<>(client, connection, saslRpcType, ugi, completionListener), - connection, saslRpcType, saslResponse, SaslMessage.class, - true /* the connection will not be backed up at this point */); - } else { - // success - completionListener.success(null, null); - if (logger.isTraceEnabled()) { - logger.trace("Successfully authenticated to server using {} mechanism and encryption context: {}", - mechanism, connection.getEncryptionCtxtString()); - } - } - } catch (final Exception e) { - logger.error("Authentication with encryption context: {} using mechanism {} failed with {}", - connection.getEncryptionCtxtString(), mechanism, e.getMessage()); - completionListener.failed(RpcException.mapException(e)); - } - } - } - - @Override - public void interrupted(InterruptedException e) { - completionListener.interrupted(e); - } - - private static class SaslChallengeContext<C extends ClientConnection> { - - final SaslMessage challenge; - final UserGroupInformation ugi; - final C connection; - - SaslChallengeContext(SaslMessage challenge, UserGroupInformation ugi, C connection) { - this.challenge = checkNotNull(challenge); - this.ugi = checkNotNull(ugi); - this.connection = checkNotNull(connection); - } - } - - private interface SaslChallengeProcessor { - - /** - * Process challenge from server, and return a response. - * - * Returns null iff SASL exchange is complete and successful. - * - * @param context challenge context - * @return response - * @throws Exception in case of any failure - */ - <CC extends ClientConnection> - SaslMessage process(SaslChallengeContext<CC> context) throws Exception; - - } - - private static class SaslInProgressProcessor implements SaslChallengeProcessor { - - @Override - public <CC extends ClientConnection> SaslMessage process(SaslChallengeContext<CC> context) throws Exception { - final SaslMessage.Builder response = SaslMessage.newBuilder(); - final SaslClient saslClient = context.connection.getSaslClient(); - - final byte[] responseBytes = evaluateChallenge(context.ugi, saslClient, - context.challenge.getData().toByteArray()); - - final boolean isComplete = saslClient.isComplete(); - logger.trace("Evaluated challenge. Completed? {}.", isComplete); - response.setData(responseBytes != null ? ByteString.copyFrom(responseBytes) : ByteString.EMPTY); - // if isComplete, the client will get one more response from server - response.setStatus(isComplete ? SaslStatus.SASL_SUCCESS : SaslStatus.SASL_IN_PROGRESS); - return response.build(); - } - } - - private static class SaslSuccessProcessor implements SaslChallengeProcessor { - - @Override - public <CC extends ClientConnection> SaslMessage process(SaslChallengeContext<CC> context) throws Exception { - final SaslClient saslClient = context.connection.getSaslClient(); - - if (saslClient.isComplete()) { - handleSuccess(context); - return null; - } else { - // server completed before client; so try once, fail otherwise - evaluateChallenge(context.ugi, saslClient, context.challenge.getData().toByteArray()); // discard response - - if (saslClient.isComplete()) { - handleSuccess(context); - return null; - } else { - throw new SaslException("Server allegedly succeeded authentication, but client did not. Suspicious?"); - } - } - } - } - - private static class SaslFailedProcessor implements SaslChallengeProcessor { - - @Override - public <CC extends ClientConnection> SaslMessage process(SaslChallengeContext<CC> context) throws Exception { - throw new SaslException(String.format("Authentication failed. Incorrect credentials? [Details: %s]", - context.connection.getEncryptionCtxtString())); - } - } - - private static byte[] evaluateChallenge(final UserGroupInformation ugi, final SaslClient saslClient, - final byte[] challengeBytes) throws SaslException { - try { - return ugi.doAs(new PrivilegedExceptionAction<byte[]>() { - @Override - public byte[] run() throws Exception { - return saslClient.evaluateChallenge(challengeBytes); - } - }); - } catch (final UndeclaredThrowableException e) { - throw new SaslException( - String.format("Unexpected failure (%s)", saslClient.getMechanismName()), e.getCause()); - } catch (final IOException | InterruptedException e) { - if (e instanceof SaslException) { - throw (SaslException) e; - } else { - throw new SaslException( - String.format("Unexpected failure (%s)", saslClient.getMechanismName()), e); - } - } - } - - - private static <CC extends ClientConnection> void handleSuccess(SaslChallengeContext<CC> context) throws - SaslException { - final CC connection = context.connection; - final SaslClient saslClient = connection.getSaslClient(); - - try { - // Check if connection was marked for being secure then verify for negotiated QOP value for - // correctness. - final String negotiatedQOP = saslClient.getNegotiatedProperty(Sasl.QOP).toString(); - final String expectedQOP = connection.isEncryptionEnabled() - ? SaslProperties.QualityOfProtection.PRIVACY.getSaslQop() - : SaslProperties.QualityOfProtection.AUTHENTICATION.getSaslQop(); - - if (!(negotiatedQOP.equals(expectedQOP))) { - throw new SaslException(String.format("Mismatch in negotiated QOP value: %s and Expected QOP value: %s", - negotiatedQOP, expectedQOP)); - } - - // Update the rawWrapChunkSize with the negotiated buffer size since we cannot call encode with more than - // negotiated size of buffer. - if (connection.isEncryptionEnabled()) { - final int negotiatedRawSendSize = Integer.parseInt( - saslClient.getNegotiatedProperty(Sasl.RAW_SEND_SIZE).toString()); - if (negotiatedRawSendSize <= 0) { - throw new SaslException(String.format("Negotiated rawSendSize: %d is invalid. Please check the configured " + - "value of encryption.sasl.max_wrapped_size. It might be configured to a very small value.", - negotiatedRawSendSize)); - } - connection.setWrapSizeLimit(negotiatedRawSendSize); - } - } catch (Exception e) { - throw new SaslException(String.format("Unexpected failure while retrieving negotiated property values (%s)", - e.getMessage()), e); - } - - if (connection.isEncryptionEnabled()) { - connection.addSecurityHandlers(); - } else { - // Encryption is not required hence we don't need to hold on to saslClient object. - connection.disposeSaslClient(); - } - } -} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/AuthenticatorFactory.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/AuthenticatorFactory.java deleted file mode 100644 index 307ae979c..000000000 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/AuthenticatorFactory.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.drill.exec.rpc.security; - -import org.apache.hadoop.security.UserGroupInformation; - -import javax.security.sasl.SaslClient; -import javax.security.sasl.SaslException; -import javax.security.sasl.SaslServer; -import java.io.IOException; -import java.util.Map; - -/** - * An implementation of this factory will be initialized once at startup, if the authenticator is enabled - * (see {@link #getSimpleName}). For every request for this mechanism (i.e. after establishing a connection), - * {@link #createSaslServer} will be invoked on the server-side and {@link #createSaslClient} will be invoked - * on the client-side. - * - * Note: - * + Custom authenticators must have a default constructor. - * - * Examples: PlainFactory and KerberosFactory. - */ -public interface AuthenticatorFactory extends AutoCloseable { - - /** - * Name of the mechanism, in upper case. - * - * If this mechanism is present in the list of enabled mechanisms, an instance of this factory is loaded. Note - * that the simple name maybe the same as it's SASL name. - * - * @return mechanism name - */ - String getSimpleName(); - - /** - * Create and get the login user based on the given properties. - * - * @param properties config properties - * @return ugi - * @throws IOException - */ - UserGroupInformation createAndLoginUser(Map<String, ?> properties) throws IOException; - - /** - * The caller is responsible for {@link SaslServer#dispose disposing} the returned SaslServer. - * - * @param ugi ugi - * @param properties config properties - * @return sasl server - * @throws SaslException - */ - SaslServer createSaslServer(UserGroupInformation ugi, Map<String, ?> properties) throws SaslException; - - /** - * The caller is responsible for {@link SaslClient#dispose disposing} the returned SaslClient. - * - * @param ugi ugi - * @param properties config properties - * @return sasl client - * @throws SaslException - */ - SaslClient createSaslClient(UserGroupInformation ugi, Map<String, ?> properties) throws SaslException; - -} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/SaslProperties.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/SaslProperties.java deleted file mode 100644 index 9ed85ce6e..000000000 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/SaslProperties.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.drill.exec.rpc.security; - -import javax.security.sasl.Sasl; -import java.util.HashMap; -import java.util.Map; - -public final class SaslProperties { - - /** - * All supported Quality of Protection values which can be negotiated - */ - enum QualityOfProtection { - AUTHENTICATION("auth"), - INTEGRITY("auth-int"), - PRIVACY("auth-conf"); - - public final String saslQop; - - QualityOfProtection(String saslQop) { - this.saslQop = saslQop; - } - - public String getSaslQop() { - return saslQop; - } - } - - /** - * Get's the map of minimum set of SaslProperties required during negotiation process either for encryption - * or authentication - * @param encryptionEnabled - Flag to determine if property needed is for encryption or authentication - * @param wrappedChunkSize - Configured wrappedChunkSize to negotiate for. - * @return Map of SaslProperties which will be used in negotiation. - */ - public static Map<String, String> getSaslProperties(boolean encryptionEnabled, int wrappedChunkSize) { - Map<String, String> saslProps = new HashMap<>(); - - if (encryptionEnabled) { - saslProps.put(Sasl.STRENGTH, "high"); - saslProps.put(Sasl.QOP, QualityOfProtection.PRIVACY.getSaslQop()); - saslProps.put(Sasl.MAX_BUFFER, Integer.toString(wrappedChunkSize)); - saslProps.put(Sasl.POLICY_NOPLAINTEXT, "true"); - } else { - saslProps.put(Sasl.QOP, QualityOfProtection.AUTHENTICATION.getSaslQop()); - } - - return saslProps; - } - - private SaslProperties() { - - } -}
\ No newline at end of file diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/user/UserClient.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/user/UserClient.java index 131febf32..1504ce936 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/user/UserClient.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/user/UserClient.java @@ -17,28 +17,24 @@ */ package org.apache.drill.exec.rpc.user; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.Set; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -import javax.net.ssl.SSLEngine; -import javax.security.sasl.SaslClient; -import javax.security.sasl.SaslException; - +import com.google.common.base.Strings; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.AbstractCheckedFuture; +import com.google.common.util.concurrent.CheckedFuture; +import com.google.common.util.concurrent.SettableFuture; +import com.google.protobuf.MessageLite; +import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; import io.netty.handler.ssl.SslHandler; import org.apache.drill.common.KerberosUtil; import org.apache.drill.common.config.DrillConfig; import org.apache.drill.common.config.DrillProperties; import org.apache.drill.common.exceptions.DrillException; import org.apache.drill.exec.client.InvalidConnectionInfoException; -import org.apache.drill.exec.ssl.SSLConfig; import org.apache.drill.exec.memory.BufferAllocator; import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint; import org.apache.drill.exec.proto.GeneralRPCProtos.Ack; @@ -76,28 +72,26 @@ import org.apache.drill.exec.rpc.RpcConstants; import org.apache.drill.exec.rpc.RpcException; import org.apache.drill.exec.rpc.RpcOutcomeListener; import org.apache.drill.exec.rpc.security.AuthStringUtil; -import org.apache.drill.exec.rpc.security.AuthenticationOutcomeListener; import org.apache.drill.exec.rpc.security.AuthenticatorFactory; import org.apache.drill.exec.rpc.security.ClientAuthenticatorProvider; -import org.apache.drill.exec.rpc.security.plain.PlainFactory; import org.apache.drill.exec.rpc.security.SaslProperties; +import org.apache.drill.exec.rpc.security.plain.PlainFactory; +import org.apache.drill.exec.ssl.SSLConfig; import org.apache.drill.exec.ssl.SSLConfigBuilder; import org.apache.hadoop.security.UserGroupInformation; import org.slf4j.Logger; -import com.google.common.base.Strings; -import com.google.common.base.Throwables; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Sets; -import com.google.common.util.concurrent.AbstractCheckedFuture; -import com.google.common.util.concurrent.CheckedFuture; -import com.google.common.util.concurrent.SettableFuture; -import com.google.protobuf.MessageLite; - - -import io.netty.buffer.ByteBuf; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.socket.SocketChannel; +import javax.net.ssl.SSLEngine; +import javax.security.sasl.SaslException; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; public class UserClient extends BasicClient<RpcType, UserClient.UserToBitConnection, UserToBitHandshake, BitToUserHandshake> { @@ -111,12 +105,11 @@ public class UserClient private RpcEndpointInfos serverInfos = null; private Set<RpcType> supportedMethods = null; - // these are used for authentication - private volatile List<String> serverAuthMechanisms = null; - private volatile boolean authComplete = true; private SSLConfig sslConfig; private DrillbitEndpoint endpoint; + private DrillProperties properties; + public UserClient(String clientName, DrillConfig config, Properties properties, boolean supportComplexTypes, BufferAllocator allocator, EventLoopGroup eventLoopGroup, Executor eventExecutor, DrillbitEndpoint endpoint) throws NonTransientRpcException { @@ -133,6 +126,8 @@ public class UserClient throw new InvalidConnectionInfoException(e.getMessage()); } + // Keep a copy of properties in UserClient + this.properties = DrillProperties.createFromProperties(properties); } @Override protected void setupSSL(ChannelPipeline pipe, @@ -195,30 +190,25 @@ public class UserClient SaslSupport.valueOf(Integer.parseInt(properties.getProperty(DrillProperties.TEST_SASL_LEVEL)))); } - if (sslConfig.isUserSslEnabled()) { - try { - connect(hsBuilder.build(), endpoint) - .checkedGet(sslConfig.getHandshakeTimeout(), TimeUnit.MILLISECONDS); - } catch (TimeoutException e) { - String msg = new StringBuilder().append( - "Connecting to the server timed out. This is sometimes due to a mismatch in the SSL configuration between" + - " client and server. [ Exception: ") - .append(e.getMessage()).append("]").toString(); - throw new NonTransientRpcException(msg); - } - } else { - connect(hsBuilder.build(), endpoint).checkedGet(); - } - - // Validate if both client and server are compatible in their security requirements for the connection - validateSaslCompatibility(properties); - - if (serverAuthMechanisms != null) { - try { - authenticate(properties).checkedGet(); - } catch (final SaslException e) { - throw new NonTransientRpcException(e); + try { + if (sslConfig.isUserSslEnabled()) { + connect(hsBuilder.build(), endpoint).checkedGet(sslConfig.getHandshakeTimeout(), TimeUnit.MILLISECONDS); + } else { + connect(hsBuilder.build(), endpoint).checkedGet(); } + } // Treat all authentication related exception as NonTransientException, since in those cases retry by client + // should not happen + catch(TimeoutException e) { + String msg = new StringBuilder().append("Connecting to the server timed out. This is sometimes due to a " + + "mismatch in the SSL configuration between" + " client and server. [ Exception: ").append(e.getMessage()) + .append("]").toString(); + throw new NonTransientRpcException(msg); + } catch (SaslException e) { + throw new NonTransientRpcException(e); + } catch (RpcException e) { + throw e; + } catch (Exception e) { + throw new RpcException(e); } } @@ -226,14 +216,18 @@ public class UserClient * Validate that security requirements from client and Drillbit side is compatible. For example: It verifies if one * side needs authentication / encryption then other side is also configured to support that security properties. * @param properties - DrillClient connection parameters + * @param serverAuthMechs - list of auth mechanisms supported by server * @throws NonTransientRpcException - When DrillClient security requirements doesn't match Drillbit side of security * configurations. */ - private void validateSaslCompatibility(DrillProperties properties) throws NonTransientRpcException { + private void validateSaslCompatibility(DrillProperties properties, List<String> serverAuthMechs) + throws NonTransientRpcException { final boolean clientNeedsEncryption = properties.containsKey(DrillProperties.SASL_ENCRYPT) && Boolean.parseBoolean(properties.getProperty(DrillProperties.SASL_ENCRYPT)); + final boolean serverAuthConfigured = (serverAuthMechs != null); + // Check if client needs encryption and server is not configured for encryption. if (clientNeedsEncryption && !connection.isEncryptionEnabled()) { throw new NonTransientRpcException( @@ -243,7 +237,7 @@ public class UserClient } // Check if client needs encryption and server doesn't support any security mechanisms. - if (clientNeedsEncryption && serverAuthMechanisms == null) { + if (clientNeedsEncryption && !serverAuthConfigured) { throw new NonTransientRpcException( "Client needs encrypted connection but server doesn't support any security mechanisms." + " Please contact your administrator. [Warn: It may be due to wrong config or a security attack in" + @@ -251,7 +245,7 @@ public class UserClient } // Check if client needs authentication and server doesn't support any security mechanisms. - if (clientNeedsAuthExceptPlain(properties) && serverAuthMechanisms == null) { + if (clientNeedsAuthExceptPlain(properties) && !serverAuthConfigured) { throw new NonTransientRpcException( "Client needs authentication but server doesn't support any security mechanisms." + " Please contact your administrator. [Warn: It may be due to wrong config or a security attack in" + @@ -280,15 +274,24 @@ public class UserClient return clientNeedsAuth; } - private CheckedFuture<Void, RpcException> connect(final UserToBitHandshake handshake, + private CheckedFuture<Void, IOException> connect(final UserToBitHandshake handshake, final DrillbitEndpoint endpoint) { final SettableFuture<Void> connectionSettable = SettableFuture.create(); - final CheckedFuture<Void, RpcException> connectionFuture = - new AbstractCheckedFuture<Void, RpcException>(connectionSettable) { - @Override protected RpcException mapException(Exception e) { + final CheckedFuture<Void, IOException> connectionFuture = + new AbstractCheckedFuture<Void, IOException>(connectionSettable) { + @Override protected IOException mapException(Exception e) { + if (e instanceof SaslException) { + return (SaslException)e; + } else if (e instanceof ExecutionException) { + final Throwable cause = Throwables.getRootCause(e); + if (cause instanceof SaslException) { + return (SaslException)cause; + } + } return RpcException.mapException(e); } }; + final RpcConnectionHandler<UserToBitConnection> connectionHandler = new RpcConnectionHandler<UserToBitConnection>() { @Override public void connectionSucceeded(UserToBitConnection connection) { @@ -296,100 +299,40 @@ public class UserClient } @Override public void connectionFailed(FailureType type, Throwable t) { - connectionSettable - .setException(new RpcException(String.format("%s : %s", type.name(), t.getMessage()), t)); - } - }; - - connectAsClient(queryResultHandler.getWrappedConnectionHandler(connectionHandler), handshake, - endpoint.getAddress(), endpoint.getUserPort()); - - return connectionFuture; - } - - private CheckedFuture<Void, SaslException> authenticate(final DrillProperties properties) { - final Map<String, String> propertiesMap = properties.stringPropertiesAsMap(); - - // Set correct QOP property and Strength based on server needs encryption or not. - // If ChunkMode is enabled then negotiate for buffer size equal to wrapChunkSize, - // If ChunkMode is disabled then negotiate for MAX_WRAPPED_SIZE buffer size. - propertiesMap.putAll( - SaslProperties.getSaslProperties(connection.isEncryptionEnabled(), connection.getMaxWrappedSize())); - - final SettableFuture<Void> authSettable = - SettableFuture.create(); // use handleAuthFailure to setException - final CheckedFuture<Void, SaslException> authFuture = - new AbstractCheckedFuture<Void, SaslException>(authSettable) { - - @Override protected SaslException mapException(Exception e) { - if (e instanceof ExecutionException) { - final Throwable cause = Throwables.getRootCause(e); + // Don't wrap NonTransientRpcException inside RpcException, since called should not retry to connect in + // this case + if (t instanceof NonTransientRpcException || t instanceof SaslException) { + connectionSettable.setException(t); + } else if (t instanceof RpcException) { + final Throwable cause = t.getCause(); if (cause instanceof SaslException) { - return new SaslException(String.format("Authentication failed. [Details: %s, Error %s]", - connection.getEncryptionCtxtString(), cause.getMessage()), cause); + connectionSettable.setException(cause); + return; } + connectionSettable.setException(t); + } else { + connectionSettable.setException( + new RpcException(String.format("%s : %s", type.name(), t.getMessage()), t)); } - return new SaslException(String - .format("Authentication failed unexpectedly. [Details: %s, Error %s]", - connection.getEncryptionCtxtString(), e.getMessage()), e); } }; - final AuthenticatorFactory factory; - final String mechanismName; - final UserGroupInformation ugi; - final SaslClient saslClient; - try { - factory = getAuthenticatorFactory(properties); - mechanismName = factory.getSimpleName(); - logger.trace("Will try to authenticate to server using {} mechanism with encryption context {}", - mechanismName, connection.getEncryptionCtxtString()); - - // Update the thread context class loader to current class loader - // See DRILL-6063 for detailed description - final ClassLoader oldThreadCtxtCL = Thread.currentThread().getContextClassLoader(); - final ClassLoader newThreadCtxtCL = this.getClass().getClassLoader(); - Thread.currentThread().setContextClassLoader(newThreadCtxtCL); - - ugi = factory.createAndLoginUser(propertiesMap); - - // Reset the thread context class loader to original one - Thread.currentThread().setContextClassLoader(oldThreadCtxtCL); - - saslClient = factory.createSaslClient(ugi, propertiesMap); - if (saslClient == null) { - throw new SaslException(String.format( - "Cannot initiate authentication using %s mechanism. Insufficient " - + "credentials or selected mechanism doesn't support configured security layers?", - factory.getSimpleName())); - } - connection.setSaslClient(saslClient); - } catch (final IOException e) { - authSettable.setException(e); - return authFuture; - } - - logger.trace("Initiating SASL exchange."); - new AuthenticationOutcomeListener<>(this, connection, RpcType.SASL_MESSAGE, ugi, - new RpcOutcomeListener<Void>() { - @Override public void failed(RpcException ex) { - authSettable.setException(ex); - } - - @Override public void success(Void value, ByteBuf buffer) { - authComplete = true; - authSettable.set(null); - } + connectAsClient(queryResultHandler.getWrappedConnectionHandler(connectionHandler), handshake, + endpoint.getAddress(), endpoint.getUserPort()); - @Override public void interrupted(InterruptedException e) { - authSettable.setException(e); - } - }).initiate(mechanismName); - return authFuture; + return connectionFuture; } - private AuthenticatorFactory getAuthenticatorFactory(final DrillProperties properties) - throws SaslException { + /** + * Get's the authenticator factory for the mechanism required by client if it's supported on the server side too. + * Otherwise it throws {@link SaslException} + * @param properties - client connection properties + * @param serverAuthMechanisms - list of authentication mechanisms supported by server + * @return - {@link AuthenticatorFactory} for the mechanism required by client for authentication + * @throws SaslException - In case of failure + */ + private AuthenticatorFactory getAuthenticatorFactory(final DrillProperties properties, + List<String> serverAuthMechanisms) throws SaslException { final Set<String> mechanismSet = AuthStringUtil.asSet(serverAuthMechanisms); // first, check if a certain mechanism must be used @@ -421,7 +364,7 @@ public class UserClient throw new SaslException(String .format("Server requires authentication using %s. Insufficient credentials?. " + "[Details: %s]. ", - serverAuthMechanisms, connection.getEncryptionCtxtString())); + mechanismSet, connection.getEncryptionCtxtString())); } protected <SEND extends MessageLite, RECEIVE extends MessageLite> void send( @@ -464,7 +407,7 @@ public class UserClient @Override protected void handle(UserToBitConnection connection, int rpcType, ByteBuf pBody, ByteBuf dBody, ResponseSender sender) throws RpcException { - if (!authComplete) { + if (!isAuthComplete()) { // Remote should not be making any requests before authenticating, drop connection throw new RpcException(String.format("Request of type %d is not allowed without authentication. " + "Remote on %s must authenticate before making requests. Connection dropped.", rpcType, @@ -484,8 +427,45 @@ public class UserClient } } - @Override protected void validateHandshake(BitToUserHandshake inbound) throws RpcException { + @Override + protected void prepareSaslHandshake(final RpcConnectionHandler<UserToBitConnection> connectionHandler, + List<String> serverAuthMechanisms) { + try { + final Map<String, String> saslProperties = properties.stringPropertiesAsMap(); + + // Set correct QOP property and Strength based on server needs encryption or not. + // If ChunkMode is enabled then negotiate for buffer size equal to wrapChunkSize, + // If ChunkMode is disabled then negotiate for MAX_WRAPPED_SIZE buffer size. + saslProperties.putAll( + SaslProperties.getSaslProperties(connection.isEncryptionEnabled(), connection.getMaxWrappedSize())); + + final AuthenticatorFactory factory = getAuthenticatorFactory(properties, serverAuthMechanisms); + final String mechanismName = factory.getSimpleName(); + logger.trace("Will try to authenticate to server using {} mechanism with encryption context {}", + mechanismName, connection.getEncryptionCtxtString()); + + // Update the thread context class loader to current class loader + // See DRILL-6063 for detailed description + final ClassLoader oldThreadCtxtCL = Thread.currentThread().getContextClassLoader(); + final ClassLoader newThreadCtxtCL = this.getClass().getClassLoader(); + Thread.currentThread().setContextClassLoader(newThreadCtxtCL); + final UserGroupInformation ugi = factory.createAndLoginUser(saslProperties); + // Reset the thread context class loader to original one + Thread.currentThread().setContextClassLoader(oldThreadCtxtCL); + + startSaslHandshake(connectionHandler, saslProperties, ugi, factory, RpcType.SASL_MESSAGE); + } catch (final IOException e) { + logger.error("Failed while doing setup for starting SASL handshake for connection", connection.getName()); + final Exception ex = new RpcException(String.format("Failed to initiate authentication for connection %s", + connection.getName()), e); + connectionHandler.connectionFailed(RpcConnectionHandler.FailureType.AUTHENTICATION, ex); + } + } + + @Override protected List<String> validateHandshake(BitToUserHandshake inbound) throws RpcException { // logger.debug("Handling handshake from bit to user. {}", inbound); + List<String> serverAuthMechanisms = null; + if (inbound.hasServerInfos()) { serverInfos = inbound.getServerInfos(); } @@ -494,9 +474,9 @@ public class UserClient switch (inbound.getStatus()) { case SUCCESS: break; - case AUTH_REQUIRED: { - authComplete = false; + case AUTH_REQUIRED: serverAuthMechanisms = ImmutableList.copyOf(inbound.getAuthenticationMechanismsList()); + setAuthComplete(false); connection.setEncryption(inbound.hasEncrypted() && inbound.getEncrypted()); if (inbound.hasMaxWrappedSize()) { @@ -506,7 +486,6 @@ public class UserClient .format("Server requires authentication with encryption context %s before proceeding.", connection.getEncryptionCtxtString())); break; - } case AUTH_FAILED: case RPC_VERSION_MISMATCH: case UNKNOWN_FAILURE: @@ -516,6 +495,11 @@ public class UserClient logger.error(errMsg); throw new NonTransientRpcException(errMsg); } + + // Before starting SASL handshake validate if both client and server are compatible in their security + // requirements for the connection + validateSaslCompatibility(properties, serverAuthMechanisms); + return serverAuthMechanisms; } @Override protected UserToBitConnection initRemoteConnection(SocketChannel channel) { |