from requests import Session
from requests.adapters import HTTPAdapter
from requests.auth import HTTPBasicAuth
from requests_toolbelt import MultipartEncoder
from requests.exceptions import (
ConnectionError,
HTTPError,
ReadTimeout,
)
from posixpath import join as urljoin
import time
import os
import logging
from ..utils.exceptions import OasisException
[docs]
class APISession(Session):
def __init__(
self,
api_url,
auth_type=None,
username=None,
password=None,
client_id=None,
client_secret=None,
access_token=None,
refresh_token=None,
token_url=None,
scope=None,
timeout=25,
retries=5,
retry_delay=1,
request_interval=0.02,
logger=None,
**kwargs
):
super(APISession, self).__init__(**kwargs)
[docs]
self.logger = logger or logging.getLogger(__name__)
# Extended class vars
[docs]
self.tkn_refresh = None
[docs]
self.url_base = urljoin(api_url, '')
[docs]
self.token_url = token_url
[docs]
self.retry_delay = retry_delay
[docs]
self.request_interval = request_interval
self.adapters.clear()
self.mount(self.url_base, HTTPAdapter(max_retries=self.retry_max))
if token_url:
self.mount(token_url, HTTPAdapter(max_retries=self.retry_max))
# Check connectivity & authentication
self.health_check()
self.retry_max = retries
server_auth = self._fetch_server_auth_type()
resolved_type = auth_type if auth_type is not None else server_auth
if resolved_type == 'disabled':
self.auth_type = 'disabled'
self.auth_credentials = {}
return
[docs]
self.auth_type = resolved_type
if resolved_type == "simple" and username and password:
self.auth_credentials = {"username": username, "password": password}
elif resolved_type in ("oidc", "m2m") and client_id and client_secret:
if resolved_type == "m2m" and not token_url:
raise OasisException("auth_type 'm2m' requires token_url")
self.auth_credentials = {"client_id": client_id, "client_secret": client_secret}
elif resolved_type == "token" and access_token and refresh_token:
self.auth_credentials = {"access_token": access_token, "refresh_token": refresh_token}
else:
raise OasisException(
f"Missing credentials for auth_type '{resolved_type}': simple requires username/password, "
"oidc/m2m requires client_id/client_secret (m2m also requires token_url), "
"token requires access_token/refresh_token.")
self.__get_access_token()
def __get_access_token(self):
# Token given directly
if self.auth_type == "token":
self.tkn_access = self.auth_credentials["access_token"]
self.tkn_refresh = self.auth_credentials["refresh_token"]
self.headers['authorization'] = 'Bearer {}'.format(self.tkn_access)
return
# Attempt to fetch token
try:
# client_credentials grant - direct request to IdP token endpoint
if self.auth_type == "m2m":
data = {'grant_type': 'client_credentials'}
if self.scope:
data['scope'] = self.scope
r = super(APISession, self).post(
self.token_url,
data=data,
auth=HTTPBasicAuth(
self.auth_credentials['client_id'],
self.auth_credentials['client_secret'],
),
headers={'Content-Type': 'application/x-www-form-urlencoded'},
timeout=self.timeout,
)
r.raise_for_status()
self.tkn_access = r.json()['access_token']
self.tkn_refresh = None
self.headers['authorization'] = 'Bearer {}'.format(self.tkn_access)
return
# Authorization Code Flow — redirect user to IdP, receive an auth code
elif self.auth_type == "oidc":
url = urljoin(self.url_base, 'access_token/')
r = super(APISession, self).post(url, json=self.auth_credentials, timeout=self.timeout)
r.raise_for_status()
self.tkn_access = r.json()['access_token']
self.tkn_refresh = None # oidc has no refresh token
self.headers['authorization'] = 'Bearer {}'.format(self.tkn_access)
return
# usermame / password ~ token fetch direct to Django with refresh
elif self.auth_type == "simple":
url = urljoin(self.url_base, 'access_token/')
r = self.post(url, json=self.auth_credentials)
r.raise_for_status()
self.tkn_access = r.json()['access_token']
self.tkn_refresh = r.json()['refresh_token']
self.headers['authorization'] = 'Bearer {}'.format(self.tkn_access)
return
# Unsupported auth type
raise OasisException(f"Unknown auth_type: '{self.auth_type}'")
except (TypeError, AttributeError, BytesWarning, HTTPError, ConnectionError, ReadTimeout) as e:
if isinstance(e, HTTPError) and e.response is not None:
self.logger.error('Access token request failed (%s): %s', e.response.status_code, e.response.text)
raise OasisException('Authentication Error', e)
def _refresh_token(self):
if self.auth_type in ("m2m", "oidc"):
self.__get_access_token()
return
try:
self.headers['authorization'] = 'Bearer {}'.format(self.tkn_refresh)
url = urljoin(self.url_base, 'refresh_token/')
r = super(APISession, self).post(url, timeout=self.timeout)
r.raise_for_status()
self.tkn_access = r.json()['access_token']
if 'refresh_token' in r.json():
self.logger.debug("Refreshing access token")
self.tkn_refresh = r.json()['refresh_token']
self.headers['authorization'] = 'Bearer {}'.format(self.tkn_access)
return
except (TypeError, AttributeError, BytesWarning, HTTPError, ConnectionError, ReadTimeout) as e:
err_msg = 'Authentication Error'
raise OasisException(err_msg, e)
[docs]
def unrecoverable_error(self, error, msg=None):
err_r = error.response
err_msg = 'api error: {}, url: {}, msg: {}'.format(err_r.status_code, err_r.url, err_r.text)
if msg:
self.logger.error(msg)
raise OasisException(err_msg, error)
# Connection Error Handlers
def __recoverable(self, error, url, request, counter=1):
self.logger.debug(f"Connection exception handler: {error}")
if not (counter < self.retry_max):
self.logger.debug("Max retries of '{}' reached".format(self.retry_max))
raise OasisException(f'Failed to recover from error: {error}', error)
if isinstance(error, (ConnectionError, ReadTimeout)):
self.logger.debug(f"Recoverable error [{error}] from {request} {url}")
self.logger.debug(f"Backoff timer: {self.retry_delay * counter}")
# Reset HTTPAdapter & clear connection pool
self.adapters.clear()
self.mount(self.url_base, HTTPAdapter(max_retries=self.retry_max))
time.sleep(self.retry_delay * counter)
return True
elif isinstance(error, HTTPError):
http_err_code = error.response.status_code
self.logger.debug(f"Recoverable error [{error}] from {request} {url}")
if http_err_code in [502, 503, 504]:
error = "HTTP {}".format(http_err_code)
return True
elif http_err_code in [401, 403]:
if self.auth_type != 'disabled' and (self.tkn_refresh is not None or self.auth_type in ('m2m', 'oidc')):
self.logger.debug("requesting refresh token")
self._refresh_token()
return True
return False
def _fetch_server_auth_type(self):
try:
url = urljoin(self.url_base, 'server_info/')
r = super(APISession, self).get(url, timeout=self.timeout)
if r.status_code == 200:
return r.json().get('config', {}).get('API_AUTH_TYPE')
except Exception:
pass
return None
# @oasis_log
[docs]
def health_check(self):
"""
Checks the health of the server.
"""
try:
url = urljoin(self.url_base, 'healthcheck/')
r = super(APISession, self).get(url)
r.raise_for_status()
return r
except (TypeError, AttributeError, BytesWarning, HTTPError, ConnectionError, ReadTimeout) as e:
err_msg = 'Health check failed: Unable to connect to {}'.format(self.url_base)
raise OasisException(err_msg, e)
[docs]
def upload(self, url, filepath, content_type, **kwargs):
counter = 0
while True:
counter += 1
try:
abs_fp = os.path.realpath(os.path.expanduser(filepath))
m = MultipartEncoder(fields={'file': (os.path.basename(filepath), open(abs_fp, 'rb'), content_type)})
r = super(APISession, self).post(url, data=m, headers={'Content-Type': m.content_type}, timeout=self.timeout, **kwargs)
r.raise_for_status()
time.sleep(self.request_interval)
except (HTTPError, ConnectionError, ReadTimeout) as e:
if self.__recoverable(e, url, 'GET', counter):
continue
else:
self.logger.debug(f'Unrecoverable error: {e}')
raise e
return r
[docs]
def upload_byte(self, url, file_bytes, filename, content_type, **kwargs):
counter = 0
while True:
counter += 1
try:
m = MultipartEncoder(fields={'file': (filename, file_bytes, content_type)})
r = super(APISession, self).post(url, data=m,
headers={'Content-Type': m.content_type},
timeout=self.timeout,
**kwargs)
r.raise_for_status()
time.sleep(self.request_interval)
except (HTTPError, ConnectionError, ReadTimeout) as e:
if self.__recoverable(e, url, 'GET', counter):
continue
else:
self.logger.debug(f'Unrecoverable error: {e}')
raise e
return r
[docs]
def get(self, url, **kwargs):
counter = 0
while True:
counter += 1
try:
r = super(APISession, self).get(url, timeout=self.timeout, **kwargs)
r.raise_for_status()
time.sleep(self.request_interval)
except (HTTPError, ConnectionError, ReadTimeout) as e:
if self.__recoverable(e, url, 'GET', counter):
continue
else:
self.logger.debug(f'Unrecoverable error: {e}')
raise e
return r
[docs]
def post(self, url, **kwargs):
counter = 0
while True:
counter += 1
try:
r = super(APISession, self).post(url, timeout=self.timeout, **kwargs)
r.raise_for_status()
time.sleep(self.request_interval)
except (HTTPError, ConnectionError, ReadTimeout) as e:
if self.__recoverable(e, url, 'POST', counter):
continue
else:
self.logger.debug(f'Unrecoverable error: {e}')
raise e
return r
[docs]
def delete(self, url, **kwargs):
counter = 0
while True:
counter += 1
try:
r = super(APISession, self).delete(url, timeout=self.timeout, **kwargs)
r.raise_for_status()
time.sleep(self.request_interval)
except (HTTPError, ConnectionError, ReadTimeout) as e:
if self.__recoverable(e, url, 'DELETE', counter):
continue
else:
self.logger.debug(f'Unrecoverable error: {e}')
raise e
return r
[docs]
def put(self, url, **kwargs):
counter = 0
while True:
counter += 1
try:
r = super(APISession, self).put(url, timeout=self.timeout, **kwargs)
r.raise_for_status()
time.sleep(self.request_interval)
except (HTTPError, ConnectionError, ReadTimeout) as e:
if self.__recoverable(e, url, 'PUT', counter):
continue
else:
self.logger.debug(f'Unrecoverable error: {e}')
raise e
return r