diff --git a/dhg/random/hypergraphs/hypergraph.py b/dhg/random/hypergraphs/hypergraph.py index 202411b..88dc55f 100644 --- a/dhg/random/hypergraphs/hypergraph.py +++ b/dhg/random/hypergraphs/hypergraph.py @@ -74,18 +74,21 @@ def hypergraph_Gnm(num_v: int, num_e: int, method:str="low_order_first", prob_k_ Args: ``num_v`` (``int``): The Number of vertices. ``num_e`` (``int``): The Number of hyperedges. - ``method`` (``str``): The method to generate hyperedges must be one of ``"uniform"``, ``"low_order_first"``, ``"high_order_first"``. Defaults to ``"uniform"``. + ``method`` (``str``): The method to generate hyperedges must be one of ``"uniform"``, ``"low_order_first"``, ``"high_order_first"``, ``"custom"``. Defaults to ``"uniform"``. Examples: >>> import dhg.random as random >>> hg = random.hypergraph_Gnm(5, 4) >>> hg.e ([(0, 1, 3, 4), (0, 2, 3, 4), (0, 2, 3), (0, 2, 4)], [1.0, 1.0, 1.0, 1.0]) + >>> hg = dhg.random.hypergraph_Gnm(5, 4, 'custom', [0, 0, 0.8, 0.2]) + >>> hg.e + ([(1, 2, 3, 4), (0, 2, 3, 4), (0, 1, 2, 3), (0, 1, 2, 3, 4)], [1.0, 1.0, 1.0, 1.0]) """ # similar to nauty in sagemath, https://doc.sagemath.org/html/en/reference/graphs/sage/graphs/hypergraph_generators.html assert num_v > 1, "num_v must be greater than 1" assert num_e > 0, "num_e must be greater than 0" - assert method in ("uniform", "low_order_first", "high_order_first"), "method must be one of 'uniform', 'low_order_first', 'high_order_first'" + assert method in ("uniform", "low_order_first", "high_order_first", "custom"), "method must be one of 'uniform', 'low_order_first', 'high_order_first', 'custom'" deg_e_list = list(range(2, num_v + 1)) if method == "uniform": prob_k_list = [C(num_v, k) / (2 ** num_v - 1) for k in deg_e_list] @@ -97,6 +100,11 @@ def hypergraph_Gnm(num_v: int, num_e: int, method:str="low_order_first", prob_k_ prob_k_list = [3 ** (-k) for k in range(len(deg_e_list))].reverse() sum_of_prob_k_list = sum(prob_k_list) prob_k_list = [prob_k / sum_of_prob_k_list for prob_k in prob_k_list] + elif method == "custom": + assert prob_k_list is not None, "prob_k_list must be provided when method is custom" + assert len(prob_k_list) == num_v - 1, "prob_k_list must have length `num_v - 1'" + sum_of_prob_k_list = sum(prob_k_list) + prob_k_list = [prob_k / sum_of_prob_k_list for prob_k in prob_k_list] else: raise ValueError(f"Unknown method: {method}")