From 08057df67bb5a8d6530d96cfb4da3d1ecacc6de3 Mon Sep 17 00:00:00 2001 From: Nikita Date: Tue, 14 Jan 2014 16:24:37 +0400 Subject: [PATCH] SocketIOClient.getHandshakeData method added --- .../socketio/HandshakeData.java | 33 +++++++++++++++---- .../socketio/SocketIOClient.java | 7 ++++ .../socketio/handler/AuthorizeHandler.java | 17 +++++----- .../store/pubsub/BaseStoreFactory.java | 2 +- .../store/pubsub/HandshakeMessage.java | 10 +++++- .../socketio/transport/MainBaseClient.java | 9 ++++- .../socketio/transport/NamespaceClient.java | 7 +++- .../socketio/transport/WebSocketClient.java | 5 +-- .../transport/WebSocketTransport.java | 6 ++-- .../socketio/transport/XHRPollingClient.java | 6 ++-- .../transport/XHRPollingTransport.java | 10 +++--- 11 files changed, 82 insertions(+), 30 deletions(-) diff --git a/src/main/java/com/corundumstudio/socketio/HandshakeData.java b/src/main/java/com/corundumstudio/socketio/HandshakeData.java index e9c1553..6b475c5 100644 --- a/src/main/java/com/corundumstudio/socketio/HandshakeData.java +++ b/src/main/java/com/corundumstudio/socketio/HandshakeData.java @@ -21,16 +21,19 @@ import java.util.Date; import java.util.List; import java.util.Map; -public final class HandshakeData implements Serializable { +public class HandshakeData implements Serializable { private static final long serialVersionUID = 1196350300161819978L; - private final Map> headers; - private final InetSocketAddress address; - private final Date time = new Date(); - private final String url; - private final Map> urlParams; - private final boolean xdomain; + private Map> headers; + private InetSocketAddress address; + private Date time = new Date(); + private String url; + private Map> urlParams; + private boolean xdomain; + + public HandshakeData() { + } public HandshakeData(Map> headers, Map> urlParams, InetSocketAddress address, String url, boolean xdomain) { super(); @@ -49,6 +52,14 @@ public final class HandshakeData implements Serializable { return headers; } + public String getSingleHeader(String name) { + List values = headers.get(name); + if (values != null && values.size() == 1) { + return values.iterator().next(); + } + return null; + } + public Date getTime() { return time; } @@ -65,4 +76,12 @@ public final class HandshakeData implements Serializable { return urlParams; } + public String getSingleUrlParam(String name) { + List values = urlParams.get(name); + if (values != null && values.size() == 1) { + return values.iterator().next(); + } + return null; + } + } diff --git a/src/main/java/com/corundumstudio/socketio/SocketIOClient.java b/src/main/java/com/corundumstudio/socketio/SocketIOClient.java index 695e575..f0d3e54 100644 --- a/src/main/java/com/corundumstudio/socketio/SocketIOClient.java +++ b/src/main/java/com/corundumstudio/socketio/SocketIOClient.java @@ -29,6 +29,13 @@ import com.corundumstudio.socketio.store.Store; */ public interface SocketIOClient extends ClientOperations, Store { + /** + * Handshake data used during client connection + * + * @return HandshakeData + */ + HandshakeData getHandshakeData(); + /** * Current client transport protocol * diff --git a/src/main/java/com/corundumstudio/socketio/handler/AuthorizeHandler.java b/src/main/java/com/corundumstudio/socketio/handler/AuthorizeHandler.java index aa28346..645a68f 100644 --- a/src/main/java/com/corundumstudio/socketio/handler/AuthorizeHandler.java +++ b/src/main/java/com/corundumstudio/socketio/handler/AuthorizeHandler.java @@ -34,8 +34,8 @@ import java.net.InetSocketAddress; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import org.slf4j.Logger; @@ -46,7 +46,6 @@ import com.corundumstudio.socketio.Disconnectable; import com.corundumstudio.socketio.HandshakeData; import com.corundumstudio.socketio.SocketIOClient; import com.corundumstudio.socketio.messages.AuthorizeMessage; -import com.corundumstudio.socketio.misc.ConcurrentHashSet; import com.corundumstudio.socketio.namespace.Namespace; import com.corundumstudio.socketio.namespace.NamespacesHub; import com.corundumstudio.socketio.parser.Packet; @@ -65,7 +64,7 @@ public class AuthorizeHandler extends ChannelInboundHandlerAdapter implements Di private final Logger log = LoggerFactory.getLogger(getClass()); private final CancelableScheduler disconnectScheduler; - private final Set authorizedSessionIds = new ConcurrentHashSet(); + private final Map authorizedSessionIds = new ConcurrentHashMap(); private final String connectPath; private final Configuration configuration; @@ -131,8 +130,8 @@ public class AuthorizeHandler extends ChannelInboundHandlerAdapter implements Di } channel.write(new AuthorizeMessage(msg, jsonpParam, origin, sessionId)); - handshake(sessionId); - HandshakeMessage message = new HandshakeMessage(sessionId); + handshake(sessionId, data); + HandshakeMessage message = new HandshakeMessage(sessionId, data); configuration.getStoreFactory().getPubSubStore().publish(PubSubStore.HANDSHAKE, message); log.debug("Handshake authorized for sessionId: {}", sessionId); } else { @@ -169,12 +168,12 @@ public class AuthorizeHandler extends ChannelInboundHandlerAdapter implements Di }); } - public boolean isSessionAuthorized(UUID sessionId) { - return authorizedSessionIds.contains(sessionId); + public HandshakeData getHandshakeData(UUID sessionId) { + return authorizedSessionIds.get(sessionId); } - public void handshake(UUID sessionId) { - authorizedSessionIds.add(sessionId); + public void handshake(UUID sessionId, HandshakeData data) { + authorizedSessionIds.put(sessionId, data); } public void connect(UUID sessionId) { diff --git a/src/main/java/com/corundumstudio/socketio/store/pubsub/BaseStoreFactory.java b/src/main/java/com/corundumstudio/socketio/store/pubsub/BaseStoreFactory.java index 463f0fa..e98b1d2 100644 --- a/src/main/java/com/corundumstudio/socketio/store/pubsub/BaseStoreFactory.java +++ b/src/main/java/com/corundumstudio/socketio/store/pubsub/BaseStoreFactory.java @@ -54,7 +54,7 @@ public abstract class BaseStoreFactory implements StoreFactory { getPubSubStore().subscribe(PubSubStore.HANDSHAKE, new PubSubListener() { @Override public void onMessage(HandshakeMessage msg) { - authorizeHandler.handshake(msg.getSessionId()); + authorizeHandler.handshake(msg.getSessionId(), msg.getData()); log.debug("{} sessionId: {}", PubSubStore.HANDSHAKE, msg.getSessionId()); } }, HandshakeMessage.class); diff --git a/src/main/java/com/corundumstudio/socketio/store/pubsub/HandshakeMessage.java b/src/main/java/com/corundumstudio/socketio/store/pubsub/HandshakeMessage.java index 2edd257..b251ce5 100644 --- a/src/main/java/com/corundumstudio/socketio/store/pubsub/HandshakeMessage.java +++ b/src/main/java/com/corundumstudio/socketio/store/pubsub/HandshakeMessage.java @@ -17,22 +17,30 @@ package com.corundumstudio.socketio.store.pubsub; import java.util.UUID; +import com.corundumstudio.socketio.HandshakeData; + public class HandshakeMessage extends PubSubMessage { private static final long serialVersionUID = 5767127795325210150L; private UUID sessionId; + private HandshakeData data; public HandshakeMessage() { } - public HandshakeMessage(UUID sessionId) { + public HandshakeMessage(UUID sessionId, HandshakeData data) { super(); this.sessionId = sessionId; + this.data = data; } public UUID getSessionId() { return sessionId; } + public HandshakeData getData() { + return data; + } + } diff --git a/src/main/java/com/corundumstudio/socketio/transport/MainBaseClient.java b/src/main/java/com/corundumstudio/socketio/transport/MainBaseClient.java index 4cc6ebe..85c7acf 100644 --- a/src/main/java/com/corundumstudio/socketio/transport/MainBaseClient.java +++ b/src/main/java/com/corundumstudio/socketio/transport/MainBaseClient.java @@ -26,6 +26,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import com.corundumstudio.socketio.DisconnectableHub; +import com.corundumstudio.socketio.HandshakeData; import com.corundumstudio.socketio.SocketIOClient; import com.corundumstudio.socketio.Transport; import com.corundumstudio.socketio.ack.AckManager; @@ -53,14 +54,16 @@ public abstract class MainBaseClient { private final UUID sessionId; private final Transport transport; private Channel channel; + private final HandshakeData handshakeData; public MainBaseClient(UUID sessionId, AckManager ackManager, DisconnectableHub disconnectable, - Transport transport, StoreFactory storeFactory) { + Transport transport, StoreFactory storeFactory, HandshakeData handshakeData) { this.sessionId = sessionId; this.ackManager = ackManager; this.disconnectable = disconnectable; this.transport = transport; this.store = storeFactory.create(sessionId); + this.handshakeData = handshakeData; } public Transport getTransport() { @@ -98,6 +101,10 @@ public abstract class MainBaseClient { } } + public HandshakeData getHandshakeData() { + return handshakeData; + } + public AckManager getAckManager() { return ackManager; } diff --git a/src/main/java/com/corundumstudio/socketio/transport/NamespaceClient.java b/src/main/java/com/corundumstudio/socketio/transport/NamespaceClient.java index eb27cf7..2733801 100644 --- a/src/main/java/com/corundumstudio/socketio/transport/NamespaceClient.java +++ b/src/main/java/com/corundumstudio/socketio/transport/NamespaceClient.java @@ -16,12 +16,12 @@ package com.corundumstudio.socketio.transport; import java.net.SocketAddress; -import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.UUID; import com.corundumstudio.socketio.AckCallback; +import com.corundumstudio.socketio.HandshakeData; import com.corundumstudio.socketio.SocketIOClient; import com.corundumstudio.socketio.Transport; import com.corundumstudio.socketio.namespace.Namespace; @@ -206,4 +206,9 @@ public class NamespaceClient implements SocketIOClient { return namespace.getRooms(this); } + @Override + public HandshakeData getHandshakeData() { + return baseClient.getHandshakeData(); + } + } diff --git a/src/main/java/com/corundumstudio/socketio/transport/WebSocketClient.java b/src/main/java/com/corundumstudio/socketio/transport/WebSocketClient.java index a9ffc11..57c90e0 100644 --- a/src/main/java/com/corundumstudio/socketio/transport/WebSocketClient.java +++ b/src/main/java/com/corundumstudio/socketio/transport/WebSocketClient.java @@ -21,6 +21,7 @@ import io.netty.channel.ChannelFuture; import java.util.UUID; import com.corundumstudio.socketio.DisconnectableHub; +import com.corundumstudio.socketio.HandshakeData; import com.corundumstudio.socketio.Transport; import com.corundumstudio.socketio.ack.AckManager; import com.corundumstudio.socketio.messages.WebSocketPacketMessage; @@ -31,8 +32,8 @@ public class WebSocketClient extends MainBaseClient { public WebSocketClient(Channel channel, AckManager ackManager, DisconnectableHub disconnectable, UUID sessionId, - Transport transport, StoreFactory storeFactory) { - super(sessionId, ackManager, disconnectable, transport, storeFactory); + Transport transport, StoreFactory storeFactory, HandshakeData handshakeData) { + super(sessionId, ackManager, disconnectable, transport, storeFactory, handshakeData); setChannel(channel); } diff --git a/src/main/java/com/corundumstudio/socketio/transport/WebSocketTransport.java b/src/main/java/com/corundumstudio/socketio/transport/WebSocketTransport.java index 495b12e..a200be7 100644 --- a/src/main/java/com/corundumstudio/socketio/transport/WebSocketTransport.java +++ b/src/main/java/com/corundumstudio/socketio/transport/WebSocketTransport.java @@ -39,6 +39,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.corundumstudio.socketio.DisconnectableHub; +import com.corundumstudio.socketio.HandshakeData; import com.corundumstudio.socketio.SocketIOClient; import com.corundumstudio.socketio.SocketIOChannelInitializer; import com.corundumstudio.socketio.Transport; @@ -155,14 +156,15 @@ public class WebSocketTransport extends BaseTransport { } private void connectClient(Channel channel, UUID sessionId) { - if (!authorizeHandler.isSessionAuthorized(sessionId)) { + HandshakeData data = authorizeHandler.getHandshakeData(sessionId); + if (data == null) { log.warn("Unauthorized client with sessionId: {}, from ip: {}. Channel closed!", sessionId, channel.remoteAddress()); channel.close(); return; } - WebSocketClient client = new WebSocketClient(channel, ackManager, disconnectableHub, sessionId, getTransport(), storeFactory); + WebSocketClient client = new WebSocketClient(channel, ackManager, disconnectableHub, sessionId, getTransport(), storeFactory, data); channelId2Client.put(channel, client); sessionId2Client.put(sessionId, client); diff --git a/src/main/java/com/corundumstudio/socketio/transport/XHRPollingClient.java b/src/main/java/com/corundumstudio/socketio/transport/XHRPollingClient.java index 58d4a68..36566c6 100644 --- a/src/main/java/com/corundumstudio/socketio/transport/XHRPollingClient.java +++ b/src/main/java/com/corundumstudio/socketio/transport/XHRPollingClient.java @@ -24,6 +24,7 @@ import java.util.UUID; import java.util.concurrent.ConcurrentLinkedQueue; import com.corundumstudio.socketio.DisconnectableHub; +import com.corundumstudio.socketio.HandshakeData; import com.corundumstudio.socketio.Transport; import com.corundumstudio.socketio.ack.AckManager; import com.corundumstudio.socketio.messages.XHRSendPacketsMessage; @@ -37,8 +38,9 @@ public class XHRPollingClient extends MainBaseClient { private final Queue packetQueue = new ConcurrentLinkedQueue(); private String origin; - public XHRPollingClient(AckManager ackManager, DisconnectableHub disconnectable, UUID sessionId, Transport transport, StoreFactory storeFactory) { - super(sessionId, ackManager, disconnectable, transport, storeFactory); + public XHRPollingClient(AckManager ackManager, DisconnectableHub disconnectable, + UUID sessionId, Transport transport, StoreFactory storeFactory, HandshakeData handshakeData) { + super(sessionId, ackManager, disconnectable, transport, storeFactory, handshakeData); } public void bindChannel(Channel channel, String origin) { diff --git a/src/main/java/com/corundumstudio/socketio/transport/XHRPollingTransport.java b/src/main/java/com/corundumstudio/socketio/transport/XHRPollingTransport.java index eafd253..acdb2d4 100644 --- a/src/main/java/com/corundumstudio/socketio/transport/XHRPollingTransport.java +++ b/src/main/java/com/corundumstudio/socketio/transport/XHRPollingTransport.java @@ -38,6 +38,7 @@ import org.slf4j.LoggerFactory; import com.corundumstudio.socketio.Configuration; import com.corundumstudio.socketio.DisconnectableHub; +import com.corundumstudio.socketio.HandshakeData; import com.corundumstudio.socketio.SocketIOClient; import com.corundumstudio.socketio.Transport; import com.corundumstudio.socketio.ack.AckManager; @@ -171,14 +172,15 @@ public class XHRPollingTransport extends BaseTransport { } private void onGet(UUID sessionId, ChannelHandlerContext ctx, String origin) { - if (!authorizeHandler.isSessionAuthorized(sessionId)) { + HandshakeData data = authorizeHandler.getHandshakeData(sessionId); + if (data == null) { sendError(ctx, origin, sessionId); return; } XHRPollingClient client = (XHRPollingClient) sessionId2Client.get(sessionId); if (client == null) { - client = createClient(origin, ctx.channel(), sessionId); + client = createClient(origin, ctx.channel(), sessionId, data); } client.bindChannel(ctx.channel(), origin); @@ -187,8 +189,8 @@ public class XHRPollingTransport extends BaseTransport { scheduleNoop(sessionId); } - private XHRPollingClient createClient(String origin, Channel channel, UUID sessionId) { - XHRPollingClient client = new XHRPollingClient(ackManager, disconnectable, sessionId, Transport.XHRPOLLING, configuration.getStoreFactory()); + private XHRPollingClient createClient(String origin, Channel channel, UUID sessionId, HandshakeData data) { + XHRPollingClient client = new XHRPollingClient(ackManager, disconnectable, sessionId, Transport.XHRPOLLING, configuration.getStoreFactory(), data); sessionId2Client.put(sessionId, client); client.bindChannel(channel, origin);