Skip to content

Commit

Permalink
refactor(exporter/id_table): simplify vastly the deserializer of `Wit…
Browse files Browse the repository at this point in the history
…hTable` (thanks @Nadrieril!)
  • Loading branch information
W95Psp committed Oct 23, 2024
1 parent 2463c31 commit 5813ed4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 122 deletions.
54 changes: 26 additions & 28 deletions engine/bin/lib.ml
Original file line number Diff line number Diff line change
Expand Up @@ -142,46 +142,44 @@ let run (options : Types.engine_options) : Types.output =

(** Shallow parses a `id_table::Node<T>` (or a raw `T`) JSON *)
let parse_id_table_node (json : Yojson.Safe.t) :
(int64 * Yojson.Safe.t) list option * Yojson.Safe.t =
let expect_assoc = function
| `Assoc alist -> Some (List.Assoc.find ~equal:[%eq: string] alist)
| _ -> None
in
(int64 * Yojson.Safe.t) list * Yojson.Safe.t =
let expect_uint64 = function
| `Intlit str -> Some (Int64.of_string str)
| `Int id -> Some (Int.to_int64 id)
| _ -> None
in
(let* assoc = expect_assoc json in
let* table = assoc "table" in
let* value = assoc "value" in
let table =
match table with
| `List json_list -> json_list
| _ -> failwith "parse_cached: `map` is supposed to be a list"
in
let table =
List.map
~f:(function
| `List [ id; `Assoc [ (_, contents) ] ] ->
let id =
expect_uint64 id
|> Option.value_exn ~message:"parse_cached: id: expected int64"
in
(id, contents)
| _ -> failwith "parse_cached: expected a list of size two")
table
in
Some (Some table, value))
|> Option.value ~default:(None, json)
let table, value =
match json with
| `List [ table; value ] -> (table, value)
| _ -> failwith "parse_id_table_node: expected a tuple at top-level"
in
let table =
match table with
| `List json_list -> json_list
| _ -> failwith "parse_id_table_node: `map` is supposed to be a list"
in
let table =
List.map
~f:(function
| `List [ id; `Assoc [ (_, contents) ] ] ->
let id =
expect_uint64 id
|> Option.value_exn
~message:"parse_id_table_node: id: expected int64"
in
(id, contents)
| _ -> failwith "parse_id_table_node: expected a list of size two")
table
in
(table, value)

