001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.auto;
018
019import java.io.IOException;
020import java.io.InputStream;
021import java.net.Socket;
022import java.net.URI;
023import java.net.URISyntaxException;
024import java.nio.ByteBuffer;
025import java.util.HashMap;
026import java.util.Map;
027import java.util.Set;
028import java.util.concurrent.ConcurrentHashMap;
029import java.util.concurrent.ConcurrentMap;
030import java.util.concurrent.Future;
031import java.util.concurrent.LinkedBlockingQueue;
032import java.util.concurrent.ThreadPoolExecutor;
033import java.util.concurrent.TimeUnit;
034import java.util.concurrent.TimeoutException;
035import java.util.concurrent.atomic.AtomicInteger;
036
037import javax.net.ServerSocketFactory;
038
039import org.apache.activemq.broker.BrokerService;
040import org.apache.activemq.broker.BrokerServiceAware;
041import org.apache.activemq.openwire.OpenWireFormatFactory;
042import org.apache.activemq.transport.InactivityIOException;
043import org.apache.activemq.transport.Transport;
044import org.apache.activemq.transport.TransportFactory;
045import org.apache.activemq.transport.TransportServer;
046import org.apache.activemq.transport.protocol.AmqpProtocolVerifier;
047import org.apache.activemq.transport.protocol.MqttProtocolVerifier;
048import org.apache.activemq.transport.protocol.OpenWireProtocolVerifier;
049import org.apache.activemq.transport.protocol.ProtocolVerifier;
050import org.apache.activemq.transport.protocol.StompProtocolVerifier;
051import org.apache.activemq.transport.tcp.TcpTransport;
052import org.apache.activemq.transport.tcp.TcpTransport.InitBuffer;
053import org.apache.activemq.transport.tcp.TcpTransportFactory;
054import org.apache.activemq.transport.tcp.TcpTransportServer;
055import org.apache.activemq.util.FactoryFinder;
056import org.apache.activemq.util.IOExceptionSupport;
057import org.apache.activemq.util.IntrospectionSupport;
058import org.apache.activemq.util.ServiceStopper;
059import org.apache.activemq.wireformat.WireFormat;
060import org.apache.activemq.wireformat.WireFormatFactory;
061import org.slf4j.Logger;
062import org.slf4j.LoggerFactory;
063
064/**
065 * A TCP based implementation of {@link TransportServer}
066 */
067public class AutoTcpTransportServer extends TcpTransportServer {
068
069    private static final Logger LOG = LoggerFactory.getLogger(AutoTcpTransportServer.class);
070
071    protected Map<String, Map<String, Object>> wireFormatOptions;
072    protected Map<String, Object> autoTransportOptions;
073    protected Set<String> enabledProtocols;
074    protected final Map<String, ProtocolVerifier> protocolVerifiers = new ConcurrentHashMap<String, ProtocolVerifier>();
075
076    protected BrokerService brokerService;
077
078    protected final ThreadPoolExecutor newConnectionExecutor;
079    protected final ThreadPoolExecutor protocolDetectionExecutor;
080    protected int maxConnectionThreadPoolSize = Integer.MAX_VALUE;
081    protected int protocolDetectionTimeOut = 30000;
082
083    private static final FactoryFinder TRANSPORT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/transport/");
084    private final ConcurrentMap<String, TransportFactory> transportFactories = new ConcurrentHashMap<String, TransportFactory>();
085
086    private static final FactoryFinder WIREFORMAT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/wireformat/");
087
088    public WireFormatFactory findWireFormatFactory(String scheme, Map<String, Map<String, Object>> options) throws IOException {
089        WireFormatFactory wff = null;
090        try {
091            wff = (WireFormatFactory)WIREFORMAT_FACTORY_FINDER.newInstance(scheme);
092            if (options != null) {
093                final Map<String, Object> wfOptions = new HashMap<>();
094                if (options.get(AutoTransportUtils.ALL) != null) {
095                    wfOptions.putAll(options.get(AutoTransportUtils.ALL));
096                }
097                if (options.get(scheme) != null) {
098                    wfOptions.putAll(options.get(scheme));
099                }
100                IntrospectionSupport.setProperties(wff, wfOptions);
101            }
102            if (wff instanceof OpenWireFormatFactory) {
103                protocolVerifiers.put(AutoTransportUtils.OPENWIRE, new OpenWireProtocolVerifier((OpenWireFormatFactory) wff));
104            }
105            return wff;
106        } catch (Throwable e) {
107           throw IOExceptionSupport.create("Could not create wire format factory for: " + scheme + ", reason: " + e, e);
108        }
109    }
110
111    public TransportFactory findTransportFactory(String scheme, Map<String, ?> options) throws IOException {
112        scheme = append(scheme, "nio");
113        scheme = append(scheme, "ssl");
114
115        if (scheme.isEmpty()) {
116            scheme = "tcp";
117        }
118
119        TransportFactory tf = transportFactories.get(scheme);
120        if (tf == null) {
121            // Try to load if from a META-INF property.
122            try {
123                tf = (TransportFactory)TRANSPORT_FACTORY_FINDER.newInstance(scheme);
124                if (options != null) {
125                    IntrospectionSupport.setProperties(tf, options);
126                }
127                transportFactories.put(scheme, tf);
128            } catch (Throwable e) {
129                throw IOExceptionSupport.create("Transport scheme NOT recognized: [" + scheme + "]", e);
130            }
131        }
132        return tf;
133    }
134
135    protected String append(String currentScheme, String scheme) {
136        if (this.getBindLocation().getScheme().contains(scheme)) {
137            if (!currentScheme.isEmpty()) {
138                currentScheme += "+";
139            }
140            currentScheme += scheme;
141        }
142        return currentScheme;
143    }
144
145    /**
146     * @param transportFactory
147     * @param location
148     * @param serverSocketFactory
149     * @throws IOException
150     * @throws URISyntaxException
151     */
152    public AutoTcpTransportServer(TcpTransportFactory transportFactory,
153            URI location, ServerSocketFactory serverSocketFactory, BrokerService brokerService,
154            Set<String> enabledProtocols)
155            throws IOException, URISyntaxException {
156        super(transportFactory, location, serverSocketFactory);
157
158        //Use an executor service here to handle new connections.  Setting the max number
159        //of threads to the maximum number of connections the thread count isn't unbounded
160        newConnectionExecutor = new ThreadPoolExecutor(maxConnectionThreadPoolSize,
161                maxConnectionThreadPoolSize,
162                30L, TimeUnit.SECONDS,
163                new LinkedBlockingQueue<Runnable>());
164        //allow the thread pool to shrink if the max number of threads isn't needed
165        //and the pool can grow and shrink as needed if contention is high
166        newConnectionExecutor.allowCoreThreadTimeOut(true);
167
168        //Executor for waiting for bytes to detection of protocol
169        protocolDetectionExecutor = new ThreadPoolExecutor(maxConnectionThreadPoolSize,
170                maxConnectionThreadPoolSize,
171                30L, TimeUnit.SECONDS,
172                new LinkedBlockingQueue<Runnable>());
173        //allow the thread pool to shrink if the max number of threads isn't needed
174        protocolDetectionExecutor.allowCoreThreadTimeOut(true);
175
176        this.brokerService = brokerService;
177        this.enabledProtocols = enabledProtocols;
178        initProtocolVerifiers();
179    }
180
181    public int getMaxConnectionThreadPoolSize() {
182        return maxConnectionThreadPoolSize;
183    }
184
185    /**
186     * Set the number of threads to be used for processing connections.  Defaults
187     * to Integer.MAX_SIZE.  Set this value to be lower to reduce the
188     * number of simultaneous connection attempts.  If not set then the maximum number of
189     * threads will generally be controlled by the transport maxConnections setting:
190     * {@link TcpTransportServer#setMaximumConnections(int)}.
191     *<p>
192     * Note that this setter controls two thread pools because connection attempts
193     * require 1 thread to start processing the connection and another thread to read from the
194     * socket and to detect the protocol. Two threads are needed because some transports
195     * block on socket read so the first thread needs to be able to abort the second thread on timeout.
196     * Therefore this setting will set each thread pool to the size passed in essentially giving
197     * 2 times as many potential threads as the value set.
198     *<p>
199     * Both thread pools will close idle threads after a period of time
200     * essentially allowing the thread pools to grow and shrink dynamically based on load.
201     *
202     * @see {@link TcpTransportServer#setMaximumConnections(int)}.
203     * @param maxConnectionThreadPoolSize
204     */
205    public void setMaxConnectionThreadPoolSize(int maxConnectionThreadPoolSize) {
206        this.maxConnectionThreadPoolSize = maxConnectionThreadPoolSize;
207        newConnectionExecutor.setCorePoolSize(maxConnectionThreadPoolSize);
208        newConnectionExecutor.setMaximumPoolSize(maxConnectionThreadPoolSize);
209        protocolDetectionExecutor.setCorePoolSize(maxConnectionThreadPoolSize);
210        protocolDetectionExecutor.setMaximumPoolSize(maxConnectionThreadPoolSize);
211    }
212
213    public void setProtocolDetectionTimeOut(int protocolDetectionTimeOut) {
214        this.protocolDetectionTimeOut = protocolDetectionTimeOut;
215    }
216
217    @Override
218    public void setWireFormatFactory(WireFormatFactory factory) {
219        super.setWireFormatFactory(factory);
220        initOpenWireProtocolVerifier();
221    }
222
223    protected void initProtocolVerifiers() {
224        initOpenWireProtocolVerifier();
225
226        if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.AMQP)) {
227            protocolVerifiers.put(AutoTransportUtils.AMQP, new AmqpProtocolVerifier());
228        }
229        if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.STOMP)) {
230            protocolVerifiers.put(AutoTransportUtils.STOMP, new StompProtocolVerifier());
231        }
232        if (isAllProtocols()|| enabledProtocols.contains(AutoTransportUtils.MQTT)) {
233            protocolVerifiers.put(AutoTransportUtils.MQTT, new MqttProtocolVerifier());
234        }
235    }
236
237    protected void initOpenWireProtocolVerifier() {
238        if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.OPENWIRE)) {
239            OpenWireProtocolVerifier owpv;
240            if (wireFormatFactory instanceof OpenWireFormatFactory) {
241                owpv = new OpenWireProtocolVerifier((OpenWireFormatFactory) wireFormatFactory);
242            } else {
243                owpv = new OpenWireProtocolVerifier(new OpenWireFormatFactory());
244            }
245            protocolVerifiers.put(AutoTransportUtils.OPENWIRE, owpv);
246        }
247    }
248
249    protected boolean isAllProtocols() {
250        return enabledProtocols == null || enabledProtocols.isEmpty();
251    }
252
253    @Override
254    protected void handleSocket(final Socket socket) {
255        final AutoTcpTransportServer server = this;
256        //This needs to be done in a new thread because
257        //the socket might be waiting on the client to send bytes
258        //doHandleSocket can't complete until the protocol can be detected
259        newConnectionExecutor.submit(new Runnable() {
260            @Override
261            public void run() {
262                server.doHandleSocket(socket);
263            }
264        });
265    }
266
267    @Override
268    protected TransportInfo configureTransport(final TcpTransportServer server, final Socket socket) throws Exception {
269        final InputStream is = socket.getInputStream();
270        final AtomicInteger readBytes = new AtomicInteger(0);
271        final ByteBuffer data = ByteBuffer.allocate(8);
272
273        // We need to peak at the first 8 bytes of the buffer to detect the protocol
274        Future<?> future = protocolDetectionExecutor.submit(new Runnable() {
275            @Override
276            public void run() {
277                try {
278                    do {
279                        //will block until enough bytes or read or a timeout
280                        //and the socket is closed
281                        int read = is.read();
282                        if (read == -1) {
283                            throw new IOException("Connection failed, stream is closed.");
284                        }
285                        data.put((byte) read);
286                        readBytes.incrementAndGet();
287                    } while (readBytes.get() < 8 && !Thread.interrupted());
288                } catch (Exception e) {
289                    throw new IllegalStateException(e);
290                }
291            }
292        });
293
294        try {
295            //If this fails and throws an exception and the socket will be closed
296            waitForProtocolDetectionFinish(future, readBytes);
297        } finally {
298            //call cancel in case task didn't complete
299            future.cancel(true);
300        }
301        data.flip();
302        ProtocolInfo protocolInfo = detectProtocol(data.array());
303
304        InitBuffer initBuffer = new InitBuffer(readBytes.get(), ByteBuffer.allocate(readBytes.get()));
305        initBuffer.buffer.put(data.array());
306
307        if (protocolInfo.detectedTransportFactory instanceof BrokerServiceAware) {
308            ((BrokerServiceAware) protocolInfo.detectedTransportFactory).setBrokerService(brokerService);
309        }
310
311        WireFormat format = protocolInfo.detectedWireFormatFactory.createWireFormat();
312        Transport transport = createTransport(socket, format, protocolInfo.detectedTransportFactory, initBuffer);
313
314        return new TransportInfo(format, transport, protocolInfo.detectedTransportFactory);
315    }
316
317    protected void waitForProtocolDetectionFinish(final Future<?> future, final AtomicInteger readBytes) throws Exception {
318        try {
319            //Wait for protocolDetectionTimeOut if defined
320            if (protocolDetectionTimeOut > 0) {
321                future.get(protocolDetectionTimeOut, TimeUnit.MILLISECONDS);
322            } else {
323                future.get();
324            }
325        } catch (TimeoutException e) {
326            throw new InactivityIOException("Client timed out before wire format could be detected. " +
327                    " 8 bytes are required to detect the protocol but only: " + readBytes.get() + " byte(s) were sent.");
328        }
329    }
330
331    /**
332     * @param socket
333     * @param format
334     * @param detectedTransportFactory
335     * @return
336     */
337    protected TcpTransport createTransport(Socket socket, WireFormat format,
338            TcpTransportFactory detectedTransportFactory, InitBuffer initBuffer) throws IOException {
339        return new TcpTransport(format, socket, initBuffer);
340    }
341
342    public void setWireFormatOptions(Map<String, Map<String, Object>> wireFormatOptions) {
343        this.wireFormatOptions = wireFormatOptions;
344    }
345
346    public void setEnabledProtocols(Set<String> enabledProtocols) {
347        this.enabledProtocols = enabledProtocols;
348    }
349
350    public void setAutoTransportOptions(Map<String, Object> autoTransportOptions) {
351        this.autoTransportOptions = autoTransportOptions;
352        if (autoTransportOptions.get("protocols") != null) {
353            this.enabledProtocols = AutoTransportUtils.parseProtocols((String) autoTransportOptions.get("protocols"));
354        }
355    }
356    @Override
357    protected void doStop(ServiceStopper stopper) throws Exception {
358        if (newConnectionExecutor != null) {
359            newConnectionExecutor.shutdownNow();
360            try {
361                if (!newConnectionExecutor.awaitTermination(3, TimeUnit.SECONDS)) {
362                    LOG.warn("Auto Transport newConnectionExecutor didn't shutdown cleanly");
363                }
364            } catch (InterruptedException e) {
365            }
366        }
367        if (protocolDetectionExecutor != null) {
368            protocolDetectionExecutor.shutdownNow();
369            try {
370                if (!protocolDetectionExecutor.awaitTermination(3, TimeUnit.SECONDS)) {
371                    LOG.warn("Auto Transport protocolDetectionExecutor didn't shutdown cleanly");
372                }
373            } catch (InterruptedException e) {
374            }
375        }
376        super.doStop(stopper);
377    }
378
379    protected ProtocolInfo detectProtocol(byte[] buffer) throws IOException {
380        TcpTransportFactory detectedTransportFactory = transportFactory;
381        WireFormatFactory detectedWireFormatFactory = wireFormatFactory;
382
383        boolean found = false;
384        for (String scheme : protocolVerifiers.keySet()) {
385            if (protocolVerifiers.get(scheme).isProtocol(buffer)) {
386                LOG.debug("Detected protocol " + scheme);
387                detectedWireFormatFactory = findWireFormatFactory(scheme, wireFormatOptions);
388
389                if (scheme.equals("default")) {
390                    scheme = "";
391                }
392
393                detectedTransportFactory = (TcpTransportFactory) findTransportFactory(scheme, transportOptions);
394                found = true;
395                break;
396            }
397        }
398
399        if (!found) {
400            throw new IllegalStateException("Could not detect the wire format");
401        }
402
403        return new ProtocolInfo(detectedTransportFactory, detectedWireFormatFactory);
404
405    }
406
407    protected class ProtocolInfo {
408        public final TcpTransportFactory detectedTransportFactory;
409        public final WireFormatFactory detectedWireFormatFactory;
410
411        public ProtocolInfo(TcpTransportFactory detectedTransportFactory,
412                WireFormatFactory detectedWireFormatFactory) {
413            super();
414            this.detectedTransportFactory = detectedTransportFactory;
415            this.detectedWireFormatFactory = detectedWireFormatFactory;
416        }
417    }
418
419}