Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup some of the C++ #871

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/stan_math_backend/Cpp_Json.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ module Str = Re.Str

let rec sizedtype_to_json (st : Expr.Typed.t SizedType.t) : Yojson.Basic.t =
let emit_cpp_expr e =
Fmt.strf "<< %a >>" Expression_gen.pp_expr e
|> Str.global_replace (Str.regexp "[\n\r\t ]+") " "
Fmt.strf "+ std::to_string(%a) +" Expression_gen.pp_expr e
in
match st with
| SInt -> `Assoc [("name", `String "int")]
Expand Down Expand Up @@ -44,17 +43,20 @@ let%expect_test "outvar to json pretty" =
"name": "var_one",
"type": {
"name": "array",
"length": "<< K >>",
"element_type": { "name": "vector", "length": "<< N >>" }
"length": "+ std::to_string(K) +",
"element_type": { "name": "vector", "length": "+ std::to_string(N) +" }
},
"block": "parameters"
} |}]

(*Adds a backslash to all the inner quotes and then
unslash the ones near a plus*)
let replace_cpp_expr s =
s
|> Str.global_replace (Str.regexp {|"|}) {|\"|}
|> Str.global_replace (Str.regexp {|\\"<<|}) {|" <<|}
|> Str.global_replace (Str.regexp {|>>\\"|}) {|<< "|}
|> Str.global_replace (Str.regexp {|\\"\+|}) {|" +|}
|> Str.global_replace (Str.regexp {|\+\\"|}) {|+ "|}
|> Str.global_replace (Str.regexp {|\\n|}) {||}

let wrap_in_quotes s = "\"" ^ s ^ "\""

Expand All @@ -70,4 +72,4 @@ let%expect_test "outvar to json" =
|> out_var_interpolated_json_str |> print_endline ;
[%expect
{|
"[{\"name\":\"var_one\",\"type\":{\"name\":\"array\",\"length\":" << K << ",\"element_type\":{\"name\":\"vector\",\"length\":" << N << "}},\"block\":\"parameters\"}]" |}]
"[{\"name\":\"var_one\",\"type\":{\"name\":\"array\",\"length\":" + std::to_string(K) + ",\"element_type\":{\"name\":\"vector\",\"length\":" + std::to_string(N) + "}},\"block\":\"parameters\"}]" |}]
138 changes: 75 additions & 63 deletions src/stan_math_backend/Stan_math_code_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,15 @@ let includes = "#include <stan/model/model_header.hpp>"
let pp_validate_data ppf (name, st) =
if String.is_suffix ~suffix:"__" name then ()
else
let pp_stdvector ppf args =
let pp_cast ppf x = pf ppf "static_cast<size_t>(%a)" pp_expr x in
pf ppf "@[<hov 2> std::vector<size_t>{@,%a}@]" (list ~sep:comma pp_cast)
args
in
pf ppf "@[<hov 4>context__.validate_dims(@,%S,@,%S,@,%S,@,%a);@]@ "
"data initialization" name
(stantype_prim_str (SizedType.to_unsized st))
pp_call
("context__.to_vec", pp_expr, SizedType.get_dims st)
pp_stdvector (SizedType.get_dims st)

(** Print the constructor of the model class.
Read in data steps:
Expand All @@ -374,26 +378,19 @@ let pp_ctor ppf p =
in
pf ppf "%s(@[<hov 0>%a) : model_base_crtp(0) @]" p.Program.prog_name
(list ~sep:comma string) params ;
let pp_mul ppf () = pf ppf " * " in
let pp_num_param ppf dims =
pf ppf "num_params_r__ += %a;" (list ~sep:pp_mul pp_expr) dims
in
let get_param_st = function
| _, {Program.out_block= Parameters; out_unconstrained_st= st; _} -> (
match SizedType.get_dims st with
| [] -> Some [Expr.Helpers.loop_bottom]
| ls -> Some ls )
| _ -> None
in
let data_idents = List.map ~f:fst p.input_vars |> String.Set.of_list in
let pp_stmt_topdecl_size_only ppf (Stmt.Fixed.({pattern; meta}) as s) =
match pattern with
| Decl {decl_id; decl_type; _} when decl_id <> "pos__" -> (
match decl_type with
| Sized st ->
| Sized st -> (
Locations.pp_smeta ppf meta ;
if Set.mem data_idents decl_id then pp_validate_data ppf (decl_id, st) ;
pp_set_size ppf (decl_id, st, DataOnly)
let is_input_data = Set.mem data_idents decl_id in
match is_input_data with
| true ->
pp_validate_data ppf (decl_id, st) ;
pp_set_size ppf (decl_id, st, DataOnly, false)
| false -> pp_set_size ppf (decl_id, st, DataOnly, true) )
| Unsized _ -> () )
| _ -> pp_statement ppf s
in
Expand All @@ -412,11 +409,29 @@ let pp_ctor ppf p =
pp_located_error ppf
(pp_block, (list ~sep:cut pp_stmt_topdecl_size_only, prepare_data)) ;
cut ppf () ;
pf ppf "num_params_r__ = 0U;@ " ;
pp_located_error ppf
( pp_block
, ( list ~sep:cut pp_num_param
, List.filter_map ~f:get_param_st output_vars ) ) )
let get_param_st = function
| _, {Program.out_block= Parameters; out_unconstrained_st= st; _} -> (
match SizedType.get_dims st with
| [] -> Some [Expr.Helpers.loop_bottom]
| ls -> Some ls )
| _ -> None
in
let output_params = List.filter_map ~f:get_param_st output_vars in
let pp_mul ppf () = pf ppf " * " in
let pp_num_param ppf (dims : Expr.Typed.t list) =
match dims with
| [a] -> pf ppf "@[%a@]@," (list ~sep:pp_mul pp_expr) [a]
| _ -> pf ppf "@[(%a)@]@," (list ~sep:pp_mul pp_expr) dims
in
let pp_plus ppf () = pf ppf " + " in
let pp_set_params ppf pars =
(list ~sep:pp_plus pp_num_param) ppf pars
in
match output_params with
| [] -> pf ppf "num_params_r__ = 0U;@,"
| _ ->
pf ppf "@[<hov 2>num_params_r__ = %a;@]@," pp_set_params
output_params )
, p )

let rec top_level_decls Stmt.Fixed.({pattern; _}) =
Expand All @@ -440,8 +455,8 @@ let pp_model_private ppf {Program.prepare_data; _} =
@param cv_attr Optional parameter to add method attributes.
@param ppbody (?A pretty printer of the method's body)
*)
let pp_method ppf rt name params intro ?(outro = nop) ?(cv_attr = ["const"])
ppbody =
let pp_method ppf rt name params intro ?(outro = nop)
?(cv_attr : string list = ["const"]) ppbody =
pf ppf "@[<v 2>inline %s %s(@[<hov>@,%a@]) %a " rt name
(list ~sep:comma string) params (list ~sep:cut string) cv_attr ;
pf ppf "{@,%a@ " intro () ;
Expand All @@ -453,39 +468,40 @@ let pp_method ppf rt name params intro ?(outro = nop) ?(cv_attr = ["const"])
@param ppf A pretty printer.
*)
let pp_get_param_names ppf {Program.output_vars; _} =
let add_param = fmt "names__.emplace_back(%S);" in
pp_method ppf "void" "get_param_names" ["std::vector<std::string>& names__"]
nop (fun ppf ->
pf ppf "names__.clear();@ " ;
(list ~sep:cut add_param) ppf (List.map ~f:fst output_vars) )
let add_param = fmt "%S" in
pp_method ppf "void" "get_param_names"
["std::vector<std::string>& names__"]
nop
(fun ppf ->
pf ppf "@[<hov 2>names__ = std::vector<std::string>{%a};@]@,"
(list ~sep:comma add_param)
(List.map ~f:fst output_vars) )
~cv_attr:["const"]

(** Print the `get_dims` method of the model class. *)
let pp_get_dims ppf {Program.output_vars; _} =
let pp_cast ppf cast_dims =
pf ppf "static_cast<size_t>(%a)@," pp_expr cast_dims
pf ppf "@[<hov 2>static_cast<size_t>(%a)@]@," pp_expr cast_dims
in
let pp_pack ppf inner_dims =
pf ppf "std::vector<size_t>{@[<hov>@,%a@]}" (list ~sep:comma pp_cast)
inner_dims
in
let pp_add_pack ppf dims =
pf ppf "dimss__.emplace_back(%a);@," pp_pack dims
let pp_add_pack ppf dims = pf ppf "%a" pp_pack dims in
let dim_list =
List.(
map ~f:(fun (_, {Program.out_constrained_st= st; _}) -> st) output_vars)
in
let pp_output_var ppf =
(list ~sep:cut pp_add_pack)
ppf
List.(
map ~f:SizedType.get_dims
(map
~f:(fun (_, {Program.out_constrained_st= st; _}) -> st)
output_vars))
let pp_output_var ppf dims =
(list ~sep:comma pp_add_pack) ppf List.(map ~f:SizedType.get_dims dims)
in
let params = ["std::vector<std::vector<size_t>>& dimss__"] in
let cv_attr = ["const"] in
pp_method ppf "void" "get_dims" params
(const string "dimss__.clear();")
(fun ppf -> pp_output_var ppf)
~cv_attr
pp_method ppf "void" "get_dims"
["std::vector<std::vector<size_t>>& dimss__"]
nop
(fun ppf ->
pf ppf "@[<hov 2>dimss__ = std::vector<std::vector<size_t>>{%a};@]@,"
pp_output_var dim_list )
~cv_attr:["const"]

let pp_method_b ppf rt name params intro ?(outro = nop) ?(cv_attr = ["const"])
body =
Expand Down Expand Up @@ -544,7 +560,7 @@ let rec pp_for_loop_iteratee ?(index_ids = []) ppf (iteratee, dims, pp_body) =
| [] -> pp_body ppf (iteratee, index_ids)
| dim :: dims ->
iter dim (fun ppf (i, idcs) ->
pf ppf "%a" pp_block
pf ppf "@[%a @]" pp_block
(pp_for_loop_iteratee ~index_ids:idcs, (i, dims, pp_body)) )

(** Print the `constrained_param_names` method of the model class. *)
Expand Down Expand Up @@ -575,15 +591,14 @@ let pp_constrained_param_names ppf {Program.output_vars; _} =
let dims = List.rev (SizedType.get_dims st) in
pp_for_loop_iteratee ppf (decl_id, dims, emit_name)
in
let cv_attr = ["const"; "final"] in
pp_method ppf "void" "constrained_param_names" params nop
(fun ppf ->
(list ~sep:cut pp_param_names) ppf paramvars ;
pf ppf "@,if (emit_transformed_parameters__) %a@," pp_block
(list ~sep:cut pp_param_names, tparamvars) ;
pf ppf "@,if (emit_generated_quantities__) %a@," pp_block
(list ~sep:cut pp_param_names, gqvars) )
~cv_attr
~cv_attr:["const"; "final"]

(* Print the `unconstrained_param_names` method of the model class.
This is just a copy of constrained, I need to figure out which one is wrong
Expand Down Expand Up @@ -695,11 +710,9 @@ let pp_log_prob ppf Program.({prog_name; log_prob; _}) =
@param outvars The parameters to gather the sizes for.
*)
let pp_outvar_metadata ppf (method_name, outvars) =
let intro = const string "std::stringstream s__;" in
let outro ppf () = pf ppf "@ return s__.str();" in
let json_str = Cpp_Json.out_var_interpolated_json_str outvars in
let ppbody ppf = pf ppf "s__ << %s;" json_str in
pp_method ppf "std::string" method_name [] intro ~outro ppbody
let ppbody ppf = pf ppf "@[<hov 2>return std::string(%s);@]@," json_str in
pp_method ppf "std::string" method_name [] nop ppbody ~cv_attr:["const"]

(** Print the `get_unconstrained_sizedtypes` method of the model class *)
let pp_unconstrained_types ppf {Program.output_vars; _} =
Expand Down Expand Up @@ -729,14 +742,13 @@ let pp_overloads ppf () =
const bool emit_transformed_parameters = true,
const bool emit_generated_quantities = true,
std::ostream* pstream = nullptr) const {
std::vector<double> vars_vec(vars.size());
std::vector<double> vars_vec;
vars_vec.reserve(vars.size());
std::vector<int> params_i;
write_array_impl(base_rng, params_r, params_i, vars_vec,
emit_transformed_parameters, emit_generated_quantities, pstream);
vars.resize(vars_vec.size());
for (int i = 0; i < vars.size(); ++i) {
vars.coeffRef(i) = vars_vec[i];
}
vars = Eigen::Map<Eigen::Matrix<double,Eigen::Dynamic,1>>(
vars_vec.data(), vars_vec.size());
}

