Skip to content

Commit

Permalink
Implement transform precoding (DFT-s-OFDM)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Schäufele <Daniel.Schaeufele@hhi.fraunhofer.de>
  • Loading branch information
danielschaeufele committed Jun 4, 2024
1 parent 2cb12fd commit d2e03a8
Show file tree
Hide file tree
Showing 18 changed files with 666 additions and 34 deletions.
9 changes: 9 additions & 0 deletions sionna/mimo/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class LinearDetector(Layer):
constellation point indices instead of soft-values.
Defaults to `False`.
post_equalizer_transformation: None or Layer
Optional layer that applies a transformation after the equalizer and
before the demapper. This can be used to apply transform precoding
when DFT-s-OFDM is enabled in NR PUSCH.
dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
The dtype of ``y``. Defaults to tf.complex64.
The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
Expand Down Expand Up @@ -96,11 +101,13 @@ def __init__(self,
num_bits_per_symbol=None,
constellation=None,
hard_out=False,
post_equalizer_transformation=None,
dtype=tf.complex64,
**kwargs):
super().__init__(dtype=dtype, **kwargs)
self._output = output
self._hard_out = hard_out
self._post_equalizer_transformation = post_equalizer_transformation

# Determine the equalizer to use
if isinstance(equalizer, str):
Expand Down Expand Up @@ -137,6 +144,8 @@ def __init__(self,

def call(self, inputs):
x_hat, no_eff = self._equalizer(*inputs)
if self._post_equalizer_transformation is not None:
x_hat = self._post_equalizer_transformation(x_hat)
z = self._demapper([x_hat, no_eff])

# Reshape to the expected output shape
Expand Down
3 changes: 2 additions & 1 deletion sionna/nr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from .pusch_dmrs_config import PUSCHDMRSConfig
from .pusch_pilot_pattern import PUSCHPilotPattern
from .pusch_precoder import PUSCHPrecoder
from .pusch_transform_precoder import PUSCHTransformPrecoder, PUSCHTransformDeprecoder
from .pusch_transmitter import PUSCHTransmitter
from .pusch_receiver import PUSCHReceiver
from .pusch_channel_estimation import PUSCHLSChannelEstimator
from .tb_config import TBConfig
from .utils import generate_prng_seq, select_mcs, calculate_tb_size
from .utils import generate_prng_seq, generate_low_papr_seq_type_1, select_mcs, calculate_tb_size
from .tb_encoder import TBEncoder
from .tb_decoder import TBDecoder
from .layer_mapping import LayerMapper, LayerDemapper
119 changes: 102 additions & 17 deletions sionna/nr/pusch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
"""
# pylint: disable=line-too-long

import functools
import numpy as np
from .utils import generate_prng_seq
from .utils import generate_prng_seq, generate_low_papr_seq_type_1
from .config import Config
from sionna import nr
from .utils import calculate_tb_size
Expand Down Expand Up @@ -233,7 +234,7 @@ def n_rnti(self, value):
assert value in range(65536), "n_rnti must be in [0, 65535]"
self._n_rnti = value

#---transform_precoding---#
#---precoding---#
@property
def precoding(self):
"""
Expand Down Expand Up @@ -427,9 +428,9 @@ def n(self):
used for DMRS generation
"""
if self.dmrs.config_type==1:
n_max = self.num_resource_blocks*12//4 -1
n_max = self.num_effective_subcarriers//4 -1
elif self.dmrs.config_type==2:
n_max = self.num_resource_blocks*12//6 -1
n_max = self.num_effective_subcarriers//6 -1
return list(range(n_max+1))

@property
Expand All @@ -450,6 +451,31 @@ def num_resource_blocks(self):
else:
return self.n_size_bwp

@property
def num_effective_resource_blocks(self):
"""
int, read-only : Number of allocated resource blocks for the
PUSCH transmissions, that are actually used (can differ from
num_subcarriers when transform precoding is enabled,
because of constraints on the largest prime factor of the
subcarrier count)
"""
@functools.lru_cache
def adjust_prbs_to_prime_factor_constraints(prbs):
# Decreases the number of PRBs until the largest prime factor is at most 5
for eff_prbs in range(prbs, 1, -1):
n = eff_prbs
for p in [2, 3, 5]:
while n % p == 0:
n /= p
if n == 1:
return eff_prbs

if self.transform_precoding:
return adjust_prbs_to_prime_factor_constraints(self.num_resource_blocks)
else:
return self.num_resource_blocks

@property
def num_subcarriers(self):
"""
Expand All @@ -458,6 +484,17 @@ def num_subcarriers(self):
"""
return 12*self.num_resource_blocks

@property
def num_effective_subcarriers(self):
"""
int, read-only : Number of allocated subcarriers for the
PUSCH transmissions, that are actually used (can differ from
num_subcarriers when transform precoding is enabled,
because of constraints on the largest prime factor of the
subcarrier count)
"""
return 12 * self.num_effective_resource_blocks

@property
def num_res_per_prb(self):
"""
Expand Down Expand Up @@ -488,7 +525,7 @@ def dmrs_mask(self):
resource elements in the resource grid. `True` corresponds to
resource elements on which no data is transmitted.
"""
mask = np.zeros([self.num_subcarriers,
mask = np.zeros([self.num_effective_subcarriers,
self.carrier.num_symbols_per_slot],
dtype=bool)

Expand All @@ -503,7 +540,7 @@ def dmrs_mask(self):
cdm_ind[:,i] = np.array([0,1, 6, 7])+2*i

for i in self.dmrs_symbol_indices:
for j in range(self.num_resource_blocks):
for j in range(self.num_effective_resource_blocks):
for k in range(num_cdm_groups):
mask[cdm_ind[:, k] + 12*j, i] = True
return mask
Expand All @@ -518,7 +555,7 @@ def dmrs_grid(self):
This property returns for each configured DMRS port an empty
resource grid filled with DMRS signals as defined in
Section 6.4.1.1 [3GPP38211]. Not all possible options are implemented,
e.g., frequency hopping and transform precoding are not available.
e.g., frequency hopping is not available.
This property provides the *unprecoded* DMRS for each configured DMRS port.
Precoding might be applied to map the DMRS to the antenna ports. However,
Expand All @@ -536,7 +573,7 @@ def dmrs_grid(self):

# Generate empty resource grid for each port
a_tilde = np.zeros([len(self.dmrs.dmrs_port_set),
self.num_subcarriers,
self.num_effective_subcarriers,
self.carrier.num_symbols_per_slot],
dtype=complex)

Expand All @@ -546,15 +583,23 @@ def dmrs_grid(self):
# For every l_prime
for l_prime in self.l_prime:

# Compute c_init
l = l_bar + l_prime
c_init = self.c_init(l)

# Generate RNG
c = generate_prng_seq(2*self.num_subcarriers, c_init=c_init)
if self.transform_precoding:
if self.dmrs.n_sid is None:
n_id = self.carrier.n_cell_id
else:
n_id = self.dmrs.n_sid
r = generate_low_papr_seq_type_1(self.num_effective_subcarriers // 2, n_id % 30, 0, 0)
else:
# Compute c_init
c_init = self.c_init(l)

# Generate RNG
c = generate_prng_seq(2*self.num_effective_subcarriers, c_init=c_init)

# Map to QAM
r = 1/np.sqrt(2)*((1-2*c[::2]) + 1j*(1-2*c[1::2]))
# Map to QAM
r = 1/np.sqrt(2)*((1-2*c[::2]) + 1j*(1-2*c[1::2]))

# For every port in the dmrs port set
for j_ind, _ in enumerate(self.dmrs.dmrs_port_set):
Expand Down Expand Up @@ -625,8 +670,38 @@ def precoding_matrix(self):

w /= np.sqrt(2)

# Table 6.3.1.5-2
elif self.transform_precoding and self.num_antenna_ports == 4:
w = np.zeros([28, 4, 1], complex)

# TPMI index 0-7
w[:8,0,0] = [ 1, 0, 0, 0, 1, 1, 1, 1]
w[:8,1,0] = [ 0, 1, 0, 0, 0, 0, 0, 0]
w[:8,2,0] = [ 0, 0, 1, 0, 1, -1, 1j,-1j]
w[:8,3,0] = [ 0, 0, 0, 1, 0, 0, 0, 0]

# TPMI index 8-15
w[8:16,0,0] = [ 0, 0, 0, 0, 1, 1, 1, 1]
w[8:16,1,0] = [ 1, 1, 1, 1, 1, 1, 1, 1]
w[8:16,2,0] = [ 0, 0, 0, 0, 1, 1j, -1,-1j]
w[8:16,3,0] = [ 1, -1, 1j,-1j, -1, 1j, 1,-1j]

# TPMI index 16-23
w[16:24,0,0] = [ 1, 1, 1, 1, 1, 1, 1, 1]
w[16:24,1,0] = [ 1j, 1j, 1j, 1j, -1, -1, -1, -1]
w[16:24,2,0] = [ 1, 1j, -1,-1j, 1, 1j, -1,-1j]
w[16:24,3,0] = [ 1j, 1,-1j, -1, 1,-1j, -1, 1j]

# TPMI index 24-27
w[24:28,0,0] = [ 1, 1, 1, 1]
w[24:28,1,0] = [-1j,-1j,-1j,-1j]
w[24:28,2,0] = [ 1, 1j, -1,-1j]
w[24:28,3,0] = [-1j, -1, 1j, 1]

w /= 2

# Table 6.3.1.5-3
elif self.num_antenna_ports==4:
elif not self.transform_precoding and self.num_antenna_ports==4:
w = np.zeros([28,4,1], complex)

# TPMI index 0-7
Expand Down Expand Up @@ -825,7 +900,7 @@ def num_coded_bits(self):
n_re_per_prb = self.num_res_per_prb - self.num_ov

# number of allocated REs
n_re = n_re_per_prb * self.num_resource_blocks
n_re = n_re_per_prb * self.num_effective_resource_blocks

# total number of bits per slot
num_coded_bits = int(self.tb.tb_scaling * self.tb.num_bits_per_symbol \
Expand All @@ -842,7 +917,7 @@ def tb_size(self):

# number of allocated REs
# the max. number of REs per PRB is limited to 156 in 38.214
n_re = min(156, n_re_per_prb) * self.num_resource_blocks
n_re = min(156, n_re_per_prb) * self.num_effective_resource_blocks

# include tb_scaling as defined in Tab. 5.1.3.2-2 38.214
target_tb_size = int(self.tb.target_coderate * self.tb.tb_scaling \
Expand Down Expand Up @@ -924,6 +999,14 @@ def check_config(self):
assert self.num_layers == self.num_antenna_ports,\
"num_layers must be == num_antenna_ports"

if self.transform_precoding:
assert self.num_layers == 1,\
"When transform precoding is used, only a single MIMO layer is supported"
assert self.dmrs.config_type == 1, \
"When transform precoding is used, DMRS config type must be 1"
assert self.dmrs.num_cdm_groups_without_data == 2, \
"When transform precoding is used, num_cdm_groups_without_data must be 2"

# Check Tables 6.4.1.1.3-3/4 are valid
if self.dmrs.length==1:
if self.mapping_type=="A":
Expand Down Expand Up @@ -1033,11 +1116,13 @@ def check_pusch_configs(pusch_configs):
"num_tx" : len(pusch_configs),
"num_layers" : pc.num_layers,
"num_subcarriers" : pc.num_subcarriers,
"num_effective_subcarriers": pc.num_effective_subcarriers,
"num_ofdm_symbols" : pc.symbol_allocation[1],
"subcarrier_spacing" : pc.carrier.subcarrier_spacing*1e3,
"num_antenna_ports" : pc.num_antenna_ports,
"precoding" : pc.precoding,
"precoding_matrices" : [],
"transform_precoding" : pc.transform_precoding,
"pusch_config" : pc,
"carrier_config" : pc.carrier,
"num_coded_bits" : pc.num_coded_bits,
Expand Down
19 changes: 17 additions & 2 deletions sionna/nr/pusch_dmrs_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,29 @@ def n_id(self, value):
if value is None:
self._n_id = None
elif isinstance(value, int):
assert value in list(range(65536)), "n_id must be in [0, 65535]"
assert value in range(65536), "n_id must be in [0, 65535]"
self._n_id = [value, value]
else:
assert len(value)==2, "n_id must be either [] or a two-tuple"
for e in value:
assert e in list(range(65536)), "Each element of n_id must be in [0, 65535]"
assert e in range(65536), "Each element of n_id must be in [0, 65535]"
self._n_id = value

#---n_sid---#
@property
def n_sid(self):
r"""
None (default), [0,...,1007] : DMRS scrambling identity for DFT-s-OFDM
:math:`n_\text{ID}^\text{PUSCH}`
"""
self._ifndef("n_sid", None)
return self._n_sid

@n_sid.setter
def n_sid(self, value):
assert value is None or (isinstance(value, int) and value in range(1008)), "n_sid must None or in [0, 1007]"
self._n_sid = value

#---n_scid---#
@property
def n_scid(self):
Expand Down
17 changes: 11 additions & 6 deletions sionna/nr/pusch_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sionna.ofdm import OFDMDemodulator, LinearDetector
from sionna.utils import insert_dims
from sionna.channel import time_to_ofdm_channel
from .pusch_transform_precoder import PUSCHTransformDeprecoder

class PUSCHReceiver(Layer):
# pylint: disable=line-too-long
Expand Down Expand Up @@ -197,14 +198,19 @@ def __init__(self,
# Use or create default MIMODetector
if mimo_detector is None:
# Default MIMO detector
transformation = PUSCHTransformDeprecoder(pusch_transmitter.resource_grid.num_effective_subcarriers,
dtype=dtype) if pusch_transmitter._transform_precoding else None
self._mimo_detector = LinearDetector("lmmse", "bit", "maxlog",
pusch_transmitter.resource_grid,
self._stream_management,
"qam",
pusch_transmitter._num_bits_per_symbol,
dtype=dtype)
pusch_transmitter.resource_grid,
self._stream_management,
"qam",
pusch_transmitter._num_bits_per_symbol,
post_equalizer_transformation=transformation,
dtype=dtype)
else:
# User-provided MIMO detector
if pusch_transmitter._transform_precoding:
print("WARNING: Using custom mimo detector which might not support transform precoding.")
self._mimo_detector = mimo_detector

# Create LayerDemapper
Expand Down Expand Up @@ -248,7 +254,6 @@ def call(self, inputs):
if self._input_domain=="time":
h = time_to_ofdm_channel(h, self.resource_grid, self._l_min)


if self._w is not None:
# Reshape h to put channel matrix dimensions last
# [batch size, num_rx, num_tx, num_ofdm_symbols,...
Expand Down
Loading

0 comments on commit d2e03a8

Please sign in to comment.