Skip to content

Commit

Permalink
Improve to_dict (#14)
Browse files Browse the repository at this point in the history
* Improve to_dict
* Extract version
  • Loading branch information
aamalev authored Oct 22, 2023
1 parent ef6d5dd commit d394b48
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 20 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ target-version = [
line-length = 120

[tool.pytest.ini_options]
markers = [
"redis",
]
asyncio_mode = "auto"
testpaths = [
"redis_rs",
Expand All @@ -69,7 +72,7 @@ path = "Cargo.toml"
dependencies = [
"maturin",
"mypy",
"black",
"black",
"ruff",
"isort",
"pytest-asyncio",
Expand Down
46 changes: 28 additions & 18 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,37 @@ pub fn decode(py: Python, v: Vec<u8>, encoding: &str) -> PyObject {
}
}

fn _to_dict(py: Python<'_>, value: Value, encoding: &str) -> PyObject {
match value {
Value::Data(v) => decode(py, v, encoding),
Value::Nil => py.None(),
Value::Int(i) => i.to_object(py),
Value::Bulk(_) => to_dict(py, value, encoding),
Value::Status(s) => s.to_object(py),
Value::Okay => true.to_object(py),
}
}

pub fn to_dict(py: Python, value: Value, encoding: &str) -> PyObject {
let result = PyDict::new(py);

if let Value::Bulk(v) = value {
if let Some(Value::Bulk(_)) = v.get(0) {
for item in v.into_iter() {
if let Value::Bulk(mut pair) = item {
let rkey: Result<String, redis::RedisError> =
FromRedisValue::from_redis_value(pair.get(0).unwrap_or(&Value::Nil));
if let Ok(key) = rkey {
let value = to_dict(py, pair.pop().unwrap_or(Value::Nil), encoding);
result.set_item(key, value).unwrap();
}
let map: HashMap<String, Value> = FromRedisValue::from_redis_value(&value).unwrap_or_default();
if !map.is_empty() {
for (k, value) in map.into_iter() {
let val = _to_dict(py, value, encoding);
result.set_item(k, val).unwrap();
}
} else if let Value::Bulk(v) = value {
for (n, value) in v.into_iter().enumerate() {
let map: HashMap<String, Value> =
FromRedisValue::from_redis_value(&value).unwrap_or_default();
if map.len() == 1 {
for (k, value) in map.into_iter() {
let val = _to_dict(py, value, encoding);
result.set_item(k, val).unwrap();
}
}
} else if let Some(Value::Data(_)) = v.get(1) {
let map: HashMap<String, Vec<u8>> =
FromRedisValue::from_redis_value(&Value::Bulk(v)).unwrap_or_default();
for (k, v) in map.into_iter() {
let val = decode(py, v, encoding);
result.set_item(k, val).unwrap();
} else {
let value = _to_dict(py, value, encoding);
result.set_item(n, value).unwrap();
}
}
}
Expand Down
50 changes: 49 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,57 @@
import asyncio
import os

import pytest

import redis_rs


def to_addr(s: str) -> str:
if s.isdigit():
return f"redis://localhost:{s}"
elif not s.startswith("redis://"):
return f"redis://{s}"
return s


async def get_redis_version(nodes: list) -> str:
key = "redis_version"
async with redis_rs.create_client(
*nodes,
) as c:
infos = await c.execute("INFO", "SERVER", encoding="info")
if isinstance(infos, dict):
result = infos.get(key)
else:
infos = filter(lambda x: isinstance(x, dict), infos)
result = min(map(lambda x: x[key], infos))
return result


NODES = [to_addr(node) for node in os.environ.get("REDIS_NODES", "").split(",") if node]
IS_CLUSTER = os.environ.get("REDIS_CLUSTER", "0") not in {"0"}
VERSION = ""


@pytest.fixture
async def async_client():
async with redis_rs.create_client() as c:
async with redis_rs.create_client(
*NODES,
cluster=IS_CLUSTER,
) as c:
yield c


def pytest_runtest_setup(item):
for marker in item.iter_markers():
if marker.name == "redis":
if marker.kwargs.get("cluster") and not IS_CLUSTER:
pytest.skip("Single redis")
elif marker.kwargs.get("single") and IS_CLUSTER:
pytest.skip("Cluster redis")
elif version := marker.kwargs.get("version"):
global VERSION
if not VERSION:
VERSION = asyncio.run(get_redis_version(NODES))
if str(version) > VERSION:
pytest.skip(f"redis_version:{VERSION} < {version}")
101 changes: 101 additions & 0 deletions tests/test_fetch_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from uuid import uuid4

import pytest

import redis_rs


async def test_hgetall(async_client: redis_rs.AsyncClient):
key = uuid4().hex

await async_client.hset(key, "a", 1)
await async_client.hset(key, "b", 2)
result = await async_client.fetch_dict("HGETALL", key)
assert result
assert result == {"a": b"1", "b": b"2"}


async def test_zrange(async_client: redis_rs.AsyncClient):
key = uuid4().hex

await async_client.execute("ZADD", key, 1.5678, "b")
await async_client.execute("ZADD", key, 2.6, "a")
result = await async_client.fetch_dict("ZRANGE", key, 0, -1, "WITHSCORES", encoding="float")
assert result
assert result == {"a": 2.6, "b": 1.5678}


async def test_xread(async_client: redis_rs.AsyncClient):
stream = uuid4().hex

result = await async_client.xread({stream: 0})
assert isinstance(result, dict)
assert result == {}

ident = await async_client.xadd(stream, {"a": "bcd"})
assert isinstance(ident, str)

result = await async_client.xread(stream)
assert isinstance(result, dict)
assert result == {stream: {ident: {"a": b"bcd"}}}


async def test_xinfo_stream(async_client: redis_rs.AsyncClient):
stream = f"stream-{uuid4()}"
group = f"group-{uuid4()}"

await async_client.execute("XGROUP", "CREATE", stream, group, 0, "MKSTREAM")
xinfo = await async_client.fetch_dict("XINFO", "STREAM", stream)
assert xinfo
assert len(xinfo) > 6


async def test_xinfo_groups(async_client: redis_rs.AsyncClient):
stream = f"stream-{uuid4()}"
group = f"group-{uuid4()}"

await async_client.execute("XGROUP", "CREATE", stream, group, 0, "MKSTREAM")
xinfo = await async_client.fetch_dict("XINFO", "GROUPS", stream)
assert xinfo
assert len(xinfo) == 1
assert len(xinfo[0]) > 3

group = f"group-{uuid4()}"
await async_client.execute("XGROUP", "CREATE", stream, group, 0, "MKSTREAM")
xinfo = await async_client.fetch_dict("XINFO", "GROUPS", stream)
assert xinfo
assert len(xinfo) == 2
assert len(xinfo[0]) > 3
assert len(xinfo[1]) > 3


@pytest.mark.redis(version=6)
async def test_xinfo_consumers(async_client: redis_rs.AsyncClient):
stream = f"stream-{uuid4()}"
group = f"group-{uuid4()}"
consumer = f"consumer-{uuid4()}"

await async_client.execute("XGROUP", "CREATE", stream, group, 0, "MKSTREAM")
await async_client.execute("XGROUP", "CREATECONSUMER", stream, group, consumer)
xinfo = await async_client.fetch_dict("XINFO", "CONSUMERS", stream, group)
assert xinfo
assert len(xinfo) == 1
assert len(xinfo[0]) > 2


@pytest.mark.redis(version=7, cluster=True)
async def test_cluster_shards(async_client: redis_rs.AsyncClient):
result = await async_client.fetch_dict("CLUSTER", "SHARDS")
assert result
assert len(result) > 1
assert len(result[0]) > 1


@pytest.mark.redis(cluster=True)
async def test_info(async_client: redis_rs.AsyncClient):
result = await async_client.fetch_dict("INFO", "SERVER", encoding="info")
assert result
assert isinstance(result, dict)
for k, v in result.items():
assert isinstance(k, str)
assert "redis_version" in v

0 comments on commit d394b48

Please sign in to comment.