 # -------------------------------------------------------------------------- # # Copyright (c) Microsoft Corporation. All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the ""Software""), to # deal in the Software without restriction, including without limitation the # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or # sell copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. # # -------------------------------------------------------------------------- from typing import Any, Optional, AsyncIterator as AsyncIteratorType from import AsyncIterator import logging import asyncio import aiohttp from requests.exceptions import ( ChunkedEncodingError, StreamConsumedError) from azure.core.configuration import ConnectionConfiguration from azure.core.exceptions import ServiceRequestError, ServiceResponseError, AzureError from azure.core.pipeline import Pipeline from ._base import HttpRequest from ._base_async import ( AsyncHttpTransport, AsyncHttpResponse, _ResponseStopIteration) # Matching requests, because why not? CONTENT_CHUNK_SIZE = 10 * 1024 _LOGGER = logging.getLogger(__name__) [docs]class AioHttpTransport(AsyncHttpTransport): """AioHttp HTTP sender implementation.  Fully asynchronous implementation using the aiohttp library.  :param session: The client session.  :param loop: The event loop.  :param bool session_owner: Session owner. Defaults True.  :keyword bool use_env_settings: Uses proxy settings from environment. Defaults to True.  .. admonition:: Example:  .. literalinclude:: ../samples/  :start-after: [START aiohttp]  :end-before: [END aiohttp]  :language: python  :dedent: 4  :caption: Asynchronous transport with aiohttp.  """ def __init__(self, *, session=None, loop=None, session_owner=True, **kwargs): self._loop = loop self._session_owner = session_owner self.session = session self.connection_config = ConnectionConfiguration(**kwargs) self._use_env_settings = kwargs.pop('use_env_settings', True) async def __aenter__(self): await return self async def __aexit__(self, *args): # pylint: disable=arguments-differ await self.close() [docs] async def open(self): """Opens the connection.  """ if not self.session and self._session_owner: jar = aiohttp.DummyCookieJar() self.session = aiohttp.ClientSession( loop=self._loop, trust_env=self._use_env_settings, cookie_jar=jar ) if self.session is not None: await self.session.__aenter__() [docs] async def close(self): """Closes the connection.  """ if self._session_owner and self.session: await self.session.close() self._session_owner = False self.session = None def _build_ssl_config(self, cert, verify): # pylint: disable=no-self-use ssl_ctx = None if cert or verify not in (True, False): import ssl if verify not in (True, False): ssl_ctx = ssl.create_default_context(cafile=verify) else: ssl_ctx = ssl.create_default_context() if cert: ssl_ctx.load_cert_chain(*cert) return ssl_ctx return verify def _get_request_data(self, request): #pylint: disable=no-self-use if request.files: form_data = aiohttp.FormData() for form_file, data in request.files.items(): content_type = data[2] if len(data) > 2 else None try: form_data.add_field(form_file, data[1], filename=data[0], content_type=content_type) except IndexError: raise ValueError("Invalid formdata formatting: {}".format(data)) return form_data return [docs] async def send(self, request: HttpRequest, **config: Any) -> Optional[AsyncHttpResponse]: """Send the request using this HTTP sender.  Will pre-load the body into memory to be available with a sync method.  Pass stream=True to avoid this behavior.  :param request: The HttpRequest object  :type request: ~azure.core.pipeline.transport.HttpRequest  :param config: Any keyword arguments  :return: The AsyncHttpResponse  :rtype: ~azure.core.pipeline.transport.AsyncHttpResponse  :keyword bool stream: Defaults to False.  :keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url)  :keyword str proxy: will define the proxy to use all the time  """ await proxies = config.pop('proxies', None) if proxies and 'proxy' not in config: # aiohttp needs a single proxy, so iterating until we found the right protocol # Sort by longest string first, so "http" is not used for "https" ;-) for protocol in sorted(proxies.keys(), reverse=True): if request.url.startswith(protocol): config['proxy'] = proxies[protocol] break error = None # type: Optional[AzureError] response = None config['ssl'] = self._build_ssl_config( cert=config.pop('connection_cert', self.connection_config.cert), verify=config.pop('connection_verify', self.connection_config.verify) ) # If we know for sure there is not body, disable "auto content type" # Otherwise, aiohttp will send "application/octect-stream" even for empty POST request # and that break services like storage signature if not and not request.files: config['skip_auto_headers'] = ['Content-Type'] try: stream_response = config.pop("stream", False) timeout = config.pop('connection_timeout', self.connection_config.timeout) read_timeout = config.pop('read_timeout', self.connection_config.read_timeout) socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout) result = await self.session.request( request.method, request.url, headers=request.headers, data=self._get_request_data(request), timeout=socket_timeout, allow_redirects=False, **config ) response = AioHttpTransportResponse(request, result, self.connection_config.data_block_size) if not stream_response: await response.load_body() except aiohttp.client_exceptions.ClientConnectorError as err: error = ServiceRequestError(err, error=err) except asyncio.TimeoutError as err: error = ServiceResponseError(err, error=err) if error: raise error return response class AioHttpStreamDownloadGenerator(AsyncIterator): """Streams the response body data.  :param pipeline: The pipeline object  :param response: The client response object.  :param block_size: block size of data sent over connection.  :type block_size: int  """ def __init__(self, pipeline: Pipeline, response: AsyncHttpResponse) -> None: self.pipeline = pipeline self.request = response.request self.response = response self.block_size = response.block_size self.content_length = int(response.internal_response.headers.get('Content-Length', 0)) self.downloaded = 0 def __len__(self): return self.content_length async def __anext__(self): retry_active = True retry_total = 3 retry_interval = 1000 while retry_active: try: chunk = await if not chunk: raise _ResponseStopIteration() self.downloaded += self.block_size return chunk except _ResponseStopIteration: self.response.internal_response.close() raise StopAsyncIteration() except (ChunkedEncodingError, ConnectionError): retry_total -= 1 if retry_total <= 0: retry_active = False else: await asyncio.sleep(retry_interval) headers = {'range': 'bytes=' + str(self.downloaded) + '-'} resp =, stream=True, headers=headers) if resp.status_code == 416: raise chunk = await if not chunk: raise StopIteration() self.downloaded += len(chunk) return chunk continue except StreamConsumedError: raise except Exception as err: _LOGGER.warning("Unable to stream download: %s", err) self.response.internal_response.close() raise [docs]class AioHttpTransportResponse(AsyncHttpResponse): """Methods for accessing response body data.  :param request: The HttpRequest object  :type request: ~azure.core.pipeline.transport.HttpRequest  :param aiohttp_response: Returned from ClientSession.request().  :type aiohttp_response: aiohttp.ClientResponse object  :param block_size: block size of data sent over connection.  :type block_size: int  """ def __init__(self, request: HttpRequest, aiohttp_response: aiohttp.ClientResponse, block_size=None) -> None: super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size) # self.status_code = aiohttp_response.status self.headers = aiohttp_response.headers self.reason = aiohttp_response.reason self.content_type = aiohttp_response.headers.get('content-type') self._body = None [docs] def body(self) -> bytes: """Return the whole body as bytes in memory.  """ if self._body is None: raise ValueError("Body is not available. Call async method load_body, or do your call with stream=False.") return self._body [docs] async def load_body(self) -> None: """Load in memory the body, so it could be accessible from sync methods.""" self._body = await [docs] def stream_download(self, pipeline) -> AsyncIteratorType[bytes]: """Generator for streaming response body data.  :param pipeline: The pipeline object  :type pipeline: azure.core.pipeline  """ return AioHttpStreamDownloadGenerator(pipeline, self) 
