Blob Blame History Raw
From 68b3718124b63fdf0c077452b559f0fccb01200d Mon Sep 17 00:00:00 2001
From: Eduardo Otubo <otubo@redhat.com>
Date: Tue, 5 May 2020 08:08:32 +0200
Subject: [PATCH 5/5] ec2: Add support for AWS IMDS v2 (session-oriented) (#55)

RH-Author: Eduardo Otubo <otubo@redhat.com>
Message-id: <20200504085238.25884-6-otubo@redhat.com>
Patchwork-id: 96245
O-Subject: [RHEL-7.8.z cloud-init PATCH 5/5] ec2: Add support for AWS IMDS v2 (session-oriented) (#55)
Bugzilla: 1827207
RH-Acked-by: Cathy Avery <cavery@redhat.com>
RH-Acked-by: Mohammed Gamal <mgamal@redhat.com>
RH-Acked-by: Vitaly Kuznetsov <vkuznets@redhat.com>

commit 4bc399e0cd0b7e9177f948aecd49f6b8323ff30b
Author: Ryan Harper <ryan.harper@canonical.com>
Date:   Fri Nov 22 21:05:44 2019 -0600

    ec2: Add support for AWS IMDS v2 (session-oriented) (#55)

    * ec2: Add support for AWS IMDS v2 (session-oriented)

    AWS now supports a new version of fetching Instance Metadata[1].

    Update cloud-init's ec2 utility functions and update ec2 derived
    datasources accordingly.  For DataSourceEc2 (versus ec2-look-alikes)
    cloud-init will issue the PUT request to obtain an API token for
    the maximum lifetime and then all subsequent interactions with the
    IMDS will include the token in the header.

    If the API token endpoint is unreachable on Ec2 platform, log a
    warning and fallback to using IMDS v1 and which does not use
    session tokens when communicating with the Instance metadata
    service.

    We handle read errors, typically seen if the IMDS is beyond one
    etwork hop (IMDSv2 responses have a ttl=1), by setting the api token
    to a disabled value and then using IMDSv1 paths.

    To support token-based headers, ec2_utils functions were updated
    to support custom headers_cb and exception_cb callback functions
    so Ec2 could store, or refresh API tokens in the event of token
    becoming stale.

    [1] https://docs.aws.amazon.com/AWSEC2/latest/ \
    UserGuide/ec2-instance-metadata.html \
    #instance-metadata-v2-how-it-works

Signed-off-by: Eduardo Otubo <otubo@redhat.com>
Signed-off-by: Miroslav Rezanina <mrezanin@redhat.com>
---
 cloudinit/ec2_utils.py                             |  37 +++--
 cloudinit/sources/DataSourceCloudStack.py          |   2 +-
 cloudinit/sources/DataSourceEc2.py                 | 166 ++++++++++++++++++---
 cloudinit/sources/DataSourceExoscale.py            |   2 +-
 cloudinit/sources/DataSourceMAAS.py                |   2 +-
 cloudinit/sources/DataSourceOpenStack.py           |   2 +-
 cloudinit/url_helper.py                            |  15 +-
 tests/unittests/test_datasource/test_cloudstack.py |  21 ++-
 tests/unittests/test_datasource/test_ec2.py        |   6 +-
 9 files changed, 201 insertions(+), 52 deletions(-)

diff --git a/cloudinit/ec2_utils.py b/cloudinit/ec2_utils.py
index 3b7b17f..57708c1 100644
--- a/cloudinit/ec2_utils.py
+++ b/cloudinit/ec2_utils.py
@@ -134,25 +134,28 @@ class MetadataMaterializer(object):
         return joined
 
 
-def _skip_retry_on_codes(status_codes, _request_args, cause):
+def skip_retry_on_codes(status_codes, _request_args, cause):
     """Returns False if cause.code is in status_codes."""
     return cause.code not in status_codes
 
 
 def get_instance_userdata(api_version='latest',
                           metadata_address='http://169.254.169.254',
-                          ssl_details=None, timeout=5, retries=5):
+                          ssl_details=None, timeout=5, retries=5,
+                          headers_cb=None, exception_cb=None):
     ud_url = url_helper.combine_url(metadata_address, api_version)
     ud_url = url_helper.combine_url(ud_url, 'user-data')
     user_data = ''
     try:
-        # It is ok for userdata to not exist (thats why we are stopping if
-        # NOT_FOUND occurs) and just in that case returning an empty string.
-        exception_cb = functools.partial(_skip_retry_on_codes,
-                                         SKIP_USERDATA_CODES)
+        if not exception_cb:
+            # It is ok for userdata to not exist (thats why we are stopping if
+            # NOT_FOUND occurs) and just in that case returning an empty
+            # string.
+            exception_cb = functools.partial(skip_retry_on_codes,
+                                             SKIP_USERDATA_CODES)
         response = url_helper.read_file_or_url(
             ud_url, ssl_details=ssl_details, timeout=timeout,
-            retries=retries, exception_cb=exception_cb)
+            retries=retries, exception_cb=exception_cb, headers_cb=headers_cb)
         user_data = response.contents
     except url_helper.UrlError as e:
         if e.code not in SKIP_USERDATA_CODES:
@@ -165,11 +168,13 @@ def get_instance_userdata(api_version='latest',
 def _get_instance_metadata(tree, api_version='latest',
                            metadata_address='http://169.254.169.254',
                            ssl_details=None, timeout=5, retries=5,
-                           leaf_decoder=None):
+                           leaf_decoder=None, headers_cb=None,
+                           exception_cb=None):
     md_url = url_helper.combine_url(metadata_address, api_version, tree)
     caller = functools.partial(
         url_helper.read_file_or_url, ssl_details=ssl_details,
-        timeout=timeout, retries=retries)
+        timeout=timeout, retries=retries, headers_cb=headers_cb,
+        exception_cb=exception_cb)
 
     def mcaller(url):
         return caller(url).contents
