-
Notifications
You must be signed in to change notification settings - Fork 1
/
leaf_sent140.py
136 lines (109 loc) · 3.8 KB
/
leaf_sent140.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.utils.data as torchdata
from ..models import nn as mnn
from ..models.utils import top_n_accuracy # noqa: F401
from ..utils.const import CACHED_DATA_DIR
from ._register import register_fed_dataset # noqa: F401
from .fed_dataset import FedNLPDataset # noqa: F401
__all__ = [
"LeafSent140",
]
LEAF_SENT140_DATA_DIR = CACHED_DATA_DIR / "leaf_sent140"
LEAF_SENT140_DATA_DIR.mkdir(parents=True, exist_ok=True)
# @register_fed_dataset()
class LeafSent140(FedNLPDataset):
"""Federeated Sentiment140 dataset from Leaf.
Sentiment140 dataset [1]_ is built from the tweets of Twitter
and is used to perform sentiment analysis tasks. The Leaf library [2]_
further processed the data.
Parameters
----------
datadir : Union[pathlib.Path, str], optional
Directory to store data.
If ``None``, use default directory.
seed : int, default 0
Random seed for data partitioning.
**extra_config : dict, optional
Extra configurations.
References
----------
.. [1] http://help.sentiment140.com
.. [2] https://github.com/TalwalkarLab/leaf/tree/master/data/sent140
"""
__name__ = "LeafSent140"
def _preload(self, datadir: Optional[Union[str, Path]] = None) -> None:
"""Preload the dataset.
Parameters
----------
datadir : Union[pathlib.Path, str], optional
Directory to store data.
If ``None``, use default directory.
Returns
-------
None
"""
self.criterion = torch.nn.CrossEntropyLoss()
raise NotImplementedError
def get_dataloader(
self,
train_bs: Optional[int] = None,
test_bs: Optional[int] = None,
client_idx: Optional[int] = None,
) -> Tuple[torchdata.DataLoader, torchdata.DataLoader]:
"""Get local dataloader at client `client_idx` or get the global dataloader.
Parameters
----------
train_bs : int, optional
Batch size for training dataloader.
If ``None``, use default batch size.
test_bs : int, optional
Batch size for testing dataloader.
If ``None``, use default batch size.
client_idx : int, optional
Index of the client to get dataloader.
If ``None``, get the dataloader containing all data.
Usually used for centralized training.
Returns
-------
train_dl : :class:`torch.utils.data.DataLoader`
Training dataloader.
test_dl : :class:`torch.utils.data.DataLoader`
Testing dataloader.
"""
raise NotImplementedError
def evaluate(self, probs: torch.Tensor, truths: torch.Tensor) -> Dict[str, float]:
"""Evaluation using predictions and ground truth.
Parameters
----------
probs : torch.Tensor
Predicted probabilities.
truths : torch.Tensor
Ground truth labels.
Returns
-------
Dict[str, float]
Evaluation results.
"""
return {
"acc": top_n_accuracy(probs, truths, 1),
"loss": self.criterion(probs, truths).item(),
"num_samples": probs.shape[0],
}
@property
def url(self) -> str:
"""URL for downloading the dataset."""
return "http://cs.stanford.edu/people/alecmgo/trainingandtestdata.zip"
@property
def candidate_models(self) -> Dict[str, torch.nn.Module]:
"""A set of candidate models."""
return {
"rnn": mnn.RNN_Sent140(),
}
@property
def doi(self) -> List[str]:
"""DOI(s) related to the dataset."""
return [
"10.48550/ARXIV.1812.01097", # LEAF
]