#!/usr/bin/python
# -*- coding: utf-8 -*-
import zmq
import uuid
import signal
import os
import sys
import threading
import time
import socket
import argparse
from obci.control.launcher.launcher_messages import message_templates
from obci.control.common.message import OBCIMessageTool, send_msg, recv_msg, PollingObject
import obci.control.common.net_tools as net
import obci.control.common.obci_control_settings as settings
from obci.control.launcher.subprocess_monitor import SubprocessMonitor
from obci.utils.openbci_logging import get_logger, log_crash
from obci.control.common import net_tools
[docs]class HandlerCollection(object):
def __init__(self):
self.handlers = {}
self.default = self._default_handler
self.error = self._error_handler
self.unsupported = self._error_handler
[docs] def new_from(other):
return HandlerCollection._new_from(other)
[docs] def copy(self):
return HandlerCollection._new_from(self)
def _new_from(other):
new = HandlerCollection()
new.handlers = dict(other.handlers)
new.default = other.default
new.error = other.error
new.unsupported = other.unsupported
return new
def _default_handler(*args):
pass
def _error_handler(*args):
pass
[docs] def handler(self, message_type):
def save_handler(fun):
self.handlers[message_type] = fun
return fun
return save_handler
[docs] def default_handler(self):
def save_default_handler(fun):
self.default = fun
return fun
return save_default_handler
[docs] def error_handler(self):
def save_error_handler(fun):
self.error = fun
return fun
return save_error_handler
[docs] def unsupported_handler(self):
def save_unsupported_handler(fun):
self.unsupported = fun
return fun
return save_unsupported_handler
[docs] def handler_for(self, message_name):
handler = self.handlers.get(message_name, None)
return handler
[docs]class OBCIControlPeer(object):
msg_handlers = HandlerCollection()
def __init__(self, source_addresses=None,
rep_addresses=None, pub_addresses=None, name='obci_control_peer'):
# TODO TODO TODO !!!!
# cleaner subclassing of obci_control_peer!!!
self.hostname = socket.gethostname()
self.source_addresses = source_addresses if source_addresses else []
self.rep_addresses = rep_addresses
self.pub_addresses = pub_addresses
self._all_sockets = []
self._pull_addr = 'inproc://publisher_msg'
self._push_addr = 'inproc://publisher'
self._subpr_push_addr = 'inproc://subprocess_info'
self.uuid = str(uuid.uuid4())
self.name = str(name)
self.type = self.peer_type()
log_dir = os.path.join(settings.OBCI_CONTROL_LOG_DIR,
self.name + '-' + self.uuid[:8])
if not hasattr(self, 'logger'):
if not os.path.exists(log_dir):
os.makedirs(log_dir)
self.logger = get_logger(self.peer_type(), log_dir=log_dir,
stream_level=net_tools.peer_loglevel(), obci_peer=self)
self.mtool = self.message_tool()
if not hasattr(self, "ctx"):
self.ctx = zmq.Context()
self.subprocess_mgr = SubprocessMonitor(self.ctx, self.uuid, logger=self.logger)
self.net_init()
if self.source_addresses:
self.registration_response = self.register()
self._handle_registration_response(self.registration_response)
else:
self.registration_response = None
self.interrupted = False
signal.signal(signal.SIGTERM, self.signal_handler())
signal.signal(signal.SIGINT, self.signal_handler())
[docs] def signal_handler(self):
def handler(signum, frame):
self.logger.info("[!!!!] %s %s %s %s",
self.name, "got signal", signum, frame)
self.interrupted = True
return handler
[docs] def peer_type(self):
return 'obci_control_peer'
def _publisher_thread(self, pub_addrs, pull_address, push_addr):
# FIXME aaaaahhh pub_addresses are set here, not in the main thread
# (which reads them in _register method)
pub_sock, self.pub_addresses = self._init_socket(
pub_addrs, zmq.PUB)
pull_sock = self.ctx.socket(zmq.PULL)
pull_sock.bind(pull_address)
push_sock = self.ctx.socket(zmq.PUSH)
push_sock.connect(push_addr)
send_msg(push_sock, b'1')
po = PollingObject()
while not self._stop_publishing:
try:
to_publish, det = po.poll_recv(pull_sock, 500)
if to_publish:
send_msg(pub_sock, to_publish)
except:
# print self.name, '.Publisher -- STOP.'
break
# self.logger.info( "close sock %s %s", pub_addrs, pub_sock)
pub_sock.close()
pull_sock.close()
push_sock.close()
def _subprocess_info(self, push_addr):
push_sock = self.ctx.socket(zmq.PUSH)
push_sock.connect(push_addr)
send_msg(push_sock, b'1')
while not self._stop_monitoring:
dead = self.subprocess_mgr.not_running_processes()
if dead:
# self.logger.warning("DEAD process" + str(dead))
for key, status in dead.items():
send_msg(push_sock, self.mtool.fill_msg('dead_process', machine=key[0],
pid=key[1], status=status))
time.sleep(0.5)
push_sock.close()
def _push_sock(self, ctx, addr):
sock = ctx.socket(zmq.PUSH)
sock.connect(addr)
return sock
def _prepare_publisher(self):
tmp_pull = self.ctx.socket(zmq.PULL)
tmp_pull.bind(self._pull_addr)
self.pub_thr = threading.Thread(target=self._publisher_thread,
args=[self.pub_addresses,
self._push_addr,
self._pull_addr])
self.pub_thr.daemon = True
self._stop_publishing = False
self.pub_thr.start()
recv_msg(tmp_pull)
self._publish_socket = self._push_sock(self.ctx, self._push_addr)
self._all_sockets.append(self._publish_socket)
tmp_pull.close()
def _prepare_subprocess_info(self):
self._subprocess_pull = self.ctx.socket(zmq.PULL)
self._subprocess_pull.bind(self._subpr_push_addr)
self.subprocess_thr = threading.Thread(target=self._subprocess_info,
args=[self._subpr_push_addr])
self.subprocess_thr.daemon = True
self._stop_monitoring = False
self.subprocess_thr.start()
recv_msg(self._subprocess_pull)
self._all_sockets.append(self._subprocess_pull)
[docs] def net_init(self):
# (self.pub_socket, self.pub_addresses) = self._init_socket(
# self.pub_addresses, zmq.PUB)
self._all_sockets = []
self._prepare_publisher()
self._prepare_subprocess_info()
(self.rep_socket, self.rep_addresses) = self._init_socket(
self.rep_addresses, zmq.REP)
self.rep_socket.setsockopt(zmq.LINGER, 0)
self._all_sockets.append(self.rep_socket)
print("\n\tname: {0}\n\tpeer_type: {1}\n\tuuid: {2}\n".format(
self.name, self.peer_type(), self.uuid))
print("rep: {0}".format(self.rep_addresses))
print("pub: {0}\n".format(self.pub_addresses))
self.source_req_socket = self.ctx.socket(zmq.REQ)
if self.source_addresses:
for addr in self.source_addresses:
self.source_req_socket.connect(addr)
self._all_sockets.append(self.source_req_socket)
self._set_poll_sockets()
def _init_socket(self, addrs, zmq_type):
# print self.peer_type(), "addresses for socket init:", addrs
addresses = addrs if addrs else ['tcp://*']
random_port = True if not addrs else False
sock = self.ctx.socket(zmq_type)
port = None
try:
for i, addr in enumerate(addresses):
if random_port and net.is_net_addr(addr):
port = str(sock.bind_to_random_port(addr,
min_port=settings.PORT_RANGE[0],
max_port=settings.PORT_RANGE[1]))
addresses[i] = addr + ':' + str(port)
else:
sock.bind(addr)
except Exception as e:
self.logger.critical("CRITICAL error: %s", str(e))
raise(e)
advertised_addrs = []
for addr in addresses:
if addr.startswith('tcp://*'):
port = addr.rsplit(':', 1)[1]
advertised_addrs.append('tcp://' + socket.gethostname() + ':' + str(port))
advertised_addrs.append('tcp://' + 'localhost:' + str(port))
else:
advertised_addrs.append(addr)
return sock, advertised_addrs
def _register(self, rep_addrs, pub_addrs, params):
message = self.mtool.fill_msg("register_peer", peer_type=self.type,
uuid=self.uuid,
rep_addrs=rep_addrs,
pub_addrs=pub_addrs,
name=self.name,
other_params=params)
self.logger.debug("_register() " + str(message))
send_msg(self.source_req_socket, message)
response_str = recv_msg(self.source_req_socket)
response = self.mtool.unpack_msg(response_str)
if response.type == "rq_error":
self.logger.critical("Registration failed: {0}".format(response_str))
sys.exit(2)
return response
[docs] def register(self):
params = self.params_for_registration()
return self._register(self.rep_addresses, self.pub_addresses, params)
def _handle_registration_response(self, response):
pass
[docs] def shutdown(self):
self.logger.info("SHUTTING DOWN")
sys.exit(0)
[docs] def params_for_registration(self):
return {}
[docs] def basic_sockets(self):
return [self.rep_socket, self._subprocess_pull]
[docs] def custom_sockets(self):
"""
subclass this
"""
return []
[docs] def all_sockets(self):
return self.basic_sockets() + self.custom_sockets()
def _set_poll_sockets(self):
self._poll_sockets = self.all_sockets()
@log_crash
[docs] def run(self):
self.pre_run()
poller = zmq.Poller()
poll_sockets = list(self._poll_sockets)
for sock in poll_sockets:
poller.register(sock, zmq.POLLIN)
try:
while True:
socks = []
try:
socks = dict(poller.poll())
except zmq.ZMQError as e:
self.logger.warning(": zmq.poll(): " + str(e.strerror))
for sock in socks:
if socks[sock] == zmq.POLLIN:
more = True
while more:
try:
msg = recv_msg(sock, flags=zmq.NOBLOCK)
except zmq.ZMQError as e:
if e.errno == zmq.EAGAIN or sock.getsockopt(zmq.TYPE) == zmq.REP:
more = False
else:
self.logger.error("handling socket read error: %s %d %s",
e, e.errno, sock)
poller.unregister(sock)
if sock in poll_sockets:
poll_sockets.remove(sock)
self.handle_socket_read_error(sock, e)
break
else:
self.handle_message(msg, sock)
else:
self.logger.warning("sock not zmq.POLLIN! Ignore !")
if self.interrupted:
break
self._update_poller(poller, poll_sockets)
except Exception as e:
# from urllib2 import HTTPError
# try:
# self.logger.critical("UNHANDLED EXCEPTION IN %s!!! ABORTING! Exception data: %s, e.args: %s, %s",
# self.name, e, e.args, vars(e), exc_info=True,
# extra={'stack': True})
# except HTTPError, e:
# self.logger.info('sentry sending failed....')
self._clean_up()
raise(e)
self._clean_up()
def _crash_extra_description(self, exception=None):
return ""
def _crash_extra_data(self, exception=None):
return {}
def _crash_extra_tags(self, exception=None):
return {'obci_part': 'launcher'}
def _update_poller(self, poller, curr_sockets):
self._set_poll_sockets()
new_sockets = list(self._poll_sockets)
for sock in new_sockets:
if sock not in curr_sockets:
poller.register(sock, zmq.POLLIN)
for sock in curr_sockets:
if sock not in new_sockets:
poller.unregister(sock)
curr_sockets = new_sockets
[docs] def handle_socket_read_error(self, socket, error):
pass
[docs] def pre_run(self):
pass
def _clean_up(self):
time.sleep(0.01)
self._stop_publishing = True
self._stop_monitoring = True
self.pub_thr.join()
self.subprocess_thr.join()
for sock in self._all_sockets:
# print self.name, "closing ", sock
sock.close()
# try:
# self.ctx.term()
# except zmq.ZMQError(), e:
# print "Ctx closing interrupted."
self.clean_up()
[docs] def clean_up(self):
self.logger.info("CLEANING UP")
# message handling ######################################
[docs] def handle_message(self, message, sock):
handler = self.msg_handlers.default
try:
msg = self.mtool.unpack_msg(message)
if msg.type != "ping" and msg.type != "rq_ok":
self.logger.debug("got message: {0}".format(msg.type))
if msg.type == "get_tail":
print(self.msg_handlers)
except ValueError as e:
print("{0} [{1}], Bad message format! {2}".format(
self.name, self.peer_type(), message))
if sock.getsockopt(zmq.TYPE) == zmq.REP:
handler = self.msg_handlers.error
msg = message
print(e)
else:
msg_type = msg.type
handler = self.msg_handlers.handler_for(msg_type)
if handler is None:
# print "{0} [{1}], Unknown message type: {2}".format(
# self.name, self.peer_type(),msg_type)
# print message
handler = self.msg_handlers.unsupported
handler(self, msg, sock)
@msg_handlers.handler("register_peer")
[docs] def handle_register_peer(self, message, sock):
"""Subclass this."""
result = self.mtool.fill_msg("rq_error",
request=vars(message), err_code="unsupported_peer_type")
send_msg(sock, result)
@msg_handlers.handler("ping")
[docs] def handle_ping(self, message, sock):
if sock.socket_type in [zmq.REP, zmq.ROUTER]:
send_msg(sock, self.mtool.fill_msg("pong"))
@msg_handlers.default_handler()
[docs] def default_handler(self, message, sock):
"""Ignore message"""
pass
@msg_handlers.unsupported_handler()
[docs] def unsupported_msg_handler(self, message, sock):
if sock.socket_type in [zmq.REP, zmq.ROUTER]:
msg = self.mtool.fill_msg("rq_error",
request=vars(message), err_code="unsupported_msg_type", sender=self.uuid)
send_msg(sock, msg)
# print "--"
@msg_handlers.error_handler()
[docs] def bad_msg_handler(self, message, sock):
msg = self.mtool.fill_msg("rq_error",
request=message, err_code="invalid_msg_format")
send_msg(sock, msg)
@msg_handlers.handler("kill")
[docs] def handle_kill(self, message, sock):
if not message.receiver or message.receiver == self.uuid:
self.cleanup_before_net_shutdown(message, sock)
self._clean_up()
self.shutdown()
@msg_handlers.handler("dead_process")
[docs] def handle_dead_process(self, message, sock):
pass
[docs] def cleanup_before_net_shutdown(self, kill_message, sock=None):
for sock in self._all_sockets:
sock.close()
[docs]class RegistrationDescription(object):
def __init__(self, uuid, name, rep_addrs, pub_addrs, machine, pid, other=None):
self.machine_ip = machine
self.pid = pid
self.uuid = uuid
self.name = name
self.rep_addrs = rep_addrs
self.pub_addrs = pub_addrs
self.other = other
[docs] def info(self):
return dict(machine=self.machine_ip, pid=self.pid, uuid=self.uuid, name=self.name,
rep_addrs=self.rep_addrs, pub_addrs=self.pub_addrs, other=self.other)
[docs]def basic_arg_parser():
parser = argparse.ArgumentParser(add_help=False,
description='Basic OBCI control peer with public PUB and REP sockets.')
parser.add_argument('--sv-addresses', nargs='+',
help='REP Addresses of the peer supervisor,\
for example an OBCI Experiment controller may need OBCI Server addresses')
parser.add_argument('--rep-addresses', nargs='+',
help='REP Addresses of the peer.')
parser.add_argument('--pub-addresses', nargs='+',
help='PUB Addresses of the peer.')
return parser
[docs]class OBCIControlPeerError(Exception):
pass
[docs]class MessageHandlingError(OBCIControlPeerError):
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser(parents=[basic_arg_parser()])
parser.add_argument('--name', default='obci_control_peer',
help='Human readable name of this process')
args = parser.parse_args()
peer = OBCIControlPeer(args.sv_addresses,
args.rep_addresses, args.pub_addresses, args.name)
peer.run()