@@ -191,22 +196,28 @@ def _get_instance_metadata(tree, api_version='latest',
 def get_instance_metadata(api_version='latest',
                           metadata_address='http://169.254.169.254',
                           ssl_details=None, timeout=5, retries=5,
-                          leaf_decoder=None):
+                          leaf_decoder=None, headers_cb=None,
+                          exception_cb=None):
     # Note, 'meta-data' explicitly has trailing /.
     # this is required for CloudStack (LP: #1356855)
     return _get_instance_metadata(tree='meta-data/', api_version=api_version,
                                   metadata_address=metadata_address,
                                   ssl_details=ssl_details, timeout=timeout,
-                                  retries=retries, leaf_decoder=leaf_decoder)
+                                  retries=retries, leaf_decoder=leaf_decoder,
+                                  headers_cb=headers_cb,
+                                  exception_cb=exception_cb)
 
 
 def get_instance_identity(api_version='latest',
                           metadata_address='http://169.254.169.254',
                           ssl_details=None, timeout=5, retries=5,
-                          leaf_decoder=None):
+                          leaf_decoder=None, headers_cb=None,
+                          exception_cb=None):
     return _get_instance_metadata(tree='dynamic/instance-identity',
                                   api_version=api_version,
                                   metadata_address=metadata_address,
                                   ssl_details=ssl_details, timeout=timeout,
-                                  retries=retries, leaf_decoder=leaf_decoder)
+                                  retries=retries, leaf_decoder=leaf_decoder,
+                                  headers_cb=headers_cb,
+                                  exception_cb=exception_cb)
 # vi: ts=4 expandtab
diff --git a/cloudinit/sources/DataSourceCloudStack.py b/cloudinit/sources/DataSourceCloudStack.py
index d4b758f..6bd2efe 100644
--- a/cloudinit/sources/DataSourceCloudStack.py
+++ b/cloudinit/sources/DataSourceCloudStack.py
@@ -93,7 +93,7 @@ class DataSourceCloudStack(sources.DataSource):
         urls = [uhelp.combine_url(self.metadata_address,
                                   'latest/meta-data/instance-id')]
         start_time = time.time()
