Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix two issues with tuple functions #1356

Merged
merged 16 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def runPerformanceTests(String testsPath, String stancFlags = ""){
sh """
cd performance-tests-cmdstan/cmdstan
echo 'O=0' >> make/local
echo 'CXXFLAGS+=-Wall' >> make/local
make -j${env.PARALLEL} build; cd ..
./runPerformanceTests.py -j${env.PARALLEL} --runs=0 ${testsPath}
"""
Expand Down
20 changes: 10 additions & 10 deletions src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,17 @@ and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
match is_fun_soa_supported name exprs with
| true -> Set.Poly.union acc demoted_eigen_names
| false -> Set.Poly.union acc demoted_and_top_level_names ) )
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) ->
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec | FnMakeTuple) ->
Set.Poly.union acc demoted_and_top_level_names
| CompilerInternal (_ : 'a Internal_fun.t) -> acc
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) ->
Set.Poly.union acc demoted_and_top_level_names

(**
* Recurse through subexpressions and return a list of Unsized types.
* Recursion continues until
* 1. A non-autodiffable type is found
* 2. An autodiffable scalar is found
(**
* Recurse through subexpressions and return a list of Unsized types.
* Recursion continues until
* 1. A non-autodiffable type is found
* 2. An autodiffable scalar is found
* 3. A `Var` type is found that is an autodiffable matrix
*)
let rec extract_nonderived_admatrix_types
Expand All @@ -225,11 +225,11 @@ let rec extract_nonderived_admatrix_types
else [(adlevel, type_)]

(**
* Recurse through functions to find nonderived ad matrix types.
* Special cases for StanLib functions are for
* Recurse through functions to find nonderived ad matrix types.
* Special cases for StanLib functions are for
* - `check_matching_dims`: compiler function that has no effect on optimization
* - `rep_*vector` These are templated in the C++ to cast up to `Var<Matrix>` types
* - `rep_matrix`. When it's only a scalar being propogated an math library overload can upcast to `Var<Matrix>`
* - `rep_*vector` These are templated in the C++ to cast up to `Var<Matrix>` types
* - `rep_matrix`. When it's only a scalar being propogated an math library overload can upcast to `Var<Matrix>`
*)
and extract_nonderived_admatrix_types_fun (kind : 'a Fun_kind.t)
(exprs : Expr.Typed.t list) =
Expand Down
18 changes: 13 additions & 5 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,17 @@ let truncate_dist ud_dists (id : Ast.identifier)
, Some y ) } in
let funapp meta kind name args =
Expr.{Fixed.pattern= FunApp (trans_fn_kind kind name, args); meta} in
let ensure_type tp lb : Expr.Typed.t =
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
match (tp, Expr.Typed.type_of lb) with
| UnsizedType.UInt, _ -> lb
| _, UInt ->
{ pattern= Promotion (lb, UReal, lb.meta.adlevel)
; meta= {lb.meta with type_= UReal} }
| _ -> lb in
let inclusive_bound tp (lb : Expr.Typed.t) =
if UnsizedType.is_int_type tp then
Expr.Helpers.binop lb Minus Expr.Helpers.one
else lb in
else ensure_type tp lb in
let size_adjust e =
if
(not (UnsizedType.is_container ast_obs.Ast.emeta.type_))
Expand All @@ -172,18 +179,19 @@ let truncate_dist ud_dists (id : Ast.identifier)
(funapp lb.meta fk fn
(inclusive_bound tp lb :: trans_exprs ast_args) ) ) ) ]
| TruncateDownFrom ub ->
let fk, fn, _ = find_function_info cdf_suffices in
let fk, fn, tp = find_function_info cdf_suffices in
let ub = trans_expr ub in
[ trunc Greater "max" ub
(targetme ub.meta.loc
(size_adjust (funapp ub.meta fk fn (ub :: trans_exprs ast_args))) )
]
(size_adjust
(funapp ub.meta fk fn
(ensure_type tp ub :: trans_exprs ast_args) ) ) ) ]
| TruncateBetween (lb, ub) ->
let fk, fn, tp = find_function_info cdf_suffices in
let lb, ub = (trans_expr lb, trans_expr ub) in
let expr args =
funapp ub.meta (Ast.StanLib FnPlain) "log_diff_exp"
[ funapp ub.meta fk fn (ub :: args)
[ funapp ub.meta fk fn (ensure_type tp ub :: args)
; funapp ub.meta fk fn (inclusive_bound tp lb :: args) ] in
let statement =
match
Expand Down
31 changes: 23 additions & 8 deletions src/stan_math_backend/Cpp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ module Types = struct
let local_scalar = TypeLiteral "local_scalar_t__"

(** A [std::vector<t>] *)
let std_vector t = StdVector t
let rec std_vector ?(dims = 1) t =
if dims = 0 then t else std_vector ~dims:(dims - 1) (StdVector t)

let bool = TypeLiteral "bool"
let complex s = Complex s
Expand Down Expand Up @@ -266,6 +267,7 @@ module Decls = struct
VariableDefn
(make_variable_defn ~type_:Int ~name:"current_statement__"
~init:(Assignment (Literal "0")) () )
:: Stmts.unused "current_statement__"

let dummy_var =
VariableDefn
Expand Down Expand Up @@ -299,7 +301,7 @@ end

type template_parameter =
| Typename of string (** The name of a template typename *)
| RequireIs of string * string
| RequireAllCondition of [`Exact of string | `OneOf of string list] * type_
(** A C++ type trait (e.g. is_arithmetic) and the template
name which needs to satisfy that.
These are collated into one require_all_t<> *)
Expand Down Expand Up @@ -412,15 +414,22 @@ module Printing = struct

let pp_requires ~default ppf requires =
if not (List.is_empty requires) then
let pp_require ppf (trait, name) = pf ppf "%s<%s>" trait name in
let pp_single_require t ppf trait = pf ppf "%s<%a>" trait pp_type_ t in
let pp_require ppf (req, t) =
match req with
| `Exact trait -> pp_single_require t ppf trait
| `OneOf traits ->
pf ppf "stan::math::disjunction<@[%a@]>"
(list ~sep:comma (pp_single_require t))
traits in
pf ppf ",@ stan::require_all_t<@[%a@]>*%s"
(list ~sep:comma pp_require)
requires
(if default then " = nullptr" else "")

(**
Pretty print a list of templates as [template <parameter-list>].name
This function pools together [RequireIs] nodes into a [require_all_t]
This function pools together [RequireAllCondition] nodes into a [require_all_t]
*)
let pp_template ~default ppf template_parameters =
let pp_basic_template ppf = function
Expand All @@ -432,7 +441,7 @@ module Printing = struct
if not (List.is_empty template_parameters) then
let templates, requires =
List.partition_map template_parameters ~f:(function
| RequireIs (trait, name) -> Second (trait, name)
| RequireAllCondition (trait, name) -> Second (trait, name)
| Typename name -> First (`Typename name)
| Bool name -> First (`Bool name)
| Require (requirement, args) -> First (`Require (requirement, args)) )
Expand Down Expand Up @@ -727,7 +736,7 @@ module Tests = struct
let ts =
let open Types in
[ matrix (complex local_scalar); const_char_array 43
; std_vector (std_vector Double); const_ref (TemplateType "T0__") ] in
; std_vector ~dims:2 Double; const_ref (TemplateType "T0__") ] in
let open Fmt in
pf stdout "@[<v>%a@]" (list ~sep:comma Printing.pp_type_) ts ;
[%expect
Expand Down Expand Up @@ -762,15 +771,21 @@ module Tests = struct
let funs =
[ make_fun_defn
~templates_init:
([[Typename "T0__"; RequireIs ("stan::is_foobar", "T0__")]], true)
( [ [ Typename "T0__"
; RequireAllCondition
(`Exact "stan::is_foobar", TemplateType "T0__") ] ]
, true )
~name:"foobar" ~return_type:Void ~inline:true ()
; (let s =
[ Comment "A potentially \n long comment"
; Expression (Assign (Var "foo", Literal "3")) ] in
let rethrow = Stmts.rethrow_located s in
make_fun_defn
~templates_init:
([[Typename "T0__"; RequireIs ("stan::is_foobar", "T0__")]], false)
( [ [ Typename "T0__"
; RequireAllCondition
(`Exact "stan::is_foobar", TemplateType "T0__") ] ]
, false )
~name:"foobar" ~return_type:Void ~inline:true ~body:rethrow () ) ]
in
let open Fmt in
Expand Down
35 changes: 17 additions & 18 deletions src/stan_math_backend/Lower_expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ and lower_functionals fname suffix es mem_pattern =
| _, args -> (fname, args @ [msgs]) in
let fname = stan_namespace_qualify fname in
let templates = templates false suffix in
Exprs.templated_fun_call fname templates (lower_exprs args) in
Exprs.templated_fun_call fname templates
(lower_exprs ~promote_reals:true args) in
Some lower_hov

and lower_fun_app suffix fname es mem_pattern
Expand All @@ -400,23 +401,21 @@ and lower_user_defined_fun f suffix es =

and lower_compiler_internal ad ut f es =
let open Expression_syntax in
let gen_tuple_literal es : expr =
(* NB: This causes some inefficencies such as eagerly
evaluating eigen expressions and copying data vectors *)
let is_simple (e : Expr.Typed.t) =
match e.pattern with
| Var _ -> e.meta.adlevel <> DataOnly
| Lit _ -> true
| Promotion ({pattern= Var _ | Lit _; _}, _, _) -> is_scalar e
| _ -> false in
if List.for_all ~f:is_simple es then
fun_call "std::forward_as_tuple" (lower_exprs es)
else
Constructor
( Tuple
(List.map es ~f:(fun {meta= {adlevel; type_; _}; _} ->
lower_unsizedtype_local adlevel type_ ) )
, lower_exprs es ) in
let gen_tuple_literal (es : Expr.Typed.t list) : expr =
(* we make full copies of tuples
due to a lack of templating sophistication
in function generation *)
let types =
List.map es ~f:(fun {meta= {adlevel; type_; _}; _} ->
let base_type = lower_unsizedtype_local adlevel type_ in
if
UnsizedType.is_dataonlytype adlevel
&& not
( UnsizedType.is_scalar_type type_
|| UnsizedType.contains_tuple type_ )
then Types.const_ref base_type
else base_type ) in
Constructor (Tuple types, lower_exprs es) in
match f with
| Internal_fun.FnMakeArray ->
let ut =
Expand Down
Loading