diff --git a/igor2/binarywave.py b/igor2/binarywave.py index 957ff6c..4e831e8 100644 --- a/igor2/binarywave.py +++ b/igor2/binarywave.py @@ -1,5 +1,6 @@ """Read IGOR Binary Wave files into Numpy arrays.""" import logging +import threading as _threading # Based on WaveMetric's Technical Note 003, "Igor Binary Format" # ftp://ftp.wavemetrics.net/IgorPro/Technical_Notes/TN003.zip # From ftp://ftp.wavemetrics.net/IgorPro/Technical_Notes/TN000.txt @@ -13,11 +14,13 @@ from .struct import DynamicStructure as _DynamicStructure from .struct import Field as _Field from .struct import DynamicField as _DynamicField +from .struct import clone_structure as _clone_structure from .util import byte_order as _byte_order from .util import need_to_reorder_bytes as _need_to_reorder_bytes logger = logging.getLogger(__name__) +_thread_local = _threading.local() # Numpy doesn't support complex integers by default, see # http://mail.python.org/pipermail/python-dev/2002-April/022408.html @@ -625,15 +628,15 @@ def post_unpack(self, parents, data): else: need_to_reorder_bytes = False + version_map = getattr(wave_structure, '_wave_versions', { + 1: Wave1, + 2: Wave2, + 3: Wave3, + 5: Wave5, + }) old_format = wave_structure.fields[-1].format - if version == 1: - wave_structure.fields[-1].format = Wave1 - elif version == 2: - wave_structure.fields[-1].format = Wave2 - elif version == 3: - wave_structure.fields[-1].format = Wave3 - elif version == 5: - wave_structure.fields[-1].format = Wave5 + if version in version_map: + wave_structure.fields[-1].format = version_map[version] elif not need_to_reorder_bytes: raise ValueError( 'invalid binary wave version: {}'.format(version)) @@ -795,6 +798,10 @@ def post_unpack(self, parents, data): def setup_wave(byte_order='='): + wave1 = _clone_structure(Wave1) + wave2 = _clone_structure(Wave2) + wave3 = _clone_structure(Wave3) + wave5 = _clone_structure(Wave5) wave = _DynamicStructure( name='Wave', fields=[ @@ -803,22 +810,43 @@ def setup_wave(byte_order='='): 'version', help='Version number for backwards compatibility.'), DynamicWaveField( - Wave1, + wave1, 'wave', help='The rest of the wave data.'), ], byte_order=byte_order) + wave._wave_versions = { + 1: wave1, + 2: wave2, + 3: wave3, + 5: wave5, + } wave.setup() return wave +def _get_thread_local_wave(): + wave = getattr(_thread_local, 'wave', None) + if wave is None: + wave = setup_wave(byte_order='=') + _thread_local.wave = wave + return wave + + +def _reset_wave_parser(wave): + wave.byte_order = '=' + wave.fields[-1].format = wave._wave_versions[1] + wave.setup() + + def load(filename): if hasattr(filename, 'read'): f = filename # filename is actually a stream object else: f = open(filename, 'rb') try: - wave = setup_wave() + wave = _get_thread_local_wave() + _reset_wave_parser(wave) data = wave.unpack_stream(f) finally: if not hasattr(filename, 'read'): diff --git a/igor2/record/variables.py b/igor2/record/variables.py index 95db9b7..47933ca 100644 --- a/igor2/record/variables.py +++ b/igor2/record/variables.py @@ -1,5 +1,6 @@ import io as _io import logging +import threading as _threading from ..binarywave import TYPE_TABLE as _TYPE_TABLE from ..binarywave import NullStaticStringField as _NullStaticStringField @@ -8,12 +9,14 @@ from ..struct import DynamicStructure as _DynamicStructure from ..struct import Field as _Field from ..struct import DynamicField as _DynamicField +from ..struct import clone_structure as _clone_structure from ..util import byte_order as _byte_order from ..util import need_to_reorder_bytes as _need_to_reorder_bytes from .base import Record logger = logging.getLogger(__name__) +_thread_local = _threading.local() class ListedStaticStringField(_NullStaticStringField): @@ -297,11 +300,13 @@ def post_unpack(self, parents, data): else: need_to_reorder_bytes = False + version_map = getattr(variables_structure, '_version_structures', { + 1: Variables1, + 2: Variables2, + }) old_format = variables_structure.fields[-1].format - if version == 1: - variables_structure.fields[-1].format = Variables1 - elif version == 2: - variables_structure.fields[-1].format = Variables2 + if version in version_map: + variables_structure.fields[-1].format = version_map[version] elif not need_to_reorder_bytes: raise ValueError( 'invalid variables record version: {}'.format(version)) @@ -318,26 +323,52 @@ def post_unpack(self, parents, data): return need_to_reorder_bytes -VariablesRecordStructure = _DynamicStructure( - name='VariablesRecord', - fields=[ - DynamicVersionField( - 'h', 'version', help='Version number for this header.'), - _Field( - Variables1, - 'variables', - help='The rest of the variables data.'), - ]) +def setup_variables_record(byte_order='='): + variables1 = _clone_structure(Variables1) + variables2 = _clone_structure(Variables2) + variables_record_structure = _DynamicStructure( + name='VariablesRecord', + fields=[ + DynamicVersionField( + 'h', 'version', help='Version number for this header.'), + _Field( + variables1, + 'variables', + help='The rest of the variables data.'), + ], + byte_order=byte_order) + variables_record_structure._version_structures = { + 1: variables1, + 2: variables2, + } + variables_record_structure.setup() + return variables_record_structure + + +def _get_thread_local_variables_record(): + variables_record_structure = getattr( + _thread_local, 'variables_record_structure', None) + if variables_record_structure is None: + variables_record_structure = setup_variables_record(byte_order='=') + _thread_local.variables_record_structure = variables_record_structure + return variables_record_structure + + +def _reset_variables_record_parser(variables_record_structure): + variables_record_structure.byte_order = '=' + variables_record_structure.fields[-1].format = ( + variables_record_structure._version_structures[1]) + variables_record_structure.setup() class VariablesRecord (Record): def __init__(self, *args, **kwargs): super(VariablesRecord, self).__init__(*args, **kwargs) # self.header['version'] # record version always 0? - VariablesRecordStructure.byte_order = '=' - VariablesRecordStructure.setup() + variables_record_structure = _get_thread_local_variables_record() + _reset_variables_record_parser(variables_record_structure) stream = _io.BytesIO(bytes(self.data)) - self.variables = VariablesRecordStructure.unpack_stream(stream) + self.variables = variables_record_structure.unpack_stream(stream) self.namespace = {} for key, value in self.variables['variables'].items(): if key not in ['var_header']: diff --git a/igor2/struct.py b/igor2/struct.py index b7ad6f3..c25b9e5 100644 --- a/igor2/struct.py +++ b/igor2/struct.py @@ -826,3 +826,42 @@ def unpack_from(self, buffer, offset=0, *args, **kwargs): args = super(Structure, self).unpack_from( buffer, offset, *args, **kwargs) return self._unpack_item(args) + + +def clone_structure(structure, _memo=None): + """Recursively clone a Structure/DynamicStructure tree. + + struct.Struct-derived instances cannot be copied with copy.copy/deepcopy, + but we need independent parser instances for thread-safe dynamic unpacking. + """ + if _memo is None: + _memo = {} + sid = id(structure) + if sid in _memo: + return _memo[sid] + + clone = structure.__class__( + name=structure.name, + fields=[], + byte_order=structure.byte_order, + ) + _memo[sid] = clone + + fields = [] + for field in structure.fields: + field_format = field.format + if isinstance(field_format, Structure): + field_format = clone_structure(field_format, _memo=_memo) + field_clone = field.__class__( + field_format, + field.name, + default=field.default, + help=field.help, + count=field.count, + array=field.array, + ) + fields.append(field_clone) + + clone.fields = fields + clone.setup() + return clone diff --git a/tests/test_pxp.py b/tests/test_pxp.py index 3f258cc..2e2969c 100644 --- a/tests/test_pxp.py +++ b/tests/test_pxp.py @@ -1,7 +1,9 @@ +import hashlib import threading import numpy as np +from igor2.binarywave import load as loadibw from igor2.packed import load as loadpxp from helper import data_dir @@ -13,6 +15,67 @@ def tostr(data): return data +def _array_fingerprint(data): + # Array signature for equality checks + array = np.ascontiguousarray(data) + return ( + str(array.dtype), + tuple(int(i) for i in array.shape), + hashlib.sha256(array.tobytes()).hexdigest(), + ) + + +def _ibw_fingerprint(path): + data = loadibw(path) + wave = data["wave"]["wData"] + header = data["wave"]["wave_header"] + # Include metadata plus payload signature for equality checks. + return ( + int(data["version"]), + tostr(header["bname"]), + _array_fingerprint(wave), + ) + + +def _pxp_fingerprint(path, initial_byte_order): + records, filesystem = loadpxp(path, initial_byte_order=initial_byte_order) + # Check both tree shape and wave payloads for equality checks. + root_keys = tuple(sorted(tostr(key) for key in filesystem["root"].keys())) + waves = [] + for record in records: + if hasattr(record, "wave"): + wave = record.wave["wave"]["wData"] + name = tostr(record.wave["wave"]["wave_header"]["bname"]) + waves.append((name,) + _array_fingerprint(wave)) + return (len(records), root_keys, tuple(waves)) + + +def _run_concurrent_workload(worker_count, iterations_per_worker, task): + barrier = threading.Barrier(worker_count) + errors = [] + lock = threading.Lock() + + def worker(thread_id): + try: + barrier.wait() + for iteration in range(iterations_per_worker): + task(thread_id, iteration) + except Exception as exc: + with lock: + errors.append(f"thread {thread_id}: {exc!r}") + + threads = [] + for thread_id in range(worker_count): + thread = threading.Thread(target=worker, args=(thread_id,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert not errors, "\n".join(errors[:10]) + + def test_pxp(): data = loadpxp(data_dir / 'polar-graphs-demo.pxp') records = data[0] @@ -157,22 +220,50 @@ def test_pxt(): def test_thread_safe(): + jobs = [ + (data_dir / "polar-graphs-demo.pxp", None), + (data_dir / "packed-byteorder.pxt", ">"), + ] + expected = {job: _pxp_fingerprint(*job) for job in jobs} - def worker(fileobj, thread_id): - expt = None - for bo in ('<', '>'): - try: - _, expt = loadpxp(fileobj, initial_byte_order=bo) - except ValueError: - pass - if expt is None: - raise ValueError(f"No experiment loaded for thread {thread_id}") + def task(thread_id, iteration): + job = jobs[(thread_id + iteration) % len(jobs)] + assert _pxp_fingerprint(*job) == expected[job] + + _run_concurrent_workload( + worker_count=32, + iterations_per_worker=12, + task=task, + ) - threads = [] - for i, fname in enumerate([data_dir / 'packed-byteorder.pxt'] * 100): - t = threading.Thread(target=worker, args=(fname, i)) - threads.append(t) - t.start() - for t in threads: - t.join() +def test_thread_safe_mixed(): + ibw_jobs = [ + data_dir / "mac-double.ibw", + data_dir / "win-double.ibw", + data_dir / "mac-version5.ibw", + data_dir / "win-version5.ibw", + ] + pxp_jobs = [ + (data_dir / "polar-graphs-demo.pxp", None), + (data_dir / "packed-byteorder.pxt", ">"), + ] + + expected_ibw = {job: _ibw_fingerprint(job) for job in ibw_jobs} + expected_pxp = {job: _pxp_fingerprint(*job) for job in pxp_jobs} + all_jobs = ( + [("ibw", job) for job in ibw_jobs] + [("pxp", job) for job in pxp_jobs] + ) + + def task(thread_id, iteration): + kind, payload = all_jobs[(thread_id * 3 + iteration) % len(all_jobs)] + if kind == "ibw": + assert _ibw_fingerprint(payload) == expected_ibw[payload] + else: + assert _pxp_fingerprint(*payload) == expected_pxp[payload] + + _run_concurrent_workload( + worker_count=32, + iterations_per_worker=10, + task=task, + )