Skip to content

Commit

Permalink
Add some type annotations, most notably to smbclient.open_file
Browse files Browse the repository at this point in the history
  • Loading branch information
mon committed Oct 17, 2024
1 parent 42804ca commit 1187457
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 7 deletions.
138 changes: 137 additions & 1 deletion src/smbclient/_os.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright: (c) 2019, Jordan Borean (@jborean93) <jborean93@gmail.com>
# MIT License (see LICENSE or https://opensource.org/licenses/MIT)

from __future__ import annotations

import collections
import datetime
import errno
Expand Down Expand Up @@ -306,6 +308,140 @@ def makedirs(path, exist_ok=False, **kwargs):
create_queue.pop(-1)


# Taken from stdlib typeshed but removed the unused 'U' flag
OpenTextModeUpdating: t.TypeAlias = t.Literal[
"r+",

Check warning on line 313 in src/smbclient/_os.py

View check run for this annotation

Codecov / codecov/patch

src/smbclient/_os.py#L313

Added line #L313 was not covered by tests
"+r",
"rt+",
"r+t",
"+rt",
"tr+",
"t+r",
"+tr",
"w+",
"+w",
"wt+",
"w+t",
"+wt",
"tw+",
"t+w",
"+tw",
"a+",
"+a",
"at+",
"a+t",
"+at",
"ta+",
"t+a",
"+ta",
"x+",
"+x",
"xt+",
"x+t",
"+xt",
"tx+",
"t+x",
"+tx",
]
OpenTextModeWriting: t.TypeAlias = t.Literal["w", "wt", "tw", "a", "at", "ta", "x", "xt", "tx"]
OpenTextModeReading: t.TypeAlias = t.Literal["r", "rt", "tr"]
OpenTextMode: t.TypeAlias = OpenTextModeUpdating | OpenTextModeWriting | OpenTextModeReading
OpenBinaryModeUpdating: t.TypeAlias = t.Literal[
"rb+",
"r+b",
"+rb",
"br+",
"b+r",
"+br",
"wb+",
"w+b",
"+wb",
"bw+",
"b+w",
"+bw",
"ab+",
"a+b",
"+ab",
"ba+",
"b+a",
"+ba",
"xb+",
"x+b",
"+xb",
"bx+",
"b+x",
"+bx",
]
OpenBinaryModeWriting: t.TypeAlias = t.Literal["wb", "bw", "ab", "ba", "xb", "bx"]
OpenBinaryModeReading: t.TypeAlias = t.Literal["rb", "br"]
OpenBinaryMode: t.TypeAlias = OpenBinaryModeUpdating | OpenBinaryModeReading | OpenBinaryModeWriting
FileType: t.TypeAlias = t.Literal["file", "dir", "pipe"]


# Text mode: always returns a TextIOWrapper
@t.overload
def open_file(
path,
mode: OpenTextMode = "r",
buffering=-1,
file_type: FileType = "file",
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
**kwargs,
) -> io.TextIOWrapper[io.BufferedRandom | io.BufferedReader | io.BufferedWriter]: ...


# Otherwise return BufferedRandom, BufferedReader, or BufferedWriter
# NOTE: This incorrectly returns unbuffered opens as Buffered types, due to difficulties
# in annotating that case
@t.overload
def open_file(
path,
mode: OpenBinaryModeUpdating,
buffering=-1,
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
file_type: FileType = "file",
**kwargs,
) -> io.BufferedRandom: ...
@t.overload
def open_file(
path,
mode: OpenBinaryModeReading,
buffering=-1,
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
file_type: FileType = "file",
**kwargs,
) -> io.BufferedReader: ...
@t.overload
def open_file(
path,
mode: OpenBinaryModeWriting,
buffering=-1,
encoding=None,
errors=None,
newline=None,
share_access=None,
desired_access=None,
file_attributes=None,
file_type: FileType = "file",
**kwargs,
) -> io.BufferedWriter: ...


def open_file(
path,
mode="r",
Expand All @@ -316,7 +452,7 @@ def open_file(
share_access=None,
desired_access=None,
file_attributes=None,
file_type="file",
file_type: t.Literal["file", "dir", "pipe"] = "file",
**kwargs,
):
"""
Expand Down
11 changes: 6 additions & 5 deletions src/smbclient/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import ntpath
import uuid
from typing import Literal, Optional

from smbprotocol._text import to_text
from smbprotocol.connection import Capabilities, Connection
Expand Down Expand Up @@ -366,14 +367,14 @@ def get_smb_tree(


def register_session(
server,
username=None,
password=None,
server: str,
username: Optional[str] = None,
password: Optional[str] = None,
port=445,
encrypt=None,
encrypt: Optional[bool] = None,
connection_timeout=60,
connection_cache=None,
auth_protocol="negotiate",
auth_protocol: Literal["negotiate", "ntlm", "kerberos"] = "negotiate",
require_signing=True,
):
"""
Expand Down
10 changes: 9 additions & 1 deletion src/smbprotocol/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import random
from collections import OrderedDict
from typing import Literal, Optional

import spnego
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -170,7 +171,14 @@ def __init__(self):


class Session:
def __init__(self, connection, username=None, password=None, require_encryption=True, auth_protocol="negotiate"):
def __init__(
self,
connection,
username: Optional[str] = None,
password: Optional[str] = None,
require_encryption=True,
auth_protocol: Literal["negotiate", "ntlm", "kerberos"] = "negotiate",
):
"""
[MS-SMB2] v53.0 2017-09-15
Expand Down

0 comments on commit 1187457

Please sign in to comment.