001/*
002 * The contents of this file are subject to the terms of the Common Development and
003 * Distribution License (the License). You may not use this file except in compliance with the
004 * License.
005 *
006 * You can obtain a copy of the License at legal/CDDLv1.0.txt. See the License for the
007 * specific language governing permission and limitations under the License.
008 *
009 * When distributing Covered Software, include this CDDL Header Notice in each file and include
010 * the License file at legal/CDDLv1.0.txt. If applicable, add the following below the CDDL
011 * Header, with the fields enclosed by brackets [] replaced by your own identifying
012 * information: "Portions Copyright [year] [name of copyright owner]".
013 *
014 * Copyright 2008 Sun Microsystems, Inc.
015 * Portions Copyright 2015 ForgeRock AS.
016 */
017
018package org.opends.admin.ads.util;
019
020import java.io.IOException;
021import java.net.Socket;
022import java.net.InetAddress;
023import java.util.Map;
024import java.util.HashMap;
025
026import java.security.GeneralSecurityException;
027
028import javax.net.SocketFactory;
029import javax.net.ssl.KeyManager;
030import javax.net.ssl.SSLContext;
031import javax.net.ssl.SSLSocketFactory;
032import javax.net.ssl.SSLKeyException;
033import javax.net.ssl.TrustManager;
034
035/**
036 * An implementation of SSLSocketFactory.
037 */
038public class TrustedSocketFactory extends SSLSocketFactory
039{
040  private static Map<Thread, TrustManager> hmTrustManager = new HashMap<>();
041  private static Map<Thread, KeyManager> hmKeyManager = new HashMap<>();
042
043  private static Map<TrustManager, SocketFactory> hmDefaultFactoryTm = new HashMap<>();
044  private static Map<KeyManager, SocketFactory> hmDefaultFactoryKm = new HashMap<>();
045
046  private SSLSocketFactory innerFactory;
047  private TrustManager trustManager;
048  private KeyManager   keyManager;
049
050  /**
051   * Constructor of the TrustedSocketFactory.
052   * @param trustManager the trust manager to use.
053   * @param keyManager   the key manager to use.
054   */
055  public TrustedSocketFactory(TrustManager trustManager, KeyManager keyManager)
056  {
057    this.trustManager = trustManager;
058    this.keyManager   = keyManager;
059  }
060
061  /**
062   * Sets the provided trust and key manager for the operations in the
063   * current thread.
064   *
065   * @param trustManager
066   *          the trust manager to use.
067   * @param keyManager
068   *          the key manager to use.
069   */
070  public static synchronized void setCurrentThreadTrustManager(
071      TrustManager trustManager, KeyManager keyManager)
072  {
073    setThreadTrustManager(trustManager, Thread.currentThread());
074    setThreadKeyManager  (keyManager, Thread.currentThread());
075  }
076
077  /**
078   * Sets the provided trust manager for the operations in the provided thread.
079   * @param trustManager the trust manager to use.
080   * @param thread the thread where we want to use the provided trust manager.
081   */
082  public static synchronized void setThreadTrustManager(
083      TrustManager trustManager, Thread thread)
084  {
085    TrustManager currentTrustManager = hmTrustManager.get(thread);
086    if (currentTrustManager != null) {
087      hmDefaultFactoryTm.remove(currentTrustManager);
088      hmTrustManager.remove(thread);
089    }
090    if (trustManager != null) {
091      hmTrustManager.put(thread, trustManager);
092    }
093  }
094
095  /**
096   * Sets the provided key manager for the operations in the provided thread.
097   * @param keyManager the key manager to use.
098   * @param thread the thread where we want to use the provided key manager.
099   */
100  public static synchronized void setThreadKeyManager(
101      KeyManager keyManager, Thread thread)
102  {
103    KeyManager currentKeyManager = hmKeyManager.get(thread);
104    if (currentKeyManager != null) {
105      hmDefaultFactoryKm.remove(currentKeyManager);
106      hmKeyManager.remove(thread);
107    }
108    if (keyManager != null) {
109      hmKeyManager.put(thread, keyManager);
110    }
111  }
112
113  //
114  // SocketFactory implementation
115  //
116  /**
117   * Returns the default SSL socket factory. The default
118   * implementation can be changed by setting the value of the
119   * "ssl.SocketFactory.provider" security property (in the Java
120   * security properties file) to the desired class. If SSL has not
121   * been configured properly for this virtual machine, the factory
122   * will be inoperative (reporting instantiation exceptions).
123   *
124   * @return the default SocketFactory
125   */
126  public static synchronized SocketFactory getDefault()
127  {
128    Thread currentThread = Thread.currentThread();
129    TrustManager trustManager = hmTrustManager.get(currentThread);
130    KeyManager   keyManager   = hmKeyManager.get(currentThread);
131    SocketFactory result;
132
133    if (trustManager == null)
134    {
135      if (keyManager == null)
136      {
137        result = new TrustedSocketFactory(null,null);
138      }
139      else
140      {
141        result = hmDefaultFactoryKm.get(keyManager);
142        if (result == null)
143        {
144          result = new TrustedSocketFactory(null,keyManager);
145          hmDefaultFactoryKm.put(keyManager, result);
146        }
147      }
148    }
149    else
150    {
151      if (keyManager == null)
152      {
153        result = hmDefaultFactoryTm.get(trustManager);
154        if (result == null)
155        {
156          result = new TrustedSocketFactory(trustManager, null);
157          hmDefaultFactoryTm.put(trustManager, result);
158        }
159      }
160      else
161      {
162        SocketFactory tmsf = hmDefaultFactoryTm.get(trustManager);
163        SocketFactory kmsf = hmDefaultFactoryKm.get(keyManager);
164        if ( tmsf == null || kmsf == null)
165        {
166          result = new TrustedSocketFactory(trustManager, keyManager);
167          hmDefaultFactoryTm.put(trustManager, result);
168          hmDefaultFactoryKm.put(keyManager, result);
169        }
170        else
171        if ( !tmsf.equals(kmsf) )
172        {
173          result = new TrustedSocketFactory(trustManager, keyManager);
174          hmDefaultFactoryTm.put(trustManager, result);
175          hmDefaultFactoryKm.put(keyManager, result);
176        }
177        else
178        {
179          result = tmsf ;
180        }
181      }
182    }
183
184    return result;
185  }
186
187  /** {@inheritDoc} */
188  public Socket createSocket(InetAddress address, int port) throws IOException {
189    return getInnerFactory().createSocket(address, port);
190  }
191
192  /** {@inheritDoc} */
193  public Socket createSocket(InetAddress address, int port,
194      InetAddress clientAddress, int clientPort) throws IOException
195  {
196    return getInnerFactory().createSocket(address, port, clientAddress,
197        clientPort);
198  }
199
200  /** {@inheritDoc} */
201  public Socket createSocket(String host, int port) throws IOException
202  {
203    return getInnerFactory().createSocket(host, port);
204  }
205
206  /** {@inheritDoc} */
207  public Socket createSocket(String host, int port, InetAddress clientHost,
208      int clientPort) throws IOException
209  {
210    return getInnerFactory().createSocket(host, port, clientHost, clientPort);
211  }
212
213  /** {@inheritDoc} */
214  public Socket createSocket(Socket s, String host, int port, boolean autoClose)
215  throws IOException
216  {
217    return getInnerFactory().createSocket(s, host, port, autoClose);
218  }
219
220  /** {@inheritDoc} */
221  public String[] getDefaultCipherSuites()
222  {
223    try
224    {
225      return getInnerFactory().getDefaultCipherSuites();
226    }
227    catch(IOException x)
228    {
229      return new String[0];
230    }
231  }
232
233  /** {@inheritDoc} */
234  public String[] getSupportedCipherSuites()
235  {
236    try
237    {
238      return getInnerFactory().getSupportedCipherSuites();
239    }
240    catch(IOException x)
241    {
242      return new String[0];
243    }
244  }
245
246  private SSLSocketFactory getInnerFactory() throws IOException {
247    if (innerFactory == null)
248    {
249      String algorithm = "TLSv1";
250      SSLKeyException xx;
251      KeyManager[] km = null;
252      TrustManager[] tm = null;
253
254      try {
255        SSLContext sslCtx = SSLContext.getInstance(algorithm);
256        if (trustManager != null)
257        {
258          tm = new TrustManager[] { trustManager };
259        }
260        if (keyManager != null)
261        {
262          km = new KeyManager[] { keyManager };
263        }
264        sslCtx.init(km, tm, new java.security.SecureRandom() );
265        innerFactory = sslCtx.getSocketFactory();
266      }
267      catch(GeneralSecurityException x) {
268        xx = new SSLKeyException("Failed to create SSLContext for " +
269            algorithm);
270        xx.initCause(x);
271        throw xx;
272      }
273    }
274    return innerFactory;
275  }
276}
277