Skip to content

Commit

Permalink
Merge pull request #1292 from stan-dev/fix/1291-matrix-expression-adtype
Browse files Browse the repository at this point in the history
Properly handle nested container expression promotion
  • Loading branch information
WardBrian authored Mar 16, 2023
2 parents 8bc5aa9 + c3734f5 commit 2eb708a
Show file tree
Hide file tree
Showing 7 changed files with 711 additions and 33 deletions.
16 changes: 9 additions & 7 deletions src/frontend/Promotion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ let rec promote (exp : Ast.typed_expression) prom =
| ArrayExpr es ->
let pes = List.map ~f:(fun e -> promote e prom) es in
let fst = List.hd_exn pes in
let type_, ad_level = (fst.emeta.type_, fst.emeta.ad_level) in
{ expr= ArrayExpr pes
; emeta=
{ exp.emeta with
type_= UnsizedType.promote_container exp.emeta.type_ type_
; ad_level } }
let type_, ad_level =
( UnsizedType.promote_container exp.emeta.type_ fst.emeta.type_
, fst.emeta.ad_level ) in
{expr= ArrayExpr pes; emeta= {exp.emeta with type_; ad_level}}
| RowVectorExpr (_ :: _ as es) ->
let pes = List.map ~f:(fun e -> promote e prom) es in
let fst = List.hd_exn pes in
Expand All @@ -79,7 +77,11 @@ let promote_list es promotions = List.map2_exn es promotions ~f:promote
*)
let rec get_type_promotion_exn (ad_orig, ty_orig) (ad_expect, ty_expect) =
match (ty_orig, ty_expect) with
| UnsizedType.(UReal, (UReal | UInt) | UVector, UVector | UMatrix, UMatrix)
| UnsizedType.(
( UReal, (UReal | UInt)
| UVector, UVector
| URowVector, URowVector
| UMatrix, UMatrix ))
when ad_orig <> ad_expect ->
ToVar
| UComplex, (UReal | UInt | UComplex)
Expand Down
21 changes: 10 additions & 11 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,9 @@ let check_variable cf loc tenv id =
let ad_level, type_ = check_id cf loc tenv id in
mk_typed_expression ~expr:(Variable id) ~ad_level ~type_ ~loc

