utils/soap: ignore content before and after SOAP XML content (#50260)
This commit is contained in:
parent
e70576bc6c
commit
090c2f4062
|
@ -36,19 +36,44 @@ class SOAPClient(Client):
|
|||
"""
|
||||
def __init__(self, resource, **kwargs):
|
||||
wsdl_url = kwargs.pop('wsdl_url', None) or resource.wsdl_url
|
||||
transport_kwargs = kwargs.pop('transport_kwargs', {})
|
||||
transport_class = getattr(resource, 'soap_transport_class', SOAPTransport)
|
||||
transport = transport_class(resource, wsdl_url, session=resource.requests, cache=InMemoryCache())
|
||||
transport = transport_class(resource, wsdl_url,
|
||||
session=resource.requests,
|
||||
cache=InMemoryCache(), **transport_kwargs)
|
||||
super(SOAPClient, self).__init__(wsdl_url, transport=transport, **kwargs)
|
||||
|
||||
|
||||
class ResponseFixContentWrapper:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.response, name)
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
content = self.response.content
|
||||
if 'multipart/related' not in self.response.headers.get('Content-Type', ''):
|
||||
try:
|
||||
first_less_than_sign = content.index(b'<')
|
||||
last_greater_than_sign = content.rindex(b'>')
|
||||
content = content[first_less_than_sign:last_greater_than_sign + 1]
|
||||
except ValueError:
|
||||
pass
|
||||
return content
|
||||
|
||||
|
||||
class SOAPTransport(Transport):
|
||||
"""Wrapper around zeep.Transport
|
||||
|
||||
disable basic_authentication hosts unrelated to wsdl's endpoints
|
||||
"""
|
||||
def __init__(self, resource, wsdl_url, **kwargs):
|
||||
def __init__(self, resource, wsdl_url, remove_first_bytes_for_xml=False, **kwargs):
|
||||
self.resource = resource
|
||||
self.wsdl_host = urlparse.urlparse(wsdl_url).netloc
|
||||
# fix content for servers returning unexpected characters before XML document start
|
||||
self.remove_first_bytes_for_xml = remove_first_bytes_for_xml
|
||||
super(SOAPTransport, self).__init__(**kwargs)
|
||||
|
||||
def _load_remote_data(self, url):
|
||||
|
@ -60,3 +85,11 @@ class SOAPTransport(Transport):
|
|||
return super(SOAPTransport, self)._load_remote_data(url)
|
||||
except RequestException as e:
|
||||
raise SOAPError('SOAP service is down, location %r cannot be loaded: %s' % (url, e), exception=e, url=url)
|
||||
|
||||
def post_xml(self, *args, **kwargs):
|
||||
response = super().post_xml(*args, **kwargs)
|
||||
|
||||
if self.remove_first_bytes_for_xml:
|
||||
return ResponseFixContentWrapper(response)
|
||||
|
||||
return response
|
||||
|
|
|
@ -19,7 +19,7 @@ import mock
|
|||
import requests
|
||||
from zeep import Settings
|
||||
from zeep.plugins import Plugin
|
||||
from zeep.exceptions import XMLParseError
|
||||
from zeep.exceptions import XMLParseError, TransportError
|
||||
|
||||
from passerelle.utils.soap import SOAPClient
|
||||
|
||||
|
@ -78,3 +78,33 @@ def test_disable_strict_mode(mocked_post):
|
|||
assert len(result) == 2
|
||||
assert result['skipMe'] is None
|
||||
assert result['price'] == 4.2
|
||||
|
||||
|
||||
@mock.patch('requests.sessions.Session.post')
|
||||
def test_remove_first_bytes_for_xml(mocked_post):
|
||||
response = requests.Response()
|
||||
response.status_code = 200
|
||||
response._content = force_bytes('''blabla \n<?xml version='1.0' encoding='utf-8'?>
|
||||
<soap-env:Envelope xmlns:soap-env="http://schemas.xmlsoap.org/soap/envelope/">
|
||||
<soap-env:Body>
|
||||
<ns0:TradePrice xmlns:ns0="http://example.com/stockquote.xsd">
|
||||
<skipMe>1.2</skipMe>
|
||||
<price>4.20</price>
|
||||
</ns0:TradePrice>
|
||||
</soap-env:Body>
|
||||
</soap-env:Envelope>\n bloublou''')
|
||||
mocked_post.return_value = response
|
||||
|
||||
soap_resource = SOAPResource()
|
||||
|
||||
client = SOAPClient(soap_resource)
|
||||
with pytest.raises(TransportError):
|
||||
client.service.GetLastTradePrice(tickerSymbol='banana')
|
||||
|
||||
client = SOAPClient(soap_resource,
|
||||
transport_kwargs={'remove_first_bytes_for_xml': True})
|
||||
result = client.service.GetLastTradePrice(tickerSymbol='banana')
|
||||
assert len(result) == 2
|
||||
assert result['skipMe'] == 1.2
|
||||
assert result['price'] == 4.2
|
||||
|
||||
|
|
Loading…
Reference in New Issue