summaryrefslogtreecommitdiff
blob: 26ffcc1a48c38a12c86b17f5f8acba2ab258778b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import httplib
import xmlrpclib
import socket
import time
import os
import select
import string
import sys
import urllib
import traceback
from OpenSSL import SSL, crypto

def display_traceback():
	for line in traceback.format_stack():
		print line.strip()

class SecureXMLRPCClient(xmlrpclib.ServerProxy):

	def __init__(self, host, port, client_cert, client_key, verify_cert_func=None):
		self._transport  = SafeTransport(self.__host, client_cert, client_key, verify_cert_func)
		xmlrpclib.ServerProxy.__init__(self, "https://" + host + ":" + str(port), transport=self._transport, encoding="utf-8", allow_none=True)

	def cancel(self):
		self._transport.close()

class SafeTransport(xmlrpclib.Transport):

	def __init__(self, host, client_cert, client_key, verify_cert_func=None):
		self.__host = host
		self.__client_cert = client_cert
		self.__client_key = client_key
		self.__verify_cert_func = verify_cert_func

	def make_connection(self, host):
		host, extra_headers, x509 = self.get_host_info(host)
		_host, _port = urllib.splitport(host)
		self._https = HTTPS(_host, int(_port), self.__client_cert, self.__client_key, self.__verify_cert_func)
		return self._https

	def close(self):
		pass
#		print "SafeTransport.close()"
#		if self._https:
#			self._https.close()
#			self._https = None

class HTTPSConnection(httplib.HTTPConnection):

	response_class = httplib.HTTPResponse

	def __init__(self, host, port=None, cert_file=None, key_file=None, verify_cert_func=None):
		httplib.HTTPConnection.__init__(self, host, port, None)
		self.verify_cert_func = verify_cert_func
		self.cert_file = cert_file
		self.key_file = key_file
		self.sock = None

	def connect(self):
		# Initialize context
		ctx = SSL.Context(SSL.SSLv23_METHOD)
		if self.verify_cert_func:
			ctx.set_verify(SSL.VERIFY_PEER, self.verify_cert_func) # Demand a certificate
		ctx.use_privatekey_file(self.key_file)
		ctx.use_certificate_file(self.cert_file)

		# Set up client
		sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		con = SSL.Connection(ctx, sock)
		self.sock = SSLConnection(con)
		self.sock.connect((self.host, self.port))

class HTTPS(httplib.HTTP):

	_connection_class = HTTPSConnection

	def __init__(self, host='', port=None, cert_file=None, key_file=None, verify_cert_func=None):
		self._setup(self._connection_class(host, port, cert_file, key_file, verify_cert_func))

# Higher-level SSL objects used by rpclib
#
# Copyright (c) 2002 Red Hat, Inc.
#
# Author: Mihai Ibanescu <misa@redhat.com>
# Modifications by Dan Williams <dcbw@redhat.com>

class SSLConnection:
	"""
	This whole class exists just to filter out a parameter
	passed in to the shutdown() method in SimpleXMLRPC.doPOST()
	"""

#	DEFAULT_TIMEOUT = 20
	DEFAULT_TIMEOUT = 0

	def __init__(self, conn):
		"""
		Connection is not yet a new-style class,
		so I'm making a proxy instead of subclassing.
		"""
		self.__dict__["conn"] = conn
		self.__dict__["close_refcount"] = 0
		self.__dict__["closed"] = False
		self.__dict__["timeout"] = self.DEFAULT_TIMEOUT

	def __del__(self):
		self.__dict__["conn"].close()

	def __getattr__(self,name):
		return getattr(self.__dict__["conn"], name)

	def __setattr__(self,name, value):
		setattr(self.__dict__["conn"], name, value)

	def settimeout(self, timeout):
		if timeout == None:
			self.__dict__["timeout"] = self.DEFAULT_TIMEOUT
		else:
			self.__dict__["timeout"] = timeout
		self.__dict__["conn"].settimeout(timeout)

	def shutdown(self, how=1):
		"""
		SimpleXMLRpcServer.doPOST calls shutdown(1),
		and Connection.shutdown() doesn't take
		an argument. So we just discard the argument.
		"""
		self.__dict__["conn"].shutdown()

	def accept(self):
		"""
		This is the other part of the shutdown() workaround.
		Since servers create new sockets, we have to infect
		them with our magic. :)
		"""
		c, a = self.__dict__["conn"].accept()
		return (SSLConnection(c), a)

	def makefile(self, mode, bufsize):
		"""
		We need to use socket._fileobject Because SSL.Connection
		doesn't have a 'dup'. Not exactly sure WHY this is, but
		this is backed up by comments in socket.py and SSL/connection.c

		Since httplib.HTTPSResponse/HTTPConnection depend on the
		socket being duplicated when they close it, we refcount the
		socket object and don't actually close until its count is 0.
		"""
#		display_traceback()
		self.__dict__["close_refcount"] = self.__dict__["close_refcount"] + 1
#		print "SSLConnection.makefile(): close_refcount=", str(self.__dict__["close_refcount"])
		return PlgFileObject(self, mode, bufsize)

	def close(self):
#		print "SSLConnection.close()"
#		display_traceback()
		if self.__dict__["closed"]:
			return
		self.__dict__["close_refcount"] = self.__dict__["close_refcount"] - 1
#		print "SSLConnection.close(): close_refcount=", str(self.__dict__["close_refcount"])
		if self.__dict__["close_refcount"] == 0:
			pass
#			print "SSLConnection.close(): close_refcount=0...actually closing"
#			self.shutdown()
#			self.__dict__["conn"].close()
#			self.__dict__["closed"] = True

	def sendall(self, data, flags=0):
		"""
		- Use select() to simulate a socket timeout without setting the socket
			to non-blocking mode.
		- Don't use pyOpenSSL's sendall() either, since it just loops on WantRead
			or WantWrite, consuming 100% CPU, and never times out.
		"""
		timeout = self.__dict__["timeout"]
		con = self.__dict__["conn"]
#		(read, write, excpt) = select.select([], [con], [], timeout)
#		if not con in write:
#			raise socket.timeout((110, "Operation timed out."))

#		starttime = time.time()
		origlen = len(data)
		sent = -1
		while len(data):
#			curtime = time.time()
#			if curtime - starttime > timeout:
#				raise socket.timeout((110, "Operation timed out."))

			try:
				sent = con.send(data, flags)
			except SSL.SysCallError, e:
				if e[0] == 32:      # Broken Pipe
					self.close()
					sent = 0
				else:
					raise socket.error(e)
			except (SSL.WantWriteError, SSL.WantReadError):
				time.sleep(0.2)
				continue

			data = data[sent:]
		return origlen - len(data)

	def recv(self, bufsize, flags=0):
#		timeout = self.__dict__["timeout"]
		con = self.__dict__["conn"]
		if self.closed:
#			print "socket is closed"
			return None
#		(read, write, excpt) = select.select([con], [], [], timeout)
#		if not con in read:
#			raise socket.timeout((110, "Operation timed out."))

#		starttime = time.time()
		while True:
#			curtime = time.time()
#			if curtime - starttime > timeout:
#				raise socket.timeout((110, "Operation timed out."))

			try:
				data = con.recv(bufsize, flags)
				return data
			except SSL.ZeroReturnError:
				return None
			except SSL.WantReadError:
				time.sleep(0.2)
		return None

class PlgFileObject(socket._fileobject):

	def close(self):
		"""
		socket._fileobject doesn't actually _close_ the socket,
		which we want it to do, so we have to override.
		"""
		try:
			if self._sock:
				self.flush()
				self._sock.close()
		finally:
			self._sock = None