(** Entrypoint of the engine. Assumes `Hax_io.init` was called. *)
let main () =
let options =
let table, json =
Hax_io.read_json () |> Option.value_exn |> parse_id_table_node
in
table |> Option.value ~default:[]
table
|> List.iter ~f:(fun (id, json) ->
Hashtbl.add_exn Types.cache_map ~key:id ~data:(`JSON json));
Types.parse_engine_options json
Expand Down
113 changes: 19 additions & 94 deletions frontend/exporter/src/id_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl<T: Sync + Send + Clone + 'static + SupportedType<Value>> Node<T> {
/// inside it. Serializing `WithTable<T>` will serialize IDs only,
/// skipping values. Deserialization of a `WithTable<T>` will
/// automatically use the table and IDs to reconstruct skipped values.
#[derive(Serialize, Debug)]
#[derive(Debug)]
pub struct WithTable<T> {
table: Table,
value: T,
Expand Down Expand Up @@ -213,11 +213,14 @@ impl<T> WithTable<T> {
}
}

/// Helper function that makes sure no nested deserializations occur.
fn full_id_deserialization<T>(f: impl FnOnce() -> T) -> T {
let _lock: MutexGuard<_> = DESERIALIZATION_STATE_LOCK.try_lock().expect("CACHE_MAP_LOCK: only one WithTable deserialization can occur at a time (nesting is forbidden)");
let result = f();
result
impl<T: Serialize> Serialize for WithTable<T> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut ts = serializer.serialize_tuple_struct("WithTable", 2)?;
use serde::ser::SerializeTupleStruct;
ts.serialize_field(&self.table)?;
ts.serialize_field(&self.value)?;
ts.end()
}
}

/// The deserializer of `WithTable<T>` is special. We first decode the
Expand All @@ -230,94 +233,13 @@ impl<'de, T: Deserialize<'de>> serde::Deserialize<'de> for WithTable<T> {
where
D: serde::Deserializer<'de>,
{
use serde::de::{self, Deserialize, MapAccess, SeqAccess, Visitor};
use std::fmt;
use std::marker::PhantomData;
#[derive(Deserialize, Debug)]
#[serde(field_identifier, rename_all = "lowercase")]
enum Field {
Table,
Value,
}

struct WithTableVisitor<T>(PhantomData<T>);

impl<'de, T: Deserialize<'de>> Visitor<'de> for WithTableVisitor<T> {
type Value = WithTable<T>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct WithTable")
}

fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
where
V: SeqAccess<'de>,
{
let (table, value) =
full_id_deserialization::<Result<(Table, T), V::Error>>(|| {
let previous = std::mem::take(&mut *DESERIALIZATION_STATE.lock().unwrap());
// Deserializing `Node<T>`s populates `DESERIALIZATION_STATE`: the table
// is already constructed in `DESERIALIZATION_STATE`, we discard it below.
let _: Table = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
let value = seq.next_element::<T>().and_then(|inner| {
inner.ok_or_else(|| de::Error::invalid_length(1, &self))
});
let table = std::mem::replace(
&mut *DESERIALIZATION_STATE.lock().unwrap(),
previous,
);
let value = value?;
Ok((table, value))
})?;
Ok(Self::Value { table, value })
}

fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let (table, value) = full_id_deserialization::<Result<(Table, T), V::Error>>(
|| {
let (table_key, table): (Field, Table) = map
.next_entry()?
.ok_or(de::Error::custom("Expected field `table` *first*"))?;
if !matches!(table_key, Field::Table) {
Err(de::Error::custom("Expected field `table` *first*"))?
}
let previous =
std::mem::replace(&mut *DESERIALIZATION_STATE.lock().unwrap(), table);

let value_result = map.next_entry::<Field, T>().and_then(|inner| {
inner.ok_or_else(|| {
de::Error::custom("Expected field `value`, got something else? Seems like the type is wrong.")
})
});
let table = std::mem::replace(
&mut *DESERIALIZATION_STATE.lock().unwrap(),
previous,
);
let (value_key, value) = value_result?;
if !matches!(value_key, Field::Value) {
Err(de::Error::custom(&format!(
"Expected field `value`, found {:#?}",
value_key
)))?
}
if let Some(field) = map.next_key()? {
Err(de::Error::unknown_field(field, &["nothing left"]))?
}
Ok((table, value))
},
)?;
Ok(Self::Value { table, value })
}
}

const FIELDS: &[&str] = &["table", "value"];
let r = deserializer.deserialize_struct("WithTable", FIELDS, WithTableVisitor(PhantomData));
r
let _lock: MutexGuard<_> = DESERIALIZATION_STATE_LOCK.try_lock().expect("CACHE_MAP_LOCK: only one WithTable deserialization can occur at a time (nesting is forbidden)");
use serde_repr::WithTableRepr;
let previous = std::mem::take(&mut *DESERIALIZATION_STATE.lock().unwrap());
let with_table_repr = WithTableRepr::deserialize(deserializer);
*DESERIALIZATION_STATE.lock().unwrap() = previous;
let WithTableRepr(table, value) = with_table_repr?;
Ok(Self { table, value })
}
}

Expand All @@ -336,6 +258,9 @@ mod serde_repr {
pub(super) struct Pair(Id, Value);
pub(super) type SortedIdValuePairs = Vec<Pair>;

#[derive(Serialize, Deserialize)]
pub(super) struct WithTableRepr<T>(pub(super) Table, pub(super) T);

impl<T: SupportedType<Value>> Into<NodeRepr<T>> for Node<T> {
fn into(self) -> NodeRepr<T> {
let value = if serialize_use_id() {
Expand Down

0 comments on commit 5813ed4

Please sign in to comment.