let get_consistent_types ad_level type_ es =
let get_consistent_types type_ es =
let ad =
UnsizedType.lub_ad_type
(ad_level :: List.map ~f:(fun e -> e.emeta.ad_level) es) in
UnsizedType.lub_ad_type (List.map ~f:(fun e -> e.emeta.ad_level) es) in
let f state e =
match state with
| Error e -> Error e
Expand All @@ -319,8 +318,8 @@ let check_array_expr loc es =
| [] ->
(* NB: This is actually disallowed by parser *)
Semantic_error.empty_array loc |> error
| {emeta= {ad_level; type_; _}; _} :: _ -> (
match get_consistent_types ad_level type_ es with
| {emeta= {type_; _}; _} :: _ -> (
match get_consistent_types type_ es with
| Error (ty, meta) ->
Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error
| Ok (ad_level, type_, promotions) ->
Expand All @@ -331,26 +330,26 @@ let check_array_expr loc es =

let check_rowvector loc es =
match es with
| {emeta= {ad_level; type_= UnsizedType.URowVector; _}; _} :: _ -> (
match get_consistent_types ad_level URowVector es with
| {emeta= {type_= UnsizedType.URowVector; _}; _} :: _ -> (
match get_consistent_types URowVector es with
| Ok (ad_level, typ, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
~ad_level
~type_:(if typ = UComplex then UComplexMatrix else UMatrix)
~type_:(if typ = UComplexRowVector then UComplexMatrix else UMatrix)
~loc
| Error (_, meta) ->
Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error )
| {emeta= {ad_level; type_= UnsizedType.UComplexRowVector; _}; _} :: _ -> (
match get_consistent_types ad_level UComplexRowVector es with
| {emeta= {type_= UnsizedType.UComplexRowVector; _}; _} :: _ -> (
match get_consistent_types UComplexRowVector es with
| Ok (ad_level, _, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
~ad_level ~type_:UComplexMatrix ~loc
| Error (_, meta) ->
Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error )
| _ -> (
match get_consistent_types DataOnly UReal es with
match get_consistent_types UReal es with
| Ok (ad_level, typ, promotions) ->
mk_typed_expression
~expr:(RowVectorExpr (Promotion.promote_list es promotions))
Expand Down
26 changes: 21 additions & 5 deletions src/middle/UnsizedType.ml
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,27 @@ and pp_returntype ppf = function
let autodifftype_can_convert at1 at2 =
match (at1, at2) with DataOnly, AutoDiffable -> false | _ -> true

let rec lub_ad_type = function
| [] -> DataOnly
| x :: xs ->
let y = lub_ad_type xs in
if compare_autodifftype x y < 0 then y else x
let lub_ad_type xs =
List.max_elt ~compare:compare_autodifftype xs
|> Option.value ~default:DataOnly

let%expect_test "lub_ad_type1" =
let ads = [DataOnly; DataOnly; DataOnly; AutoDiffable] in
let lub = lub_ad_type ads in
print_s [%sexp (lub : autodifftype)] ;
[%expect "AutoDiffable"]

let%expect_test "lub_ad_type2" =
let ads = [DataOnly; DataOnly; DataOnly] in
let lub = lub_ad_type ads in
print_s [%sexp (lub : autodifftype)] ;
[%expect "DataOnly"]

let%expect_test "lub_ad_type3" =
let ads = [AutoDiffable; DataOnly; DataOnly; DataOnly] in
let lub = lub_ad_type ads in
print_s [%sexp (lub : autodifftype)] ;
[%expect "AutoDiffable"]

(** Given two types find the minimal type both can convert to *)
let rec common_type = function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ transformed parameters {
complex_matrix[2,2] zs = [[1,2],[3,4]];
row_vector[2] x = [1,2];
complex_row_vector[2] zx = x;

complex_matrix[2,2] cm = [[1,2], [3,4i]];
}
79 changes: 69 additions & 10 deletions test/integration/good/code-gen/complex_numbers/cpp.expected
Original file line number Diff line number Diff line change
Expand Up @@ -7968,12 +7968,13 @@ namespace complex_vectors_model_namespace {
using stan::model::model_base_crtp;
using namespace stan::math;
stan::math::profile_map profiles__;
static constexpr std::array<const char*, 5> locations_array__ =
static constexpr std::array<const char*, 6> locations_array__ =
{" (found before start of program)",
" (in 'complex_vectors.stan', line 4, column 3 to column 37)",
" (in 'complex_vectors.stan', line 5, column 3 to column 42)",
" (in 'complex_vectors.stan', line 6, column 3 to column 27)",
" (in 'complex_vectors.stan', line 7, column 3 to column 32)"};
" (in 'complex_vectors.stan', line 7, column 3 to column 32)",
" (in 'complex_vectors.stan', line 9, column 3 to column 44)"};
class complex_vectors_model final : public model_base_crtp<complex_vectors_model> {
private:

Expand Down Expand Up @@ -8071,6 +8072,23 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
std::complex<local_scalar_t__>(DUMMY_VAR__, DUMMY_VAR__));
current_statement__ = 4;
stan::model::assign(zx, x, "assigning variable zx");
Eigen::Matrix<std::complex<local_scalar_t__>,-1,-1> cm =
Eigen::Matrix<std::complex<local_scalar_t__>,-1,-1>::Constant(2, 2,
std::complex<local_scalar_t__>(DUMMY_VAR__, DUMMY_VAR__));
current_statement__ = 5;
stan::model::assign(cm,
stan::math::to_matrix(
std::vector<Eigen::Matrix<std::complex<double>,1,-1>>{(Eigen::Matrix<std::complex<double>,1,-1>(2) <<
stan::math::to_complex(
1, 0),
stan::math::to_complex(
2, 0)).finished(),
(Eigen::Matrix<std::complex<double>,1,-1>(2) <<
stan::math::to_complex(
3, 0),
stan::math::to_complex(
0, 4)).finished()}),
"assigning variable cm");
} catch (const std::exception& e) {
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
Expand Down Expand Up @@ -8123,6 +8141,10 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
Eigen::Matrix<std::complex<double>,1,-1>::Constant(2,
std::complex<double>(std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN()));
Eigen::Matrix<std::complex<double>,-1,-1> cm =
Eigen::Matrix<std::complex<double>,-1,-1>::Constant(2, 2,
std::complex<double>(std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN()));
if (stan::math::logical_negation(
(stan::math::primitive_value(emit_transformed_parameters__) ||
stan::math::primitive_value(emit_generated_quantities__)))) {
Expand Down Expand Up @@ -8158,11 +8180,26 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
"assigning variable x");
current_statement__ = 4;
stan::model::assign(zx, x, "assigning variable zx");
current_statement__ = 5;
stan::model::assign(cm,
stan::math::to_matrix(
std::vector<Eigen::Matrix<std::complex<double>,1,-1>>{(Eigen::Matrix<std::complex<double>,1,-1>(2) <<
stan::math::to_complex(
1, 0),
stan::math::to_complex(
2, 0)).finished(),
(Eigen::Matrix<std::complex<double>,1,-1>(2) <<
stan::math::to_complex(
3, 0),
stan::math::to_complex(
0, 4)).finished()}),
"assigning variable cm");
if (emit_transformed_parameters__) {
out__.write(z);
out__.write(zs);
out__.write(x);
out__.write(zx);
out__.write(cm);
}
if (stan::math::logical_negation(emit_generated_quantities__)) {
return ;
Expand Down Expand Up @@ -8197,7 +8234,7 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
emit_generated_quantities__ = true) const {
names__ = std::vector<std::string>{};
if (emit_transformed_parameters__) {
std::vector<std::string> temp{"z", "zs", "x", "zx"};
std::vector<std::string> temp{"z", "zs", "x", "zx", "cm"};
names__.reserve(names__.size() + temp.size());
names__.insert(names__.end(), temp.begin(), temp.end());
}
Expand All @@ -8216,7 +8253,9 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
static_cast<size_t>(2), static_cast<size_t>(2)},
std::vector<size_t>{static_cast<size_t>(2)},
std::vector<size_t>{static_cast<size_t>(2),
static_cast<size_t>(2)}};
static_cast<size_t>(2)},
std::vector<size_t>{static_cast<size_t>(2),
static_cast<size_t>(2), static_cast<size_t>(2)}};
dimss__.reserve(dimss__.size() + temp.size());
dimss__.insert(dimss__.end(), temp.begin(), temp.end());
}
Expand Down Expand Up @@ -8253,6 +8292,16 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
param_names__.emplace_back(std::string() + "zx" + '.' +
std::to_string(sym1__) + '.' + "imag");
}
for (int sym1__ = 1; sym1__ <= 2; ++sym1__) {
for (int sym2__ = 1; sym2__ <= 2; ++sym2__) {
param_names__.emplace_back(std::string() + "cm" + '.' +
std::to_string(sym2__) + '.' + std::to_string(sym1__) + '.' +
"real");
param_names__.emplace_back(std::string() + "cm" + '.' +
std::to_string(sym2__) + '.' + std::to_string(sym1__) + '.' +
"imag");
}
}
}
if (emit_generated_quantities__) {}
}
Expand Down Expand Up @@ -8287,14 +8336,24 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
param_names__.emplace_back(std::string() + "zx" + '.' +
std::to_string(sym1__) + '.' + "imag");
}
for (int sym1__ = 1; sym1__ <= 2; ++sym1__) {
for (int sym2__ = 1; sym2__ <= 2; ++sym2__) {
param_names__.emplace_back(std::string() + "cm" + '.' +
std::to_string(sym2__) + '.' + std::to_string(sym1__) + '.' +
"real");
param_names__.emplace_back(std::string() + "cm" + '.' +
std::to_string(sym2__) + '.' + std::to_string(sym1__) + '.' +
"imag");
}
}
}
if (emit_generated_quantities__) {}
}
inline std::string get_constrained_sizedtypes() const {
return std::string("[{\"name\":\"z\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(3) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zs\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"x\",\"type\":{\"name\":\"vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zx\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"}]");
return std::string("[{\"name\":\"z\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(3) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zs\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"x\",\"type\":{\"name\":\"vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zx\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"cm\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"}]");
}
inline std::string get_unconstrained_sizedtypes() const {
return std::string("[{\"name\":\"z\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(3) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zs\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"x\",\"type\":{\"name\":\"vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zx\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"}]");
return std::string("[{\"name\":\"z\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(3) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zs\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"x\",\"type\":{\"name\":\"vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"zx\",\"type\":{\"name\":\"complex_vector\",\"length\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"},{\"name\":\"cm\",\"type\":{\"name\":\"complex_matrix\",\"rows\":" + std::to_string(2) + ",\"cols\":" + std::to_string(2) + "},\"block\":\"transformed_parameters\"}]");
}
// Begin method overload boilerplate
template <typename RNG> inline void
Expand All @@ -8304,8 +8363,8 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
emit_generated_quantities = true, std::ostream*
pstream = nullptr) const {
const size_t num_params__ = 0;
const size_t num_transformed = emit_transformed_parameters * (((((3 * 2)
+ ((2 * 2) * 2)) + 2) + (2 * 2)));
const size_t num_transformed = emit_transformed_parameters * ((((((3 * 2)
+ ((2 * 2) * 2)) + 2) + (2 * 2)) + ((2 * 2) * 2)));
const size_t num_gen_quantities = emit_generated_quantities * (0);
const size_t num_to_write = num_params__ + num_transformed +
num_gen_quantities;
Expand All @@ -8322,8 +8381,8 @@ class complex_vectors_model final : public model_base_crtp<complex_vectors_model
emit_generated_quantities = true, std::ostream*
pstream = nullptr) const {
const size_t num_params__ = 0;
const size_t num_transformed = emit_transformed_parameters * (((((3 * 2)
+ ((2 * 2) * 2)) + 2) + (2 * 2)));
const size_t num_transformed = emit_transformed_parameters * ((((((3 * 2)
+ ((2 * 2) * 2)) + 2) + (2 * 2)) + ((2 * 2) * 2)));
const size_t num_gen_quantities = emit_generated_quantities * (0);
const size_t num_to_write = num_params__ + num_transformed +
num_gen_quantities;
Expand Down
26 changes: 26 additions & 0 deletions test/integration/good/code-gen/container-promotion.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
transformed data {
array[2, 2] real Arr;
Arr = {{1, 2}, {3, 4.5}};
}
parameters {
real y;
}
transformed parameters {
vector[2] V;
V = [1, y]';

array[2] row_vector[2] arRV;
arRV = {[1, 2], [3, y]};

array[2, 2] real Mar;
Mar = {{1, 2}, {3, y}};

matrix[2, 2] M;
M = [[1, 2], [3, y]];

array[2, 2, 2] real deep_Mar;
deep_Mar = {{{0, 0}, {0, 0}}, {{1, 2}, {3, y}}};

array[2] matrix[2, 2] deep_M;
deep_M = {[[0, 0], [0, 0]], [[1, 2], [y, 4]]};
}
Loading

0 comments on commit 2eb708a

Please sign in to comment.