Skip to content

Commit

Permalink
feat: pull neuron ground truths from allen brain (#48)
Browse files Browse the repository at this point in the history
* feat: pull neurons from allen brain

* update binary mask

* save stuff

* more docs and refactor

* add tests

* don't require requests

* lint
  • Loading branch information
tlambert03 authored Jun 16, 2024
1 parent c3d2c72 commit 96cf3cf
Show file tree
Hide file tree
Showing 7 changed files with 483 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np


def bres_draw_segment_2d(
def draw_line_2d(
y0: int, x0: int, y1: int, x1: int, grid: np.ndarray, max_r: float
) -> None:
"""Bresenham's algorithm.
Expand Down Expand Up @@ -41,8 +41,16 @@ def bres_draw_segment_2d(
y0 += sy


def bres_draw_segment_3d(
x0: int, y0: int, z0: int, x1: int, y1: int, z1: int, grid: np.ndarray, max_r: float
def draw_line_3d(
x0: int,
y0: int,
z0: int,
x1: int,
y1: int,
z1: int,
grid: np.ndarray,
max_r: float,
width: float = 1.0,
) -> None:
"""Bresenham's algorithm.
Expand All @@ -65,10 +73,18 @@ def bres_draw_segment_3d(
yr = grid.shape[1] / 2
xr = grid.shape[2] / 2

if max_r < 0:
max_r = sqrt(zr**2 + yr**2 + xr**2)

while True:
r = ((x0 - xr) / xr) ** 2 + ((y0 - yr) / yr) ** 2 + ((z0 - zr) / zr) ** 2
if sqrt(r) <= max_r:
grid[z0, y0, x0] += 1
if width != 1:
# Draw a sphere around the current point with the given width
draw_sphere(grid, x0, y0, z0, width)
else:
r = ((x0 - xr) / xr) ** 2 + ((y0 - yr) / yr) ** 2 + ((z0 - zr) / zr) ** 2
if sqrt(r) <= max_r:
grid[z0, y0, x0] += 1

if i == 0:
break

Expand All @@ -87,10 +103,24 @@ def bres_draw_segment_3d(
i -= 1


def draw_sphere(grid: np.ndarray, x0: int, y0: int, z0: int, radius: float) -> None:
"""Draw a sphere of a given radius around a point in a 3D grid."""
z_range = range(int(max(0, z0 - radius)), int(min(grid.shape[0], z0 + radius + 1)))
y_range = range(int(max(0, y0 - radius)), int(min(grid.shape[1], y0 + radius + 1)))
x_range = range(int(max(0, x0 - radius)), int(min(grid.shape[2], x0 + radius + 1)))
for z in z_range:
for y in y_range:
for x in x_range:
distance = (x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2
if distance <= radius**2:
grid[z, y, x] += 1


try:
from numba import jit
from numba import njit
except Exception:
pass
else:
bres_draw_segment_2d = jit(nopython=True)(bres_draw_segment_2d)
bres_draw_segment_3d = jit(nopython=True)(bres_draw_segment_3d)
draw_line_2d = njit(draw_line_2d)
draw_line_3d = njit(draw_line_3d)
draw_sphere = njit(draw_sphere)
15 changes: 15 additions & 0 deletions src/microsim/allen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from ._fetch import (
ApiCellTypesSpecimenDetail,
NeuronReconstruction,
Specimen,
get_reconstructions,
)
from ._swc import SWC

__all__ = [
"ApiCellTypesSpecimenDetail",
"get_reconstructions",
"NeuronReconstruction",
"Specimen",
"SWC",
]
228 changes: 228 additions & 0 deletions src/microsim/allen/_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
from __future__ import annotations

from functools import cache, cached_property
from typing import TYPE_CHECKING, Literal, cast

from pydantic import BaseModel, Field

from microsim.util import http_get

if TYPE_CHECKING:
from collections.abc import Iterable

import numpy as np

from ._swc import SWC

ALLEN_ROOT = "http://api.brain-map.org"
ALLEN_V2_API = f"{ALLEN_ROOT}/api/v2/data"
ALLEN_V2_QUERY = ALLEN_V2_API + "/query.json"
SWC_FILE_TYPE = "3DNeuronReconstruction"


class WellKnownFileType(BaseModel):
"""Model representing a well-known file type in the Allen Brain Map API."""

id: int
name: str # something like '3DNeuronReconstruction'


class WellKnownFile(BaseModel):
"""Model representing a file in the Allen Brain Map API."""

attachable_id: int | None
attachable_type: str | None
download_link: str | None
id: int | None
path: str | None
well_known_file_type_id: int | None
well_known_file_type: WellKnownFileType | None


class NeuronReconstruction(BaseModel):
"""Model representing a neuron reconstruction in the Allen Brain Map API."""

id: int
specimen_id: int
number_nodes: int
number_branches: int
number_stems: int
number_bifurcations: int
max_euclidean_distance: float
neuron_reconstruction_type: str
overall_height: float
overall_width: float
overall_depth: float
scale_factor_x: float
scale_factor_y: float
scale_factor_z: float
total_length: float
total_surface: float
total_volume: float
well_known_files: list[WellKnownFile] = Field(default_factory=list)

@property
def swc_path(self) -> str:
"""The SWC file for this reconstruction."""
for f in self.well_known_files:
if (
getattr(f.well_known_file_type, "name", None) == SWC_FILE_TYPE
and f.download_link
):
return ALLEN_ROOT + f.download_link
raise ValueError(
"No SWC file found for this reconstruction."
) # pragma: no cover

@cached_property
def swc(self) -> SWC:
"""Load the SWC file for this reconstruction."""
from ._swc import SWC

return SWC.from_path(self.swc_path)

def binary_mask(self, voxel_size: float = 1, scale_factor: float = 3) -> np.ndarray:
"""Return 3D binary mask for this neuron reconstructions."""
return self.swc.binary_mask(voxel_size=voxel_size, scale_factor=scale_factor)

@classmethod
@cache
def fetch(cls, id: int) -> NeuronReconstruction:
"""Fetch NeuronReconstruction by ID from the Allen brain map API."""
q = [
"model::NeuronReconstruction",
f"rma::criteria[id$eq{id}],well_known_files",
f"rma::include,well_known_files(well_known_file_type[name$eq'{SWC_FILE_TYPE}'])",
# get all rows
"rma::options[num_rows$eq'all']",
]
response = http_get(ALLEN_V2_QUERY, params={"q": ",".join(q)})
qr = _QueryResponse.model_validate_json(response)
if not qr.success: # pragma: no cover
raise ValueError(qr.msg)
return cast("NeuronReconstruction", qr.msg[0])

def specimen(self) -> Specimen:
"""Fetch the specimen that owns this neuron reconstruction."""
return Specimen.fetch(self.specimen_id)


class Structure(BaseModel):
"""Speciment structure model from the Allen Brain Map API."""

id: int
name: str
acronym: str
structure_id_path: str


class Specimen(BaseModel):
"""Model representing a specimen in the Allen Brain Map API."""

id: int
name: str
is_cell_specimen: bool
specimen_id_path: str
structure: Structure
neuron_reconstructions: list[NeuronReconstruction] = Field(default_factory=list)

@classmethod
@cache
def fetch(cls, id: int) -> Specimen:
"""Fetch this specimen from the Allen brain map API."""
q = [
# query the Specimen model
"model::Specimen",
# limit to the specimen with the given ID
# and join on NeuronReconstruction and WellKnownFile
f"rma::criteria[id$eq{id}],neuron_reconstructions(well_known_files)",
# include structure
# and neuron_reconstructions where the well_known_file_type is SWC
"rma::include,structure,neuron_reconstructions(well_known_files("
f"well_known_file_type[name$eq'{SWC_FILE_TYPE}']))",
# get all rows
"rma::options[num_rows$eq'all']",
]
response = http_get(ALLEN_V2_QUERY, params={"q": ",".join(q)})
qr = _QueryResponse.model_validate_json(response)
if not qr.success: # pragma: no cover
raise ValueError(qr.msg)
return cast("Specimen", qr.msg[0])

def binary_masks(
self, voxel_size: float = 1, scale_factor: float = 3
) -> list[np.ndarray]:
"""Return all binary masks for this specimen's neuron reconstructions."""
masks = []
for recon in self.neuron_reconstructions:
masks.append(
recon.binary_mask(voxel_size=voxel_size, scale_factor=scale_factor)
)
return masks

@property
def url(self) -> str:
"""Return the URL for this specimen on the Allen Brain Map."""
return f"http://celltypes.brain-map.org/experiment/morphology/{self.id}"

def open_webpage(self) -> None: # pragma: no cover
"""Open the webpage for this specimen in the Allen Brain Map."""
import webbrowser

webbrowser.open(self.url)


class ApiCellTypesSpecimenDetail(BaseModel):
"""Model representing Specimen details from the Allen Brain Map API."""

specimen__id: int
structure__name: str | None
structure__acronym: str | None
donor__species: Literal["Homo Sapiens", "Mus musculus"]
nr__reconstruction_type: str | None # probably just 'full' or 'dendrite-only'
nr__max_euclidean_distance: float | None
nr__number_bifurcations: int | None
nr__number_stems: int | None

@classmethod
@cache
def all_reconstructions(cls) -> tuple[ApiCellTypesSpecimenDetail, ...]:
"""Fetch details for all Specimens with reconstruction info."""
q = (
"model::ApiCellTypesSpecimenDetail",
"rma::criteria[nr__reconstruction_type$ne'null']",
"rma::options[num_rows$eq'all']",
)
response = http_get(ALLEN_V2_QUERY, params={"q": ",".join(q)})
qr = _QueryResponse.model_validate_json(response)
if not qr.success: # pragma: no cover
raise ValueError(qr.msg)
return tuple(qr.msg) # type: ignore[arg-type]

def specimen(self) -> Specimen:
"""Return associated Specimen object."""
return Specimen.fetch(self.specimen__id)


class _QueryResponse(BaseModel):
"""Query response from the Allen Brain Map API."""

success: bool
msg: (
list[NeuronReconstruction]
| list[Specimen]
| list[ApiCellTypesSpecimenDetail]
| str
)


def get_reconstructions(
species: Literal["Homo Sapiens", "Mus musculus"] | None = None,
reconstruction_type: Literal["full", "dendrite-only"] | None = None,
) -> tuple[ApiCellTypesSpecimenDetail, ...]:
recons: Iterable = ApiCellTypesSpecimenDetail.all_reconstructions()
if species is not None:
recons = (x for x in recons if x.donor__species == species)
if reconstruction_type is not None:
recons = (x for x in recons if x.nr__reconstruction_type == reconstruction_type)
return tuple(recons)
Loading

0 comments on commit 96cf3cf

Please sign in to comment.