diff --git a/Manifest.toml b/Manifest.toml index eb644d8ef..adc573ada 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -218,6 +218,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[RecipesBase]] +deps = ["Random", "Test"] +git-tree-sha1 = "0b3cb370ee4dc00f47f1193101600949f3dcf884" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "0.6.0" + [[Reexport]] deps = ["Pkg"] git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" diff --git a/Project.toml b/Project.toml index 62802f10c..69c913c1b 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" RemoteFiles = "cbe49d4c-5af1-5b60-bb70-0a60aa018e1b" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/src/MLJ.jl b/src/MLJ.jl index 11774e696..af7e7184d 100644 --- a/src/MLJ.jl +++ b/src/MLJ.jl @@ -77,6 +77,9 @@ using LinearAlgebra using Random import Distributed: @distributed, nworkers, pmap +# for plotting +using RecipesBase + const srcdir = dirname(@__FILE__) # the directory containing this file: include("utilities.jl") # general purpose utilities @@ -104,5 +107,8 @@ include("builtins/LocalMultivariateStats.jl") include("loading.jl") # model metadata processing +## SIMPLE PLOTTING RECIPE + +include("plotrecipes.jl") end # module diff --git a/src/plotrecipes.jl b/src/plotrecipes.jl new file mode 100644 index 000000000..7c5182841 --- /dev/null +++ b/src/plotrecipes.jl @@ -0,0 +1,28 @@ +@recipe function f(mach::MLJ.Machine{<:MLJ.EitherTunedModel}) + r = report(mach) + z = r.measurements + x = r.parameter_values[:,1] + y = r.parameter_values[:,2] + xsc, ysc = r.parameter_scales + + xguide --> r.parameter_names[1] + yguide --> r.parameter_names[2] + xscale --> (xsc == :linear ? :identity : xsc) + yscale --> (ysc == :linear ? :identity : ysc) + + st = get(plotattributes, :seriestype, :scatter) + + if st ∈ (:surface, :heatmap, :contour, :contourf, :wireframe) + ux = unique(x) + uy = unique(y) + m = reshape(z, (length(ux), length(uy)))' + ux, uy, m + else + label --> "" + seriestype := st + ms = get(plotattributes, :markersize, 4) + markersize := 3ms * sqrt.(z) + marker_z --> z + x, y + end +end