From 3a8b78b0688b69cdce9f0ba27908cf158e911701 Mon Sep 17 00:00:00 2001 From: nexistao <978452096@qq.com> Date: Wed, 21 Feb 2024 22:32:37 +0800 Subject: [PATCH 1/3] add: object-3d numpy array supported --- swanlab/__init__.py | 9 +- swanlab/data/__init__.py | 1 + swanlab/data/modules/__init__.py | 4 +- swanlab/data/modules/chart.py | 2 + swanlab/data/modules/object_3d.py | 135 ++++++++++++++++++++++++++++++ test/create_experiment.py | 36 ++++++-- 6 files changed, 170 insertions(+), 17 deletions(-) create mode 100644 swanlab/data/modules/object_3d.py diff --git a/swanlab/__init__.py b/swanlab/__init__.py index c46f7795..c1673598 100755 --- a/swanlab/__init__.py +++ b/swanlab/__init__.py @@ -1,12 +1,5 @@ # 导出初始化函数和log函数 -from .data import ( - init, - log, - finish, - config, - Audio, - Image, -) +from .data import init, log, finish, config, Audio, Image, Object3D from .utils import get_package_version diff --git a/swanlab/data/__init__.py b/swanlab/data/__init__.py index 58b63998..9104f461 100644 --- a/swanlab/data/__init__.py +++ b/swanlab/data/__init__.py @@ -12,6 +12,7 @@ Audio, Image, Text, + Object3D, ) from .sdk import ( init, diff --git a/swanlab/data/modules/__init__.py b/swanlab/data/modules/__init__.py index dfc02b22..72541366 100644 --- a/swanlab/data/modules/__init__.py +++ b/swanlab/data/modules/__init__.py @@ -2,12 +2,12 @@ from .audio import Audio from .image import Image from .text import Text +from .object_3d import Object3D from typing import Protocol, Union class FloatConvertible(Protocol): - def __float__(self) -> float: - ... + def __float__(self) -> float: ... DataType = Union[float, FloatConvertible, int, BaseType] diff --git a/swanlab/data/modules/chart.py b/swanlab/data/modules/chart.py index 0ed4d7d8..11e7d322 100644 --- a/swanlab/data/modules/chart.py +++ b/swanlab/data/modules/chart.py @@ -20,3 +20,5 @@ class Chart: audio = "audio", [list, str] # 文本类型,代表一步多文本 text = "text", [list, str] + # 3D点云类型,代表一步多 object3d + object3d = "object3d", [list, str] diff --git a/swanlab/data/modules/object_3d.py b/swanlab/data/modules/object_3d.py new file mode 100644 index 00000000..4c22f902 --- /dev/null +++ b/swanlab/data/modules/object_3d.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +""" +Author: nexisato +Date: 2024-02-21 17:52:22 +FilePath: /SwanLab/swanlab/data/modules/object_3d.py +Description: + 3D Point Cloud data parsing +""" + +from .base import BaseType +import numpy as np +from typing import Union, ClassVar, Set, List +from ..utils.file import get_file_hash_numpy_array, get_file_hash_path +import os +import json +import codecs + + +class Object3D(BaseType): + """Object 3D class constructor + + Parameters + ---------- + data_or_path: numpy.array + Path to an object3d format file or numpy array of object3d. + + numpy.array: 3D point cloud data, shape (N, 3), (N, 4) or (N, 6). + (N, 3) : N * (x, y, z) coordinates + (N, 4) : N * (x, y, z, c) coordinates, where c in range [1, 14] + (N, 6) : N * (x, y, z, r, g, b) coordinates + """ + + SUPPORTED_TYPES: ClassVar[Set[str]] = { + "obj", + "gltf", + "glb", + "babylon", + "stl", + "pts.json", + } + + def __init__( + self, + data_or_path: Union[np.ndarray, List["Object3D"]], + caption: str = None, + ): + super().__init__(data_or_path) + self.object3d_data = None + self.caption = self.__convert_caption(caption) + + def get_data(self): + # 如果传入的是Object3D类列表 + if isinstance(self.value, list): + return self.get_data_list() + self.__preprocess(self.value) + # 目前只支持numpy.array格式的3D点云数据 + hash_name = get_file_hash_numpy_array(self.object3d_data)[:16] + save_dir = os.path.join(self.settings.static_dir, self.tag) + save_name = ( + f"{self.caption}-step{self.step}-{hash_name}.pts.json" + if self.caption is not None + else f"object3d-step{self.step}-{hash_name}.pts.json" + ) + # 如果不存在目录则创建 + if os.path.exists(save_dir) is False: + os.makedirs(save_dir) + save_path = os.path.join(save_dir, save_name) + # 保存图像到指定目录 + self.__save(save_path) + return save_name + + def __preprocess(self, data_or_path): + """根据输入不同的输入类型进行不同处理""" + if isinstance(data_or_path, str): + # TODO + raise NotImplementedError("The input type of string is not supported yet.") + + elif isinstance(data_or_path, np.ndarray): + # 如果输入为numpy.array,那么必须是满足我们所定义的shape的3D点云数据 + if len(data_or_path.shape) != 2 or data_or_path.shape[1] not in {3, 4, 6}: + raise TypeError( + """ + The shape of the numpy array must be one of either: + (N, 3) : N * (x, y, z) coordinates + (N, 4) : N * (x, y, z, c) coordinates, where c in range [1, 14] + (N, 6) : N * (x, y, z, r, g, b) coordinates + """ + ) + self.object3d_data = data_or_path + else: + raise TypeError("Invalid data type, only support string or numpy.array.") + + def __convert_caption(self, caption): + """将caption转换为字符串""" + # 如果类型是字符串,则不做转换 + if isinstance(caption, str): + caption = caption + # 如果类型是数字,则转换为字符串 + elif isinstance(caption, (int, float)): + caption = str(caption) + # 如果类型是None,则转换为默认字符串 + elif caption is None: + caption = None + else: + raise TypeError("caption must be a string, int or float.") + return caption + + def __save(self, save_path): + """ + 保存 3D点云资源文件 .pts.json 到指定路径 + """ + try: + list_data = self.object3d_data.tolist() + with codecs.open(save_path, "w", encoding="utf-8") as fp: + json.dump( + list_data, + fp, + separators=(",", ":"), + sort_keys=True, + indent=4, + ) + except Exception as e: + raise ValueError(f"Could not save the 3D point cloud data to the path: {save_path}") from e + + def expect_types(self, *args, **kwargs) -> list: + """返回支持的文件类型""" + return ["str", "numpy.array", "io"] + + def get_namespace(self, *args, **kwargs) -> str: + """设定分组名""" + return "Object3D" + + def get_chart_type(self) -> str: + """设定图表类型""" + return self.chart.object3d diff --git a/test/create_experiment.py b/test/create_experiment.py index f3c1f67b..adda2f2b 100644 --- a/test/create_experiment.py +++ b/test/create_experiment.py @@ -18,22 +18,44 @@ }, logggings=True, ) + + +def generate_random_nx3(n): + """生成形状为nx3的随机数组""" + return np.random.rand(n, 3) + + +def generate_random_nx4(n): + """生成形状为nx4的随机数组,最后一列是[1,14]范围内的整数分类""" + xyz = np.random.rand(n, 3) + c = np.random.randint(1, 15, size=(n, 1)) + return np.hstack((xyz, c)) + + +def generate_random_nx6(n): + """生成形状为nx6的随机数组,包含RGB颜色""" + xyz = np.random.rand(n, 3) + rgb = np.random.rand(n, 3) # RGB颜色值也可以是[0,1]之间的随机数 + rgb = (rgb * 255).astype(np.uint8) # 转换为[0,255]之间的整数 + return np.hstack((xyz, rgb)) + + for epoch in range(2, epochs): if epoch % 10 == 0: - # # 测试audio - # sample_rate = 44100 - # test_audio_arr = np.random.randn(2, 100000) + arr1 = generate_random_nx3(1000) + arr2 = generate_random_nx4(1000) + arr3 = generate_random_nx6(1000) # swanlab.log( # { - # "test/audio": [swanlab.Audio(test_audio_arr, sample_rate, caption="test")] * (epoch // 10), + # "test/object3d1": # }, # step=epoch, # ) - # 测试image - test_image = np.random.randint(0, 255, (100, 100, 3)) swanlab.log( { - "test/image": swanlab.Image(test_image, caption="test"), + "test-object3d1": swanlab.Object3D(arr1, caption="3D点云数据"), + "test-object3d2": swanlab.Object3D(arr2, caption="3D点云数据"), + "test-object3d3": swanlab.Object3D(arr3, caption="3D点云数据"), }, step=epoch, ) From 720f536e7fef8ae0733fd7adf4ed128eb3726c07 Mon Sep 17 00:00:00 2001 From: ZeYi Lin <944270057@qq.com> Date: Sun, 3 Mar 2024 19:56:36 +0800 Subject: [PATCH 2/3] del obejct3d --- swanlab/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/swanlab/__init__.py b/swanlab/__init__.py index 820c3bea..5b55d242 100755 --- a/swanlab/__init__.py +++ b/swanlab/__init__.py @@ -6,7 +6,6 @@ config, Audio, Image, - Object3D, Text, Run, ) From f10c8f78928c688e7bbaec88344617381f422ddc Mon Sep 17 00:00:00 2001 From: nexisato <978452096@qq.com> Date: Sun, 3 Mar 2024 22:28:02 +0800 Subject: [PATCH 3/3] fix: restore create_experiment.py --- test/create_experiment.py | 50 ++++----------------------------------- 1 file changed, 4 insertions(+), 46 deletions(-) diff --git a/test/create_experiment.py b/test/create_experiment.py index e846dc9e..5892a6fa 100644 --- a/test/create_experiment.py +++ b/test/create_experiment.py @@ -1,61 +1,19 @@ import swanlab import random -import numpy as np -epochs = 50 -lr = 0.01 offset = random.random() / 5 run = swanlab.init( experiment_name="Example", description="这是一个机器学习模拟实验", config={ - "epochs": epochs, - "learning_rate": lr, - "test": 1, - "debug": "这是一串" + "很长" * 100 + "的字符串", - "verbose": 1, + "learning_rate": 0.01, + "epochs": 20, }, - logggings=True, ) - -def generate_random_nx3(n): - """生成形状为nx3的随机数组""" - return np.random.rand(n, 3) - - -def generate_random_nx4(n): - """生成形状为nx4的随机数组,最后一列是[1,14]范围内的整数分类""" - xyz = np.random.rand(n, 3) - c = np.random.randint(1, 15, size=(n, 1)) - return np.hstack((xyz, c)) - - -def generate_random_nx6(n): - """生成形状为nx6的随机数组,包含RGB颜色""" - xyz = np.random.rand(n, 3) - rgb = np.random.rand(n, 3) # RGB颜色值也可以是[0,1]之间的随机数 - rgb = (rgb * 255).astype(np.uint8) # 转换为[0,255]之间的整数 - return np.hstack((xyz, rgb)) - - -for epoch in range(2, epochs): - if epoch % 10 == 0: - - # swanlab.log( - # { - # "test/object3d1": - # }, - # step=epoch, - # ) - swanlab.log( - { - "test-object3d1": swanlab.Object3D("./assets/bunny.obj", caption="bunny-obj"), - "test-object3d2": swanlab.Object3D("./assets/test1.pts.json", caption="test1-pts"), - }, - step=epoch, - ) +# 模拟机器学习训练过程 +for epoch in range(2, run.config.epochs): acc = 1 - 2**-epoch - random.random() / epoch - offset loss = 2**-epoch + random.random() / epoch + offset swanlab.log({"loss": loss, "accuracy": acc})