diff --git a/.gitignore b/.gitignore index c8329ab7..b093d80f 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,5 @@ arm64 *.neux /test_*plot*png /examples/plot.py +/tests/utils/test_rotation.net.nml +/tests/utils/test_rotation.net.png diff --git a/examples/generate_if_curve.py b/examples/generate_if_curve.py index 0a1a4bf8..b7d8183f 100644 --- a/examples/generate_if_curve.py +++ b/examples/generate_if_curve.py @@ -16,4 +16,5 @@ plot_voltage_traces=not nogui, plot_if=not nogui, plot_iv=not nogui, + save_if_data_to="if_data.dat" ) diff --git a/pyneuroml/analysis/__init__.py b/pyneuroml/analysis/__init__.py index dc1d0482..f11a65b6 100644 --- a/pyneuroml/analysis/__init__.py +++ b/pyneuroml/analysis/__init__.py @@ -60,7 +60,28 @@ def generate_current_vs_frequency_curve( segment_id: typing.Optional[str] = None, fraction_along: typing.Optional[float] = None ): - """Generate current vs firing rate frequency curve for provided cell. + """Generate current vs firing rate frequency curves for provided cell. + + It runs a number of simulations of the cell with different input currents, + and generates the following metrics/graphs: + + - sub-threshold potentials for all currents + - F-I curve for the cell + - membrane potential traces for each stimulus + + Using the method arguments, these graphs and the data they are generated + from may be enabled/disabled/saved. + + When the I-F curve plotting is enabled, it also notes the spiking threshold + current value in a file. Note that this value is simply the lowest input + stimulus current at which spiking was detected, so should be taken as an + approximate value. It does not, for example, implement a bisection based + method to find the accurate spiking threshold current. This is also true + for the I-F curves: the resolution is constrained by the values of the + stimulus currents. + + The various plotting related arguments to this method are passed on to + :py:method`pynml.generate_plot` :param nml2_file: name of NeuroML file containing cell definition :type nml2_file: str @@ -303,6 +324,10 @@ def generate_current_vs_frequency_curve( volts_labels = [] if_results = {} iv_results = {} + + # arbitrarily large value to start with + spike_threshold_current = float(math.inf) + for i in range(number_cells): t = np.array(results["t"]) * 1000 v = np.array(results["%s[%i]/v" % (pop.id, i)]) * 1000 @@ -315,6 +340,7 @@ def generate_current_vs_frequency_curve( mm = max_min(v, t, delta=0, peak_threshold=spike_threshold_mV) spike_times = mm["maxima_times"] freq = 0.0 + if len(spike_times) > 2: count = 0 for s in spike_times: @@ -323,6 +349,9 @@ def generate_current_vs_frequency_curve( ): count += 1 freq = 1000 * count / float(analysis_duration) + if count > 0: + if stims[i] < spike_threshold_current: + spike_threshold_current = stims[i] mean_freq = mean_spike_frequency(spike_times) logger.debug( @@ -399,6 +428,8 @@ def generate_current_vs_frequency_curve( with open(save_if_data_to, "w") as if_file: for i in range(len(stims_pA)): if_file.write("%s\t%s\n" % (stims_pA[i], freqs[i])) + with open(f"threshold_i_{save_if_data_to}", "w") as if_file: + print(spike_threshold_current, file=if_file) if plot_iv: stims = sorted(iv_results.keys()) stims_pA = [ii * 1000 for ii in sorted(iv_results.keys())] diff --git a/pyneuroml/plot/Plot.py b/pyneuroml/plot/Plot.py index 9387aa6a..09a352a7 100644 --- a/pyneuroml/plot/Plot.py +++ b/pyneuroml/plot/Plot.py @@ -12,7 +12,6 @@ import typing import matplotlib import matplotlib.axes -import plotly.graph_objects as go logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -374,6 +373,7 @@ def generate_interactive_plot( Note: you can also save the file from the interactive web page. :type save_figure_to: str """ + import plotly.graph_objects as go fig = go.Figure() if len(xvalues) != len(yvalues): diff --git a/pyneuroml/plot/PlotMorphology.py b/pyneuroml/plot/PlotMorphology.py index d65c0612..61651e86 100644 --- a/pyneuroml/plot/PlotMorphology.py +++ b/pyneuroml/plot/PlotMorphology.py @@ -16,28 +16,24 @@ import typing import logging -from vispy import app, scene import numpy import matplotlib from matplotlib import pyplot as plt -import plotly.graph_objects as go from pyneuroml.pynml import read_neuroml2_file from pyneuroml.utils.cli import build_namespace from pyneuroml.utils import extract_position_info from pyneuroml.utils.plot import ( add_text_to_matplotlib_2D_plot, - add_text_to_vispy_3D_plot, get_next_hex_color, add_box_to_matplotlib_2D_plot, get_new_matplotlib_morph_plot, autoscale_matplotlib_plot, add_scalebar_to_matplotlib_plot, add_line_to_matplotlib_2D_plot, - create_new_vispy_canvas, - get_cell_bound_box, + DEFAULTS, ) -from neuroml import SegmentGroup, Cell, Segment +from neuroml import SegmentGroup, Cell, Segment, NeuroMLDocument from neuroml.neuro_lex_ids import neuro_lex_ids @@ -45,19 +41,6 @@ logger.setLevel(logging.INFO) -DEFAULTS = { - "v": False, - "nogui": False, - "saveToFile": None, - "interactive3d": False, - "plane2d": "xy", - "minWidth": 0.8, - "square": False, - "plotType": "constant", - "theme": "light", -} - - def process_args(): """ Parse command-line arguments. @@ -156,6 +139,8 @@ def plot_from_console(a: typing.Optional[typing.Any] = None, **kwargs: str): a = build_namespace(DEFAULTS, a, **kwargs) print(a) if a.interactive3d: + from pyneuroml.plot.PlotMorphologyVispy import plot_interactive_3D + plot_interactive_3D( nml_file=a.nml_file, min_width=a.min_width, @@ -176,199 +161,10 @@ def plot_from_console(a: typing.Optional[typing.Any] = None, **kwargs: str): ) -def plot_interactive_3D( - nml_file: str, - min_width: float = DEFAULTS["minWidth"], - verbose: bool = False, - plot_type: str = "constant", - title: typing.Optional[str] = None, - theme: str = "light", - nogui: bool = False, -): - """Plot interactive plots in 3D using Vispy - - https://vispy.org - - :param nml_file: path to NeuroML cell file - :type nml_file: str - :param min_width: minimum width for segments (useful for visualising very - thin segments): default 0.8um - :type min_width: float - :param verbose: show extra information (default: False) - :type verbose: bool - :param plot_type: type of plot, one of: - - - "detailed": show detailed morphology taking into account each segment's - width - - "constant": show morphology, but use constant line widths - - "schematic": only plot each unbranched segment group as a straight - line, not following each segment - - This is only applicable for neuroml.Cell cells (ones with some - morphology) - - :type plot_type: str - :param title: title of plot - :type title: str - :param theme: theme to use (light/dark) - :type theme: str - :param nogui: toggle showing gui (for testing only) - :type nogui: bool - """ - if plot_type not in ["detailed", "constant", "schematic"]: - raise ValueError( - "plot_type must be one of 'detailed', 'constant', or 'schematic'" - ) - - if verbose: - print(f"Plotting {nml_file}") - - nml_model = read_neuroml2_file( - nml_file, - include_includes=True, - check_validity_pre_include=False, - verbose=False, - optimized=True, - ) - - ( - cell_id_vs_cell, - pop_id_vs_cell, - positions, - pop_id_vs_color, - pop_id_vs_radii, - ) = extract_position_info(nml_model, verbose) - - # Collect all markers and only plot one markers object - # this is more efficient than multiple markers, one for each point. - # TODO: also collect all line points and only use one object rather than a - # new object for each cell: will only work for the case where all lines - # have the same width - marker_sizes = [] - marker_points = [] - marker_colors = [] - - if title is None: - title = f"{nml_model.networks[0].id} from {nml_file}" - - logger.debug(f"positions: {positions}") - logger.debug(f"pop_id_vs_cell: {pop_id_vs_cell}") - logger.debug(f"cell_id_vs_cell: {cell_id_vs_cell}") - logger.debug(f"pop_id_vs_color: {pop_id_vs_color}") - logger.debug(f"pop_id_vs_radii: {pop_id_vs_radii}") - - if len(positions.keys()) > 1: - only_pos = [] - for posdict in positions.values(): - for poss in posdict.values(): - only_pos.append(poss) - - pos_array = numpy.array(only_pos) - center = numpy.array( - [ - numpy.mean(pos_array[:, 0]), - numpy.mean(pos_array[:, 1]), - numpy.mean(pos_array[:, 2]), - ] - ) - x_min = numpy.min(pos_array[:, 0]) - x_max = numpy.max(pos_array[:, 0]) - x_len = abs(x_max - x_min) - - y_min = numpy.min(pos_array[:, 1]) - y_max = numpy.max(pos_array[:, 1]) - y_len = abs(y_max - y_min) - - z_min = numpy.min(pos_array[:, 2]) - z_max = numpy.max(pos_array[:, 2]) - z_len = abs(z_max - z_min) - - view_min = center - numpy.array([x_len, y_len, z_len]) - view_max = center + numpy.array([x_len, y_len, z_len]) - logger.debug(f"center, view_min, max are {center}, {view_min}, {view_max}") - - else: - cell = list(pop_id_vs_cell.values())[0] - if cell is not None: - view_min, view_max = get_cell_bound_box(cell) - else: - logger.debug("Got a point cell") - pos = list((list(positions.values())[0]).values())[0] - view_min = list(numpy.array(pos)) - view_min = list(numpy.array(pos)) - - current_scene, current_view = create_new_vispy_canvas( - view_min, view_max, title, theme=theme - ) - - logger.debug(f"figure extents are: {view_min}, {view_max}") - - for pop_id in pop_id_vs_cell: - cell = pop_id_vs_cell[pop_id] - pos_pop = positions[pop_id] - - for cell_index in pos_pop: - pos = pos_pop[cell_index] - radius = pop_id_vs_radii[pop_id] if pop_id in pop_id_vs_radii else 10 - color = pop_id_vs_color[pop_id] if pop_id in pop_id_vs_color else None - - try: - logging.info(f"Plotting {cell.id}") - except AttributeError: - logging.info(f"Plotting a point cell at {pos}") - - if cell is None: - marker_points.extend([pos]) - marker_sizes.extend([radius]) - marker_colors.extend([color]) - else: - if plot_type == "schematic": - plot_3D_schematic( - offset=pos, - cell=cell, - segment_groups=None, - labels=True, - verbose=verbose, - current_scene=current_scene, - current_view=current_view, - nogui=True, - ) - else: - pts, sizes, colors = plot_3D_cell_morphology( - offset=pos, - cell=cell, - color=color, - plot_type=plot_type, - verbose=verbose, - current_scene=current_scene, - current_view=current_view, - min_width=min_width, - nogui=True, - ) - marker_points.extend(pts) - marker_sizes.extend(sizes) - marker_colors.extend(colors) - - if len(marker_points) > 0: - scene.Markers( - pos=numpy.array(marker_points), - size=numpy.array(marker_sizes), - spherical=True, - face_color=marker_colors, - edge_color=marker_colors, - edge_width=0, - parent=current_view.scene, - scaling=True, - antialias=0, - ) - if not nogui: - app.run() - - def plot_2D( - nml_file: str, + nml_file: typing.Union[str, NeuroMLDocument, Cell], plane2d: str = "xy", - min_width: float = DEFAULTS["minWidth"], + min_width: float = DEFAULTS["minWidth"], # noqa verbose: bool = False, nogui: bool = False, save_to_file: typing.Optional[str] = None, @@ -386,8 +182,9 @@ def plot_2D( This method uses matplotlib. - :param nml_file: path to NeuroML cell file - :type nml_file: str + :param nml_file: path to NeuroML cell file, or a NeuroMLDocument object + :type nml_file: str or :py:class:`neuroml.NeuroMLDocument` or + :py:class:`neuroml.Cell` :param plane2d: what plane to plot (xy/yx/yz/zy/zx/xz) :type plane2d: str :param min_width: minimum width for segments (useful for visualising very @@ -427,13 +224,25 @@ def plot_2D( if verbose: print("Plotting %s" % nml_file) - nml_model = read_neuroml2_file( - nml_file, - include_includes=True, - check_validity_pre_include=False, - verbose=False, - optimized=True, - ) + if type(nml_file) == str: + nml_model = read_neuroml2_file( + nml_file, + include_includes=True, + check_validity_pre_include=False, + verbose=False, + optimized=True, + ) + + elif isinstance(nml_file, Cell): + nml_model = NeuroMLDocument(id="newdoc") + nml_model.add(nml_file) + + elif isinstance(nml_file, NeuroMLDocument): + nml_model = nml_file + else: + raise TypeError( + "Passed model is not a NeuroML file path, nor a neuroml.Cell, nor a neuroml.NeuroMLDocument" + ) ( cell_id_vs_cell, @@ -444,7 +253,10 @@ def plot_2D( ) = extract_position_info(nml_model, verbose) if title is None: - title = "2D plot of %s from %s" % (nml_model.networks[0].id, nml_file) + if len(nml_model.networks) > 0: + title = "2D plot of %s from %s" % (nml_model.networks[0].id, nml_file) + else: + title = "2D plot of %s" % (nml_model.cells[0].id) if verbose: logger.debug(f"positions: {positions}") @@ -527,118 +339,6 @@ def plot_2D( plt.close() -def plot_3D_cell_morphology_plotly( - nml_file: str, - min_width: float = 0.8, - verbose: bool = False, - nogui: bool = False, - save_to_file: typing.Optional[str] = None, - plot_type: str = "detailed", -): - """Plot NeuroML2 cell morphology interactively using Plot.ly - - Please note that the interactive plot uses Plotly, which uses WebGL. So, - you need to use a WebGL enabled browser, and performance here may be - hardware dependent. - - https://plotly.com/python/webgl-vs-svg/ - https://en.wikipedia.org/wiki/WebGL - - :param nml_file: path to NeuroML cell file - :type nml_file: str - :param min_width: minimum width for segments (useful for visualising very - thin segments): default 0.8um - :type min_width: float - :param verbose: show extra information (default: False) - :type verbose: bool - :param nogui: do not show matplotlib GUI (default: false) - :type nogui: bool - :param save_to_file: optional filename to save generated morphology to - :type save_to_file: str - :param plot_type: type of plot, one of: - - - detailed: show detailed morphology taking into account each segment's - width - - constant: show morphology, but use constant line widths - - :type plot_type: str - """ - if plot_type not in ["detailed", "constant"]: - raise ValueError( - "plot_type must be one of 'detailed', 'constant', or 'schematic'" - ) - - nml_model = read_neuroml2_file(nml_file) - - fig = go.Figure() - for cell in nml_model.cells: - title = f"3D plot of {cell.id} from {nml_file}" - - for seg in cell.morphology.segments: - p = cell.get_actual_proximal(seg.id) - d = seg.distal - if verbose: - print( - f"\nSegment {seg.name}, id: {seg.id} has proximal point: {p}, distal: {d}" - ) - width = max(p.diameter, d.diameter) - if width < min_width: - width = min_width - if plot_type == "constant": - width = min_width - fig.add_trace( - go.Scatter3d( - x=[p.x, d.x], - y=[p.y, d.y], - z=[p.z, d.z], - name=None, - marker={"size": 2, "color": "blue"}, - line={"width": width, "color": "blue"}, - mode="lines", - showlegend=False, - hoverinfo="skip", - ) - ) - - fig.update_layout( - title={"text": title}, - hovermode=False, - plot_bgcolor="white", - scene=dict( - xaxis=dict( - backgroundcolor="white", - showbackground=False, - showgrid=False, - showspikes=False, - title=dict(text="extent (um)"), - ), - yaxis=dict( - backgroundcolor="white", - showbackground=False, - showgrid=False, - showspikes=False, - title=dict(text="extent (um)"), - ), - zaxis=dict( - backgroundcolor="white", - showbackground=False, - showgrid=False, - showspikes=False, - title=dict(text="extent (um)"), - ), - ), - ) - if not nogui: - fig.show() - if save_to_file: - logger.info( - "Saving image to %s of plot: %s" - % (os.path.abspath(save_to_file), title) - ) - fig.write_image(save_to_file, scale=2, width=1024, height=768) - logger.info("Saved image to %s of plot: %s" % (save_to_file, title)) - - def plot_2D_cell_morphology( offset: typing.List[float] = [0, 0], cell: Cell = None, @@ -657,7 +357,7 @@ def plot_2D_cell_morphology( plot_type: str = "detailed", save_to_file: typing.Optional[str] = None, close_plot: bool = False, - overlay_data: typing.Dict[int, float] = None, + overlay_data: typing.Optional[typing.Dict[int, float]] = None, overlay_data_label: typing.Optional[str] = None, datamin: typing.Optional[float] = None, datamax: typing.Optional[float] = None, @@ -899,224 +599,6 @@ def plot_2D_cell_morphology( plt.close() -def plot_3D_cell_morphology( - offset: typing.List[float] = [0, 0, 0], - cell: Cell = None, - color: typing.Optional[str] = None, - title: str = "", - verbose: bool = False, - current_scene: scene.SceneCanvas = None, - current_view: scene.ViewBox = None, - min_width: float = DEFAULTS["minWidth"], - axis_min_max: typing.List = [float("inf"), -1 * float("inf")], - nogui: bool = True, - plot_type: str = "constant", - theme="light", -): - """Plot the detailed 3D morphology of a cell using vispy. - https://vispy.org/ - - .. versionadded:: 1.0.0 - - .. seealso:: - - :py:func:`plot_2D` - general function for plotting - - :py:func:`plot_2D_schematic` - for plotting only segmeng groups with their labels - - :py:func:`plot_2D_point_cells` - for plotting point cells - - :param offset: offset for cell - :type offset: [float, float] - :param cell: cell to plot - :type cell: neuroml.Cell - :param color: color to use for segments: - - - if None, each segment is given a new unique color - - if "Groups", each unbranched segment group is given a unique color, - and segments that do not belong to an unbranched segment group are in - white - - if "Default Groups", axonal segments are in red, dendritic in blue, - somatic in green, and others in white - - :type color: str - :param min_width: minimum width for segments (useful for visualising very - :type min_width: float - :param axis_min_max: min, max value of axes - :type axis_min_max: [float, float] - :param title: title of plot - :type title: str - :param verbose: show extra information (default: False) - :type verbose: bool - :param nogui: do not show image immediately - :type nogui: bool - :param current_scene: vispy SceneCanvas to use (a new one is created if it is not - provided) - :type current_scene: SceneCanvas - :param current_view: vispy viewbox to use - :type current_view: ViewBox - :param plot_type: type of plot, one of: - - - "detailed": show detailed morphology taking into account each segment's - width. This is not performant, because a new line is required for - each segment. To only be used for cells with small numbers of - segments - - "constant": show morphology, but use constant line widths - - This is only applicable for neuroml.Cell cells (ones with some - morphology) - - :type plot_type: str - :param theme: theme to use (dark/light) - :type theme: str - :raises: ValueError if `cell` is None - - """ - if cell is None: - raise ValueError( - "No cell provided. If you would like to plot a network of point neurons, consider using `plot_2D_point_cells` instead" - ) - - try: - soma_segs = cell.get_all_segments_in_group("soma_group") - except Exception: - soma_segs = [] - try: - dend_segs = cell.get_all_segments_in_group("dendrite_group") - except Exception: - dend_segs = [] - try: - axon_segs = cell.get_all_segments_in_group("axon_group") - except Exception: - axon_segs = [] - - if current_scene is None or current_view is None: - view_min, view_max = get_cell_bound_box(cell) - current_scene, current_view = create_new_vispy_canvas( - view_min, view_max, title, theme=theme - ) - - if color == "Groups": - color_dict = {} - # if no segment groups are given, do them all - segment_groups = [] - for sg in cell.morphology.segment_groups: - if sg.neuro_lex_id == neuro_lex_ids["section"]: - segment_groups.append(sg.id) - - ord_segs = cell.get_ordered_segments_in_groups( - segment_groups, check_parentage=False - ) - - for sgs, segs in ord_segs.items(): - c = get_next_hex_color() - for s in segs: - color_dict[s.id] = c - - # for lines/segments - points = [] - toconnect = [] - colors = [] - # for any spheres which we plot as markers at once - marker_points = [] - marker_colors = [] - marker_sizes = [] - - for seg in cell.morphology.segments: - p = cell.get_actual_proximal(seg.id) - d = seg.distal - width = (p.diameter + d.diameter) / 2 - - if width < min_width: - width = min_width - - if plot_type == "constant": - width = min_width - - seg_color = "white" - if color is None: - seg_color = get_next_hex_color() - elif color == "Groups": - try: - seg_color = color_dict[seg.id] - except KeyError: - print(f"Unbranched segment found: {seg.id}") - if seg.id in soma_segs: - seg_color = "green" - elif seg.id in axon_segs: - seg_color = "red" - elif seg.id in dend_segs: - seg_color = "blue" - elif color == "Default Groups": - if seg.id in soma_segs: - seg_color = "green" - elif seg.id in axon_segs: - seg_color = "red" - elif seg.id in dend_segs: - seg_color = "blue" - else: - seg_color = color - - # check if for a spherical segment, add extra spherical node - if p.x == d.x and p.y == d.y and p.z == d.z and p.diameter == d.diameter: - marker_points.append([offset[0] + p.x, offset[1] + p.y, offset[2] + p.z]) - marker_colors.append(seg_color) - marker_sizes.append(p.diameter) - - if plot_type == "constant": - points.append([offset[0] + p.x, offset[1] + p.y, offset[2] + p.z]) - colors.append(seg_color) - points.append([offset[0] + d.x, offset[1] + d.y, offset[2] + d.z]) - colors.append(seg_color) - toconnect.append([len(points) - 2, len(points) - 1]) - # every segment plotted individually - elif plot_type == "detailed": - points = [] - toconnect = [] - colors = [] - points.append([offset[0] + p.x, offset[1] + p.y, offset[2] + p.z]) - colors.append(seg_color) - points.append([offset[0] + d.x, offset[1] + d.y, offset[2] + d.z]) - colors.append(seg_color) - toconnect.append([len(points) - 2, len(points) - 1]) - scene.Line( - pos=points, - color=colors, - connect=numpy.array(toconnect), - parent=current_view.scene, - width=width, - ) - - if plot_type == "constant": - scene.Line( - pos=points, - color=colors, - connect=numpy.array(toconnect), - parent=current_view.scene, - width=width, - ) - - if not nogui: - # markers - if len(marker_points) > 0: - scene.Markers( - pos=numpy.array(marker_points), - size=numpy.array(marker_sizes), - spherical=True, - face_color=marker_colors, - edge_color=marker_colors, - edge_width=0, - parent=current_view.scene, - scaling=True, - antialias=0, - ) - app.run() - return marker_points, marker_sizes, marker_colors - - def plot_2D_point_cells( offset: typing.List[float] = [0, 0], plane2d: str = "xy", @@ -1516,159 +998,6 @@ def plot_2D_schematic( plt.close() -def plot_3D_schematic( - cell: Cell, - segment_groups: typing.Optional[typing.List[SegmentGroup]], - offset: typing.List[float] = [0, 0, 0], - labels: bool = False, - width: float = 5.0, - verbose: bool = False, - nogui: bool = False, - title: str = "", - current_scene: scene.SceneCanvas = None, - current_view: scene.ViewBox = None, - theme: str = "light", -) -> None: - """Plot a 3D schematic of the provided segment groups in Napari as a new - layer.. - - This plots each segment group as a straight line between its first and last - segment. - - .. versionadded:: 1.0.0 - - .. seealso:: - - :py:func:`plot_2D_schematic` - general function for plotting - - :py:func:`plot_2D` - general function for plotting - - :py:func:`plot_2D_point_cells` - for plotting point cells - - :py:func:`plot_2D_cell_morphology` - for plotting cells with detailed morphologies - - :param offset: offset for cell - :type offset: [float, float, float] - :param cell: cell to plot - :type cell: neuroml.Cell - :param segment_groups: list of unbranched segment groups to plot - :type segment_groups: list(SegmentGroup) - :param labels: toggle labelling of segment groups - :type labels: bool - :param width: width for lines for segment groups - :type width: float - :param verbose: show extra information (default: False) - :type verbose: bool - :param title: title of plot - :type title: str - :param nogui: toggle if plot should be shown or not - :type nogui: bool - :param current_scene: vispy SceneCanvas to use (a new one is created if it is not - provided) - :type current_scene: SceneCanvas - :param current_view: vispy viewbox to use - :type current_view: ViewBox - :param theme: theme to use (light/dark) - :type theme: str - """ - if title == "": - title = f"3D schematic of segment groups from {cell.id}" - - # if no segment groups are given, do them all - if segment_groups is None: - segment_groups = [] - for sg in cell.morphology.segment_groups: - if sg.neuro_lex_id == neuro_lex_ids["section"]: - segment_groups.append(sg.id) - - ord_segs = cell.get_ordered_segments_in_groups( - segment_groups, check_parentage=False - ) - - # if no canvas is defined, define a new one - if current_scene is None or current_view is None: - view_min, view_max = get_cell_bound_box(cell) - current_scene, current_view = create_new_vispy_canvas( - view_min, view_max, title, theme=theme - ) - - points = [] - toconnect = [] - colors = [] - text = [] - textpoints = [] - - for sgid, segs in ord_segs.items(): - sgobj = cell.get_segment_group(sgid) - if sgobj.neuro_lex_id != neuro_lex_ids["section"]: - raise ValueError( - f"{sgobj} does not have neuro_lex_id set to indicate it is an unbranched segment" - ) - - # get proximal and distal points - first_seg = segs[0] # type: Segment - last_seg = segs[-1] # type: Segment - first_prox = cell.get_actual_proximal(first_seg.id) - - points.append( - [ - offset[0] + first_prox.x, - offset[1] + first_prox.y, - offset[2] + first_prox.z, - ] - ) - points.append( - [ - offset[0] + last_seg.distal.x, - offset[1] + last_seg.distal.y, - offset[2] + last_seg.distal.z, - ] - ) - colors.append(get_next_hex_color()) - colors.append(get_next_hex_color()) - toconnect.append([len(points) - 2, len(points) - 1]) - - # TODO: needs fixing to show labels - if labels: - text.append(f"{sgid}") - textpoints.append( - [ - offset[0] + (first_prox.x + last_seg.distal.x) / 2, - offset[1] + (first_prox.y + last_seg.distal.y) / 2, - offset[2] + (first_prox.z + last_seg.distal.z) / 2, - ] - ) - """ - - alabel = add_text_to_vispy_3D_plot(current_scene=current_view.scene, text=f"{sgid}", - xv=[offset[0] + first_seg.proximal.x, offset[0] + last_seg.distal.x], - yv=[offset[0] + first_seg.proximal.y, offset[0] + last_seg.distal.y], - zv=[offset[1] + first_seg.proximal.z, offset[1] + last_seg.distal.z], - color=colors[-1]) - alabel.font_size = 30 - """ - - scene.Line( - points, - parent=current_view.scene, - color=colors, - width=width, - connect=numpy.array(toconnect), - ) - if labels: - print("Text rendering") - scene.Text( - text, pos=textpoints, font_size=30, color="black", parent=current_view.scene - ) - - if not nogui: - app.run() - - def plot_segment_groups_curtain_plots( cell: Cell, segment_groups: typing.List[SegmentGroup], @@ -1676,7 +1005,7 @@ def plot_segment_groups_curtain_plots( verbose: bool = False, nogui: bool = False, save_to_file: typing.Optional[str] = None, - overlay_data: typing.Dict[str, typing.List[typing.Any]] = None, + overlay_data: typing.Optional[typing.Dict[str, typing.List[typing.Any]]] = None, overlay_data_label: str = "", width: typing.Union[float, int] = 4, colormap_name: str = "viridis", diff --git a/pyneuroml/plot/PlotMorphologyPlotly.py b/pyneuroml/plot/PlotMorphologyPlotly.py new file mode 100644 index 00000000..85fbe8d9 --- /dev/null +++ b/pyneuroml/plot/PlotMorphologyPlotly.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +""" +Methods to plot morphology using Plot.ly + +Note: the vispy methods are more performant. + +File: pyneuroml/plot/PlotMorphologyPlotly.py + +Copyright 2023 NeuroML contributors +""" + +import os +import typing +import logging +import plotly.graph_objects as go +from neuroml import Cell, NeuroMLDocument +from pyneuroml.pynml import read_neuroml2_file + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def plot_3D_cell_morphology_plotly( + nml_file: typing.Union[str, Cell, NeuroMLDocument], + min_width: float = 0.8, + verbose: bool = False, + nogui: bool = False, + save_to_file: typing.Optional[str] = None, + plot_type: str = "detailed", +): + """Plot NeuroML2 cell morphology interactively using Plot.ly + + Please note that the interactive plot uses Plotly, which uses WebGL. So, + you need to use a WebGL enabled browser, and performance here may be + hardware dependent. + + https://plotly.com/python/webgl-vs-svg/ + https://en.wikipedia.org/wiki/WebGL + + :param nml_file: path to NeuroML cell file, or a + :py:class:`neuroml.NeuroMLDocument` or :py:class:`neuroml.Cell` object + :type nml_file: str or neuroml.NeuroMLDocument or neuroml.Cell + :param min_width: minimum width for segments (useful for visualising very + thin segments): default 0.8um + :type min_width: float + :param verbose: show extra information (default: False) + :type verbose: bool + :param nogui: do not show matplotlib GUI (default: false) + :type nogui: bool + :param save_to_file: optional filename to save generated morphology to + :type save_to_file: str + :param plot_type: type of plot, one of: + + - detailed: show detailed morphology taking into account each segment's + width + - constant: show morphology, but use constant line widths + + :type plot_type: str + """ + if plot_type not in ["detailed", "constant"]: + raise ValueError( + "plot_type must be one of 'detailed', 'constant', or 'schematic'" + ) + + if type(nml_file) == str: + nml_model = read_neuroml2_file( + nml_file, + include_includes=True, + check_validity_pre_include=False, + verbose=False, + optimized=True, + ) + elif isinstance(nml_file, Cell): + nml_model = NeuroMLDocument(id="newdoc") + nml_model.add(nml_file) + elif isinstance(nml_file, NeuroMLDocument): + nml_model = nml_file + else: + raise TypeError("Passed model is not a NeuroML file path, nor a neuroml.Cell, nor a neuroml.NeuroMLDocument") + + fig = go.Figure() + for cell in nml_model.cells: + title = f"3D plot of {cell.id} from {nml_file}" + + for seg in cell.morphology.segments: + p = cell.get_actual_proximal(seg.id) + d = seg.distal + if verbose: + print( + f"\nSegment {seg.name}, id: {seg.id} has proximal point: {p}, distal: {d}" + ) + width = max(p.diameter, d.diameter) + if width < min_width: + width = min_width + if plot_type == "constant": + width = min_width + fig.add_trace( + go.Scatter3d( + x=[p.x, d.x], + y=[p.y, d.y], + z=[p.z, d.z], + name=None, + marker={"size": 2, "color": "blue"}, + line={"width": width, "color": "blue"}, + mode="lines", + showlegend=False, + hoverinfo="skip", + ) + ) + + fig.update_layout( + title={"text": title}, + hovermode=False, + plot_bgcolor="white", + scene=dict( + xaxis=dict( + backgroundcolor="white", + showbackground=False, + showgrid=False, + showspikes=False, + title=dict(text="extent (um)"), + ), + yaxis=dict( + backgroundcolor="white", + showbackground=False, + showgrid=False, + showspikes=False, + title=dict(text="extent (um)"), + ), + zaxis=dict( + backgroundcolor="white", + showbackground=False, + showgrid=False, + showspikes=False, + title=dict(text="extent (um)"), + ), + ), + ) + if not nogui: + fig.show() + if save_to_file: + logger.info( + "Saving image to %s of plot: %s" + % (os.path.abspath(save_to_file), title) + ) + fig.write_image(save_to_file, scale=2, width=1024, height=768) + logger.info("Saved image to %s of plot: %s" % (save_to_file, title)) diff --git a/pyneuroml/plot/PlotMorphologyVispy.py b/pyneuroml/plot/PlotMorphologyVispy.py new file mode 100644 index 00000000..119415cc --- /dev/null +++ b/pyneuroml/plot/PlotMorphologyVispy.py @@ -0,0 +1,855 @@ +#!/usr/bin/env python3 +""" +Vispy interactive plotting. + +Kept in a separate file so that the vispy dependency is not required elsewhere. + +File: pyneuroml/plot/PlotMorphologyVispy.py + +Copyright 2023 NeuroML contributors +""" + + +import logging +import typing +import numpy +import textwrap +from vispy import scene, app + +from pyneuroml.utils.plot import ( + DEFAULTS, + get_cell_bound_box, + get_next_hex_color, +) +from pyneuroml.pynml import read_neuroml2_file +from pyneuroml.utils import extract_position_info + +from neuroml import Cell, NeuroMLDocument, SegmentGroup, Segment +from neuroml.neuro_lex_ids import neuro_lex_ids + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +VISPY_THEME = { + "light": {"bg": "white", "fg": "black"}, + "dark": {"bg": "black", "fg": "white"}, +} +PYNEUROML_VISPY_THEME = "light" + + +def add_text_to_vispy_3D_plot( + current_scene: scene.SceneCanvas, + xv: typing.List[float], + yv: typing.List[float], + zv: typing.List[float], + color: str, + text: str, +): + """Add text to a vispy plot between two points. + + Wrapper around vispy.scene.visuals.Text + + Rotates the text label to ensure it is at the same angle as the line. + + :param scene: vispy scene object + :type scene: scene.SceneCanvas + :param xv: start and end coordinates in one axis + :type xv: list[x1, x2] + :param yv: start and end coordinates in second axis + :type yv: list[y1, y2] + :param zv: start and end coordinates in third axix + :type zv: list[z1, z2] + :param color: color of text + :type color: str + :param text: text to write + :type text: str + """ + + angle = int(numpy.rad2deg(numpy.arctan2((yv[1] - yv[0]), (xv[1] - xv[0])))) + if angle > 90: + angle -= 180 + elif angle < -90: + angle += 180 + + return scene.Text( + pos=((xv[0] + xv[1]) / 2, (yv[0] + yv[1]) / 2, (zv[0] + zv[1]) / 2), + text=text, + color=color, + rotation=angle, + parent=current_scene, + ) + + +def create_new_vispy_canvas( + view_min: typing.Optional[typing.List[float]] = None, + view_max: typing.Optional[typing.List[float]] = None, + title: str = "", + console_font_size: float = 10, + axes_pos: typing.Optional[typing.List] = None, + axes_length: float = 100, + axes_width: int = 2, + theme=PYNEUROML_VISPY_THEME, +): + """Create a new vispy scene canvas with a view and optional axes lines + + Reference: https://vispy.org/gallery/scene/axes_plot.html + + :param view_min: min view co-ordinates + :type view_min: [float, float, float] + :param view_max: max view co-ordinates + :type view_max: [float, float, float] + :param title: title of plot + :type title: str + :param axes_pos: position to draw axes at + :type axes_pos: [float, float, float] + :param axes_length: length of axes + :type axes_length: float + :param axes_width: width of axes lines + :type axes_width: float + :returns: scene, view + """ + canvas = scene.SceneCanvas( + keys="interactive", + show=True, + bgcolor=VISPY_THEME[theme]["bg"], + size=(800, 600), + title="NeuroML viewer (VisPy)", + ) + grid = canvas.central_widget.add_grid(margin=10) + grid.spacing = 0 + + title_widget = scene.Label(title, color=VISPY_THEME[theme]["fg"]) + title_widget.height_max = 80 + grid.add_widget(title_widget, row=0, col=0, col_span=1) + + console_widget = scene.Console( + text_color=VISPY_THEME[theme]["fg"], + font_size=console_font_size, + ) + console_widget.height_max = 80 + grid.add_widget(console_widget, row=2, col=0, col_span=1) + + bottom_padding = grid.add_widget(row=3, col=0, col_span=1) + bottom_padding.height_max = 10 + + view = grid.add_view(row=1, col=0, border_color=None) + + # create cameras + # https://vispy.org/gallery/scene/flipped_axis.html + cam1 = scene.cameras.PanZoomCamera(parent=view.scene, name="PanZoom") + + cam2 = scene.cameras.TurntableCamera(parent=view.scene, name="Turntable") + + cam3 = scene.cameras.ArcballCamera(parent=view.scene, name="Arcball") + + cam4 = scene.cameras.FlyCamera(parent=view.scene, name="Fly") + # do not keep z up + cam4.autoroll = False + + cams = [cam4, cam2] + + # console text + console_text = "Controls: reset view: 0; quit: Esc/9" + if len(cams) > 1: + console_text += "; cycle camera: 1, 2 (fwd/bwd)" + + cam_text = { + cam1: textwrap.dedent( + """ + Left mouse button: pans view; right mouse button or scroll: + zooms""" + ), + cam2: textwrap.dedent( + """ + Left mouse button: orbits view around center point; right mouse + button or scroll: change zoom level; Shift + left mouse button: + translate center point; Shift + right mouse button: change field of + view; r/R: view rotation animation""" + ), + cam3: textwrap.dedent( + """ + Left mouse button: orbits view around center point; right + mouse button or scroll: change zoom level; Shift + left mouse + button: translate center point; Shift + right mouse button: change + field of view""" + ), + cam4: textwrap.dedent( + """ + Arrow keys/WASD to move forward/backwards/left/right; F/C to move + up and down; Space to brake; Left mouse button/I/K/J/L to control + pitch and yaw; Q/E for rolling""" + ), + } + + # Turntable is default + cam_index = 1 + view.camera = cams[cam_index] + + if view_min is not None and view_max is not None: + view_center = (numpy.array(view_max) + numpy.array(view_min)) / 2 + logger.debug(f"Center is {view_center}") + cam1.center = [view_center[0], view_center[1]] + cam2.center = view_center + cam3.center = view_center + cam4.center = view_center + + for acam in cams: + x_width = abs(view_min[0] - view_max[0]) + y_width = abs(view_min[1] - view_max[1]) + z_width = abs(view_min[2] - view_max[2]) + + xrange = ( + (view_min[0] - x_width * 0.02, view_max[0] + x_width * 0.02) + if x_width > 0 + else (-100, 100) + ) + yrange = ( + (view_min[1] - y_width * 0.02, view_max[1] + y_width * 0.02) + if y_width > 0 + else (-100, 100) + ) + zrange = ( + (view_min[2] - z_width * 0.02, view_max[2] + z_width * 0.02) + if z_width > 0 + else (-100, 100) + ) + logger.debug(f"{xrange}, {yrange}, {zrange}") + + acam.set_range(x=xrange, y=yrange, z=zrange) + + for acam in cams: + acam.set_default_state() + + console_widget.write(f"Center: {view.camera.center}") + console_widget.write(console_text) + console_widget.write( + f"Current camera: {view.camera.name}: " + cam_text[view.camera].replace("\n", " ").strip() + ) + + if axes_pos: + points = [ + axes_pos, # origin + [axes_pos[0] + axes_length, axes_pos[1], axes_pos[2]], + [axes_pos[0], axes_pos[1] + axes_length, axes_pos[2]], + [axes_pos[0], axes_pos[1], axes_pos[2] + axes_length], + ] + scene.Line( + points, + connect=numpy.array([[0, 1], [0, 2], [0, 3]]), + parent=view.scene, + color=VISPY_THEME[theme]["fg"], + width=axes_width, + ) + + def vispy_rotate(self): + view.camera.orbit(azim=1, elev=0) + + rotation_timer = app.Timer(connect=vispy_rotate) + + @canvas.events.key_press.connect + def vispy_on_key_press(event): + nonlocal cam_index + + # Disable camera cycling. The fly camera looks sufficient. + # Keeping views/ranges same when switching cameras is not simple. + # Prev + if event.text == "1": + cam_index = (cam_index - 1) % len(cams) + view.camera = cams[cam_index] + # next + elif event.text == "2": + cam_index = (cam_index + 1) % len(cams) + view.camera = cams[cam_index] + # for turntable only: rotate animation + elif event.text == "R" or event.text == "r": + if view.camera == cam2: + if rotation_timer.running: + rotation_timer.stop() + else: + rotation_timer.start() + # reset + elif event.text == "0": + view.camera.reset() + # quit + elif event.text == "9": + canvas.app.quit() + + console_widget.clear() + # console_widget.write(f"Center: {view.camera.center}") + console_widget.write(console_text) + console_widget.write( + f"Current camera: {view.camera.name}: " + cam_text[view.camera].replace("\n", " ").strip() + ) + + return scene, view + + +def plot_interactive_3D( + nml_file: typing.Union[str, Cell, NeuroMLDocument], + min_width: float = DEFAULTS["minWidth"], + verbose: bool = False, + plot_type: str = "constant", + title: typing.Optional[str] = None, + theme: str = "light", + nogui: bool = False, +): + """Plot interactive plots in 3D using Vispy + + https://vispy.org + + :param nml_file: path to NeuroML cell file or + :py:class:`neuroml.NeuroMLDocument` or :py:class:`neuroml.Cell` object + :type nml_file: str or neuroml.NeuroMLDocument or neuroml.Cell + :param min_width: minimum width for segments (useful for visualising very + thin segments): default 0.8um + :type min_width: float + :param verbose: show extra information (default: False) + :type verbose: bool + :param plot_type: type of plot, one of: + + - "detailed": show detailed morphology taking into account each segment's + width + - "constant": show morphology, but use constant line widths + - "schematic": only plot each unbranched segment group as a straight + line, not following each segment + + This is only applicable for neuroml.Cell cells (ones with some + morphology) + + :type plot_type: str + :param title: title of plot + :type title: str + :param theme: theme to use (light/dark) + :type theme: str + :param nogui: toggle showing gui (for testing only) + :type nogui: bool + """ + if plot_type not in ["detailed", "constant", "schematic"]: + raise ValueError( + "plot_type must be one of 'detailed', 'constant', or 'schematic'" + ) + + if verbose: + print(f"Plotting {nml_file}") + + if type(nml_file) == str: + nml_model = read_neuroml2_file( + nml_file, + include_includes=True, + check_validity_pre_include=False, + verbose=False, + optimized=True, + ) + elif isinstance(nml_file, Cell): + nml_model = NeuroMLDocument(id="newdoc") + nml_model.add(nml_file) + elif isinstance(nml_file, NeuroMLDocument): + nml_model = nml_file + else: + raise TypeError("Passed model is not a NeuroML file path, nor a neuroml.Cell, nor a neuroml.NeuroMLDocument") + + ( + cell_id_vs_cell, + pop_id_vs_cell, + positions, + pop_id_vs_color, + pop_id_vs_radii, + ) = extract_position_info(nml_model, verbose) + + # Collect all markers and only plot one markers object + # this is more efficient than multiple markers, one for each point. + # TODO: also collect all line points and only use one object rather than a + # new object for each cell: will only work for the case where all lines + # have the same width + marker_sizes = [] + marker_points = [] + marker_colors = [] + + if title is None: + title = f"{nml_model.networks[0].id} from {nml_file}" + + logger.debug(f"positions: {positions}") + logger.debug(f"pop_id_vs_cell: {pop_id_vs_cell}") + logger.debug(f"cell_id_vs_cell: {cell_id_vs_cell}") + logger.debug(f"pop_id_vs_color: {pop_id_vs_color}") + logger.debug(f"pop_id_vs_radii: {pop_id_vs_radii}") + + if len(positions.keys()) > 1: + only_pos = [] + for posdict in positions.values(): + for poss in posdict.values(): + only_pos.append(poss) + + pos_array = numpy.array(only_pos) + center = numpy.array( + [ + numpy.mean(pos_array[:, 0]), + numpy.mean(pos_array[:, 1]), + numpy.mean(pos_array[:, 2]), + ] + ) + x_min = numpy.min(pos_array[:, 0]) + x_max = numpy.max(pos_array[:, 0]) + x_len = abs(x_max - x_min) + + y_min = numpy.min(pos_array[:, 1]) + y_max = numpy.max(pos_array[:, 1]) + y_len = abs(y_max - y_min) + + z_min = numpy.min(pos_array[:, 2]) + z_max = numpy.max(pos_array[:, 2]) + z_len = abs(z_max - z_min) + + view_min = center - numpy.array([x_len, y_len, z_len]) + view_max = center + numpy.array([x_len, y_len, z_len]) + logger.debug(f"center, view_min, max are {center}, {view_min}, {view_max}") + + else: + cell = list(pop_id_vs_cell.values())[0] + if cell is not None: + view_min, view_max = get_cell_bound_box(cell) + else: + logger.debug("Got a point cell") + pos = list((list(positions.values())[0]).values())[0] + view_min = list(numpy.array(pos)) + view_min = list(numpy.array(pos)) + + current_scene, current_view = create_new_vispy_canvas( + view_min, view_max, title, theme=theme + ) + + logger.debug(f"figure extents are: {view_min}, {view_max}") + + for pop_id in pop_id_vs_cell: + cell = pop_id_vs_cell[pop_id] + pos_pop = positions[pop_id] + + for cell_index in pos_pop: + pos = pos_pop[cell_index] + radius = pop_id_vs_radii[pop_id] if pop_id in pop_id_vs_radii else 10 + color = pop_id_vs_color[pop_id] if pop_id in pop_id_vs_color else None + + try: + logging.info(f"Plotting {cell.id}") + except AttributeError: + logging.info(f"Plotting a point cell at {pos}") + + if cell is None: + marker_points.extend([pos]) + marker_sizes.extend([radius]) + marker_colors.extend([color]) + else: + if plot_type == "schematic": + plot_3D_schematic( + offset=pos, + cell=cell, + segment_groups=None, + labels=True, + verbose=verbose, + current_scene=current_scene, + current_view=current_view, + nogui=True, + ) + else: + pts, sizes, colors = plot_3D_cell_morphology( + offset=pos, + cell=cell, + color=color, + plot_type=plot_type, + verbose=verbose, + current_scene=current_scene, + current_view=current_view, + min_width=min_width, + nogui=True, + ) + marker_points.extend(pts) + marker_sizes.extend(sizes) + marker_colors.extend(colors) + + if len(marker_points) > 0: + scene.Markers( + pos=numpy.array(marker_points), + size=numpy.array(marker_sizes), + spherical=True, + face_color=marker_colors, + edge_color=marker_colors, + edge_width=0, + parent=current_view.scene, + scaling=True, + antialias=0, + ) + if not nogui: + app.run() + + +def plot_3D_cell_morphology( + offset: typing.List[float] = [0, 0, 0], + cell: Cell = None, + color: typing.Optional[str] = None, + title: str = "", + verbose: bool = False, + current_scene: scene.SceneCanvas = None, + current_view: scene.ViewBox = None, + min_width: float = DEFAULTS["minWidth"], + axis_min_max: typing.List = [float("inf"), -1 * float("inf")], + nogui: bool = True, + plot_type: str = "constant", + theme="light", +): + """Plot the detailed 3D morphology of a cell using vispy. + https://vispy.org/ + + .. versionadded:: 1.0.0 + + .. seealso:: + + :py:func:`plot_2D` + general function for plotting + + :py:func:`plot_2D_schematic` + for plotting only segmeng groups with their labels + + :py:func:`plot_2D_point_cells` + for plotting point cells + + :param offset: offset for cell + :type offset: [float, float] + :param cell: cell to plot + :type cell: neuroml.Cell + :param color: color to use for segments: + + - if None, each segment is given a new unique color + - if "Groups", each unbranched segment group is given a unique color, + and segments that do not belong to an unbranched segment group are in + white + - if "Default Groups", axonal segments are in red, dendritic in blue, + somatic in green, and others in white + + :type color: str + :param min_width: minimum width for segments (useful for visualising very + :type min_width: float + :param axis_min_max: min, max value of axes + :type axis_min_max: [float, float] + :param title: title of plot + :type title: str + :param verbose: show extra information (default: False) + :type verbose: bool + :param nogui: do not show image immediately + :type nogui: bool + :param current_scene: vispy scene.SceneCanvas to use (a new one is created if it is not + provided) + :type current_scene: scene.SceneCanvas + :param current_view: vispy viewbox to use + :type current_view: ViewBox + :param plot_type: type of plot, one of: + + - "detailed": show detailed morphology taking into account each segment's + width. This is not performant, because a new line is required for + each segment. To only be used for cells with small numbers of + segments + - "constant": show morphology, but use constant line widths + + This is only applicable for neuroml.Cell cells (ones with some + morphology) + + :type plot_type: str + :param theme: theme to use (dark/light) + :type theme: str + :raises: ValueError if `cell` is None + + """ + if cell is None: + raise ValueError( + "No cell provided. If you would like to plot a network of point neurons, consider using `plot_2D_point_cells` instead" + ) + + try: + soma_segs = cell.get_all_segments_in_group("soma_group") + except Exception: + soma_segs = [] + try: + dend_segs = cell.get_all_segments_in_group("dendrite_group") + except Exception: + dend_segs = [] + try: + axon_segs = cell.get_all_segments_in_group("axon_group") + except Exception: + axon_segs = [] + + if current_scene is None or current_view is None: + view_min, view_max = get_cell_bound_box(cell) + current_scene, current_view = create_new_vispy_canvas( + view_min, view_max, title, theme=theme + ) + + if color == "Groups": + color_dict = {} + # if no segment groups are given, do them all + segment_groups = [] + for sg in cell.morphology.segment_groups: + if sg.neuro_lex_id == neuro_lex_ids["section"]: + segment_groups.append(sg.id) + + ord_segs = cell.get_ordered_segments_in_groups( + segment_groups, check_parentage=False + ) + + for sgs, segs in ord_segs.items(): + c = get_next_hex_color() + for s in segs: + color_dict[s.id] = c + + # for lines/segments + points = [] + toconnect = [] + colors = [] + # for any spheres which we plot as markers at once + marker_points = [] + marker_colors = [] + marker_sizes = [] + + for seg in cell.morphology.segments: + p = cell.get_actual_proximal(seg.id) + d = seg.distal + width = (p.diameter + d.diameter) / 2 + + if width < min_width: + width = min_width + + if plot_type == "constant": + width = min_width + + seg_color = "white" + if color is None: + seg_color = get_next_hex_color() + elif color == "Groups": + try: + seg_color = color_dict[seg.id] + except KeyError: + print(f"Unbranched segment found: {seg.id}") + if seg.id in soma_segs: + seg_color = "green" + elif seg.id in axon_segs: + seg_color = "red" + elif seg.id in dend_segs: + seg_color = "blue" + elif color == "Default Groups": + if seg.id in soma_segs: + seg_color = "green" + elif seg.id in axon_segs: + seg_color = "red" + elif seg.id in dend_segs: + seg_color = "blue" + else: + seg_color = color + + # check if for a spherical segment, add extra spherical node + if p.x == d.x and p.y == d.y and p.z == d.z and p.diameter == d.diameter: + marker_points.append([offset[0] + p.x, offset[1] + p.y, offset[2] + p.z]) + marker_colors.append(seg_color) + marker_sizes.append(p.diameter) + + if plot_type == "constant": + points.append([offset[0] + p.x, offset[1] + p.y, offset[2] + p.z]) + colors.append(seg_color) + points.append([offset[0] + d.x, offset[1] + d.y, offset[2] + d.z]) + colors.append(seg_color) + toconnect.append([len(points) - 2, len(points) - 1]) + # every segment plotted individually + elif plot_type == "detailed": + points = [] + toconnect = [] + colors = [] + points.append([offset[0] + p.x, offset[1] + p.y, offset[2] + p.z]) + colors.append(seg_color) + points.append([offset[0] + d.x, offset[1] + d.y, offset[2] + d.z]) + colors.append(seg_color) + toconnect.append([len(points) - 2, len(points) - 1]) + scene.Line( + pos=points, + color=colors, + connect=numpy.array(toconnect), + parent=current_view.scene, + width=width, + ) + + if plot_type == "constant": + scene.Line( + pos=points, + color=colors, + connect=numpy.array(toconnect), + parent=current_view.scene, + width=width, + ) + + if not nogui: + # markers + if len(marker_points) > 0: + scene.Markers( + pos=numpy.array(marker_points), + size=numpy.array(marker_sizes), + spherical=True, + face_color=marker_colors, + edge_color=marker_colors, + edge_width=0, + parent=current_view.scene, + scaling=True, + antialias=0, + ) + app.run() + return marker_points, marker_sizes, marker_colors + + +def plot_3D_schematic( + cell: Cell, + segment_groups: typing.Optional[typing.List[SegmentGroup]], + offset: typing.List[float] = [0, 0, 0], + labels: bool = False, + width: float = 5.0, + verbose: bool = False, + nogui: bool = False, + title: str = "", + current_scene: scene.SceneCanvas = None, + current_view: scene.ViewBox = None, + theme: str = "light", +) -> None: + """Plot a 3D schematic of the provided segment groups using vispy. + layer.. + + This plots each segment group as a straight line between its first and last + segment. + + .. versionadded:: 1.0.0 + + .. seealso:: + + :py:func:`plot_2D_schematic` + general function for plotting + + :py:func:`plot_2D` + general function for plotting + + :py:func:`plot_2D_point_cells` + for plotting point cells + + :py:func:`plot_2D_cell_morphology` + for plotting cells with detailed morphologies + + :param offset: offset for cell + :type offset: [float, float, float] + :param cell: cell to plot + :type cell: neuroml.Cell + :param segment_groups: list of unbranched segment groups to plot + :type segment_groups: list(SegmentGroup) + :param labels: toggle labelling of segment groups + :type labels: bool + :param width: width for lines for segment groups + :type width: float + :param verbose: show extra information (default: False) + :type verbose: bool + :param title: title of plot + :type title: str + :param nogui: toggle if plot should be shown or not + :type nogui: bool + :param current_scene: vispy scene.SceneCanvas to use (a new one is created if it is not + provided) + :type current_scene: scene.SceneCanvas + :param current_view: vispy viewbox to use + :type current_view: ViewBox + :param theme: theme to use (light/dark) + :type theme: str + """ + if title == "": + title = f"3D schematic of segment groups from {cell.id}" + + # if no segment groups are given, do them all + if segment_groups is None: + segment_groups = [] + for sg in cell.morphology.segment_groups: + if sg.neuro_lex_id == neuro_lex_ids["section"]: + segment_groups.append(sg.id) + + ord_segs = cell.get_ordered_segments_in_groups( + segment_groups, check_parentage=False + ) + + # if no canvas is defined, define a new one + if current_scene is None or current_view is None: + view_min, view_max = get_cell_bound_box(cell) + current_scene, current_view = create_new_vispy_canvas( + view_min, view_max, title, theme=theme + ) + + points = [] + toconnect = [] + colors = [] + text = [] + textpoints = [] + + for sgid, segs in ord_segs.items(): + sgobj = cell.get_segment_group(sgid) + if sgobj.neuro_lex_id != neuro_lex_ids["section"]: + raise ValueError( + f"{sgobj} does not have neuro_lex_id set to indicate it is an unbranched segment" + ) + + # get proximal and distal points + first_seg = segs[0] # type: Segment + last_seg = segs[-1] # type: Segment + first_prox = cell.get_actual_proximal(first_seg.id) + + points.append( + [ + offset[0] + first_prox.x, + offset[1] + first_prox.y, + offset[2] + first_prox.z, + ] + ) + points.append( + [ + offset[0] + last_seg.distal.x, + offset[1] + last_seg.distal.y, + offset[2] + last_seg.distal.z, + ] + ) + colors.append(get_next_hex_color()) + colors.append(get_next_hex_color()) + toconnect.append([len(points) - 2, len(points) - 1]) + + # TODO: needs fixing to show labels + if labels: + text.append(f"{sgid}") + textpoints.append( + [ + offset[0] + (first_prox.x + last_seg.distal.x) / 2, + offset[1] + (first_prox.y + last_seg.distal.y) / 2, + offset[2] + (first_prox.z + last_seg.distal.z) / 2, + ] + ) + """ + + alabel = add_text_to_vispy_3D_plot(current_scene=current_view.scene, text=f"{sgid}", + xv=[offset[0] + first_seg.proximal.x, offset[0] + last_seg.distal.x], + yv=[offset[0] + first_seg.proximal.y, offset[0] + last_seg.distal.y], + zv=[offset[1] + first_seg.proximal.z, offset[1] + last_seg.distal.z], + color=colors[-1]) + alabel.font_size = 30 + """ + + scene.Line( + points, + parent=current_view.scene, + color=colors, + width=width, + connect=numpy.array(toconnect), + ) + if labels: + print("Text rendering") + scene.Text( + text, pos=textpoints, font_size=30, color="black", parent=current_view.scene + ) + + if not nogui: + app.run() diff --git a/pyneuroml/utils/__init__.py b/pyneuroml/utils/__init__.py index 044c3781..78790342 100644 --- a/pyneuroml/utils/__init__.py +++ b/pyneuroml/utils/__init__.py @@ -5,14 +5,17 @@ Copyright 2023 NeuroML Contributors """ - -import typing +import math +import copy import logging import re -import neuroml +import numpy +import neuroml +from neuroml.loaders import read_neuroml2_file logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) def extract_position_info( @@ -36,6 +39,14 @@ def extract_position_info( :rtype: tuple of dicts """ + nml_model_copy = copy.deepcopy(nml_model) + + # add any included cells to the main document + for inc in nml_model_copy.includes: + inc = read_neuroml2_file(inc.href) + for acell in inc.cells: + nml_model_copy.add(acell) + cell_id_vs_cell = {} positions = {} pop_id_vs_cell = {} @@ -43,18 +54,18 @@ def extract_position_info( pop_id_vs_radii = {} cell_elements = [] - cell_elements.extend(nml_model.cells) - cell_elements.extend(nml_model.cell2_ca_poolses) + cell_elements.extend(nml_model_copy.cells) + cell_elements.extend(nml_model_copy.cell2_ca_poolses) for cell in cell_elements: cell_id_vs_cell[cell.id] = cell - if len(nml_model.networks) > 0: - popElements = nml_model.networks[0].populations + if len(nml_model_copy.networks) > 0: + popElements = nml_model_copy.networks[0].populations else: popElements = [] net = neuroml.Network(id="x") - nml_model.networks.append(net) + nml_model_copy.networks.append(net) cell_str = "" for cell in cell_elements: pop = neuroml.Population( @@ -64,7 +75,7 @@ def extract_position_info( cell_str += cell.id + "__" net.id = cell_str[:-2] - popElements = nml_model.networks[0].populations + popElements = nml_model_copy.networks[0].populations for pop in popElements: name = pop.id @@ -133,3 +144,124 @@ def convert_case(name): """Converts from camelCase to under_score""" s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def rotate_cell( + cell: neuroml.Cell, + x: float = 0, + y: float = 0, + z: float = 0, + order: str = "xyz", + relative_to_soma: bool = False +) -> neuroml.Cell: + """Return a new cell object rotated in the provided order along the + provided angles (in radians) relative to the soma position. + + :param cell: cell object to rotate + :type cell: neuroml.Cell + :param x: angle to rotate around x axis, in radians + :type x: float + :param y: angle to rotate around y axis, in radians + :type y: float + :param z: angle to rotate around z axis, in radians + :type z: float + :param order: rotation order in terms of x, y, and z + :type order: str + :param relative_to_soma: whether rotation is relative to soma + :type relative_to_soma: bool + :returns: new neuroml.Cell object + :rtype: neuroml.Cell + + Derived from LFPy's implementation: + https://github.com/LFPy/LFPy/blob/master/LFPy/cell.py#L1600 + """ + + valid_orders = [ + "xyz", "yzx", "zxy", "xzy", "yxz", "zyx" + ] + if order not in valid_orders: + raise ValueError(f"order must be one of {valid_orders}") + + soma_seg_id = cell.get_morphology_root() + soma_seg = cell.get_segment(soma_seg_id) + cell_origin = numpy.array([soma_seg.proximal.x, soma_seg.proximal.y, soma_seg.proximal.z]) + newcell = copy.deepcopy(cell) + print(f"Rotating {newcell.id} by {x}, {y}, {z}") + + # calculate rotations + if x != 0: + anglex = x + rotation_x = numpy.array([[1, 0, 0], + [0, math.cos(anglex), -math.sin(anglex)], + [0, math.sin(anglex), math.cos(anglex)] + ]) + logger.debug(f"x matrix is: {rotation_x}") + + if y != 0: + angley = y + rotation_y = numpy.array([[math.cos(angley), 0, math.sin(angley)], + [0, 1, 0], + [-math.sin(angley), 0, math.cos(angley)] + ]) + logger.debug(f"y matrix is: {rotation_y}") + + if z != 0: + anglez = z + rotation_z = numpy.array([[math.cos(anglez), -math.sin(anglez), 0], + [math.sin(anglez), math.cos(anglez), 0], + [0, 0, 1] + ]) + logger.debug(f"z matrix is: {rotation_z}") + + # rotate each segment + for aseg in newcell.morphology.segments: + prox = dist = numpy.array([]) + # may not have a proximal + try: + prox = numpy.array([aseg.proximal.x, aseg.proximal.y, aseg.proximal.z]) + except AttributeError: + pass + + # must have distal + dist = numpy.array([aseg.distal.x, aseg.distal.y, aseg.distal.z]) + + if relative_to_soma: + if prox.any(): + prox = prox - cell_origin + dist = dist - cell_origin + + # rotate + for axis in order: + if axis == 'x' and x != 0: + if prox.any(): + prox = numpy.dot(prox, rotation_x) + dist = numpy.dot(dist, rotation_x) + + if axis == 'y' and y != 0: + if prox.any(): + prox = numpy.dot(prox, rotation_y) + dist = numpy.dot(dist, rotation_y) + + if axis == 'z' and z != 0: + if prox.any(): + prox = numpy.dot(prox, rotation_z) + dist = numpy.dot(dist, rotation_z) + + if relative_to_soma: + if prox.any(): + prox = prox + cell_origin + dist = dist + cell_origin + + if prox.any(): + aseg.proximal.x = prox[0] + aseg.proximal.y = prox[1] + aseg.proximal.z = prox[2] + + aseg.distal.x = dist[0] + aseg.distal.y = dist[1] + aseg.distal.z = dist[2] + + logger.debug(f"prox is: {aseg.proximal}") + logger.debug(f"distal is: {aseg.distal}") + + return newcell diff --git a/pyneuroml/utils/plot.py b/pyneuroml/utils/plot.py index 82af7e20..3f5c1d58 100644 --- a/pyneuroml/utils/plot.py +++ b/pyneuroml/utils/plot.py @@ -8,7 +8,6 @@ """ import logging -import textwrap import numpy import typing import random @@ -24,66 +23,18 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -VISPY_THEME = { - "light": {"bg": "white", "fg": "black"}, - "dark": {"bg": "black", "fg": "white"}, -} -PYNEUROML_VISPY_THEME = "light" - -try: - from vispy import scene - from vispy.scene import SceneCanvas - from vispy.app import Timer - - - def add_text_to_vispy_3D_plot( - current_scene: SceneCanvas, - xv: typing.List[float], - yv: typing.List[float], - zv: typing.List[float], - color: str, - text: str, - ): - """Add text to a vispy plot between two points. - - Wrapper around vispy.scene.visuals.Text - - Rotates the text label to ensure it is at the same angle as the line. - - :param scene: vispy scene object - :type scene: SceneCanvas - :param xv: start and end coordinates in one axis - :type xv: list[x1, x2] - :param yv: start and end coordinates in second axis - :type yv: list[y1, y2] - :param zv: start and end coordinates in third axix - :type zv: list[z1, z2] - :param color: color of text - :type color: str - :param text: text to write - :type text: str - """ - - angle = int(numpy.rad2deg(numpy.arctan2((yv[1] - yv[0]), (xv[1] - xv[0])))) - if angle > 90: - angle -= 180 - elif angle < -90: - angle += 180 - - return scene.Text( - pos=((xv[0] + xv[1]) / 2, (yv[0] + yv[1]) / 2, (zv[0] + zv[1]) / 2), - text=text, - color=color, - rotation=angle, - parent=current_scene, - ) - -except: - print("\n**************************\n Please install vispy to use 3D plotting features!\n**************************") - - - +DEFAULTS = { + "v": False, + "nogui": False, + "saveToFile": None, + "interactive3d": False, + "plane2d": "xy", + "minWidth": 0.8, + "square": False, + "plotType": "constant", + "theme": "light", +} # type: dict[str, typing.Any] def add_text_to_matplotlib_2D_plot( @@ -356,212 +307,6 @@ def add_line_to_matplotlib_2D_plot(ax, xv, yv, width, color, axis_min_max): axis_min_max[1] = max(axis_min_max[1], xv[1]) -def create_new_vispy_canvas( - view_min: typing.Optional[typing.List[float]] = None, - view_max: typing.Optional[typing.List[float]] = None, - title: str = "", - console_font_size: float = 10, - axes_pos: typing.Optional[typing.List] = None, - axes_length: float = 100, - axes_width: int = 2, - theme=PYNEUROML_VISPY_THEME, -): - """Create a new vispy scene canvas with a view and optional axes lines - - Reference: https://vispy.org/gallery/scene/axes_plot.html - - :param view_min: min view co-ordinates - :type view_min: [float, float, float] - :param view_max: max view co-ordinates - :type view_max: [float, float, float] - :param title: title of plot - :type title: str - :param axes_pos: position to draw axes at - :type axes_pos: [float, float, float] - :param axes_length: length of axes - :type axes_length: float - :param axes_width: width of axes lines - :type axes_width: float - :returns: scene, view - """ - canvas = scene.SceneCanvas( - keys="interactive", - show=True, - bgcolor=VISPY_THEME[theme]["bg"], - size=(800, 600), - title="NeuroML viewer (VisPy)", - ) - grid = canvas.central_widget.add_grid(margin=10) - grid.spacing = 0 - - title_widget = scene.Label(title, color=VISPY_THEME[theme]["fg"]) - title_widget.height_max = 80 - grid.add_widget(title_widget, row=0, col=0, col_span=1) - - console_widget = scene.Console( - text_color=VISPY_THEME[theme]["fg"], - font_size=console_font_size, - ) - console_widget.height_max = 80 - grid.add_widget(console_widget, row=2, col=0, col_span=1) - - bottom_padding = grid.add_widget(row=3, col=0, col_span=1) - bottom_padding.height_max = 10 - - view = grid.add_view(row=1, col=0, border_color=None) - - # create cameras - # https://vispy.org/gallery/scene/flipped_axis.html - cam1 = scene.cameras.PanZoomCamera(parent=view.scene, name="PanZoom") - - cam2 = scene.cameras.TurntableCamera(parent=view.scene, name="Turntable") - - cam3 = scene.cameras.ArcballCamera(parent=view.scene, name="Arcball") - - cam4 = scene.cameras.FlyCamera(parent=view.scene, name="Fly") - # do not keep z up - cam4.autoroll = False - - cams = [cam4, cam2] - - # console text - console_text = "Controls: reset view: 0; quit: Esc/9" - if len(cams) > 1: - console_text += "; cycle camera: 1, 2 (fwd/bwd)" - - cam_text = { - cam1: textwrap.dedent( - """ - Left mouse button: pans view; right mouse button or scroll: - zooms""" - ), - cam2: textwrap.dedent( - """ - Left mouse button: orbits view around center point; right mouse - button or scroll: change zoom level; Shift + left mouse button: - translate center point; Shift + right mouse button: change field of - view; r/R: view rotation animation""" - ), - cam3: textwrap.dedent( - """ - Left mouse button: orbits view around center point; right - mouse button or scroll: change zoom level; Shift + left mouse - button: translate center point; Shift + right mouse button: change - field of view""" - ), - cam4: textwrap.dedent( - """ - Arrow keys/WASD to move forward/backwards/left/right; F/C to move - up and down; Space to brake; Left mouse button/I/K/J/L to control - pitch and yaw; Q/E for rolling""" - ), - } - - # Turntable is default - cam_index = 1 - view.camera = cams[cam_index] - - if view_min is not None and view_max is not None: - view_center = (numpy.array(view_max) + numpy.array(view_min)) / 2 - logger.debug(f"Center is {view_center}") - cam1.center = [view_center[0], view_center[1]] - cam2.center = view_center - cam3.center = view_center - cam4.center = view_center - - for acam in cams: - x_width = abs(view_min[0] - view_max[0]) - y_width = abs(view_min[1] - view_max[1]) - z_width = abs(view_min[2] - view_max[2]) - - xrange = ( - (view_min[0] - x_width * 0.02, view_max[0] + x_width * 0.02) - if x_width > 0 - else (-100, 100) - ) - yrange = ( - (view_min[1] - y_width * 0.02, view_max[1] + y_width * 0.02) - if y_width > 0 - else (-100, 100) - ) - zrange = ( - (view_min[2] - z_width * 0.02, view_max[2] + z_width * 0.02) - if z_width > 0 - else (-100, 100) - ) - logger.debug(f"{xrange}, {yrange}, {zrange}") - - acam.set_range(x=xrange, y=yrange, z=zrange) - - for acam in cams: - acam.set_default_state() - - console_widget.write(f"Center: {view.camera.center}") - console_widget.write(console_text) - console_widget.write( - f"Current camera: {view.camera.name}: " - + cam_text[view.camera].replace("\n", " ").strip() - ) - - if axes_pos: - points = [ - axes_pos, # origin - [axes_pos[0] + axes_length, axes_pos[1], axes_pos[2]], - [axes_pos[0], axes_pos[1] + axes_length, axes_pos[2]], - [axes_pos[0], axes_pos[1], axes_pos[2] + axes_length], - ] - scene.Line( - points, - connect=numpy.array([[0, 1], [0, 2], [0, 3]]), - parent=view.scene, - color=VISPY_THEME[theme]["fg"], - width=axes_width, - ) - - def vispy_rotate(self): - view.camera.orbit(azim=1, elev=0) - - rotation_timer = Timer(connect=vispy_rotate) - - @canvas.events.key_press.connect - def vispy_on_key_press(event): - nonlocal cam_index - - # Disable camera cycling. The fly camera looks sufficient. - # Keeping views/ranges same when switching cameras is not simple. - # Prev - if event.text == "1": - cam_index = (cam_index - 1) % len(cams) - view.camera = cams[cam_index] - # next - elif event.text == "2": - cam_index = (cam_index + 1) % len(cams) - view.camera = cams[cam_index] - # for turntable only: rotate animation - elif event.text == "R" or event.text == "r": - if view.camera == cam2: - if rotation_timer.running: - rotation_timer.stop() - else: - rotation_timer.start() - # reset - elif event.text == "0": - view.camera.reset() - # quit - elif event.text == "9": - canvas.app.quit() - - console_widget.clear() - # console_widget.write(f"Center: {view.camera.center}") - console_widget.write(console_text) - console_widget.write( - f"Current camera: {view.camera.name}: " - + cam_text[view.camera].replace("\n", " ").strip() - ) - - return scene, view - - def get_cell_bound_box(cell: Cell): """Get a boundary box for a cell diff --git a/setup.cfg b/setup.cfg index b6ee9ea1..a9a74a3f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = pyNeuroML -version = 1.0.8 +version = 1.0.9 author = Padraig Gleeson author_email = p.gleeson@gmail.com url = https://github.com/NeuroML/pyNeuroML diff --git a/tests/plot/test_morphology_plot.py b/tests/plot/test_morphology_plot.py index 67831da0..15162521 100644 --- a/tests/plot/test_morphology_plot.py +++ b/tests/plot/test_morphology_plot.py @@ -17,10 +17,13 @@ from pyneuroml.plot.PlotMorphology import ( plot_2D, plot_2D_cell_morphology, - plot_3D_cell_morphology_plotly, plot_2D_schematic, plot_segment_groups_curtain_plots, - plot_2D_point_cells, +) +from pyneuroml.plot.PlotMorphologyPlotly import ( + plot_3D_cell_morphology_plotly, +) +from pyneuroml.plot.PlotMorphologyVispy import ( plot_3D_schematic, plot_3D_cell_morphology, plot_interactive_3D, diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index ff0dff68..14a23ee8 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -10,9 +10,11 @@ import logging import pathlib as pl +import math -from pyneuroml.pynml import read_neuroml2_file -from pyneuroml.utils import extract_position_info +import neuroml +from pyneuroml.pynml import read_neuroml2_file, write_neuroml2_file +from pyneuroml.utils import extract_position_info, rotate_cell from .. import BaseTestCase @@ -46,3 +48,48 @@ def test_extract_position_info(self): for c in ["HL23PV", "HL23PYR", "HL23VIP", "HL23SST"]: self.assertIn(c, cell_id_vs_cell.keys()) + + def test_rotate_cell(self): + """Test rotate_cell""" + acell = neuroml.utils.component_factory("Cell", id="test_cell", validate=False) # type: neuroml.Cell + + soma = acell.add_segment(prox=[0, 0, 0, 15], dist=[0, 0, 0, 15], seg_id=0, + use_convention=False, reorder_segment_groups=False, + optimise_segment_groups=False) + + acell.add_segment(prox=[0, 0, 0, 12], dist=[100, 0, 0, 12], seg_id=1, + use_convention=False, reorder_segment_groups=False, + optimise_segment_groups=False, parent=soma) + + acell.add_segment(prox=[0, 0, 0, 7], dist=[0, 150, 0, 7], seg_id=2, + use_convention=False, reorder_segment_groups=False, + optimise_segment_groups=False, parent=soma) + + acell.add_segment(prox=[0, 0, 0, 4], dist=[0, 0, 200, 4], seg_id=3, + use_convention=False, reorder_segment_groups=False, + optimise_segment_groups=False, parent=soma) + + print(acell) + + rotated_cell = rotate_cell(acell, x=math.pi / 20, y=0, z=0, order="xyz") + rotated_cell.id = "test_rotated_cell" + print(rotated_cell) + + newdoc = neuroml.utils.component_factory("NeuroMLDocument", + id="test_doc") # type: neuroml.NeuroMLDocument + newdoc.add(acell) + newdoc.add(rotated_cell) + + net = newdoc.add("Network", id="test_net", validate=False) + pop1 = net.add("Population", id="test_pop1", size=1, component=acell.id, + type="populationList", validate=False) + pop1.add("Instance", id=0, location=pop1.component_factory("Location", x=0, y=0, z=0)) + + pop2 = net.add("Population", id="test_pop2", size=1, + component=rotated_cell.id, + type="populationList", validate=False) + pop2.add("Instance", id=0, location=pop1.component_factory("Location", + x=200, y=0, z=0)) + + newdoc.validate(recursive=True) + write_neuroml2_file(newdoc, "tests/utils/test_rotation.net.nml", validate=True)