diff --git a/requests/sessions.py b/requests/sessions.py index ef3f22b..a0a23ee 100644 --- a/requests/sessions.py +++ b/requests/sessions.py @@ -89,6 +89,23 @@ def merge_hooks(request_hooks, session_hooks, dict_class=OrderedDict): class SessionRedirectMixin(object): + + def should_strip_auth(self, old_url, new_url): + """Decide whether Authorization header should be removed when redirecting""" + old_parsed = urlparse(old_url) + new_parsed = urlparse(new_url) + if old_parsed.hostname != new_parsed.hostname: + return True + # Special case: allow http -> https redirect when using the standard + # ports. This isn't specified by RFC 7235, but is kept to avoid + # breaking backwards compatibility with older versions of requests + # that allowed any redirects on the same host. + if (old_parsed.scheme == 'http' and old_parsed.port in (80, None) + and new_parsed.scheme == 'https' and new_parsed.port in (443, None)): + return False + # Standard case: root URI must match + return old_parsed.port != new_parsed.port or old_parsed.scheme != new_parsed.scheme + def resolve_redirects(self, resp, req, stream=False, timeout=None, verify=True, cert=None, proxies=None): """Receives a Response. Returns a generator of Responses.""" @@ -209,14 +226,10 @@ class SessionRedirectMixin(object): headers = prepared_request.headers url = prepared_request.url - if 'Authorization' in headers: + if 'Authorization' in headers and self.should_strip_auth(response.request.url, url): # If we get redirected to a new host, we should strip out any # authentication headers. - original_parsed = urlparse(response.request.url) - redirect_parsed = urlparse(url) - - if (original_parsed.hostname != redirect_parsed.hostname): - del headers['Authorization'] + del headers['Authorization'] # .netrc might have more auth for us on our new host. new_auth = get_netrc_auth(url) if self.trust_env else None diff --git a/test_requests.py b/test_requests.py index 15406a2..e19b436 100755 --- a/test_requests.py +++ b/test_requests.py @@ -991,6 +991,27 @@ class RequestsTestCase(unittest.TestCase): assert h1 == h2 + def test_should_strip_auth_host_change(self): + s = requests.Session() + assert s.should_strip_auth('http://example.com/foo', 'http://another.example.com/') + + def test_should_strip_auth_http_downgrade(self): + s = requests.Session() + assert s.should_strip_auth('https://example.com/foo', 'http://example.com/bar') + + def test_should_strip_auth_https_upgrade(self): + s = requests.Session() + assert not s.should_strip_auth('http://example.com/foo', 'https://example.com/bar') + assert not s.should_strip_auth('http://example.com:80/foo', 'https://example.com/bar') + assert not s.should_strip_auth('http://example.com/foo', 'https://example.com:443/bar') + # Non-standard ports should trigger stripping + assert s.should_strip_auth('http://example.com:8080/foo', 'https://example.com/bar') + assert s.should_strip_auth('http://example.com/foo', 'https://example.com:8443/bar') + + def test_should_strip_auth_port_change(self): + s = requests.Session() + assert s.should_strip_auth('http://example.com:1234/foo', 'https://example.com:4321/bar') + def test_manual_redirect_with_partial_body_read(self): s = requests.Session() r1 = s.get(httpbin('redirect/2'), allow_redirects=False, stream=True)