# cpushare client
# Copyright (C) 2005-2006  Andrea Arcangeli <andrea@cpushare.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation;
# only version 2.1 of the License.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

from twisted.copyright import version
twisted_bug_workaround = False
if '2.5.0' in version or True:
	twisted_bug_workaround = True

from twisted.python import failure
from twisted.internet import main, error, reactor
from twisted.internet.protocol import ClientFactory
if not twisted_bug_workaround:
	from twisted.protocols.basic import Int32StringReceiver
	from twisted.internet.protocol import ReconnectingClientFactory
else:
	from cpushare.twisted_bugs import Int32StringReceiver
	from cpushare.twisted_bugs import ReconnectingClientFactory

import os, sys, struct, time, md5

# local
from cpushare.config import *
from cpushare.proto_const import *
from cpushare.exceptions import *
from cpushare.seccomp import seccomp_class
from cpushare.seccomp_gen import PROTO_SECCOMP_MAX_LENGTH

arch2bytecodes = {
	'x86_64' : 'i686',
	'i686' : 'i686',
	'ppc64' : 'ppc',
	'ppc' : 'ppc',
	}

class state_machine(object):
	def __init__(self, protocol):
		self.protocol = protocol
		self.handler = None

	def connectionLost(self):
		pass

	def login(self):
		self.protocol.sendString(self.protocol.factory.order[0]) # buy/sell
		self.protocol.sendString(self.protocol.factory.order[1]) # ID
		self.protocol.sendString(self.protocol.factory.order[2]) # password
		for i in os.uname():
			self.protocol.sendString(i)
		self.protocol.sendString(sys.version)
		from twisted.copyright import version
		self.protocol.sendString(version)
		self.protocol.sendString(CPUSHARE_PROTOCOL_VERSION[self.protocol.factory.order[0]])
		self.protocol.sendString(CPUSHARE_SECURITY_VERSION)
		self.protocol.sendString(CPUSHARE_SECURITY_VENDOR)
		self.protocol.sendString(self.protocol.factory.factory_id) # this factory id
		self.protocol.sendString(self.protocol.factory.factory_type)

		if self.protocol.factory.order[0] == 'buy':
			if self.protocol.factory.factory_type == \
			       PROTO_FACTORY_TYPE_CONTROL:
				self.handler = self.buy_handler
			else:
				from cpushare_buy import buy_state_machine_class
				self.protocol.state_machine = buy_state_machine_class(self.protocol)
				if self.protocol.state_machine.buy_api != CPUSHARE_PROTOCOL_VERSION['buy_api']:
					print 'Buy API mismatch, project should be updated to %s' % \
					      CPUSHARE_PROTOCOL_VERSION['buy_api']
					reactor.stop()
					raise Exception()
				self.protocol.state_machine.start()
		elif self.protocol.factory.order[0] == 'sell':
			arch = os.uname()[4]
			bytecodes = arch2bytecodes[arch]
			self.protocol.sendString(bytecodes)
			self.bytecodes = bytecodes.split()
			print 'max_ram_mbytes_per_cpu = %d' % self.protocol.factory.max_ram_mbytes_per_cpu
			self.protocol.sendString(str(self.protocol.factory.max_ram_mbytes_per_cpu))
			self.handler = self.sell_handler
			self.seccomp = None

	def host_unpack(self, data):
		hostname_len = ord(data[0])
		hostname = data[1:hostname_len+1]
		data = data[hostname_len+1:]
		ssl_port = struct.unpack('!H', data)[0]
		return hostname, ssl_port

	def redirect(self, data):
		hostname, ssl_port = self.host_unpack(data)

		def clone_cpushare_factory(old_factory):
			factory = cpushare_factory()
			factory.server_certificate = old_factory.server_certificate
			factory.order = old_factory.order
			return factory

		factory = clone_cpushare_factory(self.protocol.factory)
		from cpushare.client import set_context_factory
		client_context_factory = set_context_factory(factory)
		reactor.connectSSL(hostname, ssl_port, factory,
				   client_context_factory)
		self.protocol.factory.stopTrying()
		self.protocol.transport.loseConnection()

	def seccomp_factory_start(self, data):
		hostname, ssl_port = self.host_unpack(data)

		factory = self.protocol.factory.seccomp_factory
		from cpushare.client import set_context_factory
		client_context_factory = set_context_factory(factory)
		reactor.connectSSL(hostname, ssl_port, factory,
				   client_context_factory)

	def sell_handler(self, string):
		def end_common():
			self.seccomp = None # free everything
		def end_success(result):
			end_common()
			self.protocol.sendString(PROTO_SECCOMP +
						 PROTO_SECCOMP_SUCCESS)
		def end_failure(failure):
			end_common()
			self.protocol.sendString(PROTO_SECCOMP +
						 PROTO_SECCOMP_FAILURE +
						 struct.pack('!i', failure.value.status))

		def started(result):
			self.protocol.sendString(PROTO_SECCOMP +
						 PROTO_SECCOMP_RUN)

		def seccomp_run():
			d_start, d_end = self.seccomp.run()
			d_end.addCallbacks(end_success, end_failure)
			d_start.addCallback(started)

		control = string[0]
		data = string[1:]

		if control == PROTO_REDIRECT:
			self.redirect(data)
		elif control == PROTO_LOG:
			print repr(data)
		elif control == PROTO_SECCOMP_FACTORY_SIGNAL:
			factory = self.protocol.factory
			assert factory.factory_type == PROTO_FACTORY_TYPE_CONTROL

			seccomp = []
			for p in factory.seccomp_factory.protocols:
				seccomp.append(p)

			if data == PROTO_SECCOMP_SIGNAL_SIGSTOP:
				print 'Seccomp Factory Received SIGSTOP'
				for p in seccomp:
					if p.seccomp:
						p.seccomp.sigstop()
				factory.seccomp_factory.sigstop = True
			elif data == PROTO_SECCOMP_SIGNAL_SIGCONT:
				print 'Seccomp Factory Received SIGCONT'
				for p in seccomp:
					if p.seccomp:
						#raise Exception('x') # test failure
						p.seccomp.sigcont()
				factory.seccomp_factory.sigstop = False
			elif data == PROTO_SECCOMP_SIGNAL_SIGQUIT:
				print 'Seccomp Factory Received SIGQUIT'
				for p in seccomp:
					if p.seccomp:
						p.seccomp.sigquit()
			elif data == PROTO_SECCOMP_SIGNAL_SIGKILL:
				print 'Seccomp Factory Received SIGKILL'
				for p in seccomp:
					if p.seccomp:
						p.seccomp.sigkill()
					p.transport.loseConnection()
			else:
				print 'PROTO_SIGNAL_FACTORY_TYPE_SECCOMP wrong protocol %r' % data
				reactor.stop()
				return
			self.protocol.sendString(PROTO_SECCOMP_FACTORY_SIGNAL + data)
		elif control == PROTO_SECCOMP_FACTORY_START:
			self.seccomp_factory_start(data)
		elif control == PROTO_SECCOMP:
			control = data[0]
			data = data[1:]

			if control == PROTO_SECCOMP_LOAD:
				seccomp = seccomp_class(data, self)
				assert seccomp.name in self.bytecodes
				assert not self.seccomp
				self.seccomp = seccomp
				if not seccomp.find_text_data():
					self.protocol.sendString(PROTO_SECCOMP +
								 PROTO_SECCOMP_NEED_TEXT_DATA)
				else:
					seccomp_run()
			elif control == PROTO_SECCOMP_NEED_TEXT_DATA:
				self.seccomp.store_text_data(data)
				seccomp_run()
			elif control == PROTO_SECCOMP_FORWARD:
				if self.protocol.seccomp:
					self.protocol.seccomp.recv_from_server(data)
				else:
					self.protocol.sendString(PROTO_LOG +
								 'WARNING: PROTO_SECCOMP_FORWARD and not seccomp')
			elif control == PROTO_SECCOMP_SIGNAL:
				s = self.protocol.seccomp
				if not s:
					self.protocol.sendString(PROTO_LOG +
								 'WARNING: PROTO_SECCOMP_SIGNAL and not seccomp')
				elif data == PROTO_SECCOMP_SIGNAL_SIGQUIT:
					print 'Received SIGQUIT'
					s.sigquit()
				else:
					print 'PROTO_SECCOMP_SIGNAL wrong protocol %r' % data
					reactor.stop()
					return
				self.protocol.sendString(PROTO_SECCOMP +
							 control + data)
			else:
				print 'Unknown seccomp_handler failure %d - %s' % (ord(control), repr(string))
				reactor.stop()
		else:
			if control == PROTO_FAILURE:
				print repr(data)
			else:
				print 'Unknown sell_handler failure %d - %s' % (ord(control), repr(string))
			reactor.stop()

	def buy_handler(self, string):
		control = string[0]
		data = string[1:]

		if control == PROTO_REDIRECT:
			self.redirect(data)
		elif control == PROTO_LOG:
			print repr(data)
		elif control == PROTO_BUY_LOAD:
			assert not data

			try:
				from cpushare_buy import seccomp_gen_hash, checkpoint_sec
			except ImportError:
				print 'You should select the cpushare_buy project for example'
				print 'with "ln -s projects/bench cpushare_buy".'
				reactor.stop()
				return
			except CompilationError:
				print 'The compilation of the bytecode failed, please make sure that the cross'
				print 'compiler is available. You can install it with "./cross-compiler.sh".'
				reactor.stop()
				return
			data = PROTO_BUY_LOAD + chr(len(seccomp_gen_hash))
			for bytecode, seccomp in seccomp_gen_hash.iteritems():
				data += seccomp.build_header(seccomp.heap_kbytes * 1024,
							     seccomp.stack_kbytes * 1024)
			data += struct.pack('!I', checkpoint_sec)
			self.protocol.sendString(data)
		elif control == PROTO_SECCOMP_NEED_TEXT_DATA:
			nr = ord(data[0])
			data = data[1:]

			ask_bytecode = []
			for i in xrange(nr):
				digest = data[:16]
				end = 16+struct.calcsize('!I')
				size = struct.unpack('!I', data[16:end])[0]
				ask_bytecode.append((digest, size))
				data = data[end:]
			assert not data

			data = PROTO_SECCOMP_NEED_TEXT_DATA
			from cpushare_buy import seccomp_gen_hash
			for digest, size in ask_bytecode:
				for bytecode, seccomp in seccomp_gen_hash.iteritems():
					if seccomp.digest == digest and seccomp.size() == size:
						text_data = seccomp_gen_hash[bytecode].text_data
						data += struct.pack('!I', len(text_data)) + text_data
						break
			self.protocol.sendString(data)
		elif control == PROTO_SECCOMP_FACTORY_START:
			self.seccomp_factory_start(data)
		else:
			if control == PROTO_FAILURE:
				print repr(data)
			else:
				print 'Unknown buy_handler failure %d - %s' % (ord(control), repr(string))
			reactor.stop()

