Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle nested container expression promotion #1292

Merged
merged 3 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/frontend/Promotion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ let rec promote (exp : Ast.typed_expression) prom =
| ArrayExpr es ->
let pes = List.map ~f:(fun e -> promote e prom) es in
let fst = List.hd_exn pes in
let type_, ad_level = (fst.emeta.type_, fst.emeta.ad_level) in
{ expr= ArrayExpr pes
; emeta=
{ exp.emeta with
type_= UnsizedType.promote_container exp.emeta.type_ type_
; ad_level } }
let type_, ad_level =
( UnsizedType.promote_container exp.emeta.type_ fst.emeta.type_
, fst.emeta.ad_level ) in
{expr= ArrayExpr pes; emeta= {exp.emeta with type_; ad_level}}
| RowVectorExpr (_ :: _ as es) ->
let pes = List.map ~f:(fun e -> promote e prom) es in
let fst = List.hd_exn pes in
Expand All @@ -79,7 +77,11 @@ let promote_list es promotions = List.map2_exn es promotions ~f:promote
*)
let rec get_type_promotion_exn (ad_orig, ty_orig) (ad_expect, ty_expect) =
match (ty_orig, ty_expect) with
| UnsizedType.(UReal, (UReal | UInt) | UVector, UVector | UMatrix, UMatrix)
| UnsizedType.(
( UReal, (UReal | UInt)
| UVector, UVector
| URowVector, URowVector
| UMatrix, UMatrix ))
when ad_orig <> ad_expect ->
ToVar
| UComplex, (UReal | UInt | UComplex)
Expand Down
21 changes: 10 additions & 11 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,9 @@ let check_variable cf loc tenv id =
let ad_level, type_ = check_id cf loc tenv id in
mk_typed_expression ~expr:(Variable id) ~ad_level ~type_ ~loc

