-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDRG_ml.jl
97 lines (71 loc) · 2.33 KB
/
DRG_ml.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# reading stata files for classification
#=
- some packages.
Pkg.add("StatFiles")
Plg.add("Query")
=#
using DataFrames, StatFiles, Query, DecisionTree, ScikitLearn
# Load data
df1 = DataFrame(load("/Users/austinbean/Google Drive/Texas Inpatient Discharge/Full Versions/2010 4 Quarter PUDF.dta"))
# column names:
names(df1)
# query just childbirth related into a new dataframe.
x1 = @from i in df1 begin
@where (i.apr_mdc == 14 || i.apr_mdc == 15)&(i.pat_age <= 1)&(i.apr_drg >=789 & i.apr_drg <=795 )
@select {i.admitting_diagnosis, i.princ_diag_code, i.oth_diag_code_1, i.oth_diag_code_2, i.oth_diag_code_3, i.oth_diag_code_4, i.oth_diag_code_5, i.oth_diag_code_6, i.oth_diag_code_7, i.oth_diag_code_9, i.oth_diag_code_10, i.oth_diag_code_11, i.oth_diag_code_12, i.oth_diag_code_13, i.oth_diag_code_14, i.oth_diag_code_15, i.oth_diag_code_16, i.oth_diag_code_17, i.oth_diag_code_18, i.oth_diag_code_19, i.oth_diag_code_20, i.oth_diag_code_21, i.oth_diag_code_22, i.oth_diag_code_23, i.oth_diag_code_24, i.apr_drg }
@collect DataFrame
end
# slice the features out, convert to strings.
features = convert(Array{Union{Missing, String}, 2}, x1[1:25])
features = string.(features)
# labels:
labels = convert(Array{Union{Missing, Int16}, 1}, x1[:apr_drg])
# model:
mod1=DecisionTreeClassifier(max_depth = 3)
# fit:
fit!(mod1, features, labels)
# probabilities?
println(get_classes(mod1))
# cross validation:
using ScikitLearn.CrossValidation: cross_val_score
acc = cross_val_score(mod1, features, labels, cv=3)
"""
`Tab`
Like Stata's tab function.
Returns unique values plus frequencies.
"""
function Tab(df::DataFrame, name::Symbol; noprint = false )
if !isa(name, Symbol)
println("Column names must be symbols")
return 0
end
try df[name]
catch err1
if err1 == KeyError
println("Column name not found.")
return 0
else
# exists... good?
end
end
# want the type of the column element.
# start with Int64 because I know that works
outp = Dict{Int64, Int64}()
for el in df[name]
if haskey(outp, el)
outp[el] += 1
else
outp[el] = 1
end
end
if !noprint
println(" ***** Results ***** *****")
println(" Item Count")
for k1 in keys(outp)
println("| ", k1, " ", outp[k1], " |")
end
println(" ***** ****** ****** ***** *****")
end
return outp
end
# model...