# ---LICENSE-BEGIN - DO NOT CHANGE OR MOVE THIS HEADER
# This file is part of the Neurorobotics Platform software
# Copyright (C) 2014,2015,2016,2017 Human Brain Project
# https://www.humanbrainproject.eu
#
# The Human Brain Project is a European Commission funded project
# in the frame of the Horizon2020 FET Flagship plan.
# http://ec.europa.eu/programmes/horizon2020/en/h2020-section/fet-flagships
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# ---LICENSE-END
"""
This module inspects a transfer function and returns a formal object describing it
"""
from builtins import map
from builtins import object
__author__ = 'Claudio Sousa'
from cle_ros_msgs.msg import TransferFunction, Device, Topic, Variable, ExperimentPopulationInfo
import ast
import astunparse # provides unparse method in python < 3.9, otherwise use ast.unparse
import json
import textwrap
import logging
logger = logging.getLogger(__name__)
[docs]class StructureType(object):
"""
Transfer function type
"""
Neuron2Robot = 2
Robot2Neuron = 1
NeuronMonitor = 3
device_types = {
'ac_source': 'ACSource',
'dc_source': 'DCSource',
'fixed_frequency': 'FixedFrequency',
'leaky_integrator_alpha': 'LeakyIntegratorAlpha',
'leaky_integrator_exp': 'LeakyIntegratorExp',
'nc_source': 'NCSource',
'poisson': 'Poisson',
'spike_recorder': 'SpikeRecorder',
'population_rate': 'PopulationRate',
'raw_signal': 'RawSignal',
'injector': 'SpikeInjector'
}
[docs]class TransferFunctionASTParser(object):
"""
Transfer function source code parser that uses the generated AST to instrospect the TF structure
"""
def __init__(self):
self._tf = TransferFunction()
self._imports = dict()
[docs] def parse(self, code):
"""
Parses a TF code source
:returns: the structured TF
"""
tfAst = ast.parse(code)
self._tf.code = TransferFunctionASTParser.__get_tf_code(code)
self._visitAST(tfAst)
return self._tf
def _visitAST(self, tfAst):
"""
Visits the TF generated AST
"""
functions = [instr for instr in tfAst.body if isinstance(instr, ast.FunctionDef)]
importsFrom = [instr for instr in tfAst.body if isinstance(instr, ast.ImportFrom)]
for importFrom in importsFrom:
self._visit_ImportFrom(importFrom)
assert len(functions) == 1, 'Expected one function on TF, found %i' % len(functions)
self._visit_FunctionDef(functions[0])
def _visit_ImportFrom(self, importFrom):
"""
Handles import AST nodes (ie 'from ... import [as ...]')
"""
module = importFrom.module
for nameNode in importFrom.names:
alias = nameNode.asname if nameNode.asname else nameNode.name
self._imports[alias] = (nameNode.name, module)
def _visit_FunctionDef(self, functionNode):
"""
Handles function definition AST node
"""
self._tf.name = functionNode.name
for decorator in functionNode.decorator_list:
self._visit_decorator(decorator)
def _visit_decorator(self, decorator):
"""
Handles a function decorator
"""
decorator_name = self._get_attribute_path(decorator.func)[-1]
method = '_visit_' + decorator_name
visitor = getattr(self, method, self._visit_UndeclaredDecorator)
visitor(decorator)
def _visit_MapRobotPublisher(self, decorator):
"""
Handles MapRobotPublisher decorators
"""
self._visit_publisher(decorator)
def _visit_publisher(self, decorator):
"""
Handles generic publisher decorators
"""
assert len(
decorator.args) == 2, '%s: expected two arguments, found %i' % (decorator.func.attr,
len(decorator.args))
topic = self._visit_Topic(decorator.args[1])
topic.name = TransferFunctionASTParser._get_value(decorator.args[0])
topic.publishing = True
self._tf.topics.append(topic)
def _visit_subscriber(self, decorator):
"""
Handles generic subscriber decorators
"""
assert len(
decorator.args) == 2, '%s: expected two arguments, found %i' % (decorator.func.attr,
len(decorator.args))
topic = self._visit_Topic(decorator.args[1])
topic.name = TransferFunctionASTParser._get_value(decorator.args[0])
topic.publishing = False
self._tf.topics.append(topic)
def _visit_MapSpikeSource(self, decorator):
"""
Handles MapSpikeSource decorators
"""
self._visit_DeviceDecorator(decorator)
def _visit_MapSpikeSink(self, decorator):
"""
Handles MapSpikeSink decorators
"""
self._visit_DeviceDecorator(decorator)
def _visit_DeviceDecorator(self, decorator):
"""
Handles Device decorators
"""
assert len(
decorator.args) == 3, "%s: expected 3 args, found %i" % (self.str_node(decorator),
len(decorator.args))
device_name = self._get_attribute_path(decorator.args[2])[-1]
device = Device()
device.type = device_types[device_name]
device.name = TransferFunctionASTParser._get_value(decorator.args[0])
device.neurons = self._get_neurons(decorator.args[1])
self._tf.devices.append(device)
def _visit_NeuronMonitor(self, decorator):
"""
Handles NeuronMonitor decorators
"""
self._tf.type = StructureType.NeuronMonitor
device_name = self._get_attribute_path(decorator.args[1])[-1]
topic = Topic()
topic.topic = '/monitor/' + device_name
topic.type = 'cle_ros_msgs/' + \
('SpikeEvent' if device_name == 'spike_recorder' else 'SpikeRate')
topic.name = 'publisher'
topic.publishing = True
self._tf.topics.append(topic)
device = Device()
device.type = device_types[device_name]
device.name = 'device'
device.neurons = self._get_neurons(decorator.args[0])
self._tf.devices.append(device)
def _get_neurons(self, neuron_node):
"""
Parses the Population information
:returns: the list of neurons and population name
"""
neurons = ExperimentPopulationInfo()
if isinstance(neuron_node, ast.Subscript):
pop_path = self._get_attribute_path(neuron_node.value)
neurons.name = pop_path[-1]
neurons.start, neurons.stop, neurons.step = self._get_slice(neuron_node.slice)
neurons.type = 1
elif isinstance(neuron_node, ast.Attribute):
pop_path = self._get_attribute_path(neuron_node)
neurons.name = pop_path[-1]
neurons.start, neurons.stop, neurons.step = [0, 0, 0]
neurons.type = 0
else:
neurons.name = "Custom logic"
neurons.start, neurons.stop, neurons.step = [0, 0, 0]
neurons.type = 0
return neurons
def _get_slice(self, index):
"""
Parses a slice instance (eg. slice(1, 4, 1)), an index (eg. [1]) or a slice (eg. [1:4:1])
"""
if isinstance(index, ast.Index):
if (isinstance(index.value, ast.Call) and isinstance(index.value.func, ast.Name) and
index.value.func.id == "slice"):
# eg: slice(1, 4, 1)
args = [TransferFunctionASTParser._get_value(arg) for arg in index.value.args]
if len(args) == 2:
args.append(1)
return args
elif isinstance(index.value, ast.Constant):
# eg: [1]
try:
num = TransferFunctionASTParser._get_value(index.value)
except ValueError:
raise Exception(
"_get_slice unknown index value %s" % self.str_node(index.value))
return [num, num + 1, 1]
else:
raise Exception("_get_slice unknown index value %s" % self.str_node(index.value))
elif isinstance(index, ast.Slice):
# [1:4:1]
return list(map(self._get_value, [index.lower, index.upper, index.step]))
else:
raise Exception("get_slice unexpected argument, got %s" % self.str_node(index))
# pylint: disable=unused-argument
def _visit_Robot2Neuron(self, decorator):
"""
Handles the Robot2Neuron decorator
"""
self._tf.type = StructureType.Robot2Neuron
def _visit_Neuron2Robot(self, decorator):
"""
Handles the Neuron2Robot decorator
"""
if decorator.args:
assert len(decorator.args) == 1, 'Neuron2Robot: expected one arguments, found %i' % len(
decorator.args)
topic = self._visit_Topic(decorator.args[0])
topic.name = '__return__'
topic.publishing = True
self._tf.topics.append(topic)
self._tf.type = StructureType.Neuron2Robot
def _visit_MapRobotSubscriber(self, decorator):
"""
Handles the MapRobotSubscriber decorator
"""
self._visit_subscriber(decorator)
def _visit_Topic(self, node):
"""
Parses the Topic AST node and returns it's information
"""
assert isinstance(node, ast.Call) and node.func.id == 'Topic', \
'expected a Topic instantiation, found %s' % \
self.str_node(node)
assert len(node.args) == 2, 'Expected Topic to have two arguments'
topic = Topic()
# print(self.str_node(node.args[1]))
topic.topic = TransferFunctionASTParser._get_value(node.args[0])
topic.type = self._get_topic_ros_type(node.args[1])
return topic
def _visit_MapCSVRecorder(self, decorator):
"""
Handles the MapCSVRecorder decorator
"""
variable = Variable()
variable.name = TransferFunctionASTParser._get_value(decorator.args[0])
variable.type = "csv"
initialValue = {}
if len(decorator.args) >= 2:
initialValue['filename'] = TransferFunctionASTParser._get_value(decorator.args[1])
if len(decorator.args) >= 3:
initialValue['headers'] = TransferFunctionASTParser._get_list_value(decorator.args[2])
for keyword in decorator.keywords:
if keyword.arg == 'filename':
initialValue['filename'] = TransferFunctionASTParser._get_value(keyword.value)
elif keyword.arg == 'headers':
initialValue['headers'] = TransferFunctionASTParser._get_list_value(keyword.value)
variable.initial_value = json.dumps(initialValue)
self._tf.variables.append(variable)
def _visit_MapVariable(self, decorator):
"""
Handles the MapVariable decorator
"""
assert len(decorator.args) == 1, \
'Expected MapVariable %s to have at one argument, found %i' \
% (decorator.func.attr, len(decorator.args))
variable = Variable(name=TransferFunctionASTParser._get_value(decorator.args[0]))
for initial_value_node in (k.value for k in decorator.keywords if k.arg == 'initial_value'):
if isinstance(initial_value_node, (ast.Constant, ast.Name)):
v = TransferFunctionASTParser._get_value(initial_value_node)
variable.initial_value = str(v)
variable.type = type(v).__name__
else: # any other expression
variable.initial_value = astunparse.unparse(initial_value_node).strip()
variable.type = "expression"
self._tf.variables.append(variable)
def _get_topic_ros_type(self, node):
"""
Parses the ROS topic type
:returns: the ROS topic type
"""
path = self._get_attribute_path(node)
if path[0] in self._imports:
name, module = self._imports[path[0]]
path = module.split('.') + [name] + path[1:]
if len(path) == 3:
path = path[::2]
return '/'.join(path)
def _get_attribute_path(self, node):
"""
Parses an attribute AST node
:returns: list of string words in the attributs
"""
if isinstance(node, ast.Name):
return [TransferFunctionASTParser._get_value(node)]
path = self._get_attribute_path(node.value)
path.append(node.attr)
return path
@staticmethod
def __get_tf_code(tf):
'''
Extracts the body from the given in-memory transfer function
:param tf: The transfer function
:return: The extracted body of the transfer function
'''
import re
fnbody = re.search(r'.*\ndef +[a-zA-Z0-9]* *[^\)]*[^\n]*(?P<body>(\n( [^\n]*|))*)', tf)
if fnbody:
fnbody = fnbody.group('body')
if not fnbody:
return tf.source
return textwrap.dedent(fnbody[1:]).rstrip('\n')
@staticmethod
def _get_list_value(node):
"""
Parses a list AST node
"""
return [TransferFunctionASTParser._get_value(v) for v in node.elts]
@staticmethod
def _get_value(node):
"""
Parses a Value/Name AST Node
"""
scalarNames = {'None': None, 'False': False, 'True': True}
scalarObjects = {
'Constant': lambda o: o.value,
'Name': lambda o: scalarNames[o.id] if o.id in scalarNames else o.id,
'Call': lambda o: o.func
}
class_name = node.__class__.__name__
assert class_name in scalarObjects, 'Unknown scalar class %s' % class_name
return scalarObjects[class_name](node)
def _visit_UndeclaredDecorator(self, decorator):
"""
Fallback when an unknown decorator is found
"""
logger.debug('Undeclared decorator: %s', decorator.func.attr)
logger.debug(self.str_node(decorator.func))
logger.debug('Args::')
for arg in decorator.args:
logger.debug('\t-: %s', self.str_node(arg))
logger.debug('keywords::')
for keyword in decorator.keywords:
logger.debug('\t-: %s', self.str_node(keyword))
[docs] def str_node(self, node):
"""
Stringifies an AST node
"""
if isinstance(node, ast.AST):
fields = [(name, self.str_node(val)) for name, val in ast.iter_fields(node)
if name not in ('left', 'right')]
rv = '%s(%s' % (node.__class__.__name__, ', '.join('%s=%s' % field for field in fields))
return rv + ')'
else:
return repr(node)