Blob Blame History Raw
From f9fcf18105845fbb933925ae7b0a2f1033f75127 Mon Sep 17 00:00:00 2001
From: Eduardo Otubo <otubo@redhat.com>
Date: Wed, 20 May 2020 10:11:14 +0200
Subject: [PATCH] url_helper: read_file_or_url should pass headers param into
 readurl (#66)

RH-Author: Eduardo Otubo <otubo@redhat.com>
Message-id: <20200519105653.20249-1-otubo@redhat.com>
Patchwork-id: 96613
O-Subject: [RHEL-7.8.z cloud-init PATCH] url_helper: read_file_or_url should pass headers param into readurl (#66)
Bugzilla: 1832177
RH-Acked-by: Cathy Avery <cavery@redhat.com>
RH-Acked-by: Mohammed Gamal <mgamal@redhat.com>
RH-Acked-by: Vitaly Kuznetsov <vkuznets@redhat.com>

commit f69d33a723b805fec3ee70c3a6127c8cadcb02d8
Author: Chad Smith <chad.smith@canonical.com>
Date:   Mon Dec 2 16:24:18 2019 -0700

    url_helper: read_file_or_url should pass headers param into readurl (#66)

    Headers param was accidentally omitted and no longer passed through to
    readurl due to a previous commit.

    To avoid this omission of params in the future, drop positional param
    definitions from read_file_or_url and pass all kwargs through to readurl
    when we are not operating on a file.

    In util:read_seeded, correct the case where invalid positional param
    file_retries was being passed into read_file_or_url.

    Also drop duplicated file:// prefix addition from read_seeded because
    read_file_or_url does that work anyway.

    LP: #1854084

Signed-off-by: Eduardo Otubo <otubo@redhat.com>
Signed-off-by: Miroslav Rezanina <mrezanin@redhat.com>
---
 cloudinit/sources/helpers/azure.py                 |  6 ++-
 cloudinit/tests/test_url_helper.py                 | 52 ++++++++++++++++++++++
 cloudinit/url_helper.py                            | 47 +++++++++++++++----
 cloudinit/user_data.py                             |  2 +-
 cloudinit/util.py                                  | 15 ++-----
 .../unittests/test_datasource/test_azure_helper.py | 18 +++++---
 6 files changed, 112 insertions(+), 28 deletions(-)

diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py
index c2a57cc..b99c484 100755
--- a/cloudinit/sources/helpers/azure.py
+++ b/cloudinit/sources/helpers/azure.py
@@ -103,14 +103,16 @@ class AzureEndpointHttpClient(object):
         if secure:
             headers = self.headers.copy()
             headers.update(self.extra_secure_headers)
-        return url_helper.read_file_or_url(url, headers=headers)
+        return url_helper.read_file_or_url(url, headers=headers, timeout=5,
+                                           retries=10)
 
     def post(self, url, data=None, extra_headers=None):
         headers = self.headers
         if extra_headers is not None:
             headers = self.headers.copy()
             headers.update(extra_headers)
-        return url_helper.read_file_or_url(url, data=data, headers=headers)
+        return url_helper.read_file_or_url(url, data=data, headers=headers,
+                                           timeout=5, retries=10)
 
 
 class GoalState(object):
diff --git a/cloudinit/tests/test_url_helper.py b/cloudinit/tests/test_url_helper.py
index aa9f3ec..e883ddc 100644
--- a/cloudinit/tests/test_url_helper.py
+++ b/cloudinit/tests/test_url_helper.py
@@ -4,6 +4,7 @@ from cloudinit.url_helper import (
     NOT_FOUND, UrlError, oauth_headers, read_file_or_url, retry_on_url_exc)
 from cloudinit.tests.helpers import CiTestCase, mock, skipIf
 from cloudinit import util
+from cloudinit import version
 
 import httpretty
 import requests
@@ -17,6 +18,9 @@ except ImportError:
     _missing_oauthlib_dep = True
 
 
+M_PATH = 'cloudinit.url_helper.'
+
+
 class TestOAuthHeaders(CiTestCase):
 
     def test_oauth_headers_raises_not_implemented_when_oathlib_missing(self):
@@ -67,6 +71,54 @@ class TestReadFileOrUrl(CiTestCase):
         self.assertEqual(result.contents, data)
         self.assertEqual(str(result), data.decode('utf-8'))
 
+    @mock.patch(M_PATH + 'readurl')
+    def test_read_file_or_url_passes_params_to_readurl(self, m_readurl):
+        """read_file_or_url passes all params through to readurl."""
+        url = 'http://hostname/path'
+        response = 'This is my url content\n'
+        m_readurl.return_value = response
+        params = {'url': url, 'timeout': 1, 'retries': 2,
+                  'headers': {'somehdr': 'val'},
+                  'data': 'data', 'sec_between': 1,
+                  'ssl_details': {'cert_file': '/path/cert.pem'},
+                  'headers_cb': 'headers_cb', 'exception_cb': 'exception_cb'}
+        self.assertEqual(response, read_file_or_url(**params))
+        params.pop('url')  # url is passed in as a positional arg
+        self.assertEqual([mock.call(url, **params)], m_readurl.call_args_list)
+
+    def test_wb_read_url_defaults_honored_by_read_file_or_url_callers(self):
+        """Readurl param defaults used when unspecified by read_file_or_url
+
+        Param defaults tested are as follows:
+            retries: 0, additional headers None beyond default, method: GET,
+            data: None, check_status: True and allow_redirects: True
+        """
+        url = 'http://hostname/path'
+
+        m_response = mock.MagicMock()
+
+        class FakeSession(requests.Session):
+            def request(cls, **kwargs):
+                self.assertEqual(
+                    {'url': url, 'allow_redirects': True, 'method': 'GET',
+                     'headers': {
+                         'User-Agent': 'Cloud-Init/%s' % (
+                             version.version_string())}},
+                    kwargs)
+                return m_response
+
+        with mock.patch(M_PATH + 'requests.Session') as m_session:
+            error = requests.exceptions.HTTPError('broke')
+            m_session.side_effect = [error, FakeSession()]
+            # assert no retries and check_status == True
+            with self.assertRaises(UrlError) as context_manager:
+                response = read_file_or_url(url)
+            self.assertEqual('broke', str(context_manager.exception))
+            # assert default headers, method, url and allow_redirects True
+            # Success on 2nd call with FakeSession
+            response = read_file_or_url(url)
+        self.assertEqual(m_response, response._response)
+
 
 class TestRetryOnUrlExc(CiTestCase):
 
diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py
index a951b8b..beb6873 100644
--- a/cloudinit/url_helper.py
+++ b/cloudinit/url_helper.py
@@ -81,14 +81,19 @@ def combine_url(base, *add_ons):
     return url
 
 
-def read_file_or_url(url, timeout=5, retries=10,
-                     headers=None, data=None, sec_between=1, ssl_details=None,
-                     headers_cb=None, exception_cb=None):
+def read_file_or_url(url, **kwargs):
+    """Wrapper function around readurl to allow passing a file path as url.
+
+    When url is not a local file path, passthrough any kwargs to readurl.
+
+    In the case of parameter passthrough to readurl, default values for some
+    parameters. See: call-signature of readurl in this module for param docs.
+    """
     url = url.lstrip()
     if url.startswith("/"):
         url = "file://%s" % url
     if url.lower().startswith("file://"):
-        if data:
+        if kwargs.get("data"):
             LOG.warning("Unable to post data to file resource %s", url)
         file_path = url[len("file://"):]
         try:
@@ -101,10 +106,7 @@ def read_file_or_url(url, timeout=5, retries=10,
             raise UrlError(cause=e, code=code, headers=None, url=url)
         return FileResponse(file_path, contents=contents)
     else:
-        return readurl(url, timeout=timeout, retries=retries,
-                       headers_cb=headers_cb, data=data,
-                       sec_between=sec_between, ssl_details=ssl_details,
-                       exception_cb=exception_cb)
+        return readurl(url, **kwargs)
 
 
 # Made to have same accessors as UrlResponse so that the
@@ -201,6 +203,35 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1,
             check_status=True, allow_redirects=True, exception_cb=None,
             session=None, infinite=False, log_req_resp=True,
             request_method=None):
+    """Wrapper around requests.Session to read the url and retry if necessary
+
+    :param url: Mandatory url to request.
+    :param data: Optional form data to post the URL. Will set request_method
+        to 'POST' if present.
+    :param timeout: Timeout in seconds to wait for a response
+    :param retries: Number of times to retry on exception if exception_cb is
+        None or exception_cb returns True for the exception caught. Default is
+        to fail with 0 retries on exception.
+    :param sec_between: Default 1: amount of seconds passed to time.sleep
+        between retries. None or -1 means don't sleep.
+    :param headers: Optional dict of headers to send during request
+    :param headers_cb: Optional callable returning a dict of values to send as
+        headers during request
+    :param ssl_details: Optional dict providing key_file, ca_certs, and
+        cert_file keys for use on in ssl connections.
+    :param check_status: Optional boolean set True to raise when HTTPError
+        occurs. Default: True.
+    :param allow_redirects: Optional boolean passed straight to Session.request
+        as 'allow_redirects'. Default: True.
+    :param exception_cb: Optional callable which accepts the params
+        msg and exception and returns a boolean True if retries are permitted.
+    :param session: Optional exiting requests.Session instance to reuse.
+    :param infinite: Bool, set True to retry indefinitely. Default: False.
+    :param log_req_resp: Set False to turn off verbose debug messages.
+    :param request_method: String passed as 'method' to Session.request.
+        Typically GET, or POST. Default: POST if data is provided, GET
+        otherwise.
+    """
     url = _cleanurl(url)
     req_args = {
         'url': url,
diff --git a/cloudinit/user_data.py b/cloudinit/user_data.py
index ed83d2d..15af1da 100644
--- a/cloudinit/user_data.py
+++ b/cloudinit/user_data.py
@@ -224,7 +224,7 @@ class UserDataProcessor(object):
                 content = util.load_file(include_once_fn)
             else:
                 try:
-                    resp = read_file_or_url(include_url,
+                    resp = read_file_or_url(include_url, timeout=5, retries=10,
                                             ssl_details=self.ssl_details)
                     if include_once_on and resp.ok():
                         util.write_file(include_once_fn, resp.contents,
diff --git a/cloudinit/util.py b/cloudinit/util.py
index 2c9ac66..db9a229 100644
--- a/cloudinit/util.py
+++ b/cloudinit/util.py
@@ -966,13 +966,6 @@ def load_yaml(blob, default=None, allowed=(dict,)):
 
 
 def read_seeded(base="", ext="", timeout=5, retries=10, file_retries=0):
-    if base.startswith("/"):
-        base = "file://%s" % base
-
-    # default retries for file is 0. for network is 10
-    if base.startswith("file://"):
-        retries = file_retries
-
     if base.find("%s") >= 0:
         ud_url = base % ("user-data" + ext)
         md_url = base % ("meta-data" + ext)
@@ -980,14 +973,14 @@ def read_seeded(base="", ext="", timeout=5, retries=10, file_retries=0):
         ud_url = "%s%s%s" % (base, "user-data", ext)
         md_url = "%s%s%s" % (base, "meta-data", ext)
 
-    md_resp = url_helper.read_file_or_url(md_url, timeout, retries,
-                                          file_retries)
+    md_resp = url_helper.read_file_or_url(md_url, timeout=timeout,
+                                          retries=retries)
     md = None
     if md_resp.ok():
         md = load_yaml(decode_binary(md_resp.contents), default={})
 
-    ud_resp = url_helper.read_file_or_url(ud_url, timeout, retries,
-                                          file_retries)
+    ud_resp = url_helper.read_file_or_url(ud_url, timeout=timeout,
+                                          retries=retries)
     ud = None
     if ud_resp.ok():
         ud = ud_resp.contents
diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py
index 7ad5cc1..007df09 100644
--- a/tests/unittests/test_datasource/test_azure_helper.py
+++ b/tests/unittests/test_datasource/test_azure_helper.py
@@ -231,8 +231,10 @@ class TestAzureEndpointHttpClient(CiTestCase):
         response = client.get(url, secure=False)
         self.assertEqual(1, self.read_file_or_url.call_count)
         self.assertEqual(self.read_file_or_url.return_value, response)
-        self.assertEqual(mock.call(url, headers=self.regular_headers),
-                         self.read_file_or_url.call_args)
+        self.assertEqual(
+            mock.call(url, headers=self.regular_headers, retries=10,
+                      timeout=5),
+            self.read_file_or_url.call_args)
 
     def test_secure_get(self):
         url = 'MyTestUrl'
@@ -246,8 +248,10 @@ class TestAzureEndpointHttpClient(CiTestCase):
         response = client.get(url, secure=True)
         self.assertEqual(1, self.read_file_or_url.call_count)
         self.assertEqual(self.read_file_or_url.return_value, response)
-        self.assertEqual(mock.call(url, headers=expected_headers),
-                         self.read_file_or_url.call_args)
+        self.assertEqual(
+            mock.call(url, headers=expected_headers, retries=10,
+                      timeout=5),
+            self.read_file_or_url.call_args)
 
     def test_post(self):
         data = mock.MagicMock()
@@ -257,7 +261,8 @@ class TestAzureEndpointHttpClient(CiTestCase):
         self.assertEqual(1, self.read_file_or_url.call_count)
         self.assertEqual(self.read_file_or_url.return_value, response)
         self.assertEqual(
-            mock.call(url, data=data, headers=self.regular_headers),
+            mock.call(url, data=data, headers=self.regular_headers, retries=10,
+                      timeout=5),
             self.read_file_or_url.call_args)
 
     def test_post_with_extra_headers(self):
@@ -269,7 +274,8 @@ class TestAzureEndpointHttpClient(CiTestCase):
         expected_headers = self.regular_headers.copy()
         expected_headers.update(extra_headers)
         self.assertEqual(
-            mock.call(mock.ANY, data=mock.ANY, headers=expected_headers),
+            mock.call(mock.ANY, data=mock.ANY, headers=expected_headers,
+                      retries=10, timeout=5),
             self.read_file_or_url.call_args)
 
 
-- 
1.8.3.1