template <typename RNG>
Expand All @@ -746,7 +758,8 @@ let pp_overloads ppf () =
bool emit_transformed_parameters = true,
bool emit_generated_quantities = true,
std::ostream* pstream = nullptr) const {
write_array_impl(base_rng, params_r, params_i, vars, emit_transformed_parameters, emit_generated_quantities, pstream);
write_array_impl(base_rng, params_r, params_i, vars,
emit_transformed_parameters, emit_generated_quantities, pstream);
}

template <bool propto__, bool jacobian__, typename T_>
Expand All @@ -767,13 +780,12 @@ let pp_overloads ppf () =
inline void transform_inits(const stan::io::var_context& context,
Eigen::Matrix<double, Eigen::Dynamic, 1>& params_r,
std::ostream* pstream = nullptr) const final {
std::vector<double> params_r_vec(params_r.size());
std::vector<double> params_r_vec;
params_r_vec.reserve(params_r.size());
std::vector<int> params_i;
transform_inits_impl(context, params_i, params_r_vec, pstream);
params_r.resize(params_r_vec.size());
for (int i = 0; i < params_r.size(); ++i) {
params_r.coeffRef(i) = params_r_vec[i];
}
params_r = Eigen::Map<Eigen::Matrix<double,Eigen::Dynamic,1>>(
params_r_vec.data(), params_r_vec.size());
}
inline void transform_inits(const stan::io::var_context& context,
std::vector<int>& params_i,
Expand Down
28 changes: 17 additions & 11 deletions src/stan_math_backend/Statement_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ let pp_profile ppf (pp_body, name, body) =
in
pf ppf "{@;<1 2>@[<v>%s@;@;%a@]@,}" profile pp_body body

let rec contains_eigen = function
let rec contains_eigen (ut : UnsizedType.t) : bool =
match ut with
| UnsizedType.UArray t -> contains_eigen t
| UMatrix | URowVector | UVector -> true
| UInt | UReal | UMathLibraryFunction | UFun _ -> false

let pp_set_size ppf (decl_id, st, adtype) =
let pp_set_size ppf (decl_id, st, adtype, (needs_filled : bool)) =
(* TODO: generate optimal adtypes for expressions and declarations *)
let real_nan =
match adtype with
Expand All @@ -38,25 +39,30 @@ let pp_set_size ppf (decl_id, st, adtype) =
| SMatrix (d1, d2) -> pf ppf "%a(%a, %a)" pp_st st pp_expr d1 pp_expr d2
| SArray (t, d) -> pf ppf "%a(%a, %a)" pp_st st pp_expr d pp_size_ctor t
in
pf ppf "@[<hov 2>%s = %a;@]@," decl_id pp_size_ctor st ;
if contains_eigen (SizedType.to_unsized st) then
pf ppf "@[<hov 2>stan::math::fill(%s, %s);@]@," decl_id real_nan
let print_fill ppf st =
match (contains_eigen (SizedType.to_unsized st), needs_filled) with
| true, true -> pf ppf "stan::math::fill(%s, %s);" decl_id real_nan
| _, _ -> ()
in
pf ppf "@[<hov 0>%s = %a;@,%a @]@," decl_id pp_size_ctor st print_fill st

let%expect_test "set size mat array" =
let int = Expr.Helpers.int in
strf "@[<v>%a@]" pp_set_size
("d", SArray (SArray (SMatrix (int 2, int 3), int 4), int 5), DataOnly)
( "d"
, SArray (SArray (SMatrix (int 2, int 3), int 4), int 5)
, DataOnly
, false )
|> print_endline ;
[%expect
{|
d = std::vector<std::vector<Eigen::Matrix<double, -1, -1>>>(5, std::vector<Eigen::Matrix<double, -1, -1>>(4, Eigen::Matrix<double, -1, -1>(2, 3)));
stan::math::fill(d, std::numeric_limits<double>::quiet_NaN()); |}]
d = std::vector<std::vector<Eigen::Matrix<double, -1, -1>>>(5, std::vector<Eigen::Matrix<double, -1, -1>>(4, Eigen::Matrix<double, -1, -1>(2, 3))); |}]

(** [pp_for_loop ppf (loopvar, lower, upper, pp_body, body)] tries to
pretty print a for-loop from lower to upper given some loopvar.*)
let pp_for_loop ppf (loopvar, lower, upper, pp_body, body) =
pf ppf "@[<hov>for (@[<hov>int %s = %a;@ %s <= %a;@ ++%s@])" loopvar pp_expr
lower loopvar pp_expr upper loopvar ;
pf ppf "@[for (@[int %s = %a;@ %s <= %a;@ ++%s@])" loopvar pp_expr lower
loopvar pp_expr upper loopvar ;
pf ppf " %a@]" pp_body body

let rec integer_el_type = function
Expand All @@ -76,7 +82,7 @@ let pp_decl ppf (vident, ut, adtype) =
let pp_sized_decl ppf (vident, st, adtype) =
pf ppf "%a@,%a" pp_decl
(vident, SizedType.to_unsized st, adtype)
pp_set_size (vident, st, adtype)
pp_set_size (vident, st, adtype, true)

let pp_possibly_sized_decl ppf (vident, pst, adtype) =
match pst with
Expand Down
Loading