Source code for obci.control.launcher.eeg_experiment_finder

#!/usr/bin/python
# -*- coding: utf-8 -*-

import json
import zmq
import time
import socket

import obci.control.common.net_tools as net
from obci.control.common.message import OBCIMessageTool, send_msg, recv_msg, PollingObject
from obci.control.launcher.launcher_messages import message_templates

import obci.control.launcher.launcher_logging as logger
import obci.control.launcher.launcher_tools as launcher_tools
from obci.control.common.obci_control_settings import PORT_RANGE

from obci.drivers.eeg.driver_discovery.driver_discovery import find_drivers, \
    find_bluetooth_amps, find_virtual_amps, find_usb_amps

LOGGER = logger.get_logger("eeg_experiment_finder", "info")


[docs]class EEGExperimentFinder(object): def __init__(self, srv_addrs, ctx, client_push_address, nearby_servers): self.ctx = ctx self.server_req_socket = self.ctx.socket(zmq.REQ) for addr in srv_addrs: self.server_req_socket.connect(addr) self.poller = PollingObject() self.mtool = OBCIMessageTool(message_templates) self.nearby_servers = nearby_servers self._amplified_cache = {} def _running_experiments(self): send_msg(self.server_req_socket, self.mtool.fill_msg("list_experiments")) exp_list, details = self.poll_recv(self.server_req_socket, 2000) if not exp_list: LOGGER.error("Connection to obci_server failed. (list_experiments)") return None exps = exp_list.exp_data running = [] for exp in exps.values(): if exp['status_name'] == launcher_tools.RUNNING or \ exp['status_name'] == launcher_tools.LAUNCHING: running.append(exp) return running
[docs] def find_amplified_experiments(self): running_exps = self._running_experiments() amplified = [] for exp in running_exps: LOGGER.info("Found running experiment: " + str(exp['name'])) infos = self._info_amplified(exp) if infos is not None: print("Found experiments...", str(infos)[:500]) amplified += infos return amplified
def _info_amplified(self, exp_desc): amp_options = [] LOGGER.info("Processing experiment " + str(exp_desc['name']) + "w/ addr: " + str(exp_desc['rep_addrs'])) tcp_addrs = exp_desc['tcp_addrs'] rep_addrs = net.choose_not_local(exp_desc['rep_addrs']) if not rep_addrs: rep_addrs = net.choose_local(exp_desc['rep_addrs'], ip=True) rep_addr = rep_addrs.pop() pub_addrs = net.choose_not_local(exp_desc['pub_addrs']) if not pub_addrs: pub_addrs = net.choose_local(exp_desc['pub_addrs'], ip=True) pub_addr = pub_addrs.pop() tcp_addr = tcp_addrs.pop() LOGGER.info("Chosen experiment addresses: REP -- " + str(rep_addr) + ", PUB -- " + str(pub_addr)) req_sock = self.ctx.socket(zmq.REQ) try: req_sock.connect(rep_addr) send_msg(req_sock, self.mtool.fill_msg('get_experiment_info')) res, details = self.poll_recv(req_sock, 4000) finally: req_sock.close() if not res: LOGGER.error("Connection failed (experiment " + exp_desc['name'] + "), get_experiment_info") return None exp_info = res.dict() # json.loads(res) for field in ["peers_status", "details"]: del exp_info["experiment_status"][field] peer_list = exp_info["peers"] if not self._has_mx(peer_list): LOGGER.info("Experiment " + exp_desc['name'] + " does not have a multiplexer.") return None maybe_amps = self._amp_like_peers(peer_list) if not maybe_amps: LOGGER.info("Experiment " + exp_desc['name'] + " -- no amplifier.") return None req_sock = self.ctx.socket(zmq.REQ) try: req_sock.connect(rep_addr) for peer in maybe_amps: info, params = self._get_amp_info(req_sock, peer) if not self._is_amplifier(info, params): LOGGER.info("Experiment " + exp_desc['name'] + " -- peer " + str(peer) + "is not an amplifier.") continue else: exp_data = self._create_exp_data(exp_info, info, params['param_values'], rep_addr, pub_addr, tcp_addr) amp_options.append(exp_data) finally: req_sock.close() return amp_options def _get_amp_info(self, exp_sock, peer_id): send_msg(exp_sock, self.mtool.fill_msg('get_peer_info', peer_id=peer_id)) info, details = self.poll_recv(exp_sock, 4000) send_msg(exp_sock, self.mtool.fill_msg('get_peer_param_values', peer_id=peer_id)) params, details = self.poll_recv(exp_sock, 4000) if not info or not params: LOGGER.error("get_peer_info failed " + str(peer_id) + " " + str(details)) return None, None info = info.dict() params = params.dict() for field in ["sender", "sender_ip", "receiver", "type", "local_params", "external_params", "config_sources", "launch_dependencies"]: del info[field] return info, params def _is_amplifier(self, peer_info, peer_params): info = peer_info peer_id = info['peer_id'] if not info['peer_type'] == 'obci_peer': LOGGER.info("Peer " + str(peer_id) + " not obci_peer") return False params = peer_params['param_values'] if 'channels_info' not in params or\ 'active_channels' not in params: LOGGER.info('Peer ' + str(peer_id) + " no channels_info param.") return False return True def _create_exp_data(self, exp_info, peer_info, params, rep_addr, pub_addr, tcp_addr): data = {} data['amplifier_params'] = params data['amplifier_peer_info'] = peer_info data['experiment_info'] = exp_info data['rep_addrs'] = [rep_addr] data['pub_addrs'] = [pub_addr] data['tcp_addrs'] = [tcp_addr] return data def _has_mx(self, peer_list): return [peer for peer in peer_list if peer.startswith('mx')] != [] def _amp_like_peers(self, peer_list): return [peer for peer in peer_list if peer.startswith('amplifier')]
[docs] def poll_recv(self, socket, timeout): result, details = self.poller.poll_recv(socket, timeout) if result: result = self.mtool.unpack_msg(result) return result, details
def _gather_other_server_results(ctx, this_addr, ip_list): exps = [] if not ip_list: return exps mtool = OBCIMessageTool(message_templates) other_exps_pull = ctx.socket(zmq.PULL) port = other_exps_pull.bind_to_random_port('tcp://*', min_port=PORT_RANGE[0], max_port=PORT_RANGE[1], max_tries=500) my_push_addr = this_addr + ':' + str(port) LOGGER.info("my exp_data pull: " + my_push_addr) harvester = zmq.Poller() harvester.register(other_exps_pull) reqs = {} for srv_ip in ip_list: addr = 'tcp://' + srv_ip + ':' + net.server_rep_port() req = ctx.socket(zmq.REQ) req.connect(addr) send_msg(req, mtool.fill_msg('find_eeg_experiments', client_push_address='tcp://' + my_push_addr, checked_srvs=[this_addr])) harvester.register(req, zmq.POLLIN) reqs[req] = addr srv_responses = 0 for i in range(len(reqs) * 50): socks = dict(harvester.poll(timeout=300)) for req in socks: msg = recv_msg(req) msg = mtool.unpack_msg(msg) if msg.type == 'rq_ok': LOGGER.info("waiting for experiments from server: " + str(reqs[req])) harvester.unregister(req) req.close() elif msg.type == 'eeg_experiments': LOGGER.info("GOT EXPERIMENTS from: " + msg.sender_ip) exps += msg.experiment_list srv_responses += 1 else: LOGGER.warning('strange msg: ' + str(msg)) if srv_responses == len(ip_list): print("GOT ALL eeg_experiment RESPONSES :-) [addr list: ", ip_list, "]") break other_exps_pull.close() return exps
[docs]def find_eeg_experiments_and_push_results(ctx, srv_addrs, rq_message, nearby_servers): LOGGER = logger.get_logger("eeg_experiment_finder", "info") finder = EEGExperimentFinder(srv_addrs, ctx, rq_message.client_push_address, nearby_servers) exps = finder.find_amplified_experiments() checked = rq_message.checked_srvs if not isinstance(checked, list): checked = [] nrb = {} for uid, srv in nearby_servers.snapshot().items(): if srv.ip not in checked: nrb[uid] = srv if not checked and nearby_servers.dict_snapshot(): my_addr = nearby_servers.ip(hostname=socket.gethostname()) LOGGER.info("checking other servers") print([(srv.hostname, srv.ip) for srv in nrb.values()]) ip_list = [srv.ip for srv in nrb.values() if srv.ip != my_addr] LOGGER.info("number of servers to query: " + str(len(ip_list))) exps += _gather_other_server_results(ctx, my_addr, ip_list) else: LOGGER.info("not checking other servers") LOGGER.info("return to: " + rq_message.client_push_address) to_client = ctx.socket(zmq.PUSH) to_client.connect(rq_message.client_push_address) send_msg(to_client, finder.mtool.fill_msg('eeg_experiments', sender_ip=socket.gethostname(), experiment_list=exps)) LOGGER.info("sent exp data... " + str(exps)[:500] + ' [...]') time.sleep(0.1)
[docs]def find_new_experiments_and_push_results(ctx, rq_message): LOGGER = logger.get_logger("eeg_AMPLIFIER_finder", "info") if not rq_message.amplifier_types: LOGGER.info("AMPLIFIER TYPES NOT SET, FINDING ALL...") driv = find_drivers() else: driv = [] for amptype in rq_message.amplifier_types: if amptype == 'bt' or amptype == 'bluetooth': driv += find_bluetooth_amps() elif amptype == 'usb': driv += find_usb_amps() elif amptype == 'virtual': driv += find_virtual_amps() LOGGER.info("amplifiers! return to: " + rq_message.client_push_address) mtool = OBCIMessageTool(message_templates) to_client = ctx.socket(zmq.PUSH) to_client.connect(rq_message.client_push_address) send_msg(to_client, mtool.fill_msg('eeg_amplifiers', sender_ip=socket.gethostname(), amplifier_list=driv)) LOGGER.info("sent amplifier data... " + str(driv)[:500] + ' [...]') time.sleep(0.1)
if __name__ == '__main__': finder = EEGExperimentFinder(['tcp://127.0.0.1:54654'], zmq.Context(), 'tcp://127.0.0.1:12345', []) exps = finder.find_amplified_experiments() desc = json.dumps(exps, indent=4) print(desc)