class cpushare_protocol(Int32StringReceiver):
	MAX_LENGTH = PROTO_SECCOMP_MAX_LENGTH
	seccomp = None

	def connectionMade(self):
		self.transport.setTcpKeepAlive(1)
		self.transport.setTcpNoDelay(1)

		self.state_machine = state_machine(self)
		self.state_machine.login()

	def stringReceived(self, string):
		self.state_machine.handler(string)

	def connectionLost(self, reason):
		if self.seccomp:
			self.seccomp.sigkill()
		for proto in self.factory.seccomp_factory.protocols:
			proto.transport.loseConnection()

	def stopProducing(self):
		# don't disconnect the cpushare protocol when seccomp protocol exit
		pass

	def pausedConnectionLost(self, reason):
		self.transport.connectionLost(reason)

class cpushare_seccomp_protocol(cpushare_protocol):
	def connectionMade(self):
		self.factory.protocols.append(self)
		super(cpushare_seccomp_protocol, self).connectionMade()

	def connectionLost(self, reason):
		if self.seccomp:
			self.seccomp.sigkill()
		self.factory.protocols.remove(self)
		self.state_machine.connectionLost()

class cpushare_seccomp_factory(ClientFactory):
	protocol = cpushare_seccomp_protocol
	factory_type = PROTO_FACTORY_TYPE_SECCOMP
	sigstop = False

