Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FitFileEncoder for writing FIT files #58

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 8 additions & 26 deletions fitparse/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
import io
import os
import struct

# Python 2 compat
try:
num_types = (int, float, long)
str = basestring
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str is still used as a typecheck in get_messages in the following code:

names = set([
    int(n) if (isinstance(n, str) and n.isdigit()) else n
    for n in names
])

However, looking at it now, this code can probably be refactored to:

def try_int(obj):
    try:
        return int(obj)
    except ValueError:
        return obj

names = set(try_int(n) for n in names)

This would remove the typecheck so the shim wouldn't be needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this code should be removed - the caller should make sure a message number is int, not str if wants to use the number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pR0Ps Should I remove this from the FitFile?

except NameError:
num_types = (int, float)

from fitparse.processors import FitFileDataProcessor
from fitparse.profile import FIELD_TYPE_TIMESTAMP, MESSAGE_TYPES
from fitparse.records import (
@@ -94,6 +86,7 @@ def _parse_file_header(self):

# Initialize data
self._accumulators = {}
self.data_size = 0
self._bytes_left = -1
self._complete = False
self._compressed_ts_accumulator = 0
@@ -106,7 +99,7 @@ def _parse_file_header(self):
raise FitHeaderError("Invalid .FIT File Header")

# Larger fields are explicitly little endian from SDK
header_size, protocol_ver_enc, profile_ver_enc, data_size = self._read_struct('2BHI4x', data=header_data)
header_size, protocol_ver_enc, profile_ver_enc, self.data_size = self._read_struct('2BHI4x', data=header_data)

# Decode the same way the SDK does
self.protocol_version = float("%d.%d" % (protocol_ver_enc >> 4, protocol_ver_enc & ((1 << 4) - 1)))
@@ -127,7 +120,7 @@ def _parse_file_header(self):
self._read(extra_header_size - 2)

# After we've consumed the header, set the bytes left to be read
self._bytes_left = data_size
self._bytes_left = self.data_size

def _parse_message(self):
# When done, calculate the CRC and return None
@@ -239,7 +232,7 @@ def _parse_definition_message(self, header):
def _parse_raw_values_from_data_message(self, def_mesg):
# Go through mesg's field defs and read them
raw_values = []
for field_def in def_mesg.field_defs + def_mesg.dev_field_defs:
for field_def in def_mesg.all_field_defs():
base_type = field_def.base_type
is_byte = base_type.name == 'byte'
# Struct to read n base types (field def size / base type size)
@@ -277,18 +270,6 @@ def _resolve_subfield(field, def_mesg, raw_values):
return sub_field, field
return field, None

def _apply_scale_offset(self, field, raw_value):
# Apply numeric transformations (scale+offset)
if isinstance(raw_value, tuple):
# Contains multiple values, apply transformations to all of them
return tuple(self._apply_scale_offset(field, x) for x in raw_value)
elif isinstance(raw_value, num_types):
if field.scale:
raw_value = float(raw_value) / field.scale
if field.offset:
raw_value = raw_value - field.offset
return raw_value

@staticmethod
def _apply_compressed_accumulation(raw_value, accumulation, num_bits):
max_value = (1 << num_bits)
@@ -311,7 +292,7 @@ def _parse_data_message(self, header):

# TODO: Maybe refactor this and make it simpler (or at least broken
# up into sub-functions)
for field_def, raw_value in zip(def_mesg.field_defs + def_mesg.dev_field_defs, raw_values):
for field_def, raw_value in zip(def_mesg.all_field_defs(), raw_values):
field, parent_field = field_def.field, None
if field:
field, parent_field = self._resolve_subfield(field, def_mesg, raw_values)
@@ -332,7 +313,7 @@ def _parse_data_message(self, header):

# Apply scale and offset from component, not from the dynamic field
# as they may differ
cmp_raw_value = self._apply_scale_offset(component, cmp_raw_value)
cmp_raw_value = component.apply_scale_offset(cmp_raw_value)

# Extract the component's dynamic field from def_mesg
cmp_field = def_mesg.mesg_type.fields[component.def_num]
@@ -354,7 +335,8 @@ def _parse_data_message(self, header):

# TODO: Do we care about a base_type and a resolved field mismatch?
# My hunch is we don't
value = self._apply_scale_offset(field, field.render(raw_value))
value = field.render(raw_value)
value = field.apply_scale_offset(value)
else:
value = raw_value

436 changes: 436 additions & 0 deletions fitparse/encoder.py

Large diffs are not rendered by default.

83 changes: 73 additions & 10 deletions fitparse/processors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,36 @@
import datetime
from fitparse.utils import scrub_method_name

# Datetimes (uint32) represent seconds since this UTC_REFERENCE
UTC_REFERENCE = 631065600 # timestamp for UTC 00:00 Dec 31 1989
from fitparse.utils import scrub_method_name, fit_from_datetime, fit_to_datetime, fit_semicircles_to_deg


class FitFileDataProcessor(object):
class DataProcessorBase(object):
"""Empty, no-op fit file data processor."""
def run_type_processor(self, field_data):
pass

def unparse_type(self, field_data):
pass

def run_field_processor(self, field_data):
pass

def unparse_field(self, field_data):
pass

def run_unit_processor(self, field_data):
pass

def unparse_unit(self, field_data):
pass

def run_message_processor(self, data_message):
pass

def unparse_message(self, data_message):
pass


class FitFileDataProcessor(DataProcessorBase):
# TODO: Document API
# Functions that will be called to do the processing:
#def run_type_processor(field_data)
@@ -44,19 +69,36 @@ def run_type_processor(self, field_data):
self._run_processor(self._scrub_method_name(
'process_type_%s' % field_data.type.name), field_data)

def unparse_type(self, field_data):
self._run_processor(self._scrub_method_name(
'unparse_type_%s' % field_data.type.name), field_data)

def run_field_processor(self, field_data):
self._run_processor(self._scrub_method_name(
'process_field_%s' % field_data.name), field_data)

def unparse_field(self, field_data):
self._run_processor(self._scrub_method_name(
'unparse_field_%s' % field_data.name), field_data)

def run_unit_processor(self, field_data):
if field_data.units:
self._run_processor(self._scrub_method_name(
'process_units_%s' % field_data.units), field_data)

def unparse_unit(self, field_data):
if field_data.units:
self._run_processor(self._scrub_method_name(
'unparse_units_%s' % field_data.units), field_data)

def run_message_processor(self, data_message):
self._run_processor(self._scrub_method_name(
'process_message_%s' % data_message.def_mesg.name), data_message)

def unparse_message(self, data_message):
self._run_processor(self._scrub_method_name(
'unparse_message_%s' % data_message.def_mesg.name), data_message)

def _run_processor(self, processor_name, data):
try:
getattr(self, processor_name)(data)
@@ -67,27 +109,48 @@ def process_type_bool(self, field_data):
if field_data.value is not None:
field_data.value = bool(field_data.value)

def unparse_type_bool(self, field_data):
if field_data.value is not None:
field_data.raw_value = int(field_data.value)

def process_type_date_time(self, field_data):
value = field_data.value
if value is not None and value >= 0x10000000:
field_data.value = datetime.datetime.utcfromtimestamp(UTC_REFERENCE + value)
field_data.value = fit_to_datetime(value)
field_data.units = None # Units were 's', set to None

def unparse_type_date_time(self, field_data):
value = field_data.value
if value is not None and isinstance(value, datetime.datetime):
field_data.raw_value = fit_from_datetime(value)
field_data.units = 's'

def process_type_local_date_time(self, field_data):
if field_data.value is not None:
value = field_data.value
if value is not None:
# NOTE: This value was created on the device using it's local timezone.
# Unless we know that timezone, this value won't be correct. However, if we
# assume UTC, at least it'll be consistent.
field_data.value = datetime.datetime.utcfromtimestamp(UTC_REFERENCE + field_data.value)
field_data.value = fit_to_datetime(value)
field_data.units = None

def unparse_type_local_date_time(self, field_data):
self.unparse_type_date_time(field_data)

def process_type_localtime_into_day(self, field_data):
if field_data.value is not None:
m, s = divmod(field_data.value, 60)
value = field_data.value
if value is not None:
m, s = divmod(value, 60)
h, m = divmod(m, 60)
field_data.value = datetime.time(h, m, s)
field_data.units = None

def unparse_type_localtime_into_day(self, field_data):
value = field_data.value
if value is not None and isinstance(value, datetime.time):
field_data.raw_value = value.hour * 3600 + value.minute * 60 + value.second
field_data.units = 's'


class StandardUnitsDataProcessor(FitFileDataProcessor):
def run_field_processor(self, field_data):
@@ -112,5 +175,5 @@ def process_field_speed(self, field_data):

def process_units_semicircles(self, field_data):
if field_data.value is not None:
field_data.value *= 180.0 / (2 ** 31)
field_data.value = fit_semicircles_to_deg(field_data.value)
field_data.units = 'deg'
178 changes: 152 additions & 26 deletions fitparse/records.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import itertools
import math
import struct

# Python 2 compat
try:
# Python 2
int_types = (int, long,)
num_types = (int, float, long)
int_type = long
math_nan = float('nan')
byte_iter = bytearray
except NameError:
# Python 3
int_types = (int,)
num_types = (int, float)
int_type = int
math_nan = math.nan
byte_iter = lambda x: x

try:
@@ -55,14 +63,25 @@ def name(self):
return self.mesg_type.name if self.mesg_type else 'unknown_%d' % self.mesg_num

def __repr__(self):
return '<DefinitionMessage: %s (#%d) -- local mesg: #%d, field defs: [%s], dev field defs: [%s]>' % (
return '<DefinitionMessage: %s (#%s) -- local mesg: #%s, field defs: [%s], dev field defs: [%s]>' % (
self.name,
self.mesg_num,
self.header.local_mesg_num,
', '.join([fd.name for fd in self.field_defs]),
', '.join([fd.name for fd in self.dev_field_defs]),
)

def all_field_defs(self):
if not self.dev_field_defs:
return self.field_defs
return itertools.chain(self.field_defs, self.dev_field_defs)

def get_field_def(self, name):
for field_def in self.all_field_defs():
if field_def.is_named(name):
return field_def
return None


class FieldDefinition(RecordBase):
__slots__ = ('field', 'def_num', 'base_type', 'size')
@@ -76,13 +95,17 @@ def type(self):
return self.field.type if self.field else self.base_type

def __repr__(self):
return '<FieldDefinition: %s (#%d) -- type: %s (%s), size: %d byte%s>' % (
return '<FieldDefinition: %s (#%s) -- type: %s (%s), size: %s byte%s>' % (
self.name,
self.def_num,
self.type.name, self.base_type.name,
self.size, 's' if self.size != 1 else '',
)

def is_named(self, name):
return self.field.is_named(name)



class DevFieldDefinition(RecordBase):
__slots__ = ('field', 'dev_data_index', 'base_type', 'def_num', 'size')
@@ -156,7 +179,7 @@ def __iter__(self):
return iter(sorted(self.fields, key=lambda fd: (int(fd.field is None), fd.name)))

def __repr__(self):
return '<DataMessage: %s (#%d) -- local mesg: #%d, fields: [%s]>' % (
return '<DataMessage: %s (#%s) -- local mesg: #%s, fields: [%s]>' % (
self.name, self.mesg_num, self.header.local_mesg_num,
', '.join(["%s: %s" % (fd.name, fd.value) for fd in self.fields]),
)
@@ -237,13 +260,25 @@ def __str__(self):
)


class BaseType(RecordBase):
__slots__ = ('name', 'identifier', 'fmt', 'parse')
class BaseType(object):
__slots__ = ('name', 'identifier', 'fmt', 'invalid_value', 'parse', 'unparse', 'in_range', '_size')
values = None # In case we're treated as a FieldType

def __init__(self, name, identifier, fmt, invalid_value=None, parse=None, unparse=None, in_range=None):
self.name = name
self.identifier = identifier
self.fmt = fmt
self.invalid_value = invalid_value
self.parse = parse or self._parse
self.unparse = unparse or self._unparse
self.in_range = in_range or self._in_range
self._size = None

@property
def size(self):
return struct.calcsize(self.fmt)
if self._size is None:
self._size = struct.calcsize(self.fmt)
return self._size

@property
def type_num(self):
@@ -254,6 +289,17 @@ def __repr__(self):
self.name, self.type_num, self.identifier,
)

def _parse(self, x):
return None if x == self.invalid_value else x

def _unparse(self, x):
return self.invalid_value if x is None else x

def _in_range(self, x):
# basic implementation for int types
return self.invalid_value if x.bit_length() > self.size * 8 else x



class FieldType(RecordBase):
__slots__ = ('name', 'base_type', 'values')
@@ -268,8 +314,51 @@ class MessageType(RecordBase):
def __repr__(self):
return '<MessageType: %s (#%d)>' % (self.name, self.mesg_num)

def get_field_and_subfield(self, name):
"""
Get field by name.
:rtype tuple(Field, SubField) or tuple(Field, None) or (None, None)
"""
for field in self.fields.values():
if field.is_named(name):
return (field, None)
if field.subfields:
subfield = next((f for f in field.subfields if f.is_named(name)), None)
if subfield:
return (field, subfield)

return (None, None)


class ScaleOffsetMixin(object):
"""Common methods for classes with scale and offset."""

def apply_scale_offset(self, raw_value):
if isinstance(raw_value, tuple):
# Contains multiple values, apply transformations to all of them
return tuple(self.apply_scale_offset(x) for x in raw_value)
elif isinstance(raw_value, num_types):
if self.scale:
raw_value = float(raw_value) / self.scale
if self.offset:
raw_value = raw_value - self.offset
return raw_value

class FieldAndSubFieldBase(RecordBase):
def unapply_scale_offset(self, value):
if isinstance(value, tuple):
# Contains multiple values, apply transformations to all of them
return tuple(self.unapply_scale_offset(x) for x in value)
elif isinstance(value, num_types):
if self.offset:
value = value + self.offset
if self.scale:
value = float(value) * self.scale
if isinstance(value, float):
value = int_type(round(value))
return value


class FieldAndSubFieldBase(RecordBase, ScaleOffsetMixin):
__slots__ = ()

@property
@@ -280,9 +369,26 @@ def base_type(self):
def is_base_type(self):
return isinstance(self.type, BaseType)

def __repr__(self):
return '<%s: %s (#%s) -- type: %s (%s)>' % (
self.__class__.__name__,
self.name,
self.def_num,
self.type.name,
self.base_type
)

def is_named(self, name):
return self.name == name or self.def_num == name

def render(self, raw_value):
if self.type.values and (raw_value in self.type.values):
return self.type.values[raw_value]
if self.type.values:
return self.type.values.get(raw_value, raw_value)
return raw_value

def unrender(self, raw_value):
if self.type.values:
return next((k for k, v in self.type.values.items() if v == raw_value), raw_value)
return raw_value


@@ -307,7 +413,7 @@ class ReferenceField(RecordBase):
__slots__ = ('name', 'def_num', 'value', 'raw_value')


class ComponentField(RecordBase):
class ComponentField(RecordBase, ScaleOffsetMixin):
__slots__ = ('name', 'def_num', 'scale', 'offset', 'units', 'accumulate', 'bits', 'bit_offset')
field_type = 'component'

@@ -382,28 +488,48 @@ def calculate(cls, byte_arr, crc=0):
def parse_string(string):
try:
end = string.index(0x00)
except TypeError: # Python 2 compat
except TypeError: # Python 2 compat
end = string.index('\x00')

return string[:end].decode('utf-8', errors='replace') or None


def unparse_string(string):
if string is None:
string = ''
sbytes = string.encode('utf-8', errors='replace') + b'\0'
return sbytes


_FLOAT32_INVALID_VALUE = struct.unpack('f', bytes(b'\xff' * 4))[0]
_FLOAT32_MIN = -3.4028235e+38
_FLOAT32_MAX = 3.4028235e+38
_FLOAT64_INVALID_VALUE = struct.unpack('d', bytes(b'\xff' * 8))[0]

# The default base type
BASE_TYPE_BYTE = BaseType(name='byte', identifier=0x0D, fmt='B', parse=lambda x: None if all(b == 0xFF for b in x) else x)
BASE_TYPE_BYTE = BaseType(name='byte', identifier=0x0D, fmt='B',
parse=lambda x: None if all(b == 0xFF for b in x) else x,
unparse=lambda x: b'\xFF' if x is None else x,
in_range=lambda x: x)

BASE_TYPES = {
0x00: BaseType(name='enum', identifier=0x00, fmt='B', parse=lambda x: None if x == 0xFF else x),
0x01: BaseType(name='sint8', identifier=0x01, fmt='b', parse=lambda x: None if x == 0x7F else x),
0x02: BaseType(name='uint8', identifier=0x02, fmt='B', parse=lambda x: None if x == 0xFF else x),
0x83: BaseType(name='sint16', identifier=0x83, fmt='h', parse=lambda x: None if x == 0x7FFF else x),
0x84: BaseType(name='uint16', identifier=0x84, fmt='H', parse=lambda x: None if x == 0xFFFF else x),
0x85: BaseType(name='sint32', identifier=0x85, fmt='i', parse=lambda x: None if x == 0x7FFFFFFF else x),
0x86: BaseType(name='uint32', identifier=0x86, fmt='I', parse=lambda x: None if x == 0xFFFFFFFF else x),
0x07: BaseType(name='string', identifier=0x07, fmt='s', parse=parse_string),
0x88: BaseType(name='float32', identifier=0x88, fmt='f', parse=lambda x: None if math.isnan(x) else x),
0x89: BaseType(name='float64', identifier=0x89, fmt='d', parse=lambda x: None if math.isnan(x) else x),
0x0A: BaseType(name='uint8z', identifier=0x0A, fmt='B', parse=lambda x: None if x == 0x0 else x),
0x8B: BaseType(name='uint16z', identifier=0x8B, fmt='H', parse=lambda x: None if x == 0x0 else x),
0x8C: BaseType(name='uint32z', identifier=0x8C, fmt='I', parse=lambda x: None if x == 0x0 else x),
0x00: BaseType(name='enum', identifier=0x00, fmt='B', invalid_value=0xFF),
0x01: BaseType(name='sint8', identifier=0x01, fmt='b', invalid_value=0x7F),
0x02: BaseType(name='uint8', identifier=0x02, fmt='B', invalid_value=0xFF),
0x83: BaseType(name='sint16', identifier=0x83, fmt='h', invalid_value=0x7FFF),
0x84: BaseType(name='uint16', identifier=0x84, fmt='H', invalid_value=0xFFFF),
0x85: BaseType(name='sint32', identifier=0x85, fmt='i', invalid_value=0x7FFFFFFF),
0x86: BaseType(name='uint32', identifier=0x86, fmt='I', invalid_value=0xFFFFFFFF),
0x07: BaseType(name='string', identifier=0x07, fmt='s', parse=parse_string, unparse=unparse_string, in_range=lambda x: x),
0x88: BaseType(name='float32', identifier=0x88, fmt='f', invalid_value=_FLOAT32_INVALID_VALUE,
parse=lambda x: None if math.isnan(x) else x,
in_range=lambda x: x if _FLOAT32_MIN < x < _FLOAT32_MAX else _FLOAT32_INVALID_VALUE),
0x89: BaseType(name='float64', identifier=0x89, fmt='d', invalid_value=_FLOAT64_INVALID_VALUE,
parse=lambda x: None if math.isnan(x) else x,
in_range=lambda x: x),
0x0A: BaseType(name='uint8z', identifier=0x0A, fmt='B', invalid_value=0x0),
0x8B: BaseType(name='uint16z', identifier=0x8B, fmt='H', invalid_value=0x0),
0x8C: BaseType(name='uint32z', identifier=0x8C, fmt='I', invalid_value=0x0),
0x0D: BASE_TYPE_BYTE,
}

34 changes: 32 additions & 2 deletions fitparse/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re

import datetime
import io
import re
from collections import Iterable


class FitParseError(ValueError):
@@ -16,13 +17,35 @@ class FitHeaderError(FitParseError):
pass


UTC_REFERENCE = datetime.datetime(1989, 12, 31) # timestamp for UTC 00:00 Dec 31 1989
METHOD_NAME_SCRUBBER = re.compile(r'\W|^(?=\d)')
UNIT_NAME_TO_FUNC_REPLACEMENTS = (
('/', ' per '),
('%', 'percent'),
('*', ' times '),
)


def fit_to_datetime(sec):
"""Convert FIT seconds to datetime."""
return UTC_REFERENCE + datetime.timedelta(seconds=sec)


def fit_from_datetime(dt):
"""Convert datetime to FIT seconds."""
return int((dt - UTC_REFERENCE).total_seconds())


def fit_semicircles_to_deg(sc):
"""Convert FIT semicircles to deg (for the GPS lat, long)."""
return sc * 180.0 / (2 ** 31)


def fit_deg_to_semicircles(deg):
"""Convert deg to FIT semicircles (for the GPS lat, long)."""
return int(deg / 180.0 * (2 ** 31))


def scrub_method_name(method_name, convert_units=False):
if convert_units:
for replace_from, replace_to in UNIT_NAME_TO_FUNC_REPLACEMENTS:
@@ -56,3 +79,10 @@ def fileish_open(fileish, mode):
else:
# Python 3 - file contents
return io.BytesIO(fileish)


def is_iterable(obj):
"""Check, if the obj is iterable but not string or bytes.
:rtype bool"""
# Speed: do not use iter() although it's more robust, see also https://stackoverflow.com/questions/1952464/
return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes))
8 changes: 2 additions & 6 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import sys

from fitparse import FitFile
from fitparse.processors import UTC_REFERENCE, StandardUnitsDataProcessor
from fitparse.processors import fit_to_datetime, StandardUnitsDataProcessor
from fitparse.records import BASE_TYPES, Crc
from fitparse.utils import FitEOFError, FitCRCError, FitHeaderError

@@ -68,10 +68,6 @@ def generate_fitfile(data=None, endian='<'):
return file_data + pack('<' + Crc.FMT, Crc.calculate(file_data))


def secs_to_dt(secs):
return datetime.datetime.utcfromtimestamp(secs + UTC_REFERENCE)


def testfile(filename):
return os.path.join(os.path.dirname(os.path.realpath(__file__)), 'files', filename)

@@ -99,7 +95,7 @@ def test_basic_file_with_one_record(self, endian='<'):
for field in ('serial_number', 3):
self.assertEqual(file_id.get_value(field), 558069241)
for field in ('time_created', 4):
self.assertEqual(file_id.get_value(field), secs_to_dt(723842606))
self.assertEqual(file_id.get_value(field), fit_to_datetime(723842606))
self.assertEqual(file_id.get(field).raw_value, 723842606)
for field in ('number', 5):
self.assertEqual(file_id.get_value(field), None)
168 changes: 168 additions & 0 deletions tests/test_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#!/usr/bin/env python
import copy
import datetime
import io
import os
import sys

from fitparse import FitFile
from fitparse.encoder import FitFileEncoder, DataMessageCreator

if sys.version_info >= (2, 7):
import unittest
else:
import unittest2 as unittest


def testfile(filename):
return os.path.join(os.path.dirname(os.path.realpath(__file__)), 'files', filename)


class FitFileEncoderTestCase(unittest.TestCase):

def test_header(self):
file = io.BytesIO()
with FitFileEncoder(file) as fwrite:
fwrite.finish()
buff = file.getvalue()
pass
self.assertTrue(fwrite.completed)
self.assertEqual(16, len(buff))

with FitFile(buff) as fread:
self.assertEqual(0, len(fread.messages))
self.assertEqual(fwrite.protocol_version, fread.protocol_version)
self.assertEqual(fwrite.profile_version, fread.profile_version)
self.assertEqual(fwrite.data_size, fread.data_size)

def test_basic_activity_create(self):
file = io.BytesIO()
# copy of written messages
messages = []
time_created = datetime.datetime(2017, 12, 13, 14, 15, 16)
with FitFileEncoder(file) as fwrite:
def write(mesg):
fwrite.write(mesg)
messages.append(copy.deepcopy(mesg.mesg))

mesg = DataMessageCreator('file_id')
mesg.set_values((
('serial_number', 123456),
('manufacturer', 'dynastream'),
('garmin_product', 'hrm1'), # test subfield
('type', 'activity'),
('time_created', time_created)
))
write(mesg)

mesg = DataMessageCreator('device_info')
mesg.set_values((
('manufacturer', 284),
('product', 1),
('product_name', 'unit test') # test string
))
write(mesg)

rec_mesg = DataMessageCreator('record', local_mesg_num=1)
rec_mesg.set_values((
('timestamp', time_created),
('altitude', 100),
('distance', 0)
))
write(rec_mesg)

rec_mesg2 = DataMessageCreator('record', local_mesg_num=2)
rec_mesg2.set_values((
('altitude', 102),
('distance', 2)
))
rec_mesg2.set_header_timestamp(time_created + datetime.timedelta(seconds=2))
write(rec_mesg2)

rec_mesg2.set_values((
('altitude', 40000), # out of sint16 range
('distance', 4)
))
rec_mesg2.set_header_timestamp(time_created + datetime.timedelta(seconds=4))
write(rec_mesg2)
messages[-1].get('altitude').value = None # to conform the assert

mesg = DataMessageCreator('session')
mesg.set_values((
('start_time', time_created),
('timestamp', time_created),
('total_distance', 20.5),
('total_ascent', 1234),
('total_descent', 654),
('total_elapsed_time', 3661.5),
('avg_altitude', 821),
('sport', 'cycling'),
('event', 'session'),
('event_type', 'start')
))
write(mesg)

fwrite.finish()
buff = file.getvalue()

with FitFile(buff) as fread:
rmessages = fread.messages

self._assert_messages(messages, rmessages)

def test_basic_activity_read_write(self):
# note: 'Activity.fit' has some useless definition messages
with FitFile(testfile('Activity.fit')) as fread:
messages = fread.messages

file = io.BytesIO()
with FitFileEncoder(file) as fwrite:
for m in messages:
# current encoder can do just basic fields
m.fields = [f for f in m.fields if f.field_def or FitFileEncoder._is_ts_field(f)]
# need to unset raw_value
for field_data in m.fields:
field_data.raw_value = None
fwrite.write(m)
fwrite.finish()
buff = file.getvalue()

with FitFile(buff) as fread:
messages_buff = fread.messages

self._assert_messages(messages, messages_buff)

def _assert_messages(self, expected, actual):
self.assertEqual(len(expected), len(actual), msg='#messages')
for emsg, amsg in zip(expected, actual):
self.assertEqual(emsg.name, amsg.name)
self._assert_message_headers(emsg.header, amsg.header)
self.assertEqual(self._get_header_ts(emsg.fields), self._get_header_ts(amsg.fields), msg='message: {} header timestamp'.format(emsg.name))
efields = self._filter_fields_for_test(emsg.fields)
afields = self._filter_fields_for_test(amsg.fields)
self.assertEqual(len(efields), len(afields), msg='message: {} #fields'.format(emsg.name))
for efield, afield in zip(efields, afields):
self.assertEqual(efield.name, afield.name, msg='message: {} field names'.format(emsg.name))
self.assertEqual(efield.value, afield.value,
msg='message: {}, field: {} values'.format(emsg.name, efield.name))

def _assert_message_headers(self, expected, actual):
self.assertEqual(expected.is_definition, actual.is_definition)
self.assertEqual(expected.is_developer_data, actual.is_developer_data)
self.assertEqual(expected.local_mesg_num, actual.local_mesg_num)
self.assertEqual(expected.time_offset, actual.time_offset)

@staticmethod
def _filter_fields_for_test(fields):
"""Take only base field for the test."""
return [f for f in fields if f.field_def]

@staticmethod
def _get_header_ts(fields):
"""Get timestamp related to the compressed header."""
field_data = next((f for f in fields if f.field_def is None and FitFileEncoder._is_ts_field(f)), None)
return field_data.value if field_data else None


if __name__ == '__main__':
unittest.main()
42 changes: 42 additions & 0 deletions tests/test_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python
import datetime
import sys

from fitparse import FitFileDataProcessor
from fitparse.profile import FIELD_TYPE_TIMESTAMP
from fitparse.records import FieldData

if sys.version_info >= (2, 7):
import unittest
else:
import unittest2 as unittest


class ProcessorsTestCase(unittest.TestCase):

def test_fitfiledataprocessor(self):
raw_value = 3600 + 60 + 1
fd = FieldData(
field_def=None,
field=FIELD_TYPE_TIMESTAMP,
parent_field=None,
value=raw_value,
raw_value=raw_value,
)
pr = FitFileDataProcessor()
# local_date_time
pr.process_type_local_date_time(fd)
self.assertEqual(datetime.datetime(1989, 12, 31, 1, 1, 1), fd.value)
pr.unparse_type_local_date_time(fd)
self.assertEqual(raw_value, fd.raw_value)
# localtime_into_day
fd.value = raw_value
fd.raw_value = None
pr.process_type_localtime_into_day(fd)
self.assertEqual(datetime.time(1, 1, 1), fd.value)
pr.unparse_type_localtime_into_day(fd)
self.assertEqual(raw_value, fd.raw_value)


if __name__ == '__main__':
unittest.main()
8 changes: 8 additions & 0 deletions tests/test_records.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

import sys

from fitparse import records
from fitparse.records import Crc

if sys.version_info >= (2, 7):
@@ -11,6 +12,13 @@


class RecordsTestCase(unittest.TestCase):

def test_string_parse(self):
sb = b'Test string\0'
s = records.parse_string(sb)
self.assertEqual('Test string', s)
self.assertEqual(sb, records.unparse_string(s))

def test_crc(self):
crc = Crc()
self.assertEqual(0, crc.value)
31 changes: 29 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#!/usr/bin/env python

import datetime
import io
import os
import sys
import tempfile

from fitparse.utils import fileish_open
from fitparse import utils
from fitparse.utils import fileish_open, is_iterable

if sys.version_info >= (2, 7):
import unittest
@@ -19,6 +20,22 @@ def testfile(filename):

class UtilsTestCase(unittest.TestCase):

def test_fit_to_datetime(self):
sec = 3600 + 60 + 1
dt = datetime.datetime(1989, 12, 31, 1, 1, 1)
self.assertEqual(dt, utils.fit_from_datetime(sec))
self.assertEqual(sec, utils.fit_to_datetime(dt))

def test_fit_semicircles_to_deg(self):
sc = 495280430
deg = 41.513926070183516
self.assertEqual(deg, utils.fit_semicircles_to_deg(sc))
self.assertEqual(sc, utils.fit_deg_to_semicircles(deg))
# test rounding errors
for i in range(100):
sc += 1
self.assertEqual(sc, utils.fit_deg_to_semicircles(utils.fit_semicircles_to_deg(sc)))

def test_fileish_open_read(self):
"""Test the constructor does the right thing when given different types
(specifically, test files with 8 characters, followed by an uppercase.FIT
@@ -61,6 +78,16 @@ def test_fopen(fileish):
except OSError:
pass

def test_is_iterable(self):
self.assertFalse(is_iterable(None))
self.assertFalse(is_iterable(1))
self.assertFalse(is_iterable('1'))
self.assertFalse(is_iterable(b'1'))

self.assertTrue(is_iterable((1, 2)))
self.assertTrue(is_iterable([1, 2]))
self.assertTrue(is_iterable(range(2)))


if __name__ == '__main__':
unittest.main()