Skip to content

Commit

Permalink
Merge pull request #417 from moloney/enh-csa-write
Browse files Browse the repository at this point in the history
ENH: Add writer for Siemens CSA header
  • Loading branch information
effigies authored Sep 5, 2024
2 parents 977d044 + 122a923 commit 09d3f95
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
110 changes: 110 additions & 0 deletions nibabel/nicom/csareader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""CSA header reader from SPM spec"""

import numpy as np
import struct

from .structreader import Unpacker
from .utils import find_private_section
Expand Down Expand Up @@ -28,6 +29,10 @@ class CSAReadError(CSAError):
pass


class CSAWriteError(CSAError):
pass


def get_csa_header(dcm_data, csa_type='image'):
"""Get CSA header information from DICOM header
Expand Down Expand Up @@ -161,6 +166,96 @@ def read(csa_str):
return csa_dict


def write(csa_header):
''' Write string from CSA header `csa_header`
Parameters
----------
csa_header : dict
header information as dict, where `header` has fields (at least)
``type, n_tags, tags``. ``header['tags']`` is also a dictionary
with one key, value pair for each tag in the header.
Returns
-------
csa_str : str
byte string containing CSA header information
'''
result = []
if csa_header['type'] == 2:
result.append(b'SV10')
result.append(csa_header['unused0'])
if not 0 < csa_header['n_tags'] <= 128:
raise CSAWriteError('Number of tags `t` should be '
'0 < t <= 128')
result.append(struct.pack('2I',
csa_header['n_tags'],
csa_header['check'])
)

# Build list of tags in correct order
tags = list(csa_header['tags'].items())
tags.sort(key=lambda x: x[1]['tag_no'])
tag0_n_items = tags[0][1]['n_items']

# Add the information for each tag
for tag_name, tag_dict in tags:
vm = tag_dict['vm']
vr = tag_dict['vr']
n_items = tag_dict['n_items']
assert n_items < 100
result.append(struct.pack('64si4s3i',
make_nt_str(tag_name),
vm,
make_nt_str(vr),
tag_dict['syngodt'],
n_items,
tag_dict['last3'])
)

# Figure out the number of values for this tag
if vm == 0:
n_values = n_items
else:
n_values = vm

# Add each item for this tag
for item_no in range(n_items):
# Figure out the item length
if item_no >= n_values or tag_dict['items'][item_no] == '':
item_len = 0
else:
item = tag_dict['items'][item_no]
if not isinstance(item, str):
item = str(item)
item_nt_str = make_nt_str(item)
item_len = len(item_nt_str)

# These values aren't actually preserved in the dict
# representation of the header. Best we can do is set the ones
# that determine the item length appropriately.
x0, x1, x2, x3 = 0, 0, 0, 0
if csa_header['type'] == 1: # CSA1 - odd length calculation
x0 = tag0_n_items + item_len
if item_len < 0 or (ptr + item_len) > csa_len:
if item_no < vm:
items.append('')
break
else: # CSA2
x1 = item_len
result.append(struct.pack('4i', x0, x1, x2, x3))

if item_len == 0:
continue

result.append(item_nt_str)
# go to 4 byte boundary
plus4 = item_len % 4
if plus4 != 0:
result.append(b'\x00' * (4 - plus4))
return b''.join(result)


def get_scalar(csa_dict, tag_name):
try:
items = csa_dict['tags'][tag_name]['items']
Expand Down Expand Up @@ -258,3 +353,18 @@ def nt_str(s):
if zero_pos == -1:
return s
return s[:zero_pos].decode('latin-1')


def make_nt_str(s):
''' Create a null terminated byte string from a unicode object.
Parameters
----------
s : unicode
Returns
-------
result : bytes
s encoded as latin-1 with a null char appended
'''
return s.encode('latin-1') + b'\x00'
11 changes: 11 additions & 0 deletions nibabel/nicom/tests/test_csareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,14 @@ def test_missing_csa_elem():
del dcm[csa_tag]
hdr = csa.get_csa_header(dcm, 'image')
assert hdr is None


def test_read_write_rt():
# Try doing a read-write-read round trip and make sure the dictionary
# representation of the header is the same. We can't exactly reproduce the
# original string representation currently.
for csa_str in (CSA2_B0, CSA2_B1000):
csa_info = csa.read(csa_str)
new_csa_str = csa.write(csa_info)
new_csa_info = csa.read(new_csa_str)
assert csa_info == new_csa_info

0 comments on commit 09d3f95

Please sign in to comment.