class cpushare_factory(ReconnectingClientFactory):
	maxDelay = 600 # limit the maximum delay to 10 min

	protocol = cpushare_protocol

	# this allows to detect a disconnect/reconnect due a firewall/NAT interruption
	factory_id = md5.new(file('/dev/urandom').read(100) + os.uname()[1]).hexdigest()
	#factory_id = md5.new('x').hexdigest() # test NAT timeout detection
	factory_type = PROTO_FACTORY_TYPE_CONTROL

	def startFactory(self):
		self.seccomp_factory = cpushare_seccomp_factory()
		self.seccomp_factory.factory_id = self.factory_id
		self.seccomp_factory.server_certificate = self.server_certificate
		self.seccomp_factory.order = self.order
		self.seccomp_factory.cache_dir = self.cache_dir
		self.seccomp_factory.seccomp_loader = self.seccomp_loader
		self.seccomp_factory.protocols = []
		self.seccomp_factory.max_ram_mbytes_per_cpu = self.max_ram_mbytes_per_cpu

	def buildProtocol(self, addr):
		self.resetDelay()

		import gc
		gc.collect()

		return ReconnectingClientFactory.buildProtocol(self, addr)

	def clientConnectionLost(self, connector, reason):
		print 'Lost connection. Reason:', reason
		ReconnectingClientFactory.clientConnectionLost(self, connector, reason)

	def clientConnectionFailed(self, connector, reason):
		print 'Connection failed. Reason:', reason
		ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
