# Copyright (c) 2012-2016 Hewlett Packard Enterprise Development LP # # Permission to use, copy, modify, and/or distribute this software for # any purpose with or without fee is hereby granted, provided that the # above copyright notice and this permission notice appear in all copies. # # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. from __future__ import absolute_import import cgi import collections import datetime import functools import io import logging import os.path import random import time import requests.exceptions import six import six.moves.urllib_parse as urlparse from requestbuilder.exceptions import (ClientError, ServerError, ServiceInitError, TimeoutError) from requestbuilder.mixins import RegionConfigurableMixin class BaseService(RegionConfigurableMixin): NAME = None DESCRIPTION = '' API_VERSION = '' MAX_RETRIES = 2 TIMEOUT = 30 # socket timeout in seconds REGION_ENVVAR = None URL_ENVVAR = None ARGS = [] def __init__(self, config, loglevel=None, max_retries=None, timeout=None, **kwargs): self.args = kwargs self.config = config self.endpoint = None self.log = logging.getLogger(self.__class__.__name__) if loglevel is not None: self.log.level = loglevel self.max_retries = max_retries self.region_name = None # Note this can differ from config.region self.session_args = {} self.timeout = timeout self._session = None @classmethod def from_other(cls, other, **kwargs): kwargs.setdefault('loglevel', other.log.level) kwargs.setdefault('max_retries', other.max_retries) kwargs.setdefault('session_args', dict(other.session_args)) kwargs.setdefault('timeout', other.timeout) if 'region' in other.args: kwargs.setdefault('region', other.args['region']) new = cls(other.config, **kwargs) new.configure() return new def configure(self): # TODO: rename this to setup # # Configure user and region before grabbing endpoint info since # the latter may depend upon the former self.update_config_view() self.__configure_endpoint() # Configure timeout and retry handlers if self.max_retries is None: config_max_retries = self.config.get_global_option('max-retries') if config_max_retries is not None: self.max_retries = int(config_max_retries) else: self.max_retries = self.MAX_RETRIES if self.timeout is None: config_timeout = self.config.get_global_option('timeout') if config_timeout is not None: self.timeout = float(config_timeout) else: self.timeout = self.TIMEOUT self.session_args.setdefault('stream', True) # SSL cert verification is opt-in verify_ssl = self.config.get_region_option('verify-tls') if verify_ssl is None: verify_ssl = self.config.get_region_option('verify-ssl') self.session_args.setdefault( 'verify', self.config.convert_to_bool(verify_ssl, default=False)) # Ensure everything is okay and finish up self.validate_config() @property def session(self): if self._session is None: self._session = requests.session() for key, val in six.iteritems(self.session_args): setattr(self._session, key, val) for adapter in self._session.adapters.values(): # send_request handles retries to allow for re-signing adapter.max_retries = 0 return self._session def validate_config(self): if self.endpoint is None: if self.NAME is not None: url_opt = '{0}-url'.format(self.NAME) available_regions = self.config.get_all_region_options(url_opt) if len(available_regions) > 0: msg = ('No {0} endpoint to connect to was given. ' 'Configured regions with {0} endpoints are: ' '{1}').format(self.NAME, ', '.join(sorted(available_regions))) else: msg = ('No {0} endpoint to connect to was given. {0} ' 'endpoints may be specified in a config file with ' '"{1}".').format(self.NAME, url_opt) else: msg = 'No endpoint to connect to was given' raise ServiceInitError(msg) def get_request_url(self, method='GET', path=None, params=None, headers=None, data=None, files=None, auth=None): url = self.__get_url_for_path(path) headers = dict(headers or {}) if 'host' not in [header.lower() for header in headers]: headers['Host'] = urlparse.urlparse(self.endpoint).netloc p_request = self.__log_and_prepare_request(method, url, params, data, files, headers, auth) return p_request.url def send_request(self, method='GET', path=None, params=None, headers=None, data=None, files=None, auth=None): url = self.__get_url_for_path(path) headers = dict(headers) if 'host' not in [header.lower() for header in headers]: headers['Host'] = urlparse.urlparse(self.endpoint).netloc try: max_tries = self.max_retries + 1 assert max_tries >= 1 redirects_left = 5 if isinstance(data, file) and hasattr(data, 'seek'): # If we're redirected we need to be able to reset data_file_offset = data.tell() else: data_file_offset = None while True: for attempt_no, delay in enumerate( _generate_delays(max_tries), 1): # Use exponential backoff if this is a retry if delay > 0: self.log.debug('will retry after %.3f seconds', delay) time.sleep(delay) self.log.info('sending request (attempt %i of %i)', attempt_no, max_tries) p_request = self.__log_and_prepare_request( method, url, params, data, files, headers, auth) proxies = requests.utils.get_environ_proxies(url) for key, val in sorted(proxies.items()): self.log.debug('request proxy: %s=%s', key, val) p_request.start_time = datetime.datetime.now() try: response = self.session.send( p_request, timeout=self.timeout, proxies=proxies, allow_redirects=False) except requests.exceptions.Timeout: if attempt_no < max_tries: self.log.debug('timeout', exc_info=True) if data_file_offset is not None: self.log.debug('re-seeking body to ' 'beginning of file') # pylint: disable=E1101 data.seek(data_file_offset) # pylint: enable=E1101 continue elif not hasattr(data, 'tell'): continue # Fallthrough -- if it has a file pointer but not # seek we can't retry because we can't rewind. raise if response.status_code not in (500, 503): break # If it *was* in that list, retry if (response.status_code in (301, 302, 307, 308) and redirects_left > 0 and 'Location' in response.headers): # Standard redirect -- we need to handle this ourselves # because we have to re-sign requests when their URLs # change. redirects_left -= 1 parsed_rdr = urlparse.urlparse( response.headers['Location']) parsed_url = urlparse.urlparse(url) new_url_bits = [] for rdr_bit, url_bit in zip(parsed_rdr, parsed_url): new_url_bits.append(rdr_bit or url_bit) if 'Host' in headers: headers['Host'] = new_url_bits[1] # netloc url = urlparse.urlunparse(new_url_bits) self.log.debug('redirecting to %s (%i redirect(s) ' 'remaining)', url, redirects_left) if data_file_offset is not None: self.log.debug('re-seeking body to beginning of file') # pylint: disable=E1101 data.seek(data_file_offset) # pylint: enable=E1101 continue elif response.status_code >= 300: # We include 30x because we've handled the standard method # of redirecting, but the server might still be trying to # redirect another way for some reason. self.handle_http_error(response) return response except requests.exceptions.Timeout as exc: self.log.debug('timeout', exc_info=True) raise TimeoutError('request timed out', exc) except requests.exceptions.ConnectionError as exc: self.log.debug('connection error', exc_info=True) return self.__handle_connection_error(exc) except requests.exceptions.HTTPError as exc: return self.handle_http_error(response) except requests.exceptions.RequestException as exc: self.log.debug('request error', exc_info=True) raise ClientError(exc) def __handle_connection_error(self, err): if isinstance(err, six.string_types): msg = err elif isinstance(err, Exception) and len(err.args) > 0: if hasattr(err.args[0], 'reason'): msg = err.args[0].reason elif isinstance(err.args[0], Exception): return self.__handle_connection_error(err.args[0]) else: msg = err.args[0] else: raise ClientError('connection error') raise ClientError('connection error ({0})'.format(msg)) def handle_http_error(self, response): self.log.debug('HTTP error', exc_info=True) raise ServerError(response) def __get_url_for_path(self, path): if path: # We can't simply use urljoin because a path might start with '/' # like it could for S3 keys that start with that character. if self.endpoint.endswith('/'): return self.endpoint + path else: return self.endpoint + '/' + path else: return self.endpoint def __log_and_prepare_request(self, method, url, params, data, files, headers, auth): hooks = {'response': functools.partial(_log_response_data, self.log)} if auth: bound_auth = auth.bind_to_service(self) else: bound_auth = None request = requests.Request(method=method, url=url, params=params, data=data, files=files, headers=headers, auth=bound_auth) p_request = self.session.prepare_request(request) p_request.hooks = {'response': hooks['response']} self.log.debug('request method: %s', request.method) self.log.debug('request url: %s', p_request.url) if isinstance(p_request.headers, (dict, collections.Mapping)): for key, val in sorted(six.iteritems(p_request.headers)): if key.lower().endswith('password'): val = '' self.log.debug('request header: %s: %s', key, val) if isinstance(request.params, (dict, collections.Mapping)): for key, val in sorted(urlparse.parse_qsl( urlparse.urlparse(p_request.url).query, keep_blank_values=True)): if key.lower().endswith('password'): val = '' self.log.debug('request param: %s: %s', key, val) if isinstance(request.data, (dict, collections.Mapping)): content_type, content_type_params = cgi.parse_header( p_request.headers.get('content-type') or '') if content_type == 'multipart/form-data': data = cgi.parse_multipart(io.BytesIO(p_request.body), content_type_params) elif content_type == 'application/x-www-form-urlencoded': data = dict(urlparse.parse_qsl(p_request.body, keep_blank_values=True)) else: data = request.data for key, val in sorted(data.items()): # pylint: disable=superfluous-parens if key in (request.files or {}): # We probably don't want to include the contents of # entire files in debug output. continue # pylint: enable=superfluous-parens if key.lower().endswith('password'): val = '' self.log.debug('request data: %s: %s', key, val) if isinstance(request.files, (dict, collections.Mapping)): for key, val in sorted(six.iteritems(request.files)): if hasattr(val, '__len__'): val = '<{0} bytes>'.format(len(val)) self.log.debug('request file: %s: %s', key, val) return p_request def __configure_endpoint(self): # self.args gets highest precedence if self.args.get('url'): url, region_name = _parse_endpoint_url(self.args['url']) # Environment comes next elif os.getenv(self.URL_ENVVAR): url, region_name = _parse_endpoint_url(os.getenv(self.URL_ENVVAR)) # Try the config file elif self.NAME: url, section = self.config.get_region_option2(self.NAME + '-url') if section: # Check to see if the region name is explicitly specified region_name = self.config.get_region_option('name', section) if region_name is None: # If it isn't then just grab the end of the section name region_name = section.rsplit(':', 1)[-1] else: region_name = None self.endpoint = url self.region_name = region_name def _log_response_data(logger, response, **_): if hasattr(response.request, 'start_time'): duration = datetime.datetime.now() - response.request.start_time logger.debug('response time: %i.%03i seconds', duration.seconds, duration.microseconds // 1000) if response.status_code >= 400: logger.error('response status: %i', response.status_code) else: logger.info('response status: %i', response.status_code) if isinstance(response.headers, (dict, collections.Mapping)): for key, val in sorted(response.headers.items()): logger.debug('response header: %s: %s', key, val) def _generate_delays(max_tries): if max_tries >= 1: yield 0 for retry_no in range(1, max_tries): next_delay = (random.random() + 1) * 2 ** (retry_no - 1) yield min((next_delay, 15)) def _parse_endpoint_url(urlish): """ If given a URL, return the URL and None. If given a URL with a string and "::" prepended to it, return the URL and the prepended string. This is meant to give one a means to supply a region name via arguments and variables that normally only accept URLs. """ if '::' in urlish: region, url = urlish.split('::', 1) else: region = None url = urlish return url, region