@@ -8,6 +8,7 @@ def _valid_column(column_name):
8
8
9
9
class _PandasPlotter :
10
10
"""Base class for pandas plotting."""
11
+
11
12
@classmethod
12
13
def create (cls , data ):
13
14
if isinstance (data , pd .Series ):
@@ -20,6 +21,7 @@ def create(cls, data):
20
21
21
22
class _SeriesPlotter (_PandasPlotter ):
22
23
"""Functionality for plotting of pandas Series."""
24
+
23
25
def __init__ (self , data ):
24
26
if not isinstance (data , pd .Series ):
25
27
raise ValueError (f"data: expected pd.Series; got { type (data )} " )
@@ -32,7 +34,8 @@ def _preprocess_data(self, with_index=True):
32
34
if isinstance (data .index , pd .MultiIndex ):
33
35
data = data .copy ()
34
36
data .index = pd .Index (
35
- [str (i ) for i in data .index ], name = data .index .name )
37
+ [str (i ) for i in data .index ], name = data .index .name
38
+ )
36
39
data = data .reset_index ()
37
40
else :
38
41
data = data .to_frame ()
@@ -41,25 +44,29 @@ def _preprocess_data(self, with_index=True):
41
44
42
45
def _xy (self , mark , ** kwargs ):
43
46
data = self ._preprocess_data (with_index = True )
44
- return alt .Chart (data , mark = mark ).encode (
45
- x = alt .X (data .columns [0 ], title = None ),
46
- y = alt .Y (data .columns [1 ], title = None ),
47
- tooltip = list (data .columns )
48
- ).interactive ()
47
+ return (
48
+ alt .Chart (data , mark = mark )
49
+ .encode (
50
+ x = alt .X (data .columns [0 ], title = None ),
51
+ y = alt .Y (data .columns [1 ], title = None ),
52
+ tooltip = list (data .columns ),
53
+ )
54
+ .interactive ()
55
+ )
49
56
50
57
def line (self , ** kwargs ):
51
- return self ._xy (' line' , ** kwargs )
58
+ return self ._xy (" line" , ** kwargs )
52
59
53
60
def bar (self , ** kwargs ):
54
- return self ._xy ({' type' : ' bar' , ' orient' : ' vertical' }, ** kwargs )
61
+ return self ._xy ({" type" : " bar" , " orient" : " vertical" }, ** kwargs )
55
62
56
63
def barh (self , ** kwargs ):
57
- chart = self ._xy ({' type' : ' bar' , ' orient' : ' horizontal' }, ** kwargs )
64
+ chart = self ._xy ({" type" : " bar" , " orient" : " horizontal" }, ** kwargs )
58
65
chart .encoding .x , chart .encoding .y = chart .encoding .y , chart .encoding .x
59
66
return chart
60
67
61
68
def area (self , ** kwargs ):
62
- return self ._xy (mark = ' area' , ** kwargs )
69
+ return self ._xy (mark = " area" , ** kwargs )
63
70
64
71
def scatter (self , ** kwargs ):
65
72
raise ValueError ("kind='scatter' can only be used for DataFrames." )
@@ -71,23 +78,28 @@ def hist(self, bins=None, **kwargs):
71
78
bins = alt .Bin (maxbins = bins )
72
79
elif bins is None :
73
80
bins = True
74
- return alt .Chart (data ).mark_bar ().encode (
75
- x = alt .X (column , title = None , bin = bins ),
76
- y = alt .Y ('count()' , title = 'Frequency' ),
81
+ return (
82
+ alt .Chart (data )
83
+ .mark_bar ()
84
+ .encode (
85
+ x = alt .X (column , title = None , bin = bins ),
86
+ y = alt .Y ("count()" , title = "Frequency" ),
87
+ )
77
88
)
78
89
79
90
def box (self , ** kwargs ):
80
91
data = self ._preprocess_data (with_index = False )
81
- return alt . Chart ( data ). transform_fold (
82
- list (data . columns ), as_ = [ 'column' , 'value' ]
83
- ). mark_boxplot (). encode (
84
- x = alt . X ( 'column:N' , title = None ),
85
- y = ' value:Q' ,
92
+ return (
93
+ alt . Chart (data )
94
+ . transform_fold ( list ( data . columns ), as_ = [ "column" , "value" ])
95
+ . mark_boxplot ()
96
+ . encode ( x = alt . X ( "column:N" , title = None ), y = " value:Q" )
86
97
)
87
98
88
99
89
100
class _DataFramePlotter (_PandasPlotter ):
90
101
"""Functionality for plotting of pandas DataFrames."""
102
+
91
103
def __init__ (self , data ):
92
104
if not isinstance (data , pd .DataFrame ):
93
105
raise ValueError (f"data: expected pd.DataFrame; got { type (data )} " )
@@ -100,7 +112,8 @@ def _preprocess_data(self, with_index=True, usecols=None):
100
112
if with_index :
101
113
if isinstance (data .index , pd .MultiIndex ):
102
114
data .index = pd .Index (
103
- [str (i ) for i in data .index ], name = data .index .name )
115
+ [str (i ) for i in data .index ], name = data .index .name
116
+ )
104
117
return data .reset_index ()
105
118
return data
106
119
@@ -120,82 +133,80 @@ def _xy(self, mark, x=None, y=None, **kwargs):
120
133
assert y in data .columns
121
134
y_values = [y ]
122
135
123
- return alt . Chart (
124
- data ,
125
- mark = mark
126
- ). transform_fold (
127
- y_values , as_ = [ 'column' , 'value' ]
128
- ). encode (
129
- x = x ,
130
- y = alt . Y ( 'value:Q' , title = None ) ,
131
- color = alt . Color ( 'column:N' , title = None ),
132
- tooltip = [ x ] + y_values ,
133
- ). interactive ()
136
+ return (
137
+ alt . Chart ( data , mark = mark )
138
+ . transform_fold ( y_values , as_ = [ "column" , "value" ])
139
+ . encode (
140
+ x = x ,
141
+ y = alt . Y ( "value:Q" , title = None ),
142
+ color = alt . Color ( "column:N" , title = None ) ,
143
+ tooltip = [ x ] + y_values ,
144
+ )
145
+ . interactive ()
146
+ )
134
147
135
148
def line (self , x = None , y = None , ** kwargs ):
136
- return self ._xy (' line' , x , y , ** kwargs )
149
+ return self ._xy (" line" , x , y , ** kwargs )
137
150
138
151
def area (self , x = None , y = None , ** kwargs ):
139
- return self ._xy (' area' , x , y , ** kwargs )
152
+ return self ._xy (" area" , x , y , ** kwargs )
140
153
141
154
# TODO: bars should be grouped, not stacked.
142
155
def bar (self , x = None , y = None , ** kwargs ):
143
- return self ._xy (
144
- {'type' : 'bar' , 'orient' : 'vertical' }, x , y , ** kwargs )
156
+ return self ._xy ({"type" : "bar" , "orient" : "vertical" }, x , y , ** kwargs )
145
157
146
158
def barh (self , x = None , y = None , ** kwargs ):
147
- chart = self ._xy (
148
- {'type' : 'bar' , 'orient' : 'horizontal' }, x , y , ** kwargs )
159
+ chart = self ._xy ({"type" : "bar" , "orient" : "horizontal" }, x , y , ** kwargs )
149
160
chart .encoding .x , chart .encoding .y = chart .encoding .y , chart .encoding .x
150
161
return chart
151
162
152
163
def scatter (self , x , y , c = None , s = None , ** kwargs ):
153
164
if x is None or y is None :
154
165
raise ValueError ("kind='scatter' requires 'x' and 'y' arguments." )
155
- encodings = {'x' : _valid_column (x ), 'y' : _valid_column (y )}
166
+ encodings = {"x" : _valid_column (x ), "y" : _valid_column (y )}
156
167
if c is not None :
157
- encodings [' color' ] = _valid_column (c )
168
+ encodings [" color" ] = _valid_column (c )
158
169
if s is not None :
159
- encodings [' size' ] = _valid_column (s )
170
+ encodings [" size" ] = _valid_column (s )
160
171
columns = list (set (encodings .values ()))
161
172
data = self ._preprocess_data (with_index = False , usecols = columns )
162
- encodings ['tooltip' ] = columns
163
- return alt .Chart (data ).mark_point ().encode (
164
- ** encodings
165
- ).interactive ()
173
+ encodings ["tooltip" ] = columns
174
+ return alt .Chart (data ).mark_point ().encode (** encodings ).interactive ()
166
175
167
176
def hist (self , bins = None , stacked = None , ** kwargs ):
168
177
data = self ._preprocess_data (with_index = False )
169
178
if isinstance (bins , int ):
170
179
bins = alt .Bin (maxbins = bins )
171
180
elif bins is None :
172
181
bins = True
173
- return alt .Chart (data ).transform_fold (
174
- list (data .columns ), as_ = ['column' , 'value' ]
175
- ).mark_bar ().encode (
176
- x = alt .X ('value:Q' , title = None , bin = bins ),
177
- y = alt .Y ('count()' , title = 'Frequency' , stack = stacked ),
178
- color = alt .Color ('column:N' )
182
+ return (
183
+ alt .Chart (data )
184
+ .transform_fold (list (data .columns ), as_ = ["column" , "value" ])
185
+ .mark_bar ()
186
+ .encode (
187
+ x = alt .X ("value:Q" , title = None , bin = bins ),
188
+ y = alt .Y ("count()" , title = "Frequency" , stack = stacked ),
189
+ color = alt .Color ("column:N" ),
190
+ )
179
191
)
180
192
181
193
def box (self , ** kwargs ):
182
194
data = self ._preprocess_data (with_index = False )
183
- return alt . Chart ( data ). transform_fold (
184
- list (data . columns ), as_ = [ 'column' , 'value' ]
185
- ). mark_boxplot (). encode (
186
- x = alt . X ( 'column:N' , title = None ),
187
- y = ' value:Q' ,
195
+ return (
196
+ alt . Chart (data )
197
+ . transform_fold ( list ( data . columns ), as_ = [ "column" , "value" ])
198
+ . mark_boxplot ()
199
+ . encode ( x = alt . X ( "column:N" , title = None ), y = " value:Q" )
188
200
)
189
201
190
202
191
- def plot (data , kind = ' line' , ** kwargs ):
203
+ def plot (data , kind = " line" , ** kwargs ):
192
204
"""Pandas plotting interface for Altair."""
193
205
plotter = _PandasPlotter .create (data )
194
206
195
207
if hasattr (plotter , kind ):
196
208
plotfunc = getattr (plotter , kind )
197
209
else :
198
- raise NotImplementedError (
199
- f"kind='{ kind } ' for data of type { type (data )} " )
210
+ raise NotImplementedError (f"kind='{ kind } ' for data of type { type (data )} " )
200
211
201
212
return plotfunc (** kwargs )
0 commit comments