# -*- coding: utf-8 -*-
"""
base_experiment.py
==================
Base class for the experiments. ``BaseExperiment`` defines the common patterns that every experiment should have.
Importantly, it starts an independent process called publisher, that will be responsible for broadcasting messages
that are appended to a queue. The messages rely on the pyZMQ library and should be tested further in order to
assess their limitations. The general pattern is that of the PUB/SUB, with one publisher and several subscribers.
The messages should include a *topic* and data. For this, the elements in the queue should be dictionaries with two
keywords: **data** and **topic**. ``data['data']`` will be serialized through the use of cPickle, and is handled
automatically by pyZQM through the use of ``send_pyobj``. The subscribers should be aware of this and use either
unpickle or ``recv_pyobj``.
In order to stop the publisher process, the string ``'stop'`` should be placed in ``data['data']``. The message
will be broadcast and can be used to stop other processes, such as subscribers.
.. TODO:: Check whether the serialization of objects with cPickle may be a bottleneck for performance.
:copyright: Aquiles Carattino <aquiles@uetke.com>
:license: GPLv3, see LICENSE for more details
"""
from multiprocessing import Process, Event
import yaml
from pynta.util import get_logger
from pynta.model.experiment.publisher import Publisher
from pynta.model.experiment.subscriber import subscriber
[docs]class BaseExperiment:
""" Base class to define experiments. Should keep track of the basic methods needed regardless of the experiment
to be performed. For instance, a way to start and a way to finalize a measurement.
"""
def __init__(self, filename=None):
self.config = {} # Dictionary storing the configuration of the experiment
self.logger = get_logger(name=__name__)
self._threads = []
self.publisher = Publisher()
self.publisher.start()
self._connections = []
self.subscriber_events = []
if filename:
self.load_configuration(filename)
[docs] def stop_publisher(self):
""" Puts the proper data to the queue in order to stop the running publisher process
"""
self.logger.info('Stopping the publisher')
self.publisher.stop()
self.stop_subscribers()
[docs] def stop_subscribers(self):
""" Puts the proper data into every alive subscriber in order to stop it.
"""
self.logger.info('Stopping the subscribers')
for event in self.subscriber_events:
event.set()
for connection in self._connections:
if connection['process'].is_alive():
self.logger.info('Stopping {}'.format(connection['method']))
connection['event'].set()
[docs] def connect(self, method, topic, *args, **kwargs):
""" Async method that connects the running publisher to the given method on a specific topic.
:param method: method that will be connected on a given topic
:param str topic: the topic that will be used by the subscriber to discriminate what information to collect.
:param args: extra arguments will be passed to the subscriber, which in turn will pass them to the function
:param kwargs: extra keyword arguments will be passed to the subscriber, which in turn will pass them to the function
"""
event = Event()
self.logger.debug('Arguments: {}'.format(args))
arguments = [method, topic, event]
for arg in args:
arguments.append(arg)
self.logger.info('Connecting {} on topic {}'.format(method.__name__, topic))
self.logger.debug('Arguments: {}'.format(args))
self.logger.debug('KWarguments: {}'.format(kwargs))
self._connections.append({
'method':method.__name__,
'topic': topic,
'process': Process(target=subscriber, args=arguments, kwargs=kwargs),
'event': event,
})
self._connections[-1]['process'].start()
[docs] def load_configuration(self, filename):
""" Loads the configuration file in YAML format.
:param str filename: full path to where the configuration file is located.
:raises FileNotFoundError: if the file does not exist.
"""
self.logger.info('Loading configuration file {}'.format(filename))
try:
with open(filename, 'r') as f:
self.config = yaml.load(f, Loader=yaml.FullLoader)
self.logger.debug('Config loaded')
self.logger.debug(self.config)
except FileNotFoundError:
self.logger.error('The specified file {} could not be found'.format(filename))
raise
except Exception as e:
self.logger.exception('Unhandled exception')
raise
[docs] def clear_threads(self):
""" Keep only the threads that are alive.
"""
self._threads = [thread for thread in self._threads if thread[1].is_alive()]
@property
def num_threads(self):
return len(self._threads)
@property
def connections(self):
return [conn for conn in self._connections if conn['process'].is_alive()]
@property
def alive_threads(self):
alive_threads = 0
for thread in self._threads:
if thread[1].is_alive():
alive_threads += 1
return alive_threads
@property
def list_alive_threads(self):
alive_threads = []
for thread in self._threads:
if thread[1].is_alive():
alive_threads.append(thread)
return alive_threads
[docs] def set_up(self):
""" Needs to be overridden by child classes.
"""
pass
[docs] def finalize(self):
""" Needs to be overridden by child classes.
"""
self.publisher.stop()
[docs] def update_config(self, **kwargs):
self.logger.info('Updating config')
self.logger.debug('Config params: {}'.format(kwargs))
self.config.update(**kwargs)
def __enter__(self):
self.set_up()
return self
def __exit__(self, *args):
self.logger.info("Exiting the experiment")
self.finalize()
self.logger.debug('Number of open connections: {}'.format(len(self.connections)))
for event in self.subscriber_events:
event.set()
for conn in self.connections:
self.logger.debug('Waiting for {} to finish'.format(conn['method']))
conn['process'].join()
self.logger.info('Finished the base experiment')
self.publisher.stop()