let get_consistent_types ad_level type_ es =
let get_consistent_types type_ es =
let ad =
UnsizedType.lub_ad_type
(ad_level :: List.map ~f:(fun e -> e.emeta.ad_level) es) in
UnsizedType.lub_ad_type (List.map ~f:(fun e -> e.emeta.ad_level) es) in
let f state e =
match state with
| Error e -> Error e
Expand All @@ -319,8 +318,8 @@ let check_array_expr loc es =
| [] ->
(* NB: This is actually disallowed by parser *)
Semantic_error.empty_array loc |> error
| {emeta= {ad_level; type_; _}; _} :: _ -> (
match get_consistent_types ad_level type_ es with
| {emeta= {type_; _}; _} :: _ -> (
match get_consistent_types type_ es with
| Error (ty, meta) ->
Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error
| Ok (ad_level, type_, promotions) ->
Expand All @@ -331,26 +330,26 @@ let check_array_expr loc es =

let check_rowvector loc es =
match es with
| {emeta= {ad_level; type_= UnsizedType.URowVector; _}; _} :: _ -> (
match get_consistent_types ad_level URowVector es with
| {emeta= {type_= UnsizedType.URowVector; _}; _} :: _ -> (
match get_consistent_types URowVector es with
| Ok (ad_level, typ, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
~ad_level
~type_:(if typ = UComplex then UComplexMatrix else UMatrix)
~type_:(if typ = UComplexRowVector then UComplexMatrix else UMatrix)
~loc
| Error (_, meta) ->
Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error )
| {emeta= {ad_level; type_= UnsizedType.UComplexRowVector; _}; _} :: _ -> (
match get_consistent_types ad_level UComplexRowVector es with
| {emeta= {type_= UnsizedType.UComplexRowVector; _}; _} :: _ -> (
match get_consistent_types UComplexRowVector es with
| Ok (ad_level, _, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
~ad_level ~type_:UComplexMatrix ~loc
| Error (_, meta) ->
Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error )
| _ -> (
match get_consistent_types DataOnly UReal es with
match get_consistent_types UReal es with
| Ok (ad_level, typ, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
Expand Down
26 changes: 21 additions & 5 deletions src/middle/UnsizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,27 @@ and pp_returntype ppf = function
let autodifftype_can_convert at1 at2 =
match (at1, at2) with DataOnly, AutoDiffable -> false | _ -> true

let rec lub_ad_type = function
| [] -> DataOnly
| x :: xs ->
let y = lub_ad_type xs in
if compare_autodifftype x y < 0 then y else x
let lub_ad_type xs =
List.max_elt ~compare:compare_autodifftype xs
|> Option.value ~default:DataOnly

let%expect_test "lub_ad_type1" =
let ads = [DataOnly; DataOnly; DataOnly; AutoDiffable] in
let lub = lub_ad_type ads in
print_s [%sexp (lub : autodifftype)] ;
[%expect "AutoDiffable"]

let%expect_test "lub_ad_type2" =
let ads = [DataOnly; DataOnly; DataOnly] in
let lub = lub_ad_type ads in
print_s [%sexp (lub : autodifftype)] ;
[%expect "DataOnly"]

let%expect_test "lub_ad_type3" =
let ads = [AutoDiffable; DataOnly; DataOnly; DataOnly] in
let lub = lub_ad_type ads in
print_s [%sexp (lub : autodifftype)] ;
[%expect "AutoDiffable"]

(** Given two types find the minimal type both can convert to *)
let rec common_type = function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ transformed parameters {
complex_matrix[2,2] zs = [[1,2],[3,4]];
row_vector[2] x = [1,2];
complex_row_vector[2] zx = x;

complex_matrix[2,2] cm = [[1,2], [3,4i]];
}
79 changes: 69 additions & 10 deletions test/integration/good/code-gen/complex_numbers/cpp.expected
Original file line number Diff line number Diff line change
Expand Up @@ -7968,12 +7968,13 @@ namespace complex_vectors_model_namespace {
using stan::model::model_base_crtp;
using namespace stan::math;
stan::math::profile_map profiles__;
static constexpr std::array<const char*, 5> locations_array__ =
static constexpr std::array<const char*, 6> locations_array__ =
{" (found before start of program)",
" (in 'complex_vectors.stan', line 4, column 3 to column 37)",
" (in 'complex_vectors.stan', line 5, column 3 to column 42)",
" (in 'complex_vectors.stan', line 6, column 3 to column 27)",
" (in 'complex_vectors.stan', line 7, column 3 to column 32)"};
" (in 'complex_vectors.stan', line 7, column 3 to column 32)",
" (in 'complex_vectors.stan', line 9, column 3 to column 44)"};
class complex_vectors_model final : public model_base_crtp<complex_vectors_model> {
private:

Expand Down Expand Up @@ -8071,6 +8072,23 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
std::complex<local_scalar_t__>(DUMMY_VAR__, DUMMY_VAR__));
current_statement__ = 4;
stan::model::assign(zx, x, "assigning variable zx");
Eigen::Matrix<std::complex<local_scalar_t__>,-1,-1> cm =
Eigen::Matrix<std::complex<local_scalar_t__>,-1,-1>::Constant(2, 2,
std::complex<local_scalar_t__>(DUMMY_VAR__, DUMMY_VAR__));
current_statement__ = 5;
stan::model::assign(cm,
stan::math::to_matrix(
std::vector<Eigen::Matrix<std::complex<double>,1,-1>>{(Eigen::Matrix<std::complex<double>,1,-1>(2) <<
stan::math::to_complex(
1, 0),
stan::math::to_complex(
2, 0)).finished(),
(Eigen::Matrix<std::complex<double>,1,-1>(2) <<
stan::math::to_complex(
3, 0),
stan::math::to_complex(
0, 4)).finished()}),
"assigning variable cm");
} catch (const std::exception& e) {
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
Expand Down Expand Up @@ -8123,6 +8141,10 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
Eigen::Matrix<std::complex<double>,1,-1>::Constant(2,
std::complex<double>(std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN()));
Eigen::Matrix<std::complex<double>,-1,-1> cm =
Eigen::Matrix<std::complex<double>,-1,-1>::Constant(2, 2,
std::complex<double>(std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN()));
if (stan::math::logical_negation(
(stan::math::primitive_value(emit_transformed_parameters__) ||
stan::math::primitive_value(emit_generated_quantities__)))) {
Expand Down Expand Up @@ -8158,11 +8180,26 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
"assigning variable x");
current_statement__ = 4;
stan::model::assign(zx, x, "assigning variable zx");
current_statement__ = 5;
stan::model::assign(cm,
stan::math::to_matrix(
std::vector<Eigen::Matrix<std::complex<double>,1,-1>>{(Eigen::Matrix<std::complex<double>,1,-1>(2) <<
stan::math::to_complex(
1, 0),
stan::math::to_complex(
2, 0)).finished(),
(Eigen::Matrix<std::complex<double>,1,-1>(2) <<
stan::math::to_complex(
3, 0),
stan::math::to_complex(
0, 4)).finished()}),
"assigning variable cm");
if (emit_transformed_parameters__) {
out__.write(z);
out__.write(zs);
out__.write(x);
out__.write(zx);
out__.write(cm);
}
if (stan::math::logical_negation(emit_generated_quantities__)) {
return ;
Expand Down Expand Up @@ -8197,7 +8234,7 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
emit_generated_quantities__ = true) const {
names__ = std::vector<std::string>{};
if (emit_transformed_parameters__) {
std::vector<std::string> temp{"z", "zs", "x", "zx"};
std::vector<std::string> temp{"z", "zs", "x", "zx", "cm"};
names__.reserve(names__.size() + temp.size());
names__.insert(names__.end(), temp.begin(), temp.end());
}
Expand All @@ -8216,7 +8253,9 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
static_cast<size_t>(2), static_cast<size_t>(2)},
std::vector<size_t>{static_cast<size_t>(2)},
std::vector<size_t>{static_cast<size_t>(2),
static_cast<size_t>(2)}};
static_cast<size_t>(2)},
std::vector<size_t>{static_cast<size_t>(2),
static_cast<size_t>(2), static_cast<size_t>(2)}};
dimss__.reserve(dimss__.size() + temp.size());
dimss__.insert(dimss__.end(), temp.begin(), temp.end());
}
Expand Down Expand Up @@ -8253,6 +8292,16 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
param_names__.emplace_back(std::string() + "zx" + '.' +
std::to_string(sym1__) + '.' + "imag");
}
for (int sym1__ = 1; sym1__ <= 2; ++sym1__) {
for (int sym2__ = 1; sym2__ <= 2; ++sym2__) {
param_names__.emplace_back(std::string() + "cm" + '.' +
std::to_string(sym2__) + '.' + std::to_string(sym1__) + '.' +
"real");
param_names__.emplace_back(std::string() + "cm" + '.' +
std::to_string(sym2__) + '.' + std::to_string(sym1__) + '.' +
"imag");
}
}
}
if (emit_generated_quantities__) {}
}
Expand Down Expand Up @@ -8287,14 +8336,24 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
param_names__.emplace_back(std::string() + "zx" + '.' +
std::to_string(sym1__) + '.' + "imag");
}
for (int sym1__ = 1; sym1__ <= 2; ++sym1__) {
for (int sym2__ = 1; sym2__ <= 2; ++sym2__) {
param_names__.emplace_back(std::string() + "cm" + '.' +
std::to_string(sym2__) + '.' + std::to_string(sym1__) + '.' +
"real");
param_names__.emplace_back(std::string() + "cm" + '.' +
std::to_string(sym2__) + '.' + std::to_string(sym1__) + '.' +
"imag");
}
}
}
if (emit_generated_quantities__) {}
}
inline std::string get_constrained_sizedtypes() const {
return std::string("[{\"name\":\"z\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(3) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zs\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"x\",\"type\":{\"name\":\"vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zx\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"}]");
return std::string("[{\"name\":\"z\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(3) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zs\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"x\",\"type\":{\"name\":\"vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zx\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"cm\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"}]");
}
inline std::string get_unconstrained_sizedtypes() const {
return std::string("[{\"name\":\"z\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(3) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zs\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"x\",\"type\":{\"name\":\"vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zx\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"}]");
return std::string("[{\"name\":\"z\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(3) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zs\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"x\",\"type\":{\"name\":\"vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zx\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"cm\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"}]");
}
// Begin method overload boilerplate
template <typename RNG> inline void
Expand All @@ -8304,8 +8363,8 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
emit_generated_quantities = true, std::ostream*
pstream = nullptr) const {
const size_t num_params__ = 0;
const size_t num_transformed = emit_transformed_parameters * (((((3 * 2)
+ ((2 * 2) * 2)) + 2) + (2 * 2)));
const size_t num_transformed = emit_transformed_parameters * ((((((3 * 2)
+ ((2 * 2) * 2)) + 2) + (2 * 2)) + ((2 * 2) * 2)));
const size_t num_gen_quantities = emit_generated_quantities * (0);
const size_t num_to_write = num_params__ + num_transformed +
num_gen_quantities;
Expand All @@ -8322,8 +8381,8 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
emit_generated_quantities = true, std::ostream*
pstream = nullptr) const {
const size_t num_params__ = 0;
const size_t num_transformed = emit_transformed_parameters * (((((3 * 2)
+ ((2 * 2) * 2)) + 2) + (2 * 2)));
const size_t num_transformed = emit_transformed_parameters * ((((((3 * 2)
+ ((2 * 2) * 2)) + 2) + (2 * 2)) + ((2 * 2) * 2)));
const size_t num_gen_quantities = emit_generated_quantities * (0);
const size_t num_to_write = num_params__ + num_transformed +
num_gen_quantities;
Expand Down
26 changes: 26 additions & 0 deletions test/integration/good/code-gen/container-promotion.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
transformed data {
array[2, 2] real Arr;
Arr = {{1, 2}, {3, 4.5}};
}
parameters {
real y;
}
transformed parameters {
vector[2] V;
V = [1, y]';

array[2] row_vector[2] arRV;
arRV = {[1, 2], [3, y]};

array[2, 2] real Mar;
Mar = {{1, 2}, {3, y}};

matrix[2, 2] M;
M = [[1, 2], [3, y]];

array[2, 2, 2] real deep_Mar;
deep_Mar = {{{0, 0}, {0, 0}}, {{1, 2}, {3, y}}};

array[2] matrix[2, 2] deep_M;
deep_M = {[[0, 0], [0, 0]], [[1, 2], [y, 4]]};
}
Loading