Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

数据维度问题 #2

Open
unlessyo opened this issue May 27, 2024 · 7 comments
Open

数据维度问题 #2

unlessyo opened this issue May 27, 2024 · 7 comments

Comments

@unlessyo
Copy link

class TrafficBase(Dataset):
def init(self,
data_root,
txt_file,
size = None,
):
self.data_root = data_root
self.data_paths = txt_file
with open(self.data_paths, "r") as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.size = size

def __len__(self):
    return self._length

def __getitem__(self, i):
    example = dict()
    path = self.image_paths[i]
    traffic_npy = np.load(path)

    # traffic_npy = np.array(traffic_npy).astype(np.uint8)
    traffic_npy[:,:,2][traffic_npy[:,:,2] > 3600] = 3600
    traffic_npy[:,:,1][traffic_npy[:,:,1] > 150] = 150

    traffic_npy[:,:,0] = (traffic_npy[:,:,0] / 5.0).astype(np.float32)
    traffic_npy[:,:,1] = (traffic_npy[:,:,1] / 150.0).astype(np.float32)
    traffic_npy[:,:,2] = (traffic_npy[:,:,2] / 3600.0).astype(np.float32)
    example['image'] = traffic_npy

    if 'train' in path:
        textpath = './datasets/traffic/train/' + 'text/' + path.split('/')[-1].split('.')[0] + '.txt'
    else:
        textpath = './datasets/traffic/validation/' + 'text/' + path.split('/')[-1].split('.')[0] + '.txt'
    with open(textpath, "r") as f:
        text = str(f.read().splitlines()[0])
    example['caption'] = text
    # example['structure'] = np.load('/home/zcy/latent-diffusion-main/datasets/traffic/matrix_roadclass&length.npy')
    return example

此函数中:
traffic_npy[:,:,0] = (traffic_npy[:,:,0] / 5.0).astype(np.float32)
traffic_npy[:,:,1] = (traffic_npy[:,:,1] / 150.0).astype(np.float32)
traffic_npy[:,:,2] = (traffic_npy[:,:,2] / 3600.0).astype(np.float32)
显示需要处理数据大小为[:,:,3]的数据。根据通信作者提供的数据集,.npy中的数据大小为[:,:,2]。请问第三维数据是什么?提供的代码或数据是否有误?

@ChyaZhang
Copy link
Owner

注释掉第三维即可

@yue0708
Copy link

yue0708 commented Jun 7, 2024

return self.training_type_plugin.validation_step(*step_kwargs.values()) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 386, in validation_step return self.model(*args, **kwargs) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward output = self.module(*inputs[0], **kwargs[0]) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward output = self.module.validation_step(*inputs, **kwargs) File "/root/autodl-fs/ChatTraffic/ldm/models/autoencoder.py", line 374, in validation_step reconstructions, posterior = self(inputs) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/autodl-fs/ChatTraffic/ldm/models/autoencoder.py", line 336, in forward posterior = self.encode(input) File "/root/autodl-fs/ChatTraffic/ldm/models/autoencoder.py", line 325, in encode h = self.encoder(x) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/autodl-fs/ChatTraffic/ldm/modules/diffusionmodules/model.py", line 439, in forward hs = [self.conv_in(x)] File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 399, in forward return self._conv_forward(input, self.weight, self.bias) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 395, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Given groups=1, weight of size [128, 3, 3, 3], expected input[64, 2, 36, 36] to have 3 channels, but got 2 channels instead
我也有相似的问题,论文中提到的traffic data维度是36,36,3,3代表3个特征么,但是数据集的数据维度只有36,36,2,把代码的第三维注释了还是有报错,提示通道数不匹配,请问这是否是数据集不完整导致的呢

@ChyaZhang
Copy link
Owner

return self.training_type_plugin.validation_step(*step_kwargs.values()) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 386, in validation_step return self.model(*args, **kwargs) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward output = self.module(*inputs[0], **kwargs[0]) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward output = self.module.validation_step(*inputs, **kwargs) File "/root/autodl-fs/ChatTraffic/ldm/models/autoencoder.py", line 374, in validation_step reconstructions, posterior = self(inputs) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/autodl-fs/ChatTraffic/ldm/models/autoencoder.py", line 336, in forward posterior = self.encode(input) File "/root/autodl-fs/ChatTraffic/ldm/models/autoencoder.py", line 325, in encode h = self.encoder(x) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/autodl-fs/ChatTraffic/ldm/modules/diffusionmodules/model.py", line 439, in forward hs = [self.conv_in(x)] File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 399, in forward return self._conv_forward(input, self.weight, self.bias) File "/root/miniconda3/envs/ChatTraffic/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 395, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Given groups=1, weight of size [128, 3, 3, 3], expected input[64, 2, 36, 36] to have 3 channels, but got 2 channels instead 我也有相似的问题,论文中提到的traffic data维度是36,36,3,3代表3个特征么,但是数据集的数据维度只有36,36,2,把代码的第三维注释了还是有报错,提示通道数不匹配,请问这是否是数据集不完整导致的呢

发布的数据集确实没有第三个维度的数据,因为这个维度的数据经过我们检查在某些月份存在部分缺失,因此就没有提供。如果只是要预测速度数据的话,要解决这个问题最简单的方法就是可以把速度维度repeat3次。

@JsonL0
Copy link

JsonL0 commented Sep 20, 2024

论文中的输入数据x_i维度为(1260 * 2), 但在implementation detail小节将其重塑为36 * 36 * 2,reshape之后是1296 * 2, 和1260个顶点似乎对不上。我想请问数据维度这个36, 36, 2的数据维度分别代表的是什么?或者我可能漏掉了哪些信息?

@ChyaZhang
Copy link
Owner

论文中的输入数据x_i维度为(1260 * 2), 但在implementation detail小节将其重塑为36 * 36 * 2,reshape之后是1296 * 2, 和1260个顶点似乎对不上。我想请问数据维度这个36, 36, 2的数据维度分别代表的是什么?或者我可能漏掉了哪些信息?

1260是节点数,然后用0填充到1296再reshape,2个维度分别是速度数据和拥挤程度数据

@1013764208
Copy link

你好,是否可以提供下 有关包含第三个维度“流量”的 完整的数据集

@ChyaZhang
Copy link
Owner

ChyaZhang commented Dec 2, 2024

你好,是否可以提供下 有关包含第三个维度“流量”的 完整的数据集

数据集中没有流量数据,只有速度和拥挤程度,第三个维度原本是通行时间,但是这个维度缺失比较严重就没放出来

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants