Blame SOURCES/00295-fix-https-behind-proxy.patch

25c6e4
diff --git a/Lib/httplib.py b/Lib/httplib.py
25c6e4
index 592ee57..b69145b 100644
25c6e4
--- a/Lib/httplib.py
25c6e4
+++ b/Lib/httplib.py
25c6e4
@@ -735,25 +735,40 @@ class HTTPConnection:
25c6e4
         self._tunnel_host = None
25c6e4
         self._tunnel_port = None
25c6e4
         self._tunnel_headers = {}
25c6e4
-
25c6e4
-        self._set_hostport(host, port)
25c6e4
         if strict is not None:
25c6e4
             self.strict = strict
25c6e4
 
25c6e4
+        (self.host, self.port) = self._get_hostport(host, port)
25c6e4
+
25c6e4
+        # This is stored as an instance variable to allow unittests
25c6e4
+        # to replace with a suitable mock
25c6e4
+        self._create_connection = socket.create_connection
25c6e4
+
25c6e4
     def set_tunnel(self, host, port=None, headers=None):
25c6e4
-        """ Sets up the host and the port for the HTTP CONNECT Tunnelling.
25c6e4
+        """ Set up host and port for HTTP CONNECT tunnelling.
25c6e4
+
25c6e4
+        In a connection that uses HTTP Connect tunneling, the host passed to the
25c6e4
+        constructor is used as proxy server that relays all communication to the
25c6e4
+        endpoint passed to set_tunnel. This is done by sending a HTTP CONNECT
25c6e4
+        request to the proxy server when the connection is established.
25c6e4
+
25c6e4
+        This method must be called before the HTTP connection has been
25c6e4
+        established.
25c6e4
 
25c6e4
         The headers argument should be a mapping of extra HTTP headers
25c6e4
         to send with the CONNECT request.