-        url = uhelp.wait_for_url(
+        url, _response = uhelp.wait_for_url(
             urls=urls, max_wait=url_params.max_wait_seconds,
             timeout=url_params.timeout_seconds, status_cb=LOG.warn)
 
diff --git a/cloudinit/sources/DataSourceEc2.py b/cloudinit/sources/DataSourceEc2.py
index 9ccf2cd..fbe8f3f 100644
--- a/cloudinit/sources/DataSourceEc2.py
+++ b/cloudinit/sources/DataSourceEc2.py
@@ -27,6 +27,10 @@ SKIP_METADATA_URL_CODES = frozenset([uhelp.NOT_FOUND])
 STRICT_ID_PATH = ("datasource", "Ec2", "strict_id")
 STRICT_ID_DEFAULT = "warn"
 
+API_TOKEN_ROUTE = 'latest/api/token'
+API_TOKEN_DISABLED = '_ec2_disable_api_token'
+AWS_TOKEN_TTL_SECONDS = '21600'
+
 
 class CloudNames(object):
     ALIYUN = "aliyun"
@@ -59,6 +63,7 @@ class DataSourceEc2(sources.DataSource):
     url_max_wait = 120
     url_timeout = 50
 
+    _api_token = None  # API token for accessing the metadata service
     _network_config = sources.UNSET  # Used to cache calculated network cfg v1
 
     # Whether we want to get network configuration from the metadata service.
@@ -132,11 +137,12 @@ class DataSourceEc2(sources.DataSource):
         min_metadata_version.
         """
         # Assumes metadata service is already up
+        url_tmpl = '{0}/{1}/meta-data/instance-id'
+        headers = self._get_headers()
         for api_ver in self.extended_metadata_versions:
-            url = '{0}/{1}/meta-data/instance-id'.format(
-                self.metadata_address, api_ver)
+            url = url_tmpl.format(self.metadata_address, api_ver)
             try:
-                resp = uhelp.readurl(url=url)
+                resp = uhelp.readurl(url=url, headers=headers)
             except uhelp.UrlError as e:
                 LOG.debug('url %s raised exception %s', url, e)
             else:
@@ -156,12 +162,39 @@ class DataSourceEc2(sources.DataSource):
                 # setup self.identity. So we need to do that now.
                 api_version = self.get_metadata_api_version()
                 self.identity = ec2.get_instance_identity(
-                    api_version, self.metadata_address).get('document', {})
+                    api_version, self.metadata_address,
+                    headers_cb=self._get_headers,
+                    exception_cb=self._refresh_stale_aws_token_cb).get(
+                        'document', {})
             return self.identity.get(
                 'instanceId', self.metadata['instance-id'])
         else:
             return self.metadata['instance-id']
 
+    def _maybe_fetch_api_token(self, mdurls, timeout=None, max_wait=None):
+        if self.cloud_name != CloudNames.AWS:
+            return
+
+        urls = []
+        url2base = {}
+        url_path = API_TOKEN_ROUTE
+        request_method = 'PUT'
+        for url in mdurls:
+            cur = '{0}/{1}'.format(url, url_path)
+            urls.append(cur)
+            url2base[cur] = url
+
+        # use the self._status_cb to check for Read errors, which means
+        # we can't reach the API token URL, so we should disable IMDSv2
+        LOG.debug('Fetching Ec2 IMDSv2 API Token')
+        url, response = uhelp.wait_for_url(
+            urls=urls, max_wait=1, timeout=1, status_cb=self._status_cb,
+            headers_cb=self._get_headers, request_method=request_method)
+
+        if url and response:
+            self._api_token = response
+            return url2base[url]
+
     def wait_for_metadata_service(self):
         mcfg = self.ds_cfg
 
@@ -183,27 +216,39 @@ class DataSourceEc2(sources.DataSource):
             LOG.warning("Empty metadata url list! using default list")
             mdurls = self.metadata_urls
 
-        urls = []
-        url2base = {}
-        for url in mdurls:
-            cur = '{0}/{1}/meta-data/instance-id'.format(
-                url, self.min_metadata_version)
-            urls.append(cur)
-            url2base[cur] = url
-
-        start_time = time.time()
-        url = uhelp.wait_for_url(
-            urls=urls, max_wait=url_params.max_wait_seconds,
-            timeout=url_params.timeout_seconds, status_cb=LOG.warn)
-
-        if url:
-            self.metadata_address = url2base[url]
+        # try the api token path first
+        metadata_address = self._maybe_fetch_api_token(mdurls)
+        if not metadata_address:
+            if self._api_token == API_TOKEN_DISABLED:
+                LOG.warning('Retrying with IMDSv1')
+            # if we can't get a token, use instance-id path
+            urls = []
+            url2base = {}
+            url_path = '{ver}/meta-data/instance-id'.format(
+                ver=self.min_metadata_version)
+            request_method = 'GET'
+            for url in mdurls:
+                cur = '{0}/{1}'.format(url, url_path)
+                urls.append(cur)
+                url2base[cur] = url
+
+            start_time = time.time()
+            url, _ = uhelp.wait_for_url(
+                urls=urls, max_wait=url_params.max_wait_seconds,
+                timeout=url_params.timeout_seconds, status_cb=LOG.warning,
+                headers_cb=self._get_headers, request_method=request_method)
+
+            if url:
+                metadata_address = url2base[url]
+
+        if metadata_address:
+            self.metadata_address = metadata_address
             LOG.debug("Using metadata source: '%s'", self.metadata_address)
         else:
             LOG.critical("Giving up on md from %s after %s seconds",
                          urls, int(time.time() - start_time))
 
-        return bool(url)
+        return bool(metadata_address)
 
     def device_name_to_device(self, name):
         # Consult metadata service, that has
@@ -349,14 +394,22 @@ class DataSourceEc2(sources.DataSource):
             return {}
         api_version = self.get_metadata_api_version()
         crawled_metadata = {}
+        if self.cloud_name == CloudNames.AWS:
+            exc_cb = self._refresh_stale_aws_token_cb
+            exc_cb_ud = self._skip_or_refresh_stale_aws_token_cb
+        else:
+            exc_cb = exc_cb_ud = None
         try:
             crawled_metadata['user-data'] = ec2.get_instance_userdata(
-                api_version, self.metadata_address)
+                api_version, self.metadata_address,
+                headers_cb=self._get_headers, exception_cb=exc_cb_ud)
             crawled_metadata['meta-data'] = ec2.get_instance_metadata(
-                api_version, self.metadata_address)
+                api_version, self.metadata_address,
+                headers_cb=self._get_headers, exception_cb=exc_cb)
             if self.cloud_name == CloudNames.AWS:
                 identity = ec2.get_instance_identity(
-                    api_version, self.metadata_address)
+                    api_version, self.metadata_address,
+                    headers_cb=self._get_headers, exception_cb=exc_cb)
                 crawled_metadata['dynamic'] = {'instance-identity': identity}
         except Exception:
             util.logexc(
@@ -366,6 +419,73 @@ class DataSourceEc2(sources.DataSource):
         crawled_metadata['_metadata_api_version'] = api_version
         return crawled_metadata
 
+    def _refresh_api_token(self, seconds=AWS_TOKEN_TTL_SECONDS):
+        """Request new metadata API token.
+        @param seconds: The lifetime of the token in seconds
+
+        @return: The API token or None if unavailable.
+        """
+        if self.cloud_name != CloudNames.AWS:
+            return None
+        LOG.debug("Refreshing Ec2 metadata API token")
+        request_header = {'X-aws-ec2-metadata-token-ttl-seconds': seconds}
+        token_url = '{}/{}'.format(self.metadata_address, API_TOKEN_ROUTE)
+        try:
+            response = uhelp.readurl(
+                token_url, headers=request_header, request_method="PUT")
+        except uhelp.UrlError as e:
+            LOG.warning(
+                'Unable to get API token: %s raised exception %s',
+                token_url, e)
+            return None
+        return response.contents
+
+    def _skip_or_refresh_stale_aws_token_cb(self, msg, exception):
+        """Callback will not retry on SKIP_USERDATA_CODES or if no token
+           is available."""
+        retry = ec2.skip_retry_on_codes(
+            ec2.SKIP_USERDATA_CODES, msg, exception)
+        if not retry:
+            return False  # False raises exception
+        return self._refresh_stale_aws_token_cb(msg, exception)
+
+    def _refresh_stale_aws_token_cb(self, msg, exception):
+        """Exception handler for Ec2 to refresh token if token is stale."""
+        if isinstance(exception, uhelp.UrlError) and exception.code == 401:
+            # With _api_token as None, _get_headers will _refresh_api_token.
+            LOG.debug("Clearing cached Ec2 API token due to expiry")
+            self._api_token = None
+        return True  # always retry
+
+    def _status_cb(self, msg, exc=None):
+        LOG.warning(msg)
+        if 'Read timed out' in msg:
+            LOG.warning('Cannot use Ec2 IMDSv2 API tokens, using IMDSv1')
+            self._api_token = API_TOKEN_DISABLED
+
+    def _get_headers(self, url=''):
+        """Return a dict of headers for accessing a url.
+
+        If _api_token is unset on AWS, attempt to refresh the token via a PUT
+        and then return the updated token header.
+        """
+        if self.cloud_name != CloudNames.AWS or (self._api_token ==
+                                                 API_TOKEN_DISABLED):
+            return {}
+        # Request a 6 hour token if URL is API_TOKEN_ROUTE
+        request_token_header = {
+            'X-aws-ec2-metadata-token-ttl-seconds': AWS_TOKEN_TTL_SECONDS}
+        if API_TOKEN_ROUTE in url:
+            return request_token_header
+        if not self._api_token:
+            # If we don't yet have an API token, get one via a PUT against
+            # API_TOKEN_ROUTE. This _api_token may get unset by a 403 due
+            # to an invalid or expired token
+            self._api_token = self._refresh_api_token()
+            if not self._api_token:
+                return {}
+        return {'X-aws-ec2-metadata-token': self._api_token}
+
 
 class DataSourceEc2Local(DataSourceEc2):
     """Datasource run at init-local which sets up network to query metadata.
diff --git a/cloudinit/sources/DataSourceExoscale.py b/cloudinit/sources/DataSourceExoscale.py
index 4616daa..d59aefd 100644
--- a/cloudinit/sources/DataSourceExoscale.py
+++ b/cloudinit/sources/DataSourceExoscale.py
@@ -61,7 +61,7 @@ class DataSourceExoscale(sources.DataSource):
         metadata_url = "{}/{}/meta-data/instance-id".format(
             self.metadata_url, self.api_version)
 
-        url = url_helper.wait_for_url(
+        url, _response = url_helper.wait_for_url(
             urls=[metadata_url],
             max_wait=self.url_max_wait,
             timeout=self.url_timeout,
diff --git a/cloudinit/sources/DataSourceMAAS.py b/cloudinit/sources/DataSourceMAAS.py
index 61aa6d7..517913a 100644
--- a/cloudinit/sources/DataSourceMAAS.py
+++ b/cloudinit/sources/DataSourceMAAS.py
@@ -136,7 +136,7 @@ class DataSourceMAAS(sources.DataSource):
             url = url[:-1]
         check_url = "%s/%s/meta-data/instance-id" % (url, MD_VERSION)
         urls = [check_url]
-        url = self.oauth_helper.wait_for_url(
+        url, _response = self.oauth_helper.wait_for_url(
             urls=urls, max_wait=max_wait, timeout=timeout)
 
         if url:
diff --git a/cloudinit/sources/DataSourceOpenStack.py b/cloudinit/sources/DataSourceOpenStack.py
index 4a01524..7a5e71b 100644
--- a/cloudinit/sources/DataSourceOpenStack.py
+++ b/cloudinit/sources/DataSourceOpenStack.py
@@ -76,7 +76,7 @@ class DataSourceOpenStack(openstack.SourceMixin, sources.DataSource):
 
         url_params = self.get_url_params()
         start_time = time.time()
-        avail_url = url_helper.wait_for_url(
+        avail_url, _response = url_helper.wait_for_url(
             urls=md_urls, max_wait=url_params.max_wait_seconds,
             timeout=url_params.timeout_seconds)
         if avail_url:
diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py
index 1b0721b..a951b8b 100644
--- a/cloudinit/url_helper.py
+++ b/cloudinit/url_helper.py
@@ -101,7 +101,7 @@ def read_file_or_url(url, timeout=5, retries=10,
             raise UrlError(cause=e, code=code, headers=None, url=url)
         return FileResponse(file_path, contents=contents)
     else:
-        return readurl(url, timeout=timeout, retries=retries, headers=headers,
+        return readurl(url, timeout=timeout, retries=retries,
                        headers_cb=headers_cb, data=data,
                        sec_between=sec_between, ssl_details=ssl_details,
                        exception_cb=exception_cb)
@@ -310,7 +310,7 @@ def readurl(url, data=None, timeout=None, retries=0, sec_between=1,
 
 def wait_for_url(urls, max_wait=None, timeout=None,
                  status_cb=None, headers_cb=None, sleep_time=1,
-                 exception_cb=None, sleep_time_cb=None):
+                 exception_cb=None, sleep_time_cb=None, request_method=None):
     """
     urls:      a list of urls to try
     max_wait:  roughly the maximum time to wait before giving up
@@ -325,6 +325,8 @@ def wait_for_url(urls, max_wait=None, timeout=None,
                   'exception', the exception that occurred.
     sleep_time_cb: call method with 2 arguments (response, loop_n) that
                    generates the next sleep time.
+    request_method: indicate the type of HTTP request, GET, PUT, or POST
+    returns: tuple of (url, response contents), on failure, (False, None)
 
     the idea of this routine is to wait for the EC2 metdata service to
     come up.  On both Eucalyptus and EC2 we have seen the case where
@@ -381,8 +383,9 @@ def wait_for_url(urls, max_wait=None, timeout=None,
                 else:
                     headers = {}
 
-                response = readurl(url, headers=headers, timeout=timeout,
-                                   check_status=False)
+                response = readurl(
+                    url, headers=headers, timeout=timeout,
+                    check_status=False, request_method=request_method)
                 if not response.contents:
                     reason = "empty response [%s]" % (response.code)
                     url_exc = UrlError(ValueError(reason), code=response.code,
@@ -392,7 +395,7 @@ def wait_for_url(urls, max_wait=None, timeout=None,
                     url_exc = UrlError(ValueError(reason), code=response.code,
                                        headers=response.headers, url=url)
                 else:
-                    return url
+                    return url, response.contents
             except UrlError as e:
                 reason = "request error [%s]" % e
                 url_exc = e
@@ -421,7 +424,7 @@ def wait_for_url(urls, max_wait=None, timeout=None,
                   sleep_time)
         time.sleep(sleep_time)
 
-    return False
+    return False, None
 
 
 class OauthUrlHelper(object):
diff --git a/tests/unittests/test_datasource/test_cloudstack.py b/tests/unittests/test_datasource/test_cloudstack.py
index d6d2d6b..83c2f75 100644
--- a/tests/unittests/test_datasource/test_cloudstack.py
+++ b/tests/unittests/test_datasource/test_cloudstack.py
@@ -10,6 +10,9 @@ from cloudinit.tests.helpers import CiTestCase, ExitStack, mock
 import os
 import time
 
+MOD_PATH = 'cloudinit.sources.DataSourceCloudStack'
+DS_PATH = MOD_PATH + '.DataSourceCloudStack'
+
 
 class TestCloudStackPasswordFetching(CiTestCase):
 
@@ -17,7 +20,7 @@ class TestCloudStackPasswordFetching(CiTestCase):
         super(TestCloudStackPasswordFetching, self).setUp()
         self.patches = ExitStack()
         self.addCleanup(self.patches.close)
-        mod_name = 'cloudinit.sources.DataSourceCloudStack'
+        mod_name = MOD_PATH
         self.patches.enter_context(mock.patch('{0}.ec2'.format(mod_name)))
         self.patches.enter_context(mock.patch('{0}.uhelp'.format(mod_name)))
         default_gw = "192.201.20.0"
@@ -56,7 +59,9 @@ class TestCloudStackPasswordFetching(CiTestCase):
         ds.get_data()
         self.assertEqual({}, ds.get_config_obj())
 
-    def test_password_sets_password(self):
+    @mock.patch(DS_PATH + '.wait_for_metadata_service')
+    def test_password_sets_password(self, m_wait):
+        m_wait.return_value = True
         password = 'SekritSquirrel'
         self._set_password_server_response(password)
         ds = DataSourceCloudStack(
@@ -64,7 +69,9 @@ class TestCloudStackPasswordFetching(CiTestCase):
         ds.get_data()
         self.assertEqual(password, ds.get_config_obj()['password'])
 
-    def test_bad_request_doesnt_stop_ds_from_working(self):
+    @mock.patch(DS_PATH + '.wait_for_metadata_service')
+    def test_bad_request_doesnt_stop_ds_from_working(self, m_wait):
+        m_wait.return_value = True
         self._set_password_server_response('bad_request')
         ds = DataSourceCloudStack(
             {}, None, helpers.Paths({'run_dir': self.tmp}))
@@ -79,7 +86,9 @@ class TestCloudStackPasswordFetching(CiTestCase):
                     request_types.append(arg.split()[1])
         self.assertEqual(expected_request_types, request_types)
 
-    def test_valid_response_means_password_marked_as_saved(self):
+    @mock.patch(DS_PATH + '.wait_for_metadata_service')
+    def test_valid_response_means_password_marked_as_saved(self, m_wait):
+        m_wait.return_value = True
         password = 'SekritSquirrel'
         subp = self._set_password_server_response(password)
         ds = DataSourceCloudStack(
@@ -92,7 +101,9 @@ class TestCloudStackPasswordFetching(CiTestCase):
         subp = self._set_password_server_response(response_string)
         ds = DataSourceCloudStack(
             {}, None, helpers.Paths({'run_dir': self.tmp}))
-        ds.get_data()
+        with mock.patch(DS_PATH + '.wait_for_metadata_service') as m_wait:
+            m_wait.return_value = True
+            ds.get_data()
         self.assertRequestTypesSent(subp, ['send_my_password'])
 
     def test_password_not_saved_if_empty(self):
diff --git a/tests/unittests/test_datasource/test_ec2.py b/tests/unittests/test_datasource/test_ec2.py
index 1a5956d..5c5c787 100644
--- a/tests/unittests/test_datasource/test_ec2.py
+++ b/tests/unittests/test_datasource/test_ec2.py
@@ -191,7 +191,9 @@ def register_mock_metaserver(base_url, data):
             register(base_url, 'not found', status=404)
 
     def myreg(*argc, **kwargs):
-        return httpretty.register_uri(httpretty.GET, *argc, **kwargs)
+        url = argc[0]
+        method = httpretty.PUT if ec2.API_TOKEN_ROUTE in url else httpretty.GET
+        return httpretty.register_uri(method, *argc, **kwargs)
 
     register_helper(myreg, base_url, data)
 
@@ -237,6 +239,8 @@ class TestEc2(test_helpers.HttprettyTestCase):
         if md:
             all_versions = (
                 [ds.min_metadata_version] + ds.extended_metadata_versions)
+            token_url = self.data_url('latest', data_item='api/token')
+            register_mock_metaserver(token_url, 'API-TOKEN')
             for version in all_versions:
                 metadata_url = self.data_url(version) + '/'
                 if version == md_version:
-- 
1.8.3.1