Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement aggregate_all #21

Merged
merged 2 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions examples/sample_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import logging

from myo import MyoClient
from myo import AggregatedData, MyoClient
from myo.types import (
ClassifierEvent,
ClassifierMode,
Expand All @@ -23,15 +23,22 @@
class SampleClient(MyoClient):
async def on_classifier_event(self, ce: ClassifierEvent):
logging.info(ce.json())
pass

async def on_aggregated_data(self, ad: AggregatedData):
logging.info(ad.json())

async def on_emg_data(self, emg: EMGData):
logging.info(emg)
# logging.info(emg)
pass

async def on_fv_data(self, fvd: FVData):
logging.info(fvd.json())
# logging.info(fvd.json())
pass

async def on_imu_data(self, imu: IMUData):
logging.info(imu.json())
# logging.info(imu.json())
pass

async def on_motion_event(self, me: MotionEvent):
logging.info(me.json())
Expand All @@ -40,7 +47,7 @@ async def on_motion_event(self, me: MotionEvent):
async def main(args: argparse.Namespace):
logging.info("scanning for a Myo device...")

sc = await SampleClient.with_device(mac=args.mac)
sc = await SampleClient.with_device(mac=args.mac, aggregate_all=True)

# get the available services on the myo device
info = await sc.get_services()
Expand All @@ -49,8 +56,8 @@ async def main(args: argparse.Namespace):
# setup the MyoClient
await sc.setup(
classifier_mode=ClassifierMode.ENABLED,
emg_mode=EMGMode.SEND_FILT,
imu_mode=IMUMode.SEND_EVENTS,
emg_mode=EMGMode.SEND_FILT, # for aggregate_all
imu_mode=IMUMode.SEND_ALL, # for aggregate_all
)

# start the indicate/notify
Expand Down
7 changes: 6 additions & 1 deletion myo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
"""
from __future__ import absolute_import, annotations

from .core import Myo, MyoClient
from .core import (
AggregatedData,
EMGDataSingle,
Myo,
MyoClient,
)
from .profile import Handle
from .types import (
ClassifierEvent,
Expand Down
7 changes: 5 additions & 2 deletions myo/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@
class SetMode(Command):
cmd = 0x01

def __init__(self, emg_mode, imu_mode, classifier_mode):
def __init__(self, classifier_mode, emg_mode, imu_mode):
self.classifier_mode = classifier_mode

Check warning on line 33 in myo/commands.py

View check run for this annotation

Codecov / codecov/patch

myo/commands.py#L33

Added line #L33 was not covered by tests
self.emg_mode = emg_mode
self.imu_mode = imu_mode
self.classifier_mode = classifier_mode

@property
def payload(self) -> bytearray:
"""
notice that the payload requires the bytearray in this order
"""
return bytearray(
(
self.emg_mode.value,
Expand Down
121 changes: 105 additions & 16 deletions myo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,37 @@
logger = logging.getLogger(__name__)


# this is a custom data type for fv and imu
class AggregatedData:
def __init__(self, fvd: FVData, imu: IMUData):
self.fvd = fvd
self.imu = imu

def __str__(self):
return f"{repr(self.fvd)},{repr(self.imu)}"

def json(self):
return json.dumps(self.to_dict())

def to_dict(self):
return {"fvd": self.fvd.to_dict(), "imu": self.imu.to_dict()}


# this is just one sample in EMGData
class EMGDataSingle:
def __init__(self, data):
self.data = data

def __str__(self):
return str(self.data)

def json(self):
return json.dumps(self.to_dict())

def to_dict(self):
return {"data": self.data}


class Myo:
__slots__ = "_device"

Expand Down Expand Up @@ -137,12 +168,25 @@ async def led(self, client: BleakClient, *args):

await self.command(client, LED(args[0], args[1]))

async def set_mode(self, client: BleakClient, emg_mode, imu_mode, classifier_mode):
async def set_mode(
self,
client: BleakClient,
classifier_mode: ClassifierMode,
emg_mode: EMGMode,
imu_mode: IMUMode,
):
"""
Set Mode Command
- configures EMG, IMU, and Classifier modes
"""
await self.command(client, SetMode(emg_mode, imu_mode, classifier_mode))
await self.command(
client,
SetMode(
classifier_mode=classifier_mode,
emg_mode=emg_mode,
imu_mode=imu_mode,
),
)

async def set_sleep_mode(self, client: BleakClient, sleep_mode):
"""
Expand Down Expand Up @@ -182,17 +226,20 @@ async def write(self, client: BleakClient, handle, value):


class MyoClient:
def __init__(self):
def __init__(self, aggregate_all=False, aggregate_emg=False):
self.m = None
self.aggregate_emg = False
self.aggregate_all = aggregate_all
self.aggregate_emg = aggregate_emg
self.classifier_mode = None
self.emg_mode = None
self.imu_mode = None
self._client = None
self.fv_aggregated = None # for aggregate_all
self.imu_aggregated = None # for aggregate_all

@classmethod
async def with_device(cls, mac=None):
self = cls()
async def with_device(cls, mac=None, aggregate_all=False, aggregate_emg=False):
self = cls(aggregate_all=aggregate_all, aggregate_emg=aggregate_emg)
while self.m is None:
if mac and mac != "":
self.m = await Myo.with_mac(mac)
Expand Down Expand Up @@ -281,10 +328,32 @@ async def led(self, color):
async def on_classifier_event(self, ce: ClassifierEvent):
raise NotImplementedError()

async def on_data(self, data):
"""
for on_aggregated_data
data is either FVData or IMUData
"""
if isinstance(data, FVData):
self.fv_aggregated = data
elif isinstance(data, IMUData):
self.imu_aggregated = data
# trigger on_aggregated_data when both FVData and IMUData are ready
if all((self.fv_aggregated, self.imu_aggregated)):
await self.on_aggregated_data(AggregatedData(self.fv_aggregated, self.imu_aggregated))
self.fv_aggregated = None
self.imu_aggregated = None

async def on_aggregated_data(self, ad: AggregatedData):
"""
on_aggregated_data is invoked when both FVData and IMUData are ready.
it doesn't support EMGData since it is collected at different interval (200HZ instead of 50Hz)
"""
raise NotImplementedError()

async def on_emg_data(self, emg: EMGData): # data: list of 8 8-bit unsigned short
raise NotImplementedError()

async def on_emg_data_aggregated(self, data):
async def on_emg_data_aggregated(self, eds: EMGDataSingle):
"""
<> aggregate the raw EMG data channels
"""
Expand All @@ -303,14 +372,21 @@ async def notify_callback(self, sender: BleakGATTCharacteristic, data: bytearray
"""
<> invoke the on_* callbacks
"""

handle = Handle(sender.handle)
logger.debug(f"notify_callback ({handle}): {data}")
if handle == Handle.CLASSIFIER_EVENT:
await self.on_classifier_event(ClassifierEvent(data))
elif handle == Handle.FV_DATA:
await self.on_fv_data(FVData(data))
if self.aggregate_all:
await self.on_data(FVData(data))
else:
await self.on_fv_data(FVData(data))
elif handle == Handle.IMU_DATA:
await self.on_imu_data(IMUData(data))
if self.aggregate_all:
await self.on_data(IMUData(data))
else:
await self.on_imu_data(IMUData(data))
elif handle == Handle.MOTION_EVENT:
await self.on_motion_event(MotionEvent(data))
elif handle in [
Expand All @@ -321,17 +397,22 @@ async def notify_callback(self, sender: BleakGATTCharacteristic, data: bytearray
]:
emg = EMGData(data)
if self.aggregate_emg:
await self.on_emg_data_aggregated(emg.sample1)
await self.on_emg_data_aggregated(emg.sample2)
await self.on_emg_data_aggregated(EMGDataSingle(emg.sample1))
await self.on_emg_data_aggregated(EMGDataSingle(emg.sample2))
else:
await self.on_emg_data(emg)

async def set_mode(self, emg_mode, imu_mode, classifier_mode):
async def set_mode(self, classifier_mode: ClassifierMode, emg_mode: EMGMode, imu_mode: IMUMode):
"""
Set Mode Command
- configures EMG, IMU, and Classifier modes
"""
await self.m.set_mode(self._client, emg_mode, imu_mode, classifier_mode)
await self.m.set_mode(
client=self._client,
classifier_mode=classifier_mode,
emg_mode=emg_mode,
imu_mode=imu_mode,
)

async def set_sleep_mode(self, sleep_mode):
"""
Expand Down Expand Up @@ -362,10 +443,18 @@ async def setup(
self.emg_mode = emg_mode
self.imu_mode = imu_mode
self.classifier_mode = classifier_mode

# enforce the modes when aggregate_all
if self.aggregate_all and (
self.emg_mode != EMGMode.SEND_FILT or self.imu_mode in (IMUMode.NONE, IMUMode.SEND_EVENTS, IMUMode.SEND_RAW)
):
self.emg_mode = EMGMode.SEND_FILT
self.imu_mode = IMUMode.SEND_ALL

await self.set_mode(
emg_mode,
imu_mode,
classifier_mode,
classifier_mode=self.classifier_mode,
emg_mode=self.emg_mode,
imu_mode=self.imu_mode,
)
await self.led(RGB_PINK)

Expand Down
2 changes: 1 addition & 1 deletion myo/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ class Pose(Enum):
REST = 0x0000
FIST = 0x0001
WAVE_IN = 0x0002
WAVE_OUT = 0x003
WAVE_OUT = 0x0003
FINGERS_SPREAD = 0x0004
DOUBLE_TAP = 0x0005
UNKNOWN = 0xFFFF
Expand Down