25c6e4
         """
25c6e4
-        self._tunnel_host = host
25c6e4
-        self._tunnel_port = port
25c6e4
+        # Verify if this is required.
25c6e4
+        if self.sock:
25c6e4
+            raise RuntimeError("Can't setup tunnel for established connection.")
25c6e4
+
25c6e4
+        self._tunnel_host, self._tunnel_port = self._get_hostport(host, port)
25c6e4
         if headers:
25c6e4
             self._tunnel_headers = headers
25c6e4
         else:
25c6e4
             self._tunnel_headers.clear()
25c6e4
 
25c6e4
-    def _set_hostport(self, host, port):
25c6e4
+    def _get_hostport(self, host, port):
25c6e4
         if port is None:
25c6e4
             i = host.rfind(':')
25c6e4
             j = host.rfind(']')         # ipv6 addresses have [...]
25c6e4
@@ -770,15 +785,14 @@ class HTTPConnection:
25c6e4
                 port = self.default_port
25c6e4
             if host and host[0] == '[' and host[-1] == ']':
25c6e4
                 host = host[1:-1]
25c6e4
-        self.host = host
25c6e4
-        self.port = port
25c6e4
+        return (host, port)
25c6e4
 
25c6e4
     def set_debuglevel(self, level):
25c6e4
         self.debuglevel = level
25c6e4
 
25c6e4
     def _tunnel(self):
25c6e4
-        self._set_hostport(self._tunnel_host, self._tunnel_port)
25c6e4
-        self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port))
25c6e4
+        self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self._tunnel_host,
25c6e4
+            self._tunnel_port))
25c6e4
         for header, value in self._tunnel_headers.iteritems():
25c6e4
             self.send("%s: %s\r\n" % (header, value))
25c6e4
         self.send("\r\n")
25c6e4
@@ -803,8 +817,8 @@ class HTTPConnection:
25c6e4
 
25c6e4
     def connect(self):
25c6e4
         """Connect to the host and port specified in __init__."""
25c6e4
-        self.sock = socket.create_connection((self.host,self.port),
25c6e4
-                                             self.timeout, self.source_address)
25c6e4
+        self.sock = self._create_connection((self.host,self.port),
25c6e4
+                                           self.timeout, self.source_address)
25c6e4
 
25c6e4
         if self._tunnel_host:
25c6e4
             self._tunnel()
25c6e4
@@ -942,17 +956,24 @@ class HTTPConnection:
25c6e4
                         netloc_enc = netloc.encode("idna")
25c6e4
                     self.putheader('Host', netloc_enc)
25c6e4
                 else:
25c6e4
+                    if self._tunnel_host:
25c6e4
+                        host = self._tunnel_host
25c6e4
+                        port = self._tunnel_port
25c6e4
+                    else:
25c6e4
+                        host = self.host
25c6e4
+                        port = self.port
25c6e4
+
25c6e4
                     try:
25c6e4
-                        host_enc = self.host.encode("ascii")
25c6e4
+                        host_enc = host.encode("ascii")
25c6e4
                     except UnicodeEncodeError:
25c6e4
-                        host_enc = self.host.encode("idna")
25c6e4
+                        host_enc = host.encode("idna")
25c6e4
                     # Wrap the IPv6 Host Header with [] (RFC 2732)
25c6e4
                     if host_enc.find(':') >= 0:
25c6e4
                         host_enc = "[" + host_enc + "]"
25c6e4
-                    if self.port == self.default_port:
25c6e4
+                    if port == self.default_port:
25c6e4
                         self.putheader('Host', host_enc)
25c6e4
                     else:
25c6e4
-                        self.putheader('Host', "%s:%s" % (host_enc, self.port))
25c6e4
+                        self.putheader('Host', "%s:%s" % (host_enc, port))
25c6e4
 
25c6e4
             # note: we are assuming that clients will not attempt to set these
25c6e4
             #       headers since *this* library must deal with the
25c6e4
@@ -1141,7 +1162,7 @@ class HTTP:
25c6e4
         "Accept arguments to set the host/port, since the superclass doesn't."
25c6e4
 
25c6e4
         if host is not None:
25c6e4
-            self._conn._set_hostport(host, port)
25c6e4
+            (self._conn.host, self._conn.port) = self._conn._get_hostport(host, port)
25c6e4
         self._conn.connect()
25c6e4
 
25c6e4
     def getfile(self):
25c6e4
diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py
25c6e4
index 29af589..9db30cc 100644
25c6e4
--- a/Lib/test/test_httplib.py
25c6e4
+++ b/Lib/test/test_httplib.py
25c6e4
@@ -21,10 +21,12 @@ CERT_selfsigned_pythontestdotnet = os.path.join(here, 'selfsigned_pythontestdotn
25c6e4
 HOST = test_support.HOST
25c6e4
 
25c6e4
 class FakeSocket:
25c6e4
-    def __init__(self, text, fileclass=StringIO.StringIO):
25c6e4
+    def __init__(self, text, fileclass=StringIO.StringIO, host=None, port=None):
25c6e4
         self.text = text
25c6e4
         self.fileclass = fileclass
25c6e4
         self.data = ''
25c6e4
+        self.host = host
25c6e4
+        self.port = port
25c6e4
 
25c6e4
     def sendall(self, data):
25c6e4
         self.data += ''.join(data)
25c6e4
@@ -34,6 +36,9 @@ class FakeSocket:
25c6e4
             raise httplib.UnimplementedFileMode()
25c6e4
         return self.fileclass(self.text)
25c6e4
 
25c6e4
+    def close(self):
25c6e4
+        pass
25c6e4
+
25c6e4
 class EPipeSocket(FakeSocket):
25c6e4
 
25c6e4
     def __init__(self, text, pipe_trigger):
25c6e4
@@ -487,7 +492,11 @@ class OfflineTest(TestCase):
25c6e4
         self.assertEqual(httplib.responses[httplib.NOT_FOUND], "Not Found")
25c6e4
 
25c6e4
 
25c6e4
-class SourceAddressTest(TestCase):
25c6e4
+class TestServerMixin:
25c6e4
+    """A limited socket server mixin.
25c6e4
+
25c6e4
+    This is used by test cases for testing http connection end points.
25c6e4
+    """
25c6e4
     def setUp(self):
25c6e4
         self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25c6e4
         self.port = test_support.bind_port(self.serv)
25c6e4
@@ -502,6 +511,7 @@ class SourceAddressTest(TestCase):
25c6e4
         self.serv.close()
25c6e4
         self.serv = None
25c6e4
 
25c6e4
+class SourceAddressTest(TestServerMixin, TestCase):
25c6e4
     def testHTTPConnectionSourceAddress(self):
25c6e4
         self.conn = httplib.HTTPConnection(HOST, self.port,
25c6e4
                 source_address=('', self.source_port))
25c6e4
@@ -518,6 +528,24 @@ class SourceAddressTest(TestCase):
25c6e4
         # for an ssl_wrapped connect() to actually return from.
25c6e4
 
25c6e4
 
25c6e4
+class HTTPTest(TestServerMixin, TestCase):
25c6e4
+    def testHTTPConnection(self):
25c6e4
+        self.conn = httplib.HTTP(host=HOST, port=self.port, strict=None)
25c6e4
+        self.conn.connect()
25c6e4
+        self.assertEqual(self.conn._conn.host, HOST)
25c6e4
+        self.assertEqual(self.conn._conn.port, self.port)
25c6e4
+
25c6e4
+    def testHTTPWithConnectHostPort(self):
25c6e4
+        testhost = 'unreachable.test.domain'
25c6e4
+        testport = '80'
25c6e4
+        self.conn = httplib.HTTP(host=testhost, port=testport)
25c6e4
+        self.conn.connect(host=HOST, port=self.port)
25c6e4
+        self.assertNotEqual(self.conn._conn.host, testhost)
25c6e4
+        self.assertNotEqual(self.conn._conn.port, testport)
25c6e4
+        self.assertEqual(self.conn._conn.host, HOST)
25c6e4
+        self.assertEqual(self.conn._conn.port, self.port)
25c6e4
+
25c6e4
+
25c6e4
 class TimeoutTest(TestCase):
25c6e4
     PORT = None
25c6e4
 
25c6e4
@@ -716,13 +744,54 @@ class HTTPSTest(TestCase):
25c6e4
             c = httplib.HTTPSConnection(hp, context=context)
25c6e4
             self.assertEqual(h, c.host)
25c6e4
             self.assertEqual(p, c.port)
25c6e4
- 
25c6e4
+
25c6e4
+class TunnelTests(TestCase):
25c6e4
+    def test_connect(self):
25c6e4
+        response_text = (
25c6e4
+            'HTTP/1.0 200 OK\r\n\r\n'   # Reply to CONNECT
25c6e4
+            'HTTP/1.1 200 OK\r\n'       # Reply to HEAD
25c6e4
+            'Content-Length: 42\r\n\r\n'
25c6e4
+        )
25c6e4
+
25c6e4
+        def create_connection(address, timeout=None, source_address=None):
25c6e4
+            return FakeSocket(response_text, host=address[0], port=address[1])
25c6e4
+
25c6e4
+        conn = httplib.HTTPConnection('proxy.com')
25c6e4
+        conn._create_connection = create_connection
25c6e4
+
25c6e4
+        # Once connected, we should not be able to tunnel anymore
25c6e4
+        conn.connect()
25c6e4
+        self.assertRaises(RuntimeError, conn.set_tunnel, 'destination.com')
25c6e4
+
25c6e4
+        # But if close the connection, we are good.
25c6e4
+        conn.close()
25c6e4
+        conn.set_tunnel('destination.com')
25c6e4
+        conn.request('HEAD', '/', '')
25c6e4
+
25c6e4
+        self.assertEqual(conn.sock.host, 'proxy.com')
25c6e4
+        self.assertEqual(conn.sock.port, 80)
25c6e4
+        self.assertIn('CONNECT destination.com', conn.sock.data)
25c6e4
+        # issue22095
25c6e4
+        self.assertNotIn('Host: destination.com:None', conn.sock.data)
25c6e4
+        # issue22095
25c6e4
+
25c6e4
+        self.assertNotIn('Host: proxy.com', conn.sock.data)
25c6e4
+
25c6e4
+        conn.close()
25c6e4
+
25c6e4
+        conn.request('PUT', '/', '')
25c6e4
+        self.assertEqual(conn.sock.host, 'proxy.com')
25c6e4
+        self.assertEqual(conn.sock.port, 80)
25c6e4
+        self.assertTrue('CONNECT destination.com' in conn.sock.data)
25c6e4
+        self.assertTrue('Host: destination.com' in conn.sock.data)
25c6e4
+
25c6e4
 
25c6e4
 
25c6e4
 @test_support.reap_threads
25c6e4
 def test_main(verbose=None):
25c6e4
     test_support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
25c6e4
-                              HTTPSTest, SourceAddressTest)
25c6e4
+                              HTTPTest, HTTPSTest, SourceAddressTest,
25c6e4
+                               TunnelTests)
25c6e4
 
25c6e4
 if __name__ == '__main__':
25c6e4
     test_main()