Skip to content

Commit

Permalink
Resolved comments PR#44: Added MutoxConfig opt layer, style changes, …
Browse files Browse the repository at this point in the history
…repo decoupling, other
  • Loading branch information
David-OC17 committed Nov 9, 2024
1 parent 60f1816 commit 55a342d
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 50 deletions.
22 changes: 19 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,19 @@ Detailed model cards with more examples: [facebook/blaser-2.0-ref](https://huggi

### Classifying the toxicity of sentences with MuTox

[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR.
[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based classifier (binary) and dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. The output of the MuTox classifier is a probability of the evaluated being _"toxic"_, according to the definition adopted in the corresponding dataset.

```Python
from sonar.models.mutox.loader import load_mutox_model
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
import torch

if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

t2vec_model = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
Expand All @@ -157,14 +165,22 @@ t2vec_model = TextToEmbeddingModelPipeline(
)
text_column='lang_txt'
classifier = load_mutox_model(
"mutox",
"sonar_mutox",
device=device,
dtype=dtype,
).eval()

with torch.inference_mode():
emb = t2vec_model.predict(["De peur que le pays ne se prostitue et ne se remplisse de crimes."], source_lang='fra_Latn')
x = classifier(emb.to(device).half()) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)
x = classifier(emb.to(device).to(dtype)) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)

with torch.inference_mode():
emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='fra_Latn')
x = classifier(emb.to(device).to(dtype)) # tensor([[-58.0625]], device='cuda:0', dtype=torch.float16)

with torch.inference_mode():
emb = t2vec_model.predict(["El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."], source_lang='spa_Latn')
x = classifier(emb.to(device).to(dtype)) # tensor([[-24.6094]], device='cuda:0', dtype=torch.float16)
```

For a CLI way of running the MuTox pipeline, go to [Seamless Communication/.../MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox).
Expand Down
15 changes: 11 additions & 4 deletions examples/mutox_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,15 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sonar.inference_pipelines.speech import SpeechInferenceParams\n",
"from sonar.inference_pipelines.mutox_speech import MutoxSpeechClassifierPipeline\n",
"\n",
"pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n",
" mutox_classifier_name =\"mutox\",\n",
" mutox_classifier_name =\"sonar_mutox\",\n",
" encoder_name=f\"sonar_speech_encoder_eng\",\n",
" device=device,\n",
")"
Expand All @@ -116,6 +116,13 @@
"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note:** This model was trained using a \"Binary Cross Entropy loss with logits\" objective (as described in the paper). To convert the model's output into probabilities, apply a sigmoid function to the output.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
Expand Down Expand Up @@ -162,7 +169,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -184,7 +191,7 @@
")\n",
"text_column='lang_txt'\n",
"classifier = load_mutox_model(\n",
" \"mutox\",\n",
" \"sonar_mutox\",\n",
" device=device,\n",
" dtype=dtype,\n",
").eval()"
Expand Down
11 changes: 11 additions & 0 deletions sonar/cards/sonar_mutox.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

name: sonar_mutox
model_type: mutox_classifier
model_arch: mutox
checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/mutox.pt"
input_size: 1024
21 changes: 12 additions & 9 deletions sonar/inference_pipelines/mutox_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,19 @@
# MIT_LICENSE file in the root directory of this source tree.

from typing import Union
import torch

import torch
from fairseq2.data import DataPipelineBuilder
from fairseq2.typing import Device
from fairseq2.data import (
DataPipelineBuilder,
)

from sonar.models.sonar_speech.loader import load_sonar_speech_model
from sonar.models.encoder_model import SonarEncoderModel
from sonar.inference_pipelines.speech import (
SpeechToEmbeddingPipeline,
SpeechInferenceParams,
SpeechToEmbeddingPipeline,
)

from sonar.models.encoder_model import SonarEncoderModel
from sonar.models.mutox.classifier import MutoxClassifier
from sonar.models.mutox.loader import load_mutox_model
from sonar.models.sonar_speech.loader import load_sonar_speech_model

CPU_DEVICE = torch.device("cpu")

Expand All @@ -32,7 +29,13 @@ def __init__(
encoder: Union[str, SonarEncoderModel],
device: Device = CPU_DEVICE,
) -> None:
super().__init__(encoder)
if isinstance(encoder, str):
model = self.load_model_from_name("sonar_mutox", encoder, device=device)
else:
model = encoder

super().__init__(model)

self.model.to(device).eval()
self.mutox_classifier = mutox_classifier.to(device).eval()

Expand Down
2 changes: 1 addition & 1 deletion sonar/models/mutox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
# MIT_LICENSE file in the root directory of this source tree.
24 changes: 12 additions & 12 deletions sonar/models/mutox/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@
# MIT_LICENSE file in the root directory of this source tree.

import typing as tp

import torch
from torch import nn
from fairseq2.typing import DataType, Device
from torch import nn

from .classifier import (
MutoxClassifier,
MutoxConfig,
)
from .classifier import MutoxClassifier, MutoxConfig


class MutoxClassifierBuilder:
Expand Down Expand Up @@ -42,29 +40,31 @@ def __init__(
self.config = config
self.device, self.dtype = device, dtype

def build_model(self) -> MutoxClassifier:
def build_model(self, activation=nn.ReLU) -> MutoxClassifier:
model_h1 = nn.Sequential(
nn.Dropout(0.01),
nn.Linear(self.config.input_size, 512),
)

model_h2 = nn.Sequential(
nn.ReLU(),
activation,
nn.Linear(512, 128),
)

model_h3 = nn.Sequential(
nn.ReLU(),
nn.Linear(128, 1),
)
if self.config.output_prob:
model_h3 = nn.Sequential(activation(), nn.Linear(128, 1), nn.Sigmoid())
else:
model_h3 = nn.Sequential(activation(), nn.Linear(128, 1))

model_all = nn.Sequential(
model_h1,
model_h2,
model_h3,
)

return MutoxClassifier(model_all,).to(
return MutoxClassifier(
model_all,
).to(
device=self.device,
dtype=self.dtype,
)
Expand Down
12 changes: 7 additions & 5 deletions sonar/models/mutox/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
# MIT_LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
import torch
from torch import nn
from typing import Optional

from fairseq2.typing import DataType, Device
import torch
from fairseq2.models.utils.arch_registry import ArchitectureRegistry

from typing import Optional
from fairseq2.typing import DataType, Device
from torch import nn


class MutoxClassifier(nn.Module):
Expand All @@ -33,5 +32,8 @@ class MutoxConfig:
# size of the input embedding supported by this model
input_size: int

# add sigmoid as last layer to output probability
output_prob: bool = False


mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier")
8 changes: 3 additions & 5 deletions sonar/models/mutox/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
from fairseq2.models.utils import ConfigLoader, ModelLoader

from .builder import create_mutox_model
from .classifier import (
MutoxClassifier,
MutoxConfig,
mutox_archs,
)
from .classifier import MutoxClassifier, MutoxConfig, mutox_archs

__import__("sonar") # Import only to update asset_store


@mutox_archs.decorator("mutox")
Expand Down
35 changes: 24 additions & 11 deletions tests/unit_tests/test_mutox.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
import pytest
import torch
from torch import nn
from unittest.mock import Mock

from sonar.models.mutox.builder import MutoxConfig, MutoxClassifierBuilder, create_mutox_model
from sonar.models.mutox.classifier import MutoxClassifier
from sonar.models.mutox.loader import (
convert_mutox_checkpoint,
from sonar.models.mutox.builder import (
MutoxClassifierBuilder,
MutoxConfig,
create_mutox_model,
)
from sonar.models.mutox.classifier import MutoxClassifier
from sonar.models.mutox.loader import convert_mutox_checkpoint

# Builder tests


@pytest.mark.parametrize("input_size", [256, 512, 1024])
@pytest.mark.parametrize("device", [torch.device("cpu")])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
Expand Down Expand Up @@ -52,9 +54,10 @@ def test_create_mutox_model(input_size):

# Classifier tests


def test_mutox_classifier_forward():
"""Test that MutoxClassifier forward pass returns expected output shape."""
test_model= nn.Sequential(
test_model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 1),
Expand All @@ -63,30 +66,40 @@ def test_mutox_classifier_forward():

test_input = torch.randn(3, 10)
output = model(test_input)
assert output.shape == (3, 1), f"Expected output shape (3, 1), but instead got {output.shape}"
assert output.shape == (
3,
1,
), f"Expected output shape (3, 1), but instead got {output.shape}"


def test_mutox_config():
"""Test that MutoxConfig stores the configuration for a model."""
config = MutoxConfig(input_size=512)
assert config.input_size == 512, f"Config input_size should be 512, but got {config.input_size}"
assert (
config.input_size == 512
), f"Config input_size should be 512, but got {config.input_size}"


# Loader tests


def test_convert_mutox_checkpoint():
"""Test convert_mutox_checkpoint correctly filters keys in the checkpoint."""
# Create a mock checkpoint with both 'model_all.' prefixed keys and other keys
checkpoint = {
"model_all.layer1.weight": torch.tensor([1.0]),
"model_all.layer1.bias": torch.tensor([0.5]),
"non_model_key": torch.tensor([3.0])
"non_model_key": torch.tensor([3.0]),
}
config = MutoxConfig(input_size=1024)
converted = convert_mutox_checkpoint(checkpoint, config)

# Verify only 'model_all.' keys are retained in the converted dictionary
assert "model" in converted, "Converted checkpoint should contain a 'model' key"
assert "model_all.layer1.weight" in converted["model"], "Expected 'model_all.layer1.weight'"
assert "model_all.layer1.bias" in converted["model"], "Expected 'model_all.layer1.bias'"
assert (
"model_all.layer1.weight" in converted["model"]
), "Expected 'model_all.layer1.weight'"
assert (
"model_all.layer1.bias" in converted["model"]
), "Expected 'model_all.layer1.bias'"
assert "non_model_key" not in converted["model"], "Unexpected 'non_model_key'"

0 comments on commit 55a342d

Please sign in to comment.