Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/object3d chart #382

Merged
merged 4 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swanlab/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Image,
Text,
Video,
Object3D,
)
from .sdk import (
init,
Expand Down
4 changes: 2 additions & 2 deletions swanlab/data/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from .image import Image
from .text import Text
from .video import Video
from .object_3d import Object3D

# from .video import Video
from typing import Protocol, Union


class FloatConvertible(Protocol):
def __float__(self) -> float:
...
def __float__(self) -> float: ...


DataType = Union[float, FloatConvertible, int, BaseType]
2 changes: 2 additions & 0 deletions swanlab/data/modules/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ class Chart:
text = "text", [list, str]
# 视频类型 list代表一步多视频
video = "video", [list, str]
# 3D点云类型,list代表一步多3D点云
object3d = "object3d", [list, str]
191 changes: 191 additions & 0 deletions swanlab/data/modules/object_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
"""
Author: nexisato
Date: 2024-02-21 17:52:22
FilePath: /SwanLab/swanlab/data/modules/object_3d.py
Description:
3D Point Cloud data parsing
"""

from .base import BaseType
import numpy as np
from typing import Union, ClassVar, Set, List, Optional
from ..utils.file import get_file_hash_numpy_array, get_file_hash_path
import os
import json
import shutil
from io import BytesIO

# 格式化输出 json
import codecs


class Object3D(BaseType):
"""Object 3D class constructor

Parameters
----------
data_or_path: numpy.array, string, io
Path to an object3d format file or numpy array of object3d or Bytes.IO.

numpy.array: 3D point cloud data, shape (N, 3), (N, 4) or (N, 6).
(N, 3) : N * (x, y, z) coordinates
(N, 4) : N * (x, y, z, c) coordinates, where c in range [1, 14]
(N, 6) : N * (x, y, z, r, g, b) coordinates
caption: str
caption associated with the object3d for display
"""

SUPPORTED_TYPES: ClassVar[Set[str]] = {
"obj",
"gltf",
"glb",
"babylon",
"stl",
"pts.json",
}

def __init__(
self,
data_or_path: Union[str, "np.ndarray", "BytesIO", List["Object3D"]],
caption: Optional[str] = None,
):
super().__init__(data_or_path)
self.object3d_data = None
self.caption = self.__convert_caption(caption)
self.extension = None

def get_data(self):
# 如果传入的是Object3D类列表
if isinstance(self.value, list):
return self.get_data_list()

self.object3d_data = self.__preprocess(self.value)

# 根据不同的输入类型进行不同的哈希校验
hash_name = (
get_file_hash_numpy_array(self.object3d_data)[:16]
if isinstance(self.object3d_data, np.ndarray)
else get_file_hash_path(self.object3d_data)[:16]
)

save_dir = os.path.join(self.settings.static_dir, self.tag)
save_name = (
f"{self.caption}-step{self.step}-{hash_name}.{self.extension}"
if self.caption is not None
else f"object3d-step{self.step}-{hash_name}.{self.extension}"
)
# 如果不存在目录则创建
if os.path.exists(save_dir) is False:
os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name)

self.__save(save_path)
return save_name

def __preprocess(self, data_or_path):
"""根据输入不同的输入类型进行不同处理"""
# 如果类型为 str,进行文件后缀格式检查
if isinstance(data_or_path, str):
extension = None
for SUPPORTED_TYPE in Object3D.SUPPORTED_TYPES:
if data_or_path.endswith(SUPPORTED_TYPE):
extension = SUPPORTED_TYPE
break
if not extension:
raise TypeError(
"File '"
+ data_or_path
+ "' is not compatible with Object3D: supported types are: "
+ ", ".join(Object3D.SUPPORTED_TYPES)
)
self.extension = extension
return data_or_path
# 如果类型为 io.BytesIO 二进制流,直接返回
elif isinstance(data_or_path, BytesIO):
self.extension = "pts.json"
return data_or_path

# 如果类型为 numpy.array,进行numpy格式检查
elif isinstance(data_or_path, np.ndarray):
if len(data_or_path.shape) != 2 or data_or_path.shape[1] not in {3, 4, 6}:
raise TypeError(
"""
The shape of the numpy array must be one of either:
(N, 3) : N * (x, y, z) coordinates
(N, 4) : N * (x, y, z, c) coordinates, where c in range [1, 14]
(N, 6) : N * (x, y, z, r, g, b) coordinates
"""
)
self.extension = "pts.json"
return data_or_path
else:
raise TypeError("swanlab.Object3D accepts a file path or numpy like data as input")

def __convert_caption(self, caption):
"""将caption转换为字符串"""
# 如果类型是字符串,则不做转换
if isinstance(caption, str):
caption = caption
# 如果类型是数字,则转换为字符串
elif isinstance(caption, (int, float)):
caption = str(caption)
# 如果类型是None,则转换为默认字符串
elif caption is None:
caption = None
else:
raise TypeError("caption must be a string, int or float.")
return caption

def __save_numpy(self, save_path):
"""保存 numpy.array 格式的 3D点云资源文件 .pts.json 到指定路径"""
try:
list_data = self.object3d_data.tolist()
with codecs.open(save_path, "w", encoding="utf-8") as fp:
json.dump(
list_data,
fp,
separators=(",", ":"),
sort_keys=True,
indent=4,
)
except Exception as e:
raise TypeError(f"Could not save the 3D point cloud data to the path: {save_path}") from e

def __save(self, save_path):
"""
保存 3D点云资源文件到指定路径
"""
if isinstance(self.object3d_data, str):
shutil.copy(self.object3d_data, save_path)
elif isinstance(self.object3d_data, BytesIO):
with open(save_path, "wb") as f:
f.write(self.object3d_data.read())
elif isinstance(self.object3d_data, np.ndarray):
self.__save_numpy(save_path)

def get_more(self, *args, **kwargs) -> dict:
"""返回config数据"""
# 如果传入的是Objet3d类列表
if isinstance(self.value, list):
return self.get_more_list()
else:
return (
{
"caption": self.caption,
}
if self.caption is not None
else None
)

def expect_types(self, *args, **kwargs) -> list:
"""返回支持的文件类型"""
return ["str", "numpy.array", "io"]

def get_namespace(self, *args, **kwargs) -> str:
"""设定分组名"""
return "Object3D"

def get_chart_type(self) -> str:
"""设定图表类型"""
return self.chart.object3d