Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/object3d chart #382

Merged
merged 4 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swanlab/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Image,
Text,
Video,
Object3D,
)
from .sdk import (
init,
Expand Down
4 changes: 2 additions & 2 deletions swanlab/data/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from .image import Image
from .text import Text
from .video import Video
from .object_3d import Object3D

# from .video import Video
from typing import Protocol, Union


class FloatConvertible(Protocol):
def __float__(self) -> float:
...
def __float__(self) -> float: ...


DataType = Union[float, FloatConvertible, int, BaseType]
2 changes: 2 additions & 0 deletions swanlab/data/modules/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ class Chart:
text = "text", [list, str]
# 视频类型 list代表一步多视频
video = "video", [list, str]
# 3D点云类型,list代表一步多3D点云
object3d = "object3d", [list, str]
191 changes: 191 additions & 0 deletions swanlab/data/modules/object_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
"""
Author: nexisato
Date: 2024-02-21 17:52:22
FilePath: /SwanLab/swanlab/data/modules/object_3d.py
Description:
3D Point Cloud data parsing
"""

from .base import BaseType
import numpy as np
from typing import Union, ClassVar, Set, List, Optional
from ..utils.file import get_file_hash_numpy_array, get_file_hash_path
import os
import json
import shutil
from io import BytesIO

# 格式化输出 json
import codecs


class Object3D(BaseType):
"""Object 3D class constructor

Parameters
----------
data_or_path: numpy.array, string, io
Path to an object3d format file or numpy array of object3d or Bytes.IO.

numpy.array: 3D point cloud data, shape (N, 3), (N, 4) or (N, 6).
(N, 3) : N * (x, y, z) coordinates
(N, 4) : N * (x, y, z, c) coordinates, where c in range [1, 14]
(N, 6) : N * (x, y, z, r, g, b) coordinates
caption: str
caption associated with the object3d for display
"""

SUPPORTED_TYPES: ClassVar[Set[str]] = {
"obj",
"gltf",
"glb",
"babylon",
"stl",
"pts.json",
}

def __init__(
self,
data_or_path: Union[str, "np.ndarray", "BytesIO", List["Object3D"]],
caption: Optional[str] = None,
):
super().__init__(data_or_path)
self.object3d_data = None
self.caption = self.__convert_caption(caption)
self.extension = None

def get_data(self):
# 如果传入的是Object3D类列表
if isinstance(self.value, list):
return self.get_data_list()

self.object3d_data = self.__preprocess(self.value)

# 根据不同的输入类型进行不同的哈希校验
hash_name = (
get_file_hash_numpy_array(self.object3d_data)[:16]
if isinstance(self.object3d_data, np.ndarray)
else get_file_hash_path(self.object3d_data)[:16]
)

save_dir = os.path.join(self.settings.static_dir, self.tag)
save_name = (
f"{self.caption}-step{self.step}-{hash_name}.{self.extension}"
if self.caption is not None
else f"object3d-step{self.step}-{hash_name}.{self.extension}"
)
# 如果不存在目录则创建
if os.path.exists(save_dir) is False:
os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name)

self.__save(save_path)
return save_name

def __preprocess(self, data_or_path):
"""根据输入不同的输入类型进行不同处理"""
# 如果类型为 str,进行文件后缀格式检查
if isinstance(data_or_path, str):
extension = None
for SUPPORTED_TYPE in Object3D.SUPPORTED_TYPES:
if data_or_path.endswith(SUPPORTED_TYPE):
extension = SUPPORTED_TYPE
break
if not extension:
raise TypeError(
"File '"
+ data_or_path
+ "' is not compatible with Object3D: supported types are: "
+ ", ".join(Object3D.SUPPORTED_TYPES)
)
self.extension = extension
return data_or_path
# 如果类型为 io.BytesIO 二进制流,直接返回
elif isinstance(data_or_path, BytesIO):
self.extension = "pts.json"
return data_or_path

# 如果类型为 numpy.array,进行numpy格式检查
elif isinstance(data_or_path, np.ndarray):
if len(data_or_path.shape) != 2 or data_or_path.shape[1] not in {3, 4, 6}:
raise TypeError(
"""
The shape of the numpy array must be one of either:
(N, 3) : N * (x, y, z) coordinates
(N, 4) : N * (x, y, z, c) coordinates, where c in range [1, 14]
(N, 6) : N * (x, y, z, r, g, b) coordinates
"""
)
self.extension = "pts.json"
return data_or_path
else:
raise TypeError("swanlab.Object3D accepts a file path or numpy like data as input")

def __convert_caption(self, caption):
"""将caption转换为字符串"""
# 如果类型是字符串,则不做转换
if isinstance(caption, str):
caption = caption
# 如果类型是数字,则转换为字符串
elif isinstance(caption, (int, float)):
caption = str(caption)
# 如果类型是None,则转换为默认字符串
elif caption is None:
caption = None
else:
raise TypeError("caption must be a string, int or float.")
return caption

def __save_numpy(self, save_path):
"""保存 numpy.array 格式的 3D点云资源文件 .pts.json 到指定路径"""
try:
list_data = self.object3d_data.tolist()
with codecs.open(save_path, "w", encoding="utf-8") as fp:
json.dump(
list_data,
fp,
separators=(",", ":"),
sort_keys=True,
indent=4,
)
except Exception as e:
raise TypeError(f"Could not save the 3D point cloud data to the path: {save_path}") from e

def __save(self, save_path):
"""
保存 3D点云资源文件到指定路径
"""
if isinstance(self.object3d_data, str):
shutil.copy(self.object3d_data, save_path)
elif isinstance(self.object3d_data, BytesIO):
with open(save_path, "wb") as f:
f.write(self.object3d_data.read())
elif isinstance(self.object3d_data, np.ndarray):
self.__save_numpy(save_path)

def get_more(self, *args, **kwargs) -> dict:
"""返回config数据"""
# 如果传入的是Objet3d类列表
if isinstance(self.value, list):
return self.get_more_list()
else:
return (
{
"caption": self.caption,
}
if self.caption is not None
else None
)

def expect_types(self, *args, **kwargs) -> list:
"""返回支持的文件类型"""
return ["str", "numpy.array", "io"]

def get_namespace(self, *args, **kwargs) -> str:
"""设定分组名"""
return "Object3D"

def get_chart_type(self) -> str:
"""设定图表类型"""
return self.chart.object3d
50 changes: 46 additions & 4 deletions test/create_experiment.py
Nexisato marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,19 +1,61 @@
import swanlab
import random
import numpy as np

epochs = 50
lr = 0.01
offset = random.random() / 5

run = swanlab.init(
experiment_name="Example",
description="这是一个机器学习模拟实验",
config={
"learning_rate": 0.01,
"epochs": 20,
"epochs": epochs,
"learning_rate": lr,
"test": 1,
"debug": "这是一串" + "很长" * 100 + "的字符串",
"verbose": 1,
},
logggings=True,
)

# 模拟机器学习训练过程
for epoch in range(2, run.config.epochs):

def generate_random_nx3(n):
"""生成形状为nx3的随机数组"""
return np.random.rand(n, 3)


def generate_random_nx4(n):
"""生成形状为nx4的随机数组,最后一列是[1,14]范围内的整数分类"""
xyz = np.random.rand(n, 3)
c = np.random.randint(1, 15, size=(n, 1))
return np.hstack((xyz, c))


def generate_random_nx6(n):
"""生成形状为nx6的随机数组,包含RGB颜色"""
xyz = np.random.rand(n, 3)
rgb = np.random.rand(n, 3) # RGB颜色值也可以是[0,1]之间的随机数
rgb = (rgb * 255).astype(np.uint8) # 转换为[0,255]之间的整数
return np.hstack((xyz, rgb))


for epoch in range(2, epochs):
if epoch % 10 == 0:

# swanlab.log(
# {
# "test/object3d1":
# },
# step=epoch,
# )
swanlab.log(
{
"test-object3d1": swanlab.Object3D("./assets/bunny.obj", caption="bunny-obj"),
"test-object3d2": swanlab.Object3D("./assets/test1.pts.json", caption="test1-pts"),
},
step=epoch,
)
acc = 1 - 2**-epoch - random.random() / epoch - offset
loss = 2**-epoch + random.random() / epoch + offset
swanlab.log({"loss": loss, "accuracy": acc})