Skip to content

Commit

Permalink
Merge pull request #21 from iomz/feat-imu-emg-agg
Browse files Browse the repository at this point in the history
feat: implement aggregate_all
  • Loading branch information
iomz authored Aug 10, 2023
2 parents 63451e2 + 253e6d3 commit 3944058
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 27 deletions.
21 changes: 14 additions & 7 deletions examples/sample_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import logging

from myo import MyoClient
from myo import AggregatedData, MyoClient
from myo.types import (
ClassifierEvent,
ClassifierMode,
Expand All @@ -23,15 +23,22 @@
class SampleClient(MyoClient):
async def on_classifier_event(self, ce: ClassifierEvent):
logging.info(ce.json())
pass

async def on_aggregated_data(self, ad: AggregatedData):
logging.info(ad.json())

async def on_emg_data(self, emg: EMGData):
logging.info(emg)
# logging.info(emg)
pass

async def on_fv_data(self, fvd: FVData):
logging.info(fvd.json())
# logging.info(fvd.json())
pass

async def on_imu_data(self, imu: IMUData):
logging.info(imu.json())
# logging.info(imu.json())
pass

async def on_motion_event(self, me: MotionEvent):
logging.info(me.json())
Expand All @@ -40,7 +47,7 @@ async def on_motion_event(self, me: MotionEvent):
async def main(args: argparse.Namespace):
logging.info("scanning for a Myo device...")

sc = await SampleClient.with_device(mac=args.mac)
sc = await SampleClient.with_device(mac=args.mac, aggregate_all=True)

# get the available services on the myo device
info = await sc.get_services()
Expand All @@ -49,8 +56,8 @@ async def main(args: argparse.Namespace):
# setup the MyoClient
await sc.setup(
classifier_mode=ClassifierMode.ENABLED,
emg_mode=EMGMode.SEND_FILT,
imu_mode=IMUMode.SEND_EVENTS,
emg_mode=EMGMode.SEND_FILT, # for aggregate_all
imu_mode=IMUMode.SEND_ALL, # for aggregate_all
)

# start the indicate/notify
Expand Down
7 changes: 6 additions & 1 deletion myo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
"""
from __future__ import absolute_import, annotations

from .core import Myo, MyoClient
from .core import (
AggregatedData,
EMGDataSingle,
Myo,
MyoClient,
)
from .profile import Handle
from .types import (
ClassifierEvent,
Expand Down
7 changes: 5 additions & 2 deletions myo/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ def __str__(self):
class SetMode(Command):
cmd = 0x01

def __init__(self, emg_mode, imu_mode, classifier_mode):
def __init__(self, classifier_mode, emg_mode, imu_mode):
self.classifier_mode = classifier_mode
self.emg_mode = emg_mode
self.imu_mode = imu_mode
self.classifier_mode = classifier_mode

@property
def payload(self) -> bytearray:
"""
notice that the payload requires the bytearray in this order
"""
return bytearray(
(
self.emg_mode.value,
Expand Down
121 changes: 105 additions & 16 deletions myo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,37 @@
logger = logging.getLogger(__name__)


# this is a custom data type for fv and imu
class AggregatedData:
def __init__(self, fvd: FVData, imu: IMUData):
self.fvd = fvd
self.imu = imu

def __str__(self):
return f"{repr(self.fvd)},{repr(self.imu)}"

def json(self):
return json.dumps(self.to_dict())

def to_dict(self):
return {"fvd": self.fvd.to_dict(), "imu": self.imu.to_dict()}


# this is just one sample in EMGData
class EMGDataSingle:
def __init__(self, data):
self.data = data

def __str__(self):
return str(self.data)

def json(self):
return json.dumps(self.to_dict())

def to_dict(self):
return {"data": self.data}


class Myo:
__slots__ = "_device"

Expand Down Expand Up @@ -137,12 +168,25 @@ async def led(self, client: BleakClient, *args):

await self.command(client, LED(args[0], args[1]))

async def set_mode(self, client: BleakClient, emg_mode, imu_mode, classifier_mode):
async def set_mode(
self,
client: BleakClient,
classifier_mode: ClassifierMode,
emg_mode: EMGMode,
imu_mode: IMUMode,
):
"""
Set Mode Command
- configures EMG, IMU, and Classifier modes
"""
await self.command(client, SetMode(emg_mode, imu_mode, classifier_mode))
await self.command(
client,
SetMode(
classifier_mode=classifier_mode,
emg_mode=emg_mode,
imu_mode=imu_mode,
),
)

async def set_sleep_mode(self, client: BleakClient, sleep_mode):
"""
Expand Down Expand Up @@ -182,17 +226,20 @@ async def write(self, client: BleakClient, handle, value):


class MyoClient:
def __init__(self):
def __init__(self, aggregate_all=False, aggregate_emg=False):
self.m = None
self.aggregate_emg = False
self.aggregate_all = aggregate_all
self.aggregate_emg = aggregate_emg
self.classifier_mode = None
self.emg_mode = None
self.imu_mode = None
self._client = None
self.fv_aggregated = None # for aggregate_all
self.imu_aggregated = None # for aggregate_all

@classmethod
async def with_device(cls, mac=None):
self = cls()
async def with_device(cls, mac=None, aggregate_all=False, aggregate_emg=False):
self = cls(aggregate_all=aggregate_all, aggregate_emg=aggregate_emg)
while self.m is None:
if mac and mac != "":
self.m = await Myo.with_mac(mac)
Expand Down Expand Up @@ -281,10 +328,32 @@ async def led(self, color):
async def on_classifier_event(self, ce: ClassifierEvent):
raise NotImplementedError()

async def on_data(self, data):
"""
for on_aggregated_data
data is either FVData or IMUData
"""
if isinstance(data, FVData):
self.fv_aggregated = data
elif isinstance(data, IMUData):
self.imu_aggregated = data
# trigger on_aggregated_data when both FVData and IMUData are ready
if all((self.fv_aggregated, self.imu_aggregated)):
await self.on_aggregated_data(AggregatedData(self.fv_aggregated, self.imu_aggregated))
self.fv_aggregated = None
self.imu_aggregated = None

async def on_aggregated_data(self, ad: AggregatedData):
"""
on_aggregated_data is invoked when both FVData and IMUData are ready.
it doesn't support EMGData since it is collected at different interval (200HZ instead of 50Hz)
"""
raise NotImplementedError()

async def on_emg_data(self, emg: EMGData): # data: list of 8 8-bit unsigned short
raise NotImplementedError()

async def on_emg_data_aggregated(self, data):
async def on_emg_data_aggregated(self, eds: EMGDataSingle):
"""
<> aggregate the raw EMG data channels
"""
Expand All @@ -303,14 +372,21 @@ async def notify_callback(self, sender: BleakGATTCharacteristic, data: bytearray
"""
<> invoke the on_* callbacks
"""

handle = Handle(sender.handle)
logger.debug(f"notify_callback ({handle}): {data}")
if handle == Handle.CLASSIFIER_EVENT:
await self.on_classifier_event(ClassifierEvent(data))
elif handle == Handle.FV_DATA:
await self.on_fv_data(FVData(data))
if self.aggregate_all:
await self.on_data(FVData(data))
else:
await self.on_fv_data(FVData(data))
elif handle == Handle.IMU_DATA:
await self.on_imu_data(IMUData(data))
if self.aggregate_all:
await self.on_data(IMUData(data))
else:
await self.on_imu_data(IMUData(data))
elif handle == Handle.MOTION_EVENT:
await self.on_motion_event(MotionEvent(data))
elif handle in [
Expand All @@ -321,17 +397,22 @@ async def notify_callback(self, sender: BleakGATTCharacteristic, data: bytearray
]:
emg = EMGData(data)
if self.aggregate_emg:
await self.on_emg_data_aggregated(emg.sample1)
await self.on_emg_data_aggregated(emg.sample2)
await self.on_emg_data_aggregated(EMGDataSingle(emg.sample1))
await self.on_emg_data_aggregated(EMGDataSingle(emg.sample2))
else:
await self.on_emg_data(emg)

async def set_mode(self, emg_mode, imu_mode, classifier_mode):
async def set_mode(self, classifier_mode: ClassifierMode, emg_mode: EMGMode, imu_mode: IMUMode):
"""
Set Mode Command
- configures EMG, IMU, and Classifier modes
"""
await self.m.set_mode(self._client, emg_mode, imu_mode, classifier_mode)
await self.m.set_mode(
client=self._client,
classifier_mode=classifier_mode,
emg_mode=emg_mode,
imu_mode=imu_mode,
)

async def set_sleep_mode(self, sleep_mode):
"""
Expand Down Expand Up @@ -362,10 +443,18 @@ async def setup(
self.emg_mode = emg_mode
self.imu_mode = imu_mode
self.classifier_mode = classifier_mode

# enforce the modes when aggregate_all
if self.aggregate_all and (
self.emg_mode != EMGMode.SEND_FILT or self.imu_mode in (IMUMode.NONE, IMUMode.SEND_EVENTS, IMUMode.SEND_RAW)
):
self.emg_mode = EMGMode.SEND_FILT
self.imu_mode = IMUMode.SEND_ALL

await self.set_mode(
emg_mode,
imu_mode,
classifier_mode,
classifier_mode=self.classifier_mode,
emg_mode=self.emg_mode,
imu_mode=self.imu_mode,
)
await self.led(RGB_PINK)

Expand Down
2 changes: 1 addition & 1 deletion myo/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ class Pose(Enum):
REST = 0x0000
FIST = 0x0001
WAVE_IN = 0x0002
WAVE_OUT = 0x003
WAVE_OUT = 0x0003
FINGERS_SPREAD = 0x0004
DOUBLE_TAP = 0x0005
UNKNOWN = 0xFFFF
Expand Down

0 comments on commit 3944058

Please sign in to comment.