Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Language Detector task #22

Merged
merged 23 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
10990cf
Adds mediapipe_core package (#11)
craiglabenz Nov 13, 2023
7850d32
Add utility to collect headers from google/mediapipe (#10)
craiglabenz Nov 13, 2023
6a7dacb
[FFI] MediaPipe SDKs finder automation (#16)
craiglabenz Jan 12, 2024
8980132
Adds mediapipe_text package (#12)
craiglabenz Mar 11, 2024
4c159a7
Native Assets CI fix (#20)
craiglabenz Mar 12, 2024
4304801
Text Embedding task (#21)
craiglabenz Apr 3, 2024
3fc2d11
updated and re-ran generators
craiglabenz Mar 11, 2024
e01f26b
fixed embedding header file and bindings
craiglabenz Mar 11, 2024
ce42365
adds text embedding classes to text pkg
craiglabenz Mar 11, 2024
d36976b
moved worker dispose method to base class
craiglabenz Mar 25, 2024
f0344b9
class hierarchy improvements
craiglabenz Apr 1, 2024
992332b
cleaned up dispose methods
craiglabenz Apr 1, 2024
4ac6d20
initial commit of language detection task
craiglabenz Apr 1, 2024
b9b6c89
finishes language detection impl
craiglabenz Apr 2, 2024
bed398a
adds language detection demo
craiglabenz Apr 2, 2024
73bada3
backfilling improvements to classification and embedding
craiglabenz Apr 2, 2024
a392f83
adds language detection tests
craiglabenz Apr 2, 2024
4b742a3
add new model download to CI script
craiglabenz Apr 3, 2024
65cc969
fixes stale classification widget test, adds language detection widge…
craiglabenz Apr 3, 2024
078fe62
Merge branch 'main' into ffi-wrapper-language-detection
craiglabenz Apr 22, 2024
d81b148
copied latest headers from mediapipe
craiglabenz Apr 22, 2024
d3ac430
Update packages/mediapipe-task-text/lib/src/io/tasks/language_detecti…
craiglabenz May 12, 2024
fad589c
comments / documentation improvements
craiglabenz May 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ headers:
models:
cd tool/builder && dart bin/main.dart model -m textclassification
cd tool/builder && dart bin/main.dart model -m textembedding

cd tool/builder && dart bin/main.dart model -m languagedetection

# Runs `ffigen` for all packages
generate: generate_core generate_text
Expand Down
2 changes: 1 addition & 1 deletion packages/mediapipe-core/lib/universal_mediapipe_core.dart
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Classifications extends BaseClassifications {
});

@override
Iterable<BaseCategory> get categories => throw UnimplementedError();
Iterable<Category> get categories => throw UnimplementedError();

@override
int get headIndex => throw UnimplementedError();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ struct Classifications {
// The array of predicted categories, usually sorted by descending scores,
// e.g. from high to low probability.
struct Category* categories;

// The number of elements in the categories array.
uint32_t categories_count;

Expand Down Expand Up @@ -58,7 +57,6 @@ struct ClassificationResult {
// exceed the maximum size that the model can process: to solve this, the
// input data is split into multiple chunks starting at different timestamps.
int64_t timestamp_ms;

// Specifies whether the timestamp contains a valid value.
bool has_timestamp_ms;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,13 @@ struct ClassifierOptions {
// category name is not in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_denylist.
const char** category_allowlist;

// The number of elements in the category allowlist.
uint32_t category_allowlist_count;

// The denylist of category names. If non-empty, detection results whose
// category name is in this set will be filtered out. Duplicate or unknown
// category names are ignored. Mutually exclusive with category_allowlist.
const char** category_denylist;

// The number of elements in the category denylist.
uint32_t category_denylist_count;
};
Expand Down
26 changes: 24 additions & 2 deletions packages/mediapipe-task-text/example/lib/enumerate.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,34 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

extension Enumeratable<T> on List<T> {
Iterable<S> enumerate<S>(S Function(T, int) fn) sync* {
extension EnumeratableList<T> on List<T> {
/// Invokes the callback on each element of the list, optionally stopping
/// after [max] (inclusive) invocations.
Iterable<S> enumerate<S>(S Function(T, int) fn, {int? max}) sync* {
int count = 0;
while (count < length) {
yield fn(this[count], count);
count++;

if (max != null && count >= max) {
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
return;
}
}
}
}

extension EnumeratableIterable<T> on Iterable<T> {
/// Invokes the callback on each element of the iterable, optionally stopping
/// after [max] (inclusive) invocations.
Iterable<S> enumerate<S>(S Function(T, int) fn, {int? max}) sync* {
int count = 0;
for (final T obj in this) {
yield fn(obj, count);
count++;

if (max != null && count >= max) {
craiglabenz marked this conversation as resolved.
Show resolved Hide resolved
return;
}
}
}
}
158 changes: 158 additions & 0 deletions packages/mediapipe-task-text/example/lib/language_detection_demo.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright 2014 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

import 'dart:async';
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:getwidget/getwidget.dart';
import 'package:mediapipe_text/mediapipe_text.dart';
import 'enumerate.dart';

class LanguageDetectionDemo extends StatefulWidget {
const LanguageDetectionDemo({super.key, this.detector});

final LanguageDetector? detector;

@override
State<LanguageDetectionDemo> createState() => _LanguageDetectionDemoState();
}

class _LanguageDetectionDemoState extends State<LanguageDetectionDemo>
with AutomaticKeepAliveClientMixin<LanguageDetectionDemo> {
final TextEditingController _controller = TextEditingController();
final Completer<LanguageDetector> _completer = Completer<LanguageDetector>();
final results = <Widget>[];
String? _isProcessing;

@override
void initState() {
super.initState();
_controller.text = 'Quiero agua, por favor';
_initDetector();
}

Future<void> _initDetector() async {
if (widget.detector != null) {
return _completer.complete(widget.detector!);
}

ByteData? bytes = await DefaultAssetBundle.of(context)
.load('assets/language_detector.tflite');

final detector = LanguageDetector(
LanguageDetectorOptions.fromAssetBuffer(
bytes.buffer.asUint8List(),
),
);
_completer.complete(detector);
bytes = null;
}

void _prepareForDetection() {
setState(() {
_isProcessing = _controller.text;
results.add(const CircularProgressIndicator.adaptive());
});
}

Future<void> _detect() async {
_prepareForDetection();
_completer.future.then((detector) async {
final result = await detector.detect(_controller.text);
_showDetectionResults(result);
result.dispose();
});
}

void _showDetectionResults(LanguageDetectorResult result) {
setState(
() {
results.last = Card(
key: Key('prediction-"$_isProcessing" ${results.length}'),
margin: const EdgeInsets.all(10),
child: Column(
children: [
Padding(
padding: const EdgeInsets.all(10),
child: Text(_isProcessing!),
),
Padding(
padding: const EdgeInsets.all(10.0),
child: Wrap(
children: <Widget>[
...result.predictions
.enumerate<Widget>(
(prediction, index) => _languagePrediction(
prediction,
predictionColors[index],
),
// Take first 4 because the model spits out dozens of
// astronomically low probability language predictions
max: predictionColors.length,
)
.toList(),
],
),
),
],
),
);
_isProcessing = null;
},
);
}

static final predictionColors = <Color>[
Colors.blue[300]!,
Colors.orange[300]!,
Colors.green[300]!,
Colors.red[300]!,
];

Widget _languagePrediction(LanguagePrediction prediction, Color color) {
return Padding(
padding: const EdgeInsets.only(right: 8),
child: GFButton(
onPressed: null,
text: '${prediction.languageCode} :: '
'${prediction.probability.roundTo(8)}',
shape: GFButtonShape.pills,
color: color,
),
);
}

@override
Widget build(BuildContext context) {
super.build(context);
return Scaffold(
body: SafeArea(
child: Padding(
padding: const EdgeInsets.all(16.0),
child: SingleChildScrollView(
child: Column(
children: <Widget>[
TextField(controller: _controller),
...results.reversed,
],
),
),
),
),
floatingActionButton: FloatingActionButton(
onPressed:
_isProcessing != null && _controller.text != '' ? null : _detect,
child: const Icon(Icons.search),
),
);
}

@override
bool get wantKeepAlive => true;
}

extension on double {
double roundTo(int decimalPlaces) =>
double.parse(toStringAsFixed(decimalPlaces));
}
50 changes: 31 additions & 19 deletions packages/mediapipe-task-text/example/lib/main.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import 'package:flutter/material.dart';
import 'language_detection_demo.dart';
import 'logging.dart';
import 'text_classification_demo.dart';
import 'text_embedding_demo.dart';
Expand Down Expand Up @@ -31,7 +32,7 @@ class TextTaskPages extends StatefulWidget {
class TextTaskPagesState extends State<TextTaskPages> {
final PageController controller = PageController();

final titles = <String>['Classify', 'Embed'];
final titles = <String>['Classify', 'Embed', 'Detect Languages'];
int titleIndex = 0;

void switchToPage(int index) {
Expand Down Expand Up @@ -61,28 +62,39 @@ class TextTaskPagesState extends State<TextTaskPages> {
children: const <Widget>[
TextClassificationDemo(),
TextEmbeddingDemo(),
LanguageDetectionDemo(),
],
),
bottomNavigationBar: ColoredBox(
color: Colors.blueGrey,
child: Row(
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: <Widget>[
TextButton(
onPressed: () => switchToPage(0),
child: Text(
'Classify',
style: titleIndex == 0 ? activeTextStyle : inactiveTextStyle,
bottomNavigationBar: SizedBox(
height: 50,
child: ColoredBox(
color: Colors.blueGrey,
child: Row(
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: <Widget>[
TextButton(
onPressed: () => switchToPage(0),
child: Text(
'Classify',
style: titleIndex == 0 ? activeTextStyle : inactiveTextStyle,
),
),
),
TextButton(
onPressed: () => switchToPage(1),
child: Text(
'Embed',
style: titleIndex == 1 ? activeTextStyle : inactiveTextStyle,
TextButton(
onPressed: () => switchToPage(1),
child: Text(
'Embed',
style: titleIndex == 1 ? activeTextStyle : inactiveTextStyle,
),
),
),
],
TextButton(
onPressed: () => switchToPage(2),
child: Text(
'Detect Languages',
style: titleIndex == 2 ? activeTextStyle : inactiveTextStyle,
),
),
],
),
),
),
);
Expand Down
Loading