diff --git a/.gitignore b/.gitignore index ac72d1c..a109ed7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ docs/.build/ pyrad.egg-info/ *.pyc *__pycache__ +*.egg/ diff --git a/CHANGES.rst b/CHANGES.rst index 2bab2d1..5ec3079 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,8 +1,18 @@ Changelog ========= -2.1 - Unreleased ------------------- +* Add experimental async client and server implementation for python >=3.5. + +* Add IPv6 bind support for client and server. + +* Add support of tlv and integer64 attributes. + +* Multiple minor enhancements and fixes. + +2.1 - Feb 2, 2017 +----------------- + +* Add CoA support (client and server). * Add tagged attribute support (send only). diff --git a/MANIFEST.in b/MANIFEST.in index 6cb73b5..ce75542 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ +include LICENSE.txt recursive-include example * prune example/.svn diff --git a/README.rst b/README.rst index b9a3646..d21e921 100644 --- a/README.rst +++ b/README.rst @@ -4,6 +4,9 @@ :target: https://pypi.python.org/pypi/pyrad .. image:: https://img.shields.io/pypi/dm/pyrad.svg :target: https://pypi.python.org/pypi/pyrad +.. image:: https://readthedocs.org/projects/pyrad/badge/?version=latest + :target: http://pyrad.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status Introduction ============ @@ -55,8 +58,9 @@ Python modules:: Author, Copyright, Availability =============================== -pyrad was written by Wichert Akkerman and is licensed -under a BSD license. +pyrad was written by Wichert Akkerman and is maintained by Christian Giese (GIC-de). + +This project is licensed under a BSD license. Copyright and license information can be found in the LICENSE.txt file. diff --git a/TODO.rst b/TODO.rst new file mode 100644 index 0000000..165b971 --- /dev/null +++ b/TODO.rst @@ -0,0 +1,2 @@ +ToDo +==== diff --git a/TODO.txt b/TODO.txt deleted file mode 100644 index d614787..0000000 --- a/TODO.txt +++ /dev/null @@ -1,11 +0,0 @@ -RFC5176 -======= -github ticket #2 -github ticket #3 - -Add Change of Authorization packet type. - -- use UDP port 3799 as default destination port -- CoA packets have same format as Access packets - -In general CoA requests are already working as shown in `example/coa.py`! diff --git a/docs/.static/repoze.css b/docs/.static/repoze.css deleted file mode 100644 index 0935625..0000000 --- a/docs/.static/repoze.css +++ /dev/null @@ -1,33 +0,0 @@ -@import url('default.css'); -body { - background-color: #006339; -} - -div.document { - background-color: #dad3bd; -} - -div.sphinxsidebar h3,h4,h5,li,a { - color: #127c56 !important; -} - -div.related { - color: #dad3bd; - background-color: #00744a; -} - -div.related a { - color: #dad3bd; -} - -div.body h3 { - font-size: 120%; -} - -div.body h3, -div.body h4, -div.body h5, -div.body h6 { - background-color: transparent; -} - diff --git a/docs/Makefile b/docs/Makefile index e0984c5..1446ca6 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,70 +1,20 @@ -# Makefile for Sphinx documentation +# Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = -SPHINXBUILD = env PYTHONPATH=.. sphinx-build -PAPER = a4 - -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d .build/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . - -.PHONY: help clean html web pickle htmlhelp latex changes linkcheck +SPHINXBUILD = sphinx-build +SPHINXPROJ = pyrad +SOURCEDIR = source +BUILDDIR = build +# Put it first so that "make" without argument is like "make help". help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " pickle to make pickle files (usable by e.g. sphinx-web)" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " changes to make an overview over all changed/added/deprecated items" - @echo " linkcheck to check all external links for integrity" - -clean: - -rm -rf .build/* - -html: - mkdir -p .build/html .build/doctrees - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) .build/html - @echo - @echo "Build finished. The HTML pages are in .build/html." - -pickle: - mkdir -p .build/pickle .build/doctrees - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) .build/pickle - @echo - @echo "Build finished; now you can process the pickle files or run" - @echo " sphinx-web .build/pickle" - @echo "to start the sphinx-web server." - -web: pickle - -htmlhelp: - mkdir -p .build/htmlhelp .build/doctrees - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) .build/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in .build/htmlhelp." - -latex: - mkdir -p .build/latex .build/doctrees - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) .build/latex - @echo - @echo "Build finished; the LaTeX files are in .build/latex." - @echo "Run \`make all-pdf' or \`make all-ps' in that directory to" \ - "run these through (pdf)latex." + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -changes: - mkdir -p .build/changes .build/doctrees - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) .build/changes - @echo - @echo "The overview file is in .build/changes." +.PHONY: help Makefile -linkcheck: - mkdir -p .build/linkcheck .build/doctrees - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) .build/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in .build/linkcheck/output.txt." +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py deleted file mode 100644 index 9df2a73..0000000 --- a/docs/conf.py +++ /dev/null @@ -1,192 +0,0 @@ -# -*- coding: utf-8 -*- -# -# repoze.atemplate documentation build configuration file -# -# This file is execfile()d with the current directory set to its containing -# dir. -# -# The contents of this file are pickled, so don't put values in the -# namespace that aren't pickleable (module imports are okay, they're -# removed automatically). -# -# All configuration values have a default value; values that are commented -# out serve to show the default value. - -import sys, os - -# If your extensions are in another directory, add it here. If the -# directory is relative to the documentation root, use os.path.abspath to -# make it absolute, like shown here. -#sys.path.append(os.path.abspath('some/directory')) - - -# General configuration -# --------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.todo' ] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['.templates'] - -# The suffix of source filenames. -source_suffix = '.rst' - -# The master toctree document. -master_doc = 'index' - -# General substitutions. -project = 'pyrad' -copyright = u'2002-2009 Wichert Akkerman, 2009 Kristoffer Gronlund' - -# The default replacements for |version| and |release|, also used in various -# other places throughout the built documents. -# -# The short X.Y version. -version = '1.2' -# The full version, including alpha/beta/rc tags. -release = '1.2' - -# There are two options for replacing |today|: either, you set today to -# some non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -today_fmt = '%B %d, %Y' - -# List of documents that shouldn't be included in the build. -#unused_docs = [] - -# List of directories, relative to source directories, that shouldn't be -# searched for source files. -#exclude_dirs = [] - -# The reST default role (used for this markup: `text`) to use for all -# documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - - -# Options for HTML output -# ----------------------- - -# The style sheet to use for HTML and HTML Help pages. A file of that name -# must exist either in Sphinx' static/ path, or in one of the custom paths -# given in html_static_path. -html_style = 'repoze.css' - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -#html_title = None - -# A shorter title for the navigation bar. Default is the same as -# html_title. -#html_short_title = None - -# The name of an image file (within the static path) to place at the top of -# the sidebar. -html_logo = '.static/logo.png' - -# The name of an image file (within the static path) to use as favicon of -# the docs. This file should be a Windows icon file (.ico) being 16x16 or -# 32x32 pixels large. -#html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) -# here, relative to this directory. They are copied after the builtin -# static files, so a file named "default.css" will overwrite the builtin -# "default.css". -html_static_path = ['.static'] - -# If not '', a 'Last updated on:' timestamp is inserted at every page -# bottom, using the given strftime format. -html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -#html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_use_modindex = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, the reST sources are included in the HTML build as -# _sources/. -#html_copy_source = True - -# If true, an OpenSearch description file will be output, and all pages -# will contain a tag referring to it. The value of this option must -# be the base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = '' - -# Output file base name for HTML help builder. -htmlhelp_basename = 'euphoriecontent' - - -# Options for LaTeX output -# ------------------------ - -# The paper size ('letter' or 'a4'). -#latex_paper_size = 'a4' - -# The font size ('10pt', '11pt' or '12pt'). -#latex_font_size = '10pt' - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, document class [howto/manual]). -latex_documents = [ - ('index', 'euphoriecontent.tex', 'euphorie.content Documentation', - 'Simplon', 'manual'), -] - -# The name of an image file (relative to this directory) to place at the -# top of the title page. -latex_logo = '.static/logo.png' - -# For "manual" documents, if this is true, then toplevel headings are -# parts, not chapters. -#latex_use_parts = False - -# Additional stuff for the LaTeX preamble. -#latex_preamble = '' - -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_use_modindex = True - - -# Options for extras -# ------------------ -todo_include_todos = True - diff --git a/docs/index.rst b/docs/index.rst deleted file mode 100644 index d0a42a6..0000000 --- a/docs/index.rst +++ /dev/null @@ -1,74 +0,0 @@ -.. _index: - -********************************* -:mod:`pyrad` -- RADIUS for Python -********************************* - -:Author: Wichert Akkerman -:Version: |version| - -.. module:: pyrad - -Introduction -============ - -pyrad is an implementation of a RADIUS client as described in RFC2865. -It takes care of all the details like building RADIUS packets, sending -them and decoding responses. - -Here is an example of doing a authentication request:: - - import pyrad.packet - from pyrad.client import Client - from pyrad.dictionary import Dictionary - - srv=Client(server="radius.my.domain", secret="s3cr3t", - dict=Dictionary("dicts/dictionary", "dictionary.acc")) - - req=srv.CreateAuthPacket(code=pyrad.packet.AccessRequest, - User_Name="wichert", NAS_Identifier="localhost") - req["User-Password"]=req.PwCrypt("password") - - reply=srv.SendPacket(req) - if reply.code==pyrad.packet.AccessAccept: - print "access accepted" - else: - print "access denied" - - print "Attributes returned by server:" - for i in reply.keys(): - print "%s: %s" % (i, reply[i]) - - -Requirements & Installation -=========================== - -pyrad requires Python 2.6 or later, or Python 3.2 or later - -Installing is simple; pyrad uses the standard distutils system for installing -Python modules:: - - python setup.py install - - -API Documentation -================= - -Per-module :mod:`pyrad` API documentation. - -.. toctree:: - :maxdepth: 2 - - api/client - api/dictionary - api/host - api/packet - api/proxy - api/server - -Indices and tables ------------------- - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..9fa49d9 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,36 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build +set SPHINXPROJ=pyrad + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/.static/logo.png b/docs/source/_static/logo.png similarity index 100% rename from docs/.static/logo.png rename to docs/source/_static/logo.png diff --git a/docs/api/client.rst b/docs/source/api/client.rst similarity index 98% rename from docs/api/client.rst rename to docs/source/api/client.rst index ecf94e8..ae90c0d 100644 --- a/docs/api/client.rst +++ b/docs/source/api/client.rst @@ -1,7 +1,6 @@ :mod:`pyrad.client` -- basic client =================================== - .. automodule:: pyrad.client .. autoclass:: Timeout @@ -9,4 +8,3 @@ .. autoclass:: Client :members: - diff --git a/docs/api/dictionary.rst b/docs/source/api/dictionary.rst similarity index 99% rename from docs/api/dictionary.rst rename to docs/source/api/dictionary.rst index 21f7883..da370e8 100644 --- a/docs/api/dictionary.rst +++ b/docs/source/api/dictionary.rst @@ -1,7 +1,6 @@ :mod:`pyrad.dictionary` -- RADIUS dictionary ============================================ - .. automodule:: pyrad.dictionary .. autoclass:: ParseError @@ -9,4 +8,3 @@ .. autoclass:: Dictionary :members: - diff --git a/docs/api/host.rst b/docs/source/api/host.rst similarity index 98% rename from docs/api/host.rst rename to docs/source/api/host.rst index 13f2624..29e1760 100644 --- a/docs/api/host.rst +++ b/docs/source/api/host.rst @@ -1,9 +1,7 @@ :mod:`pyrad.host` -- RADIUS host definition =========================================== - .. automodule:: pyrad.host .. autoclass:: Host :members: - diff --git a/docs/api/packet.rst b/docs/source/api/packet.rst similarity index 89% rename from docs/api/packet.rst rename to docs/source/api/packet.rst index b149d23..05c4a21 100644 --- a/docs/api/packet.rst +++ b/docs/source/api/packet.rst @@ -1,7 +1,6 @@ :mod:`pyrad.packet` -- packet encoding and decoding =================================================== - .. automodule:: pyrad.packet .. autoclass:: Packet @@ -13,6 +12,9 @@ .. autoclass:: AcctPacket :members: + .. autoclass:: CoAPacket + :members: + .. autoclass:: PacketError :members: @@ -20,7 +22,7 @@ Constants --------- -The :mod:`pyrad.packet` module defines several common constants +The :mod:`pyrad.packet` module defines several common constants that are useful when dealing with RADIUS packets. The following packet codes are defined: @@ -44,5 +46,3 @@ CoARequest 43 CoAACK 44 CoANAK 45 ================== ====== - - diff --git a/docs/api/proxy.rst b/docs/source/api/proxy.rst similarity index 99% rename from docs/api/proxy.rst rename to docs/source/api/proxy.rst index 43c7064..34017a4 100644 --- a/docs/api/proxy.rst +++ b/docs/source/api/proxy.rst @@ -1,7 +1,6 @@ :mod:`pyrad.proxy` -- basic proxy ================================= - .. automodule:: pyrad.proxy .. autoclass:: Proxy diff --git a/docs/api/server.rst b/docs/source/api/server.rst similarity index 99% rename from docs/api/server.rst rename to docs/source/api/server.rst index 70c981f..2cfbd38 100644 --- a/docs/api/server.rst +++ b/docs/source/api/server.rst @@ -1,7 +1,6 @@ :mod:`pyrad.server` -- basic server =================================== - .. automodule:: pyrad.server .. autoclass:: RemoteHost @@ -12,4 +11,3 @@ .. autoclass:: Server :members: - diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..3cf1b40 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +# +# pyrad documentation build configuration file, created by +# sphinx-quickstart on Thu Feb 2 15:16:16 2017. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +import os +import sys +sys.path.insert(0, os.path.abspath('../../')) + + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = ['sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', + 'sphinx.ext.todo', + 'sphinx.ext.viewcode'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'pyrad' +copyright = u'2017, Wichert Akkerman' +author = u'Wichert Akkerman' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = u'2.1' +# The full version, including alpha/beta/rc tags. +release = u'2.1' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = [] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +# html_theme = 'alabaster' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +html_logo = '_static/logo.png' + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'pyraddoc' + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'pyrad.tex', u'pyrad Documentation', + u'Wichert Akkerman', 'manual'), +] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'pyrad', u'pyrad Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'pyrad', u'pyrad Documentation', + author, 'pyrad', 'One line description of project.', + 'Miscellaneous'), +] + + + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = {'https://docs.python.org/': None} diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..27998ec --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,76 @@ + +********************************* +:mod:`pyrad` -- RADIUS for Python +********************************* + +:Author: Wichert Akkerman +:Version: |version| + +Introduction +============ + +pyrad is an implementation of a RADIUS client/server as described in RFC2865. +It takes care of all the details like building RADIUS packets, sending +them and decoding responses. + +Here is an example of doing a authentication request:: + + from __future__ import print_function + from pyrad.client import Client + from pyrad.dictionary import Dictionary + import pyrad.packet + + srv = Client(server="localhost", secret=b"Kah3choteereethiejeimaeziecumi", + dict=Dictionary("dictionary")) + + # create request + req = srv.CreateAuthPacket(code=pyrad.packet.AccessRequest, + User_Name="wichert", NAS_Identifier="localhost") + req["User-Password"] = req.PwCrypt("password") + + # send request + reply = srv.SendPacket(req) + + if reply.code == pyrad.packet.AccessAccept: + print("access accepted") + else: + print("access denied") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + + +Requirements & Installation +=========================== + +pyrad requires Python 2.6 or later, or Python 3.2 or later + +Installing is simple; pyrad uses the standard distutils system for installing +Python modules:: + + python setup.py install + + +API Documentation +================= + +Per-module :mod:`pyrad` API documentation. + +.. toctree:: + :maxdepth: 2 + + api/client + api/dictionary + api/host + api/packet + api/proxy + api/server + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/example/auth_async.py b/example/auth_async.py new file mode 100644 index 0000000..9ce4a41 --- /dev/null +++ b/example/auth_async.py @@ -0,0 +1,164 @@ +#!/usr/bin/python + +import asyncio + +import logging +import traceback +from pyrad.dictionary import Dictionary +from pyrad.client_async import ClientAsync +from pyrad.packet import AccessAccept + +logging.basicConfig(level="DEBUG", + format="%(asctime)s [%(levelname)-8s] %(message)s") +client = ClientAsync(server="localhost", + secret=b"Kah3choteereethiejeimaeziecumi", + timeout=4, + dict=Dictionary("dictionary")) + +loop = asyncio.get_event_loop() + + +def create_request(client, user): + req = client.CreateAuthPacket(User_Name=user) + + req["NAS-IP-Address"] = "192.168.1.10" + req["NAS-Port"] = 0 + req["Service-Type"] = "Login-User" + req["NAS-Identifier"] = "trillian" + req["Called-Station-Id"] = "00-04-5F-00-0F-D1" + req["Calling-Station-Id"] = "00-01-24-80-B3-9C" + req["Framed-IP-Address"] = "10.0.0.100" + + return req + +def print_reply(reply): + if reply.code == AccessAccept: + print("Access accepted") + else: + print("Access denied") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + +def test_auth1(): + + global client + + try: + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + client.initialize_transports(enable_auth=True, + local_addr='127.0.0.1', + local_auth_port=8000, + enable_acct=True, + enable_coa=True))) + + + + req = client.CreateAuthPacket(User_Name="wichert") + + req["NAS-IP-Address"] = "192.168.1.10" + req["NAS-Port"] = 0 + req["Service-Type"] = "Login-User" + req["NAS-Identifier"] = "trillian" + req["Called-Station-Id"] = "00-04-5F-00-0F-D1" + req["Calling-Station-Id"] = "00-01-24-80-B3-9C" + req["Framed-IP-Address"] = "10.0.0.100" + + future = client.SendPacket(req) + + # loop.run_until_complete(future) + loop.run_until_complete(asyncio.ensure_future( + asyncio.gather( + future, + return_exceptions=True + ) + + )) + + if future.exception(): + print('EXCEPTION ', future.exception()) + else: + reply = future.result() + + if reply.code == AccessAccept: + print("Access accepted") + else: + print("Access denied") + + print("Attributes returned by server:") + for i in reply.keys(): + print("%s: %s" % (i, reply[i])) + + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + print('END') + + del client + except Exception as exc: + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + loop.close() + +def test_multi_auth(): + + global client + + try: + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + client.initialize_transports(enable_auth=True, + local_addr='127.0.0.1', + local_auth_port=8000, + enable_acct=True, + enable_coa=True))) + + + + reqs = [] + for i in range(255): + req = create_request(client, "user%s" % i) + future = client.SendPacket(req) + reqs.append(future) + + # loop.run_until_complete(future) + loop.run_until_complete(asyncio.ensure_future( + asyncio.gather( + *reqs, + return_exceptions=True + ) + + )) + + for future in reqs: + if future.exception(): + print('EXCEPTION ', future.exception()) + else: + reply = future.result() + print_reply(reply) + + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + print('END') + + del client + except Exception as exc: + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + # Close transports + loop.run_until_complete(asyncio.ensure_future( + client.deinitialize_transports())) + + loop.close() + +#test_multi_auth() +test_auth1() diff --git a/example/client-coa.py b/example/client-coa.py new file mode 100755 index 0000000..30e30b9 --- /dev/null +++ b/example/client-coa.py @@ -0,0 +1,62 @@ +#!/usr/bin/python +# +# Copyright 6WIND, 2017 +# + +from __future__ import print_function +from pyrad import dictionary, packet, server +import sys +import prctl + +class FakeCoA(server.Server): + + def HandleCoaPacket(self, pkt): + """Accounting packet handler. + Function that is called when a valid + accounting packet has been received. + + :param pkt: packet to process + :type pkt: Packet class instance + """ + print("Received a coa request %d" % pkt.code) + print(" Attributes: ") + for attr in pkt.keys(): + print(" %s: %s" % (attr, pkt[attr])) + + reply = self.CreateReplyPacket(pkt) + # try ACK or NACK + # reply.code = packet.CoANAK + reply.code = packet.CoAACK + self.SendReplyPacket(pkt.fd, reply) + + def HandleDisconnectPacket(self, pkt): + print("Received a disconnect request %d" % pkt.code) + print(" Attributes: ") + for attr in pkt.keys(): + print(" %s: %s" % (attr, pkt[attr])) + + reply = self.CreateReplyPacket(pkt) + # try ACK or NACK + # reply.code = packet.DisconnectNAK + reply.code = packet.DisconnectACK + self.SendReplyPacket(pkt.fd, reply) + +if __name__ == '__main__': + + prctl.set_name('radius-FakeCoA-client') + + if len(sys.argv) != 2: + print ("usage: client-coa.py 3799") + sys.exit(1) + + bindport=int(sys.argv[1]) + + # create server/coa only and read dictionary + # bind and listen only on 127.0.0.1:argv[1] + coa = FakeCoA(addresses=["127.0.0.1"], dict=dictionary.Dictionary("dictionary"), coaport=bindport, auth_enabled=False, acct_enabled=False, coa_enabled=True) + + # add peers (address, secret, name) + coa.hosts["127.0.0.1"] = server.RemoteHost("127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost") + + # start + coa.Run() diff --git a/example/coa.py b/example/coa.py old mode 100644 new mode 100755 index ec0479c..0d462bf --- a/example/coa.py +++ b/example/coa.py @@ -3,6 +3,11 @@ from pyrad.client import Client from pyrad import dictionary from pyrad import packet +import sys + +if len(sys.argv) != 3: + print ("usage: coa.py {coa|dis} daemon-1234") + sys.exit(1) ADDRESS = "127.0.0.1" SECRET = b"Kah3choteereethiejeimaeziecumi" @@ -10,6 +15,8 @@ "Acct-Session-Id": "1337" } +ATTRIBUTES["NAS-Identifier"] = sys.argv[2] + # create coa client client = Client(server=ADDRESS, secret=SECRET, dict=dictionary.Dictionary("dictionary")) @@ -19,10 +26,14 @@ # create coa request packet attributes = {k.replace("-", "_"): ATTRIBUTES[k] for k in ATTRIBUTES} -# create coa request -request = client.CreateCoAPacket(**attributes) -# create disconnect request -# request = client.CreateCoAPacket(code=packet.DisconnectRequest, **attributes) +if sys.argv[1] == "coa": + # create coa request + request = client.CreateCoAPacket(**attributes) +elif sys.argv[1] == "dis": + # create disconnect request + request = client.CreateCoAPacket(code=packet.DisconnectRequest, **attributes) +else: + sys.exit(1) # send request result = client.SendPacket(request) diff --git a/example/pyrad.log b/example/pyrad.log new file mode 100644 index 0000000..e69de29 diff --git a/example/server_async.py b/example/server_async.py new file mode 100644 index 0000000..3b893da --- /dev/null +++ b/example/server_async.py @@ -0,0 +1,117 @@ +#!/usr/bin/python + +import asyncio + +import logging +import traceback +from pyrad.dictionary import Dictionary +from pyrad.server_async import ServerAsync +from pyrad.packet import AccessAccept +from pyrad.server import RemoteHost + +try: + import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except: + pass + +logging.basicConfig(level="DEBUG", + format="%(asctime)s [%(levelname)-8s] %(message)s") + +class FakeServer(ServerAsync): + + def __init__(self, loop, dictionary): + + ServerAsync.__init__(self, loop=loop, dictionary=dictionary, + enable_pkt_verify=True, debug=True) + + + def handle_auth_packet(self, protocol, pkt, addr): + + print("Received an authentication request with id ", pkt.id) + print('Authenticator ', pkt.authenticator.hex()) + print('Secret ', pkt.secret) + print("Attributes: ") + for attr in pkt.keys(): + print("%s: %s" % (attr, pkt[attr])) + + reply = self.CreateReplyPacket(pkt, **{ + "Service-Type": "Framed-User", + "Framed-IP-Address": '192.168.0.1', + "Framed-IPv6-Prefix": "fc66::1/64" + }) + + reply.code = AccessAccept + protocol.send_response(reply, addr) + + def handle_acct_packet(self, protocol, pkt, addr): + + print("Received an accounting request") + print("Attributes: ") + for attr in pkt.keys(): + print("%s: %s" % (attr, pkt[attr])) + + reply = self.CreateReplyPacket(pkt) + protocol.send_response(reply, addr) + + def handle_coa_packet(self, protocol, pkt, addr): + + print("Received an coa request") + print("Attributes: ") + for attr in pkt.keys(): + print("%s: %s" % (attr, pkt[attr])) + + reply = self.CreateReplyPacket(pkt) + protocol.send_response(reply, addr) + + def handle_disconnect_packet(self, protocol, pkt, addr): + + print("Received an disconnect request") + print("Attributes: ") + for attr in pkt.keys(): + print("%s: %s" % (attr, pkt[attr])) + + reply = self.CreateReplyPacket(pkt) + # COA NAK + reply.code = 45 + protocol.send_response(reply, addr) + + +if __name__ == '__main__': + + # create server and read dictionary + loop = asyncio.get_event_loop() + server = FakeServer(loop=loop, dictionary=Dictionary('dictionary')) + + # add clients (address, secret, name) + server.hosts["127.0.0.1"] = RemoteHost("127.0.0.1", + b"Kah3choteereethiejeimaeziecumi", + "localhost") + + try: + + # Initialize transports + loop.run_until_complete( + asyncio.ensure_future( + server.initialize_transports(enable_auth=True, + enable_acct=True, + enable_coa=True))) + + try: + # start server + loop.run_forever() + except KeyboardInterrupt as k: + pass + + # Close transports + loop.run_until_complete(asyncio.ensure_future( + server.deinitialize_transports())) + + except Exception as exc: + print('Error: ', exc) + print('\n'.join(traceback.format_exc().splitlines())) + # Close transports + loop.run_until_complete(asyncio.ensure_future( + server.deinitialize_transports())) + + loop.close() diff --git a/pyrad/client.py b/pyrad/client.py index 9369010..b335e24 100644 --- a/pyrad/client.py +++ b/pyrad/client.py @@ -53,6 +53,7 @@ def __init__(self, server, authport=1812, acctport=1813, self._socket = None self.retries = 3 self.timeout = 5 + self._poll = select.poll() def bind(self, addr): """Bind socket to an address. @@ -67,14 +68,20 @@ def bind(self, addr): self._socket.bind(addr) def _SocketOpen(self): + try: + family = socket.getaddrinfo(self.server, 'www')[0][0] + except: + family = socket.AF_INET if not self._socket: - self._socket = socket.socket(socket.AF_INET, + self._socket = socket.socket(family, socket.SOCK_DGRAM) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._poll.register(self._socket, select.POLLIN) def _CloseSocket(self): if self._socket: + self._poll.unregister(self._socket) self._socket.close() self._socket = None @@ -86,7 +93,7 @@ def CreateAuthPacket(self, **args): dictionary and secret used for the client. :return: a new empty packet instance - :rtype: pyrad.packet.Packet + :rtype: pyrad.packet.AuthPacket """ return host.Host.CreateAuthPacket(self, secret=self.secret, **args) @@ -101,7 +108,7 @@ def CreateAcctPacket(self, **args): :rtype: pyrad.packet.Packet """ return host.Host.CreateAcctPacket(self, secret=self.secret, **args) - + def CreateCoAPacket(self, **args): """Create a new RADIUS packet. This utility function creates a new RADIUS packet which can @@ -134,16 +141,16 @@ def _SendPacket(self, pkt, port): pkt["Acct-Delay-Time"][0] + self.timeout else: pkt["Acct-Delay-Time"] = self.timeout - self._socket.sendto(pkt.RequestPacket(), (self.server, port)) now = time.time() waitto = now + self.timeout + self._socket.sendto(pkt.RequestPacket(), (self.server, port)) + while now < waitto: - ready = select.select([self._socket], [], [], - (waitto - now)) + ready = self._poll.poll((waitto - now) * 1000) - if ready[0]: + if ready: rawreply = self._socket.recv(4096) else: now = time.time() diff --git a/pyrad/client_async.py b/pyrad/client_async.py new file mode 100644 index 0000000..de08917 --- /dev/null +++ b/pyrad/client_async.py @@ -0,0 +1,407 @@ +# client_async.py +# +# Copyright 2018-2020 Geaaru gmail.com> + +__docformat__ = "epytext en" + +from datetime import datetime +import asyncio +import six +import logging +import random + +from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket + + +class DatagramProtocolClient(asyncio.Protocol): + + def __init__(self, server, port, logger, + client, retries=3, timeout=30): + self.transport = None + self.port = port + self.server = server + self.logger = logger + self.retries = retries + self.timeout = timeout + self.client = client + + # Map of pending requests + self.pending_requests = {} + + # Use cryptographic-safe random generator as provided by the OS. + random_generator = random.SystemRandom() + self.packet_id = random_generator.randrange(0, 256) + + self.timeout_future = None + + async def __timeout_handler__(self): + + try: + + while True: + + req2delete = [] + now = datetime.now() + next_weak_up = self.timeout + # noinspection PyShadowingBuiltins + for id, req in self.pending_requests.items(): + + secs = (req['send_date'] - now).seconds + if secs > self.timeout: + if req['retries'] == self.retries: + self.logger.debug('[%s:%d] For request %d execute all retries', self.server, self.port, id) + req['future'].set_exception( + TimeoutError('Timeout on Reply') + ) + req2delete.append(id) + else: + # Send again packet + req['send_date'] = now + req['retries'] += 1 + self.logger.debug('[%s:%d] For request %d execute retry %d', self.server, self.port, id, req['retries']) + self.transport.sendto(req['packet'].RequestPacket()) + elif next_weak_up > secs: + next_weak_up = secs + + # noinspection PyShadowingBuiltins + for id in req2delete: + # Remove request for map + del self.pending_requests[id] + + await asyncio.sleep(next_weak_up) + + except asyncio.CancelledError: + pass + + def send_packet(self, packet, future): + if packet.id in self.pending_requests: + raise Exception('Packet with id %d already present' % packet.id) + + # Store packet on pending requests map + self.pending_requests[packet.id] = { + 'packet': packet, + 'creation_date': datetime.now(), + 'retries': 0, + 'future': future, + 'send_date': datetime.now() + } + + # In queue packet raw on socket buffer + self.transport.sendto(packet.RequestPacket()) + + def connection_made(self, transport): + self.transport = transport + socket = transport.get_extra_info('socket') + self.logger.info( + '[%s:%d] Transport created with binding in %s:%d', + self.server, self.port, + socket.getsockname()[0], + socket.getsockname()[1] + ) + + pre_loop = asyncio.get_event_loop() + asyncio.set_event_loop(loop=self.client.loop) + # Start asynchronous timer handler + self.timeout_future = asyncio.ensure_future( + self.__timeout_handler__() + ) + asyncio.set_event_loop(loop=pre_loop) + + def error_received(self, exc): + self.logger.error('[%s:%d] Error received: %s', self.server, self.port, exc) + + def connection_lost(self, exc): + if exc: + self.logger.warn('[%s:%d] Connection lost: %s', self.server, self.port, str(exc)) + else: + self.logger.info('[%s:%d] Transport closed', self.server, self.port) + + # noinspection PyUnusedLocal + def datagram_received(self, data, addr): + try: + reply = Packet(packet=data, dict=self.client.dict) + + if reply and reply.id in self.pending_requests: + req = self.pending_requests[reply.id] + packet = req['packet'] + + reply.dict = packet.dict + reply.secret = packet.secret + + if packet.VerifyReply(reply, data): + req['future'].set_result(reply) + # Remove request for map + del self.pending_requests[reply.id] + else: + self.logger.warn('[%s:%d] Ignore invalid reply for id %d. %s', self.server, self.port, reply.id) + else: + self.logger.warn('[%s:%d] Ignore invalid reply: %d', self.server, self.port, data) + + except Exception as exc: + self.logger.error('[%s:%d] Error on decode packet: %s', self.server, self.port, exc) + + async def close_transport(self): + if self.transport: + self.logger.debug('[%s:%d] Closing transport...', self.server, self.port) + self.transport.close() + self.transport = None + if self.timeout_future: + self.timeout_future.cancel() + await self.timeout_future + self.timeout_future = None + + def create_id(self): + self.packet_id = (self.packet_id + 1) % 256 + return self.packet_id + + def __str__(self): + return 'DatagramProtocolClient(server?=%s, port=%d)' % (self.server, self.port) + + # Used as protocol_factory + def __call__(self): + return self + + +class ClientAsync: + """Basic RADIUS client. + This class implements a basic RADIUS client. It can send requests + to a RADIUS server, taking care of timeouts and retries, and + validate its replies. + + :ivar retries: number of times to retry sending a RADIUS request + :type retries: integer + :ivar timeout: number of seconds to wait for an answer + :type timeout: integer + """ + # noinspection PyShadowingBuiltins + def __init__(self, server, auth_port=1812, acct_port=1813, + coa_port=3799, secret=six.b(''), dict=None, + loop=None, retries=3, timeout=30, + logger_name='pyrad'): + + """Constructor. + + :param server: hostname or IP address of RADIUS server + :type server: string + :param auth_port: port to use for authentication packets + :type auth_port: integer + :param acct_port: port to use for accounting packets + :type acct_port: integer + :param coa_port: port to use for CoA packets + :type coa_port: integer + :param secret: RADIUS secret + :type secret: string + :param dict: RADIUS dictionary + :type dict: pyrad.dictionary.Dictionary + :param loop: Python loop handler + :type loop: asyncio event loop + """ + if not loop: + self.loop = asyncio.get_event_loop() + else: + self.loop = loop + self.logger = logging.getLogger(logger_name) + + self.server = server + self.secret = secret + self.retries = retries + self.timeout = timeout + self.dict = dict + + self.auth_port = auth_port + self.protocol_auth = None + + self.acct_port = acct_port + self.protocol_acct = None + + self.protocol_coa = None + self.coa_port = coa_port + + async def initialize_transports(self, enable_acct=False, + enable_auth=False, enable_coa=False, + local_addr=None, local_auth_port=None, + local_acct_port=None, local_coa_port=None): + + task_list = [] + + if not enable_acct and not enable_auth and not enable_coa: + raise Exception('No transports selected') + + if enable_acct and not self.protocol_acct: + self.protocol_acct = DatagramProtocolClient( + self.server, + self.acct_port, + self.logger, self, + retries=self.retries, + timeout=self.timeout + ) + bind_addr = None + if local_addr and local_acct_port: + bind_addr = (local_addr, local_acct_port) + + acct_connect = self.loop.create_datagram_endpoint( + self.protocol_acct, + reuse_address=True, reuse_port=True, + remote_addr=(self.server, self.acct_port), + local_addr=bind_addr + ) + task_list.append(acct_connect) + + if enable_auth and not self.protocol_auth: + self.protocol_auth = DatagramProtocolClient( + self.server, + self.auth_port, + self.logger, self, + retries=self.retries, + timeout=self.timeout + ) + bind_addr = None + if local_addr and local_auth_port: + bind_addr = (local_addr, local_auth_port) + + auth_connect = self.loop.create_datagram_endpoint( + self.protocol_auth, + reuse_address=True, reuse_port=True, + remote_addr=(self.server, self.auth_port), + local_addr=bind_addr + ) + task_list.append(auth_connect) + + if enable_coa and not self.protocol_coa: + self.protocol_coa = DatagramProtocolClient( + self.server, + self.coa_port, + self.logger, self, + retries=self.retries, + timeout=self.timeout + ) + bind_addr = None + if local_addr and local_coa_port: + bind_addr = (local_addr, local_coa_port) + + coa_connect = self.loop.create_datagram_endpoint( + self.protocol_coa, + reuse_address=True, reuse_port=True, + remote_addr=(self.server, self.coa_port), + local_addr=bind_addr + ) + task_list.append(coa_connect) + + await asyncio.ensure_future( + asyncio.gather( + *task_list, + return_exceptions=False, + ), + loop=self.loop + ) + + # noinspection SpellCheckingInspection + async def deinitialize_transports(self, deinit_coa=True, + deinit_auth=True, + deinit_acct=True): + if self.protocol_coa and deinit_coa: + await self.protocol_coa.close_transport() + del self.protocol_coa + self.protocol_coa = None + if self.protocol_auth and deinit_auth: + await self.protocol_auth.close_transport() + del self.protocol_auth + self.protocol_auth = None + if self.protocol_acct and deinit_acct: + await self.protocol_acct.close_transport() + del self.protocol_acct + self.protocol_acct = None + + # noinspection PyPep8Naming + def CreateAuthPacket(self, **args): + """Create a new RADIUS packet. + This utility function creates a new RADIUS packet which can + be used to communicate with the RADIUS server this client + talks to. This is initializing the new packet with the + dictionary and secret used for the client. + + :return: a new empty packet instance + :rtype: pyrad.packet.Packet + """ + if not self.protocol_auth: + raise Exception('Transport not initialized') + + return AuthPacket(dict=self.dict, + id=self.protocol_auth.create_id(), + secret=self.secret, **args) + + # noinspection PyPep8Naming + def CreateAcctPacket(self, **args): + """Create a new RADIUS packet. + This utility function creates a new RADIUS packet which can + be used to communicate with the RADIUS server this client + talks to. This is initializing the new packet with the + dictionary and secret used for the client. + + :return: a new empty packet instance + :rtype: pyrad.packet.Packet + """ + if not self.protocol_acct: + raise Exception('Transport not initialized') + + return AcctPacket(id=self.protocol_acct.create_id(), + dict=self.dict, + secret=self.secret, **args) + + # noinspection PyPep8Naming + def CreateCoAPacket(self, **args): + """Create a new RADIUS packet. + This utility function creates a new RADIUS packet which can + be used to communicate with the RADIUS server this client + talks to. This is initializing the new packet with the + dictionary and secret used for the client. + + :return: a new empty packet instance + :rtype: pyrad.packet.Packet + """ + + if not self.protocol_acct: + raise Exception('Transport not initialized') + + return CoAPacket(id=self.protocol_coa.create_id(), + dict=self.dict, + secret=self.secret, **args) + + # noinspection PyPep8Naming + # noinspection PyShadowingBuiltins + def CreatePacket(self, id, **args): + if not id: + raise Exception('Missing mandatory packet id') + + return Packet(id=id, dict=self.dict, + secret=self.secret, **args) + + # noinspection PyPep8Naming + def SendPacket(self, pkt): + """Send a packet to a RADIUS server. + + :param pkt: the packet to send + :type pkt: pyrad.packet.Packet + :return: Future related with packet to send + :rtype: asyncio.Future + """ + + ans = asyncio.Future(loop=self.loop) + + if isinstance(pkt, AuthPacket): + if not self.protocol_auth: + raise Exception('Transport not initialized') + + self.protocol_auth.send_packet(pkt, ans) + + elif isinstance(pkt, AcctPacket): + if not self.protocol_acct: + raise Exception('Transport not initialized') + + elif isinstance(pkt, CoAPacket): + if not self.protocol_coa: + raise Exception('Transport not initialized') + else: + raise Exception('Unsupported packet') + + return ans diff --git a/pyrad/dictionary.py b/pyrad/dictionary.py index 6c415e2..33639c8 100644 --- a/pyrad/dictionary.py +++ b/pyrad/dictionary.py @@ -30,21 +30,35 @@ The datatypes currently supported are: -======= ====================== -type description -======= ====================== -string ASCII string -ipaddr IPv4 address -date 32 bits UNIX timestamp -octets arbitrary binary data -abinary ascend binary data -ipv6addr 16 octets in network byte order -ipv6prefix 18 octets in network byte order -integer 32 bits unsigned number -signed 32 bits signed number -short 16 bits unsigned number -byte 8 bits unsigned number -======= ====================== ++---------------+----------------------------------------------+ +| type | description | ++===============+==============================================+ +| string | ASCII string | ++---------------+----------------------------------------------+ +| ipaddr | IPv4 address | ++---------------+----------------------------------------------+ +| date | 32 bits UNIX | ++---------------+----------------------------------------------+ +| octets | arbitrary binary data | ++---------------+----------------------------------------------+ +| abinary | ascend binary data | ++---------------+----------------------------------------------+ +| ipv6addr | 16 octets in network byte order | ++---------------+----------------------------------------------+ +| ipv6prefix | 18 octets in network byte order | ++---------------+----------------------------------------------+ +| integer | 32 bits unsigned number | ++---------------+----------------------------------------------+ +| signed | 32 bits signed number | ++---------------+----------------------------------------------+ +| short | 16 bits unsigned number | ++---------------+----------------------------------------------+ +| byte | 8 bits unsigned number | ++---------------+----------------------------------------------+ +| tlv | Nested tag-length-value | ++---------------+----------------------------------------------+ +| integer64 | 64 bits unsigned number | ++---------------+----------------------------------------------+ These datatypes are parsed but not supported: @@ -68,7 +82,7 @@ DATATYPES = frozenset(['string', 'ipaddr', 'integer', 'date', 'octets', 'abinary', 'ipv6addr', 'ipv6prefix', 'short', 'byte', - 'signed', 'ifid', 'ether']) + 'signed', 'ifid', 'ether', 'tlv', 'integer64']) class ParseError(Exception): @@ -101,7 +115,7 @@ def __str__(self): class Attribute(object): - def __init__(self, name, code, datatype, vendor='', values={}, + def __init__(self, name, code, datatype, is_sub_attribute=False, vendor='', values=None, encrypt=0, has_tag=False): if datatype not in DATATYPES: raise ValueError('Invalid data type') @@ -112,8 +126,12 @@ def __init__(self, name, code, datatype, vendor='', values={}, self.encrypt = encrypt self.has_tag = has_tag self.values = bidict.BiDict() - for (key, value) in values.items(): - self.values.Add(key, value) + self.sub_attributes = {} + self.parent = None + self.is_sub_attribute = is_sub_attribute + if values: + for (key, value) in values.items(): + self.values.Add(key, value) class Dictionary(object): @@ -201,11 +219,16 @@ def keyval(o): (attribute, code, datatype) = tokens[1:4] - try: - # todo: check if float like for extended attributes - code = int(code, 0) - except: - return None + codes = code.split('.') + is_sub_attribute = (len(codes) > 1) + if len(codes) == 2: + code = int(codes[1]) + parent_code = int(codes[0]) + elif len(codes) == 1: + code = int(codes[0]) + parent_code = None + else: + raise ParseError('nested tlvs are not supported') datatype = datatype.split("[")[0] @@ -214,12 +237,25 @@ def keyval(o): file=state['file'], line=state['line']) if vendor: - key = (self.vendors.GetForward(vendor), code) + if is_sub_attribute: + key = (self.vendors.GetForward(vendor), parent_code, code) + else: + key = (self.vendors.GetForward(vendor), code) else: - key = code + if is_sub_attribute: + key = (parent_code, code) + else: + key = code self.attrindex.Add(attribute, key) - self.attributes[attribute] = Attribute(attribute, code, datatype, vendor, encrypt=encrypt, has_tag=has_tag) + self.attributes[attribute] = Attribute(attribute, code, datatype, is_sub_attribute, vendor, encrypt=encrypt, has_tag=has_tag) + if datatype == 'tlv': + # save attribute in tlvs + state['tlvs'][code] = self.attributes[attribute] + if is_sub_attribute: + # save sub attribute in parent tlv and update their parent field + state['tlvs'][parent_code].sub_attributes[code] = attribute + self.attributes[attribute].parent = state['tlvs'][parent_code] def __ParseValue(self, state, tokens, defer): if len(tokens) != 4: @@ -239,7 +275,7 @@ def __ParseValue(self, state, tokens, defer): file=state['file'], line=state['line']) - if adef.type in ['integer','signed','short','byte']: + if adef.type in ['integer', 'signed', 'short', 'byte', 'integer64']: value = int(value, 0) value = tools.EncodeAttr(adef.type, value) self.attributes[attr].values.Add(key, value) @@ -322,7 +358,7 @@ class instance. state = {} state['vendor'] = '' - + state['tlvs'] = {} self.defer_parse = [] for line in fil: state['file'] = fil.File() diff --git a/pyrad/packet.py b/pyrad/packet.py index dfb9b58..27c14f4 100644 --- a/pyrad/packet.py +++ b/pyrad/packet.py @@ -15,7 +15,9 @@ import md5 md5_constructor = md5.new import six + from pyrad import tools +from pyrad import dictionary # Packet codes AccessRequest = 1 @@ -85,7 +87,7 @@ def __init__(self, code=0, id=None, secret=six.b(''), authenticator=None, **attr self.secret = secret if authenticator is not None and \ not isinstance(authenticator, six.binary_type): - raise TypeError('authenticator must be a binary string') + raise TypeError('authenticator must be a binary string') self.authenticator = authenticator if 'dict' in attributes: @@ -133,13 +135,8 @@ def _EncodeKeyValues(self, key, values): return (key, values) key, _, tag = key.partition(":") - attr = self.dict.attributes[key] - if attr.vendor: - key = (self.dict.vendors.GetForward(attr.vendor), attr.code) - else: - key = attr.code - + key = self._EncodeKey(key) if tag: tag = struct.pack('B', int(tag)) if attr.type == "integer": @@ -154,7 +151,7 @@ def _EncodeKey(self, key): return key attr = self.dict.attributes[key] - if attr.vendor: + if attr.vendor and not attr.is_sub_attribute: #sub attribute keys don't need vendor return (self.dict.vendors.GetForward(attr.vendor), attr.code) else: return attr.code @@ -175,13 +172,20 @@ def AddAttribute(self, key, value): :param value: value :type value: depends on type of attribute """ + attr = self.dict.attributes[key] + if isinstance(value, list): (key, value) = self._EncodeKeyValues(key, value) - self.setdefault(key, []).extend(value) else: (key, value) = self._EncodeKeyValues(key, [value]) - value = value[0] - self.setdefault(key, []).append(value) + + if attr.is_sub_attribute: + tlv = self.setdefault(self._EncodeKey(attr.parent.name), {}) + encoded = tlv.setdefault(key, []) + else: + encoded = self.setdefault(key, []) + + encoded.extend(value) def __getitem__(self, key): if not isinstance(key, six.string_types): @@ -189,10 +193,19 @@ def __getitem__(self, key): values = dict.__getitem__(self, self._EncodeKey(key)) attr = self.dict.attributes[key] - res = [] - for v in values: - res.append(self._DecodeValue(attr, v)) - return res + if attr.type == 'tlv': # return map from sub attribute code to its values + res = {} + for (sub_attr_key, sub_attr_val) in values.items(): + sub_attr_name = attr.sub_attributes[sub_attr_key] + sub_attr = self.dict.attributes[sub_attr_name] + for v in sub_attr_val: + res.setdefault(sub_attr_name, []).append(self._DecodeValue(sub_attr, v)) + return res + else: + res = [] + for v in values: + res.append(self._DecodeValue(attr, v)) + return res def __contains__(self, key): try: @@ -287,12 +300,46 @@ def _PktEncodeAttribute(self, key, value): return struct.pack('!BB', key, (len(value) + 2)) + value + def _PktEncodeTlv(self, tlv_key, tlv_value): + tlv_attr = self.dict.attributes[self._DecodeKey(tlv_key)] + curr_avp = six.b('') + avps = [] + max_sub_attribute_len = max(map(lambda item: len(item[1]), tlv_value.items())) + for i in range(max_sub_attribute_len): + sub_attr_encoding = six.b('') + for (code, datalst) in tlv_value.items(): + if i < len(datalst): + sub_attr_encoding += self._PktEncodeAttribute(code, datalst[i]) + # split above 255. assuming len of one instance of all sub tlvs is lower than 255 + if (len(sub_attr_encoding) + len(curr_avp)) < 245: + curr_avp += sub_attr_encoding + else: + avps.append(curr_avp) + curr_avp = sub_attr_encoding + avps.append(curr_avp) + tlv_avps = [] + for avp in avps: + value = struct.pack('!BB', tlv_attr.code, (len(avp) + 2)) + avp + tlv_avps.append(value) + if tlv_attr.vendor: + vendor_avps = six.b('') + for avp in tlv_avps: + vendor_avps += struct.pack( + '!BBL', 26, (len(avp) + 6), + self.dict.vendors.GetForward(tlv_attr.vendor) + ) + avp + return vendor_avps + else: + return b''.join(tlv_avps) + def _PktEncodeAttributes(self): result = six.b('') for (code, datalst) in self.items(): - for data in datalst: - result += self._PktEncodeAttribute(code, data) - + if self.dict.attributes[self._DecodeKey(code)].type == 'tlv': + result += self._PktEncodeTlv(code, datalst) + else: + for data in datalst: + result += self._PktEncodeAttribute(code, data) return result def _PktDecodeVendorAttribute(self, data): @@ -303,7 +350,14 @@ def _PktDecodeVendorAttribute(self, data): (vendor, type, length) = struct.unpack('!LBB', data[:6])[0:3] - tlvs = [((vendor, type), data[6:length+4])] + try: + if self.dict.attributes[self._DecodeKey((vendor, type))].type == 'tlv': + self._PktDecodeTlvAttribute((vendor, type), data[6:length + 4]) + tlvs = [] # tlv is added to the packet inside _PktDecodeTlvAttribute + else: + tlvs = [((vendor, type), data[6:length + 4])] + except: + return [(26, data)] sumlength = 4 + length while len(data) > sumlength: @@ -315,6 +369,30 @@ def _PktDecodeVendorAttribute(self, data): sumlength += length return tlvs + def _PktDecodeTlvAttribute(self, code, data): + + sub_attributes = self.setdefault(code, {}) + loc = 0 + + while loc < len(data): + type, length = struct.unpack('!BB', data[loc:loc+2])[0:2] + sub_attributes.setdefault(type, []).append(data[loc+2:loc+length]) + loc += length + + def _DictionaryHasAttribute(self, decoded_key): + """Determines if the dictionary has an attribute""" + attribute = self.dict.attributes.get(decoded_key) + if attribute is None: + message = 'Attribute "{}" does not exist in the dictionaries.'.format(decoded_key) + raise dictionary.ParseError(message) + + return True + + def _DictionaryAttributeTypeIs(self, decoded_key, _type): + """Determines if a dictionary attribute is certain type""" + attribute = self.dict.attributes.get(decoded_key) + return attribute.type == _type + def DecodePacket(self, packet): """Initialize the object from raw packet data. Decode a packet as received from the network and decode it. @@ -325,6 +403,7 @@ def DecodePacket(self, packet): try: (self.code, self.id, length, self.authenticator) = \ struct.unpack('!BBH16s', packet[0:20]) + except struct.error: raise PacketError('Packet header is corrupt') if len(packet) != length: @@ -346,9 +425,13 @@ def DecodePacket(self, packet): 'Attribute length is too small (%d)' % attrlen) value = packet[2:attrlen] + decoded_key = self._DecodeKey(key) + if key == 26: for (key, value) in self._PktDecodeVendorAttribute(value): self.setdefault(key, []).append(value) + elif self._DictionaryHasAttribute(decoded_key) and self._DictionaryAttributeTypeIs(decoded_key, 'tlv'): + self._PktDecodeTlvAttribute(key,value) else: self.setdefault(key, []).append(value) @@ -415,6 +498,8 @@ def __init__(self, code=AccessRequest, id=None, secret=six.b(''), :type packet: string """ Packet.__init__(self, code, id, secret, authenticator, **attributes) + if 'packet' in attributes: + self.raw_packet = attributes['packet'] def CreateReply(self, **attributes): """Create a new packet as a reply to this one. This method @@ -546,6 +631,17 @@ def VerifyChapPasswd(self, userpwd): return password == md5_constructor("%s%s%s" % (chapid, userpwd, challenge)).digest() + def VerifyAuthRequest(self): + """Verify request authenticator. + + :return: True if verification failed else False + :rtype: boolean + """ + assert(self.raw_packet) + hash = md5_constructor(self.raw_packet[0:4] + 16 * six.b('\x00') + + self.raw_packet[20:] + self.secret).digest() + return hash == self.authenticator + class AcctPacket(Packet): """RADIUS accounting packets. This class is a specialization @@ -651,7 +747,7 @@ def VerifyCoARequest(self): """ assert(self.raw_packet) hash = md5_constructor(self.raw_packet[0:4] + 16 * six.b('\x00') + - self.raw_packet[20:] + self.secret).digest() + self.raw_packet[20:] + self.secret).digest() return hash == self.authenticator def RequestPacket(self): diff --git a/pyrad/server.py b/pyrad/server.py index 92434a5..d732798 100644 --- a/pyrad/server.py +++ b/pyrad/server.py @@ -103,6 +103,27 @@ def __init__(self, addresses=[], authport=1812, acctport=1813, coaport=3799, for addr in addresses: self.BindToAddress(addr) + def _GetAddrInfo(self, addr): + """Use getaddrinfo to lookup all addresses for each address. + + Returns a list of tuples or an empty list: + [(family, address)] + + :param addr: IP address to lookup + :type addr: string + """ + results = set() + try: + tmp = socket.getaddrinfo(addr, 'www') + except socket.gaierror: + return [] + + for el in tmp: + results.add((el[0], el[4][0])) + + return results + + def BindToAddress(self, addr): """Add an address to listen to. An empty string indicated you want to listen on all addresses. @@ -110,23 +131,25 @@ def BindToAddress(self, addr): :param addr: IP address to listen on :type addr: string """ - if self.auth_enabled: - authfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - authfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - authfd.bind((addr, self.authport)) - self.authfds.append(authfd) - - if self.acct_enabled: - acctfd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - acctfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - acctfd.bind((addr, self.acctport)) - self.acctfds.append(acctfd) - - if self.coa_enabled: - coafd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - coafd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - coafd.bind((addr, self.coaport)) - self.coafds.append(coafd) + addrFamily = self._GetAddrInfo(addr) + for (family, address) in addrFamily: + if self.auth_enabled: + authfd = socket.socket(family, socket.SOCK_DGRAM) + authfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + authfd.bind((address, self.authport)) + self.authfds.append(authfd) + + if self.acct_enabled: + acctfd = socket.socket(family, socket.SOCK_DGRAM) + acctfd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + acctfd.bind((address, self.acctport)) + self.acctfds.append(acctfd) + + if self.coa_enabled: + coafd = socket.socket(family, socket.SOCK_DGRAM) + coafd.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + coafd.bind((address, self.coaport)) + self.coafds.append(coafd) def HandleAuthPacket(self, pkt): @@ -169,6 +192,19 @@ def HandleDisconnectPacket(self, pkt): :type pkt: Packet class instance """ + def _AddSecret(self, pkt): + """Add secret to packets received and raise ServerPacketError + for unknown hosts. + + :param pkt: packet to process + :type pkt: Packet class instance + """ + if pkt.source[0] in self.hosts: + pkt.secret = self.hosts[pkt.source[0]].secret + elif '0.0.0.0' in self.hosts: + pkt.secret = self.hosts['0.0.0.0'].secret + else: + raise ServerPacketError('Received packet from unknown host') def _HandleAuthPacket(self, pkt): """Process a packet received on the authentication port. @@ -179,10 +215,7 @@ def _HandleAuthPacket(self, pkt): :param pkt: packet to process :type pkt: Packet class instance """ - if pkt.source[0] not in self.hosts: - raise ServerPacketError('Received packet from unknown host') - - pkt.secret = self.hosts[pkt.source[0]].secret + self._AddSecret(pkt) if pkt.code != packet.AccessRequest: raise ServerPacketError( 'Received non-authentication packet on authentication port') @@ -197,10 +230,7 @@ def _HandleAcctPacket(self, pkt): :param pkt: packet to process :type pkt: Packet class instance """ - if pkt.source[0] not in self.hosts: - raise ServerPacketError('Received packet from unknown host') - - pkt.secret = self.hosts[pkt.source[0]].secret + self._AddSecret(pkt) if pkt.code not in [packet.AccountingRequest, packet.AccountingResponse]: raise ServerPacketError( @@ -216,9 +246,7 @@ def _HandleCoaPacket(self, pkt): :param pkt: packet to process :type pkt: Packet class instance """ - if pkt.source[0] not in self.hosts: - raise ServerPacketError('Received packet from unknown host') - + self._AddSecret(pkt) pkt.secret = self.hosts[pkt.source[0]].secret if pkt.code == packet.CoARequest: self.HandleCoaPacket(pkt) @@ -281,15 +309,17 @@ def _ProcessInput(self, fd): :param fd: socket to read packet from :type fd: socket class instance """ - if fd.fileno() in self._realauthfds: + if self.auth_enabled and fd.fileno() in self._realauthfds: pkt = self._GrabPacket(lambda data, s=self: s.CreateAuthPacket(packet=data), fd) self._HandleAuthPacket(pkt) - elif fd.fileno() in self._realacctfds: + elif self.acct_enabled and fd.fileno() in self._realacctfds: pkt = self._GrabPacket(lambda data, s=self: s.CreateAcctPacket(packet=data), fd) self._HandleAcctPacket(pkt) - else: + elif self.coa_enabled: pkt = self._GrabPacket(lambda data, s=self: s.CreateCoAPacket(packet=data), fd) self._HandleCoaPacket(pkt) + else: + raise ServerPacketError('Received packet for unknown handler') def Run(self): """Main loop. diff --git a/pyrad/server_async.py b/pyrad/server_async.py new file mode 100644 index 0000000..070754d --- /dev/null +++ b/pyrad/server_async.py @@ -0,0 +1,336 @@ +# server_async.py +# +# Copyright 2018-2019 Geaaru + +import asyncio +import logging +import traceback + +from abc import abstractmethod, ABCMeta +from enum import Enum +from datetime import datetime +from pyrad.packet import Packet, AccessAccept, AccessReject, \ + AccountingRequest, AccountingResponse, \ + DisconnectACK, DisconnectNAK, DisconnectRequest, CoARequest, \ + CoAACK, CoANAK, AccessRequest, AuthPacket, AcctPacket, CoAPacket, \ + PacketError + +from pyrad.server import ServerPacketError + + +class ServerType(Enum): + Auth = 'Authentication' + Acct = 'Accounting' + Coa = 'Coa' + + +class DatagramProtocolServer(asyncio.Protocol): + + def __init__(self, ip, port, logger, server, server_type, hosts, + request_callback): + self.transport = None + self.ip = ip + self.port = port + self.logger = logger + self.server = server + self.hosts = hosts + self.server_type = server_type + self.request_callback = request_callback + + def connection_made(self, transport): + self.transport = transport + self.logger.info('[%s:%d] Transport created', self.ip, self.port) + + def connection_lost(self, exc): + if exc: + self.logger.warn('[%s:%d] Connection lost: %s', self.ip, self.port, str(exc)) + else: + self.logger.info('[%s:%d] Transport closed', self.ip, self.port) + + def send_response(self, reply, addr): + self.transport.sendto(reply.ReplyPacket(), addr) + + def datagram_received(self, data, addr): + self.logger.debug('[%s:%d] Received %d bytes from %s', self.ip, self.port, len(data), addr) + + receive_date = datetime.utcnow() + + if addr[0] in self.hosts: + remote_host = self.hosts[addr[0]] + elif '0.0.0.0' in self.hosts: + remote_host = self.hosts['0.0.0.0'].secret + else: + self.logger.warn('[%s:%d] Drop package from unknown source %s', self.ip, self.port, addr) + return + + try: + self.logger.debug('[%s:%d] Received from %s packet: %s', self.ip, self.port, addr, data.hex()) + req = Packet(packet=data, dict=self.server.dict) + except Exception as exc: + self.logger.error('[%s:%d] Error on decode packet: %s', self.ip, self.port, exc) + return + + try: + if req.code in (AccountingResponse, AccessAccept, AccessReject, CoANAK, CoAACK, DisconnectNAK, DisconnectACK): + raise ServerPacketError('Invalid response packet %d' % req.code) + + elif self.server_type == ServerType.Auth: + if req.code != AccessRequest: + raise ServerPacketError('Received non-auth packet on auth port') + req = AuthPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + if self.server.enable_pkt_verify: + if req.VerifyAuthRequest(): + raise PacketError('Packet verification failed') + + elif self.server_type == ServerType.Coa: + if req.code != DisconnectRequest and req.code != CoARequest: + raise ServerPacketError('Received non-coa packet on coa port') + req = CoAPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + if self.server.enable_pkt_verify: + if req.VerifyCoARequest(): + raise PacketError('Packet verification failed') + + elif self.server_type == ServerType.Acct: + + if req.code != AccountingRequest: + raise ServerPacketError('Received non-acct packet on acct port') + req = AcctPacket(secret=remote_host.secret, + dict=self.server.dict, + packet=data) + if self.server.enable_pkt_verify: + if req.VerifyAcctRequest(): + raise PacketError('Packet verification failed') + + # Call request callback + self.request_callback(self, req, addr) + except Exception as exc: + if self.server.debug: + self.logger.exception('[%s:%d] Error for packet from %s', self.ip, self.port, addr) + else: + self.logger.error('[%s:%d] Error for packet from %s: %s', self.ip, self.port, addr, exc) + + process_date = datetime.utcnow() + self.logger.debug('[%s:%d] Request from %s processed in %d ms', self.ip, self.port, addr, (process_date-receive_date).microseconds/1000) + + def error_received(self, exc): + self.logger.error('[%s:%d] Error received: %s', self.ip, self.port, exc) + + async def close_transport(self): + if self.transport: + self.logger.debug('[%s:%d] Close transport...', self.ip, self.port) + self.transport.close() + self.transport = None + + def __str__(self): + return 'DatagramProtocolServer(ip=%s, port=%d)' % (self.ip, self.port) + + # Used as protocol_factory + def __call__(self): + return self + + +class ServerAsync(metaclass=ABCMeta): + + def __init__(self, auth_port=1812, acct_port=1813, + coa_port=3799, hosts=None, dictionary=None, + loop=None, logger_name='pyrad', + enable_pkt_verify=False, + debug=False): + + if not loop: + self.loop = asyncio.get_event_loop() + else: + self.loop = loop + self.logger = logging.getLogger(logger_name) + + if hosts is None: + self.hosts = {} + else: + self.hosts = hosts + + self.auth_port = auth_port + self.auth_protocols = [] + + self.acct_port = acct_port + self.acct_protocols = [] + + self.coa_port = coa_port + self.coa_protocols = [] + + self.dict = dictionary + self.enable_pkt_verify = enable_pkt_verify + + self.debug = debug + + def __request_handler__(self, protocol, req, addr): + + try: + if protocol.server_type == ServerType.Acct: + self.handle_acct_packet(protocol, req, addr) + elif protocol.server_type == ServerType.Auth: + self.handle_auth_packet(protocol, req, addr) + elif protocol.server_type == ServerType.Coa and \ + req.code == CoARequest: + self.handle_coa_packet(protocol, req, addr) + elif protocol.server_type == ServerType.Coa and \ + req.code == DisconnectRequest: + self.handle_disconnect_packet(protocol, req, addr) + else: + self.logger.error('[%s:%s] Unexpected request found', protocol.ip, protocol.port) + except Exception as exc: + if self.debug: + self.logger.exception('[%s:%s] Unexpected error', protocol.ip, protocol.port) + + else: + self.logger.error('[%s:%s] Unexpected error: %s', protocol.ip, protocol.port, exc) + + def __is_present_proto__(self, ip, port): + if port == self.auth_port: + for proto in self.auth_protocols: + if proto.ip == ip: + return True + elif port == self.acct_port: + for proto in self.acct_protocols: + if proto.ip == ip: + return True + elif port == self.coa_port: + for proto in self.coa_protocols: + if proto.ip == ip: + return True + return False + + # noinspection PyPep8Naming + @staticmethod + def CreateReplyPacket(pkt, **attributes): + """Create a reply packet. + Create a new packet which can be returned as a reply to a received + packet. + + :param pkt: original packet + :type pkt: Packet instance + """ + reply = pkt.CreateReply(**attributes) + return reply + + async def initialize_transports(self, enable_acct=False, + enable_auth=False, enable_coa=False, + addresses=None): + + task_list = [] + + if not enable_acct and not enable_auth and not enable_coa: + raise Exception('No transports selected') + if not addresses or len(addresses) == 0: + addresses = ['127.0.0.1'] + + # noinspection SpellCheckingInspection + for addr in addresses: + + if enable_acct and not self.__is_present_proto__(addr, self.acct_port): + protocol_acct = DatagramProtocolServer( + addr, + self.acct_port, + self.logger, self, + ServerType.Acct, + self.hosts, + self.__request_handler__ + ) + + bind_addr = (addr, self.acct_port) + acct_connect = self.loop.create_datagram_endpoint( + protocol_acct, + reuse_address=True, reuse_port=True, + local_addr=bind_addr + ) + self.acct_protocols.append(protocol_acct) + task_list.append(acct_connect) + + if enable_auth and not self.__is_present_proto__(addr, self.auth_port): + protocol_auth = DatagramProtocolServer( + addr, + self.auth_port, + self.logger, self, + ServerType.Auth, + self.hosts, + self.__request_handler__ + ) + bind_addr = (addr, self.auth_port) + + auth_connect = self.loop.create_datagram_endpoint( + protocol_auth, + reuse_address=True, reuse_port=True, + local_addr=bind_addr + ) + self.auth_protocols.append(protocol_auth) + task_list.append(auth_connect) + + if enable_coa and not self.__is_present_proto__(addr, self.coa_port): + protocol_coa = DatagramProtocolServer( + addr, + self.coa_port, + self.logger, self, + ServerType.Coa, + self.hosts, + self.__request_handler__ + ) + bind_addr = (addr, self.coa_port) + + coa_connect = self.loop.create_datagram_endpoint( + protocol_coa, + reuse_address=True, reuse_port=True, + local_addr=bind_addr + ) + self.coa_protocols.append(protocol_coa) + task_list.append(coa_connect) + + await asyncio.ensure_future( + asyncio.gather( + *task_list, + return_exceptions=False, + ), + loop=self.loop + ) + + # noinspection SpellCheckingInspection + async def deinitialize_transports(self, deinit_coa=True, deinit_auth=True, deinit_acct=True): + + if deinit_coa: + for proto in self.coa_protocols: + await proto.close_transport() + del proto + + self.coa_protocols = [] + + if deinit_auth: + for proto in self.auth_protocols: + await proto.close_transport() + del proto + + self.auth_protocols = [] + + if deinit_acct: + for proto in self.acct_protocols: + await proto.close_transport() + del proto + + self.acct_protocols = [] + + @abstractmethod + def handle_auth_packet(self, protocol, pkt, addr): + pass + + @abstractmethod + def handle_acct_packet(self, protocol, pkt, addr): + pass + + @abstractmethod + def handle_coa_packet(self, protocol, pkt, addr): + pass + + @abstractmethod + def handle_disconnect_packet(self, protocol, pkt, addr): + pass diff --git a/pyrad/tests/data/full b/pyrad/tests/data/full index 1bdc870..a4162c6 100644 --- a/pyrad/tests/data/full +++ b/pyrad/tests/data/full @@ -10,15 +10,25 @@ VALUE Test-Integer Two 2 VALUE Test-Integer Three 3 VALUE Test-Integer Four 4 +ATTRIBUTE Test-Tlv 4 tlv +ATTRIBUTE Test-Tlv-Str 4.1 string +ATTRIBUTE Test-Tlv-Int 4.2 integer + VENDOR Simplon 16 BEGIN-VENDOR Simplon ATTRIBUTE Simplon-Number 1 integer +ATTRIBUTE Simplon-String 2 string VALUE Simplon-Number Zero 0 VALUE Simplon-Number One 1 VALUE Simplon-Number Two 2 VALUE Simplon-Number Three 3 VALUE Simplon-Number Four 4 + +ATTRIBUTE Simplon-Tlv 3 tlv +ATTRIBUTE Simplon-Tlv-Str 3.1 string +ATTRIBUTE Simplon-Tlv-Int 3.2 integer + END-VENDOR Simplon diff --git a/pyrad/tests/data/simple b/pyrad/tests/data/simple index f9694c6..2cb4858 100644 --- a/pyrad/tests/data/simple +++ b/pyrad/tests/data/simple @@ -8,4 +8,7 @@ ATTRIBUTE Test-Ipv6-Address 5 ipv6addr ATTRIBUTE Test-If-Id 6 ifid ATTRIBUTE Test-Date 7 date ATTRIBUTE Test-Abinary 8 abinary - +ATTRIBUTE Test-Tlv 9 tlv +ATTRIBUTE Test-Tlv-Str 9.1 string +ATTRIBUTE Test-Tlv-Int 9.2 integer +ATTRIBUTE Test-Integer64 10 integer64 diff --git a/pyrad/tests/mock.py b/pyrad/tests/mock.py index 47c6746..ee71fb3 100644 --- a/pyrad/tests/mock.py +++ b/pyrad/tests/mock.py @@ -79,12 +79,18 @@ class MockPoll: results = [] def __init__(self): - self.registry = [] + self.registry = {} def register(self, fd, options): - self.registry.append((fd, options)) + self.registry[fd] = options - def poll(self): + def unregister(self, fd): + try: + del self.registry[fd] + except KeyError: + pass + + def poll(self, timeout=None): for result in self.results: yield result raise MockFinished diff --git a/pyrad/tests/testClient.py b/pyrad/tests/testClient.py index 708525f..9ecf42e 100644 --- a/pyrad/tests/testClient.py +++ b/pyrad/tests/testClient.py @@ -1,3 +1,4 @@ +import select import socket import unittest import six @@ -8,6 +9,7 @@ from pyrad.packet import AccessRequest from pyrad.packet import AccountingRequest from pyrad.tests.mock import MockPacket +from pyrad.tests.mock import MockPoll from pyrad.tests.mock import MockSocket BIND_IP = "127.0.0.1" @@ -56,6 +58,7 @@ def setUp(self): self.orgsocket = socket.socket socket.socket = MockSocket + def tearDown(self): socket.socket = self.orgsocket @@ -74,6 +77,7 @@ def testBind(self): def testBindClosesSocket(self): s = MockSocket(socket.AF_INET, socket.SOCK_DGRAM) self.client._socket = s + self.client._poll = MockPoll() self.client.bind((BIND_IP, BIND_PORT)) self.assertEqual(s.closed, True) @@ -146,6 +150,8 @@ def testValidReply(self): self.client.retries = 1 self.client.timeout = 1 self.client._socket = MockSocket(1, 2, six.b("valid reply")) + self.client._poll = MockPoll() + MockPoll.results = [(1, select.POLLIN)] packet = MockPacket(AccountingRequest, verify=True) reply = self.client._SendPacket(packet, 432) self.failUnless(reply is packet.reply) @@ -154,6 +160,7 @@ def testInvalidReply(self): self.client.retries = 1 self.client.timeout = 1 self.client._socket = MockSocket(1, 2, six.b("invalid reply")) + MockPoll.results = [(1, select.POLLIN)] packet = MockPacket(AccountingRequest, verify=False) self.assertRaises(Timeout, self.client._SendPacket, packet, 432) diff --git a/pyrad/tests/testDictionary.py b/pyrad/tests/testDictionary.py index dd6a45e..ab67738 100644 --- a/pyrad/tests/testDictionary.py +++ b/pyrad/tests/testDictionary.py @@ -15,12 +15,14 @@ def testInvalidDataType(self): self.assertRaises(ValueError, Attribute, 'name', 'code', 'datatype') def testConstructionParameters(self): - attr = Attribute('name', 'code', 'integer', 'vendor') + attr = Attribute('name', 'code', 'integer', False, 'vendor') self.assertEqual(attr.name, 'name') self.assertEqual(attr.code, 'code') self.assertEqual(attr.type, 'integer') + self.assertEqual(attr.is_sub_attribute, False) self.assertEqual(attr.vendor, 'vendor') self.assertEqual(len(attr.values), 0) + self.assertEqual(len(attr.sub_attributes), 0) def testNamedConstructionParameters(self): attr = Attribute(name='name', code='code', datatype='integer', @@ -32,7 +34,7 @@ def testNamedConstructionParameters(self): self.assertEqual(len(attr.values), 0) def testValues(self): - attr = Attribute('name', 'code', 'integer', 'vendor', + attr = Attribute('name', 'code', 'integer', False, 'vendor', dict(pie='custard', shake='vanilla')) self.assertEqual(len(attr.values), 2) self.assertEqual(attr.values['shake'], 'vanilla') @@ -63,6 +65,22 @@ def testReadonlyContainer(self): class DictionaryParsingTests(unittest.TestCase): + + simple_dict_values = [ + ('Test-String', 1, 'string'), + ('Test-Octets', 2, 'octets'), + ('Test-Integer', 3, 'integer'), + ('Test-Ip-Address', 4, 'ipaddr'), + ('Test-Ipv6-Address', 5, 'ipv6addr'), + ('Test-If-Id', 6, 'ifid'), + ('Test-Date', 7, 'date'), + ('Test-Abinary', 8, 'abinary'), + ('Test-Tlv', 9, 'tlv'), + ('Test-Tlv-Str', 1, 'string'), + ('Test-Tlv-Int', 2, 'integer'), + ('Test-Integer64', 10, 'integer64') + ] + def setUp(self): self.path = os.path.join(home, 'tests', 'data') self.dict = Dictionary(os.path.join(self.path, 'simple')) @@ -80,19 +98,8 @@ def testParseMultipleDictionaries(self): self.assertEqual(len(dict), 2) def testParseSimpleDictionary(self): - self.assertEqual(len(self.dict), 8) - values = [ - ('Test-String', 1, 'string'), - ('Test-Octets', 2, 'octets'), - ('Test-Integer', 3, 'integer'), - ('Test-Ip-Address', 4, 'ipaddr'), - ('Test-Ipv6-Address', 5, 'ipv6addr'), - ('Test-If-Id', 6, 'ifid'), - ('Test-Date', 7, 'date'), - ('Test-Abinary', 8, 'abinary'), - ] - - for (attr, code, type) in values: + self.assertEqual(len(self.dict),len(self.simple_dict_values)) + for (attr, code, type) in self.simple_dict_values: attr = self.dict[attr] self.assertEqual(attr.code, code) self.assertEqual(attr.type, type) @@ -163,6 +170,15 @@ def testIntegerValueParsing(self): self.dict['Test-Integer'].values['Value-Six']), 5) + def testInteger64ValueParsing(self): + self.assertEqual(len(self.dict['Test-Integer64'].values), 0) + self.dict.ReadDictionary(StringIO('VALUE Test-Integer64 Value-Six 5')) + self.assertEqual(len(self.dict['Test-Integer64'].values), 1) + self.assertEqual( + DecodeAttr('integer64', + self.dict['Test-Integer64'].values['Value-Six']), + 5) + def testStringValueParsing(self): self.assertEqual(len(self.dict['Test-String'].values), 0) self.dict.ReadDictionary(StringIO( @@ -173,6 +189,27 @@ def testStringValueParsing(self): self.dict['Test-String'].values['Value-Custard']), 'custardpie') + def testTlvParsing(self): + self.assertEqual(len(self.dict['Test-Tlv'].sub_attributes), 2) + self.assertEqual(self.dict['Test-Tlv'].sub_attributes, {1:'Test-Tlv-Str', 2: 'Test-Tlv-Int'}) + + def testSubTlvParsing(self): + for (attr, _, _) in self.simple_dict_values: + if attr.startswith('Test-Tlv-'): + self.assertEqual(self.dict[attr].is_sub_attribute, True) + self.assertEqual(self.dict[attr].parent, self.dict['Test-Tlv']) + else: + self.assertEqual(self.dict[attr].is_sub_attribute, False) + self.assertEqual(self.dict[attr].parent, None) + + # tlv with vendor + full_dict = Dictionary(os.path.join(self.path, 'full')) + self.assertEqual(full_dict['Simplon-Tlv-Str'].is_sub_attribute, True) + self.assertEqual(full_dict['Simplon-Tlv-Str'].parent, full_dict['Simplon-Tlv']) + self.assertEqual(full_dict['Simplon-Tlv-Int'].is_sub_attribute, True) + self.assertEqual(full_dict['Simplon-Tlv-Int'].parent, full_dict['Simplon-Tlv']) + + def testVenderTooFewColumnsError(self): try: self.dict.ReadDictionary(StringIO('VENDOR Simplon')) @@ -188,7 +225,7 @@ def testVendorParsing(self): self.assertEqual(self.dict.vendors['Simplon'], 42) self.dict.ReadDictionary(StringIO( 'ATTRIBUTE Test-Type 1 integer Simplon')) - self.assertEquals(self.dict.attrindex['Test-Type'], (42, 1)) + self.assertEqual(self.dict.attrindex['Test-Type'], (42, 1)) def testVendorOptionError(self): self.assertRaises(ParseError, self.dict.ReadDictionary, @@ -243,7 +280,7 @@ def testBeginVendorParsing(self): 'VENDOR Simplon 42\n' 'BEGIN-VENDOR Simplon\n' 'ATTRIBUTE Test-Type 1 integer')) - self.assertEquals(self.dict.attrindex['Test-Type'], (42, 1)) + self.assertEqual(self.dict.attrindex['Test-Type'], (42, 1)) def testEndVendorUnknownVendor(self): try: @@ -270,7 +307,7 @@ def testEndVendorParsing(self): 'BEGIN-VENDOR Simplon\n' 'END-VENDOR Simplon\n' 'ATTRIBUTE Test-Type 1 integer')) - self.assertEquals(self.dict.attrindex['Test-Type'], 1) + self.assertEqual(self.dict.attrindex['Test-Type'], 1) def testInclude(self): try: @@ -290,14 +327,14 @@ def testDictFilePostParse(self): 'VENDOR Simplon 42\n')) for _ in f: pass - self.assertEquals(f.File(), '') - self.assertEquals(f.Line(), -1) + self.assertEqual(f.File(), '') + self.assertEqual(f.Line(), -1) def testDictFileParseError(self): tmpdict = Dictionary() try: tmpdict.ReadDictionary(os.path.join(self.path, 'dictfiletest')) except ParseError as e: - self.assertEquals('dictfiletest' in str(e), True) + self.assertEqual('dictfiletest' in str(e), True) else: self.fail() diff --git a/pyrad/tests/testPacket.py b/pyrad/tests/testPacket.py index 2cd0cbf..8cdda38 100644 --- a/pyrad/tests/testPacket.py +++ b/pyrad/tests/testPacket.py @@ -9,7 +9,7 @@ class UtilityTests(unittest.TestCase): def testGenerateID(self): id = packet.CreateID() - self.failUnless(isinstance(id, int)) + self.assertTrue(isinstance(id, int)) newid = packet.CreateID() self.assertNotEqual(id, newid) @@ -23,9 +23,9 @@ def setUp(self): def testBasicConstructor(self): pkt = self.klass() - self.failUnless(isinstance(pkt.code, int)) - self.failUnless(isinstance(pkt.id, int)) - self.failUnless(isinstance(pkt.secret, six.binary_type)) + self.assertTrue(isinstance(pkt.code, int)) + self.assertTrue(isinstance(pkt.id, int)) + self.assertTrue(isinstance(pkt.secret, six.binary_type)) def testNamedConstructor(self): pkt = self.klass(code=26, id=38, secret=six.b('secret'), @@ -39,20 +39,28 @@ def testNamedConstructor(self): def testConstructWithDictionary(self): pkt = self.klass(dict=self.dict) - self.failUnless(pkt.dict is self.dict) + self.assertTrue(pkt.dict is self.dict) def testConstructorIgnoredParameters(self): marker = [] pkt = self.klass(fd=marker) - self.failIf(getattr(pkt, 'fd', None) is marker) + self.assertFalse(getattr(pkt, 'fd', None) is marker) def testSecretMustBeBytestring(self): self.assertRaises(TypeError, self.klass, secret=six.u('secret')) def testConstructorWithAttributes(self): - pkt = self.klass(dict=self.dict, Test_String='this works') + pkt = self.klass(**{'Test-String' :'this works', 'dict' : self.dict}) self.assertEqual(pkt['Test-String'], ['this works']) + def testConstructorWithTlvAttribute(self): + pkt = self.klass(**{ + 'Test-Tlv-Str': 'this works', + 'Test-Tlv-Int': 10, + 'dict': self.dict + }) + self.assertEqual(pkt['Test-Tlv'], {'Test-Tlv-Str': ['this works'], 'Test-Tlv-Int' : [10]} ) + class PacketTests(unittest.TestCase): def setUp(self): @@ -62,7 +70,7 @@ def setUp(self): authenticator=six.b('01234567890ABCDEF'), dict=self.dict) def testCreateReply(self): - reply = self.packet.CreateReply(Test_Integer=10) + reply = self.packet.CreateReply(**{'Test-Integer' : 10}) self.assertEqual(reply.id, self.packet.id) self.assertEqual(reply.secret, self.packet.secret) self.assertEqual(reply.authenticator, self.packet.authenticator) @@ -94,9 +102,9 @@ def testVendorAttributeAccess(self): def testRawAttributeAccess(self): marker = [six.b('')] self.packet[1] = marker - self.failUnless(self.packet[1] is marker) + self.assertTrue(self.packet[1] is marker) self.packet[(16, 1)] = marker - self.failUnless(self.packet[(16, 1)] is marker) + self.assertTrue(self.packet[(16, 1)] is marker) def testHasKey(self): self.assertEqual(self.packet.has_key('Test-String'), False) @@ -130,7 +138,7 @@ def testKeys(self): def testCreateAuthenticator(self): a = packet.Packet.CreateAuthenticator() - self.failUnless(isinstance(a, six.binary_type)) + self.assertTrue(isinstance(a, six.binary_type)) self.assertEqual(len(a), 16) b = packet.Packet.CreateAuthenticator() @@ -138,7 +146,7 @@ def testCreateAuthenticator(self): def testGenerateID(self): id = self.packet.CreateID() - self.failUnless(isinstance(id, int)) + self.assertTrue(isinstance(id, int)) newid = self.packet.CreateID() self.assertNotEqual(id, newid) @@ -176,15 +184,48 @@ def testPktEncodeAttribute(self): encode((1, 2), six.b('value')), six.b('\x1a\x0d\x00\x00\x00\x01\x02\x07value')) + def testPktEncodeTlvAttribute(self): + encode = self.packet._PktEncodeTlv + + # Encode a normal tlv attribute + self.assertEqual( + encode(4, {1:[six.b('value')], 2:[six.b('\x00\x00\x00\x02')]}), + six.b('\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x02')) + + # Encode a normal tlv attribute with several sub attribute instances + self.assertEqual( + encode(4, {1:[six.b('value'), six.b('other')], 2:[six.b('\x00\x00\x00\x02')]}), + six.b('\x04\x16\x01\x07value\x02\x06\x00\x00\x00\x02\x01\x07other')) + # Encode a vendor tlv attribute + self.assertEqual( + encode((16, 3), {1:[six.b('value')], 2:[six.b('\x00\x00\x00\x02')]}), + six.b('\x1a\x15\x00\x00\x00\x10\x03\x0f\x01\x07value\x02\x06\x00\x00\x00\x02')) + + def testPktEncodeLongTlvAttribute(self): + encode = self.packet._PktEncodeTlv + + long_str = 'a' * 245 + # Encode a long tlv attribute - check it is split between AVPs + self.assertEqual( + encode(4, {1:[six.b('value'), six.b(long_str)], 2:[six.b('\x00\x00\x00\x02')]}), + six.b('\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x02\x04\xf9\x01\xf7' + long_str)) + + # Encode a long vendor tlv attribute + first_avp = '\x1a\x15\x00\x00\x00\x10\x03\x0f\x01\x07value\x02\x06\x00\x00\x00\x02' + second_avp = '\x1a\xff\x00\x00\x00\x10\x03\xf9\x01\xf7' + long_str + self.assertEqual( + encode((16, 3), {1:[six.b('value'), six.b(long_str)], 2:[six.b('\x00\x00\x00\x02')]}), + six.b(first_avp + second_avp)) + def testPktEncodeAttributes(self): self.packet[1] = [six.b('value')] self.assertEqual(self.packet._PktEncodeAttributes(), six.b('\x01\x07value')) self.packet.clear() - self.packet[(1, 2)] = [six.b('value')] + self.packet[(16, 2)] = [six.b('value')] self.assertEqual(self.packet._PktEncodeAttributes(), - six.b('\x1a\x0d\x00\x00\x00\x01\x02\x07value')) + six.b('\x1a\x0d\x00\x00\x00\x10\x02\x07value')) self.packet.clear() self.packet[1] = [six.b('one'), six.b('two'), six.b('three')] @@ -193,10 +234,10 @@ def testPktEncodeAttributes(self): self.packet.clear() self.packet[1] = [six.b('value')] - self.packet[(1, 2)] = [six.b('value')] + self.packet[(16, 2)] = [six.b('value')] self.assertEqual( self.packet._PktEncodeAttributes(), - six.b('\x1a\x0d\x00\x00\x00\x01\x02\x07value\x01\x07value')) + six.b('\x01\x07value\x1a\x0d\x00\x00\x00\x10\x02\x07value')) def testPktDecodeVendorAttribute(self): decode = self.packet._PktDecodeVendorAttribute @@ -212,14 +253,31 @@ def testPktDecodeVendorAttribute(self): # Proper RFC2865 recommended form self.assertEqual( - decode(six.b('\x00\x00\x00\x01\x02\x07value')), - [((1, 2), six.b('value'))]) + decode(six.b('\x00\x00\x00\x10\x02\x07value')), + [((16, 2), six.b('value'))]) + + def testPktDecodeTlvAttribute(self): + decode = self.packet._PktDecodeTlvAttribute + + decode(4,six.b('\x01\x07value')) + self.assertEqual(self.packet[4], {1: [six.b('value')]}) + + #add another instance of the same sub attribute + decode(4,six.b('\x01\x07other')) + self.assertEqual(self.packet[4], {1: [six.b('value'), six.b('other')]}) + + #add a different sub attribute + decode(4,six.b('\x02\x07\x00\x00\x00\x01')) + self.assertEqual(self.packet[4], { + 1: [six.b('value'), six.b('other')], + 2: [six.b('\x00\x00\x00\x01')] + }) def testDecodePacketWithEmptyPacket(self): try: self.packet.DecodePacket(six.b('')) except packet.PacketError as e: - self.failUnless('header is corrupt' in str(e)) + self.assertTrue('header is corrupt' in str(e)) else: self.fail() @@ -227,7 +285,7 @@ def testDecodePacketWithInvalidLength(self): try: self.packet.DecodePacket(six.b('\x00\x00\x00\x001234567890123456')) except packet.PacketError as e: - self.failUnless('invalid length' in str(e)) + self.assertTrue('invalid length' in str(e)) else: self.fail() @@ -235,7 +293,7 @@ def testDecodePacketWithTooBigPacket(self): try: self.packet.DecodePacket(six.b('\x00\x00\x24\x00') + (0x2400 - 4) * six.b('X')) except packet.PacketError as e: - self.failUnless('too long' in str(e)) + self.assertTrue('too long' in str(e)) else: self.fail() @@ -244,7 +302,7 @@ def testDecodePacketWithPartialAttributes(self): self.packet.DecodePacket( six.b('\x01\x02\x00\x151234567890123456\x00')) except packet.PacketError as e: - self.failUnless('header is corrupt' in str(e)) + self.assertTrue('header is corrupt' in str(e)) else: self.fail() @@ -260,30 +318,49 @@ def testDecodePacketWithBadAttribute(self): self.packet.DecodePacket( six.b('\x01\x02\x00\x161234567890123456\x00\x01')) except packet.PacketError as e: - self.failUnless('too small' in str(e)) + self.assertTrue('too small' in str(e)) else: self.fail() def testDecodePacketWithEmptyAttribute(self): self.packet.DecodePacket( - six.b('\x01\x02\x00\x161234567890123456\x00\x02')) - self.assertEqual(self.packet[0], [six.b('')]) + six.b('\x01\x02\x00\x161234567890123456\x01\x02')) + self.assertEqual(self.packet[1], [six.b('')]) def testDecodePacketWithAttribute(self): self.packet.DecodePacket( - six.b('\x01\x02\x00\x1b1234567890123456\x00\x07value')) - self.assertEqual(self.packet[0], [six.b('value')]) + six.b('\x01\x02\x00\x1b1234567890123456\x01\x07value')) + self.assertEqual(self.packet[1], [six.b('value')]) + + def testDecodePacketWithTlvAttribute(self): + self.packet.DecodePacket( + six.b('\x01\x02\x00\x1d1234567890123456\x04\x09\x01\x07value')) + self.assertEqual(self.packet[4], {1:[six.b('value')]}) + + def testDecodePacketWithVendorTlvAttribute(self): + self.packet.DecodePacket( + six.b('\x01\x02\x00\x231234567890123456\x1a\x0f\x00\x00\x00\x10\x03\x09\x01\x07value')) + self.assertEqual(self.packet[(16,3)], {1:[six.b('value')]}) + + def testDecodePacketWithTlvAttributeWith2SubAttributes(self): + self.packet.DecodePacket( + six.b('\x01\x02\x00\x231234567890123456\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x09')) + self.assertEqual(self.packet[4], {1:[six.b('value')], 2:[six.b('\x00\x00\x00\x09')]}) + + def testDecodePacketWithSplitTlvAttribute(self): + self.packet.DecodePacket( + six.b('\x01\x02\x00\x251234567890123456\x04\x09\x01\x07value\x04\x09\x02\x06\x00\x00\x00\x09')) + self.assertEqual(self.packet[4], {1:[six.b('value')], 2:[six.b('\x00\x00\x00\x09')]}) def testDecodePacketWithMultiValuedAttribute(self): self.packet.DecodePacket( - six.b('\x01\x02\x00\x1e1234567890123456\x00\x05one\x00\x05two')) - self.assertEqual(self.packet[0], [six.b('one'), six.b('two')]) + six.b('\x01\x02\x00\x1e1234567890123456\x01\x05one\x01\x05two')) + self.assertEqual(self.packet[1], [six.b('one'), six.b('two')]) def testDecodePacketWithTwoAttributes(self): self.packet.DecodePacket( - six.b('\x01\x02\x00\x1e1234567890123456\x00\x05one\x01\x05two')) - self.assertEqual(self.packet[0], [six.b('one')]) - self.assertEqual(self.packet[1], [six.b('two')]) + six.b('\x01\x02\x00\x1e1234567890123456\x01\x05one\x01\x05two')) + self.assertEqual(self.packet[1], [six.b('one'), six.b('two')]) def testDecodePacketWithVendorAttribute(self): self.packet.DecodePacket( @@ -297,11 +374,12 @@ def testEncodeKey(self): self.assertEqual(self.packet._EncodeKey(1), 1) def testAddAttribute(self): - self.packet.AddAttribute(1, 1) - self.assertEqual(dict.__getitem__(self.packet, 1), [1]) - self.packet.AddAttribute(1, 1) - self.assertEqual(dict.__getitem__(self.packet, 1), [1, 1]) - + self.packet.AddAttribute('Test-String', '1') + self.assertEqual(self.packet['Test-String'], ['1']) + self.packet.AddAttribute('Test-String', '1') + self.assertEqual(self.packet['Test-String'], ['1', '1']) + self.packet.AddAttribute('Test-String', ['2', '3']) + self.assertEqual(self.packet['Test-String'], ['1', '1', '2', '3']) class AuthPacketConstructionTests(PacketConstructionTests): klass = packet.AuthPacket @@ -319,7 +397,7 @@ def setUp(self): authenticator=six.b('01234567890ABCDEF'), dict=self.dict) def testCreateReply(self): - reply = self.packet.CreateReply(Test_Integer=10) + reply = self.packet.CreateReply(**{'Test-Integer' : 10}) self.assertEqual(reply.code, packet.AccessAccept) self.assertEqual(reply.id, self.packet.id) self.assertEqual(reply.secret, self.packet.secret) @@ -333,12 +411,12 @@ def testRequestPacket(self): def testRequestPacketCreatesAuthenticator(self): self.packet.authenticator = None self.packet.RequestPacket() - self.failUnless(self.packet.authenticator is not None) + self.assertTrue(self.packet.authenticator is not None) def testRequestPacketCreatesID(self): self.packet.id = None self.packet.RequestPacket() - self.failUnless(self.packet.id is not None) + self.assertTrue(self.packet.id is not None) def testPwCryptEmptyPassword(self): self.assertEqual(self.packet.PwCrypt(''), six.b('')) @@ -350,7 +428,7 @@ def testPwCryptPassword(self): def testPwCryptSetsAuthenticator(self): self.packet.authenticator = None self.packet.PwCrypt(six.u('')) - self.failUnless(self.packet.authenticator is not None) + self.assertTrue(self.packet.authenticator is not None) def testPwDecryptEmptyPassword(self): self.assertEqual(self.packet.PwDecrypt(six.b('')), six.u('')) @@ -383,7 +461,7 @@ def setUp(self): authenticator=six.b('01234567890ABCDEF'), dict=self.dict) def testCreateReply(self): - reply = self.packet.CreateReply(Test_Integer=10) + reply = self.packet.CreateReply(**{'Test-Integer' : 10}) self.assertEqual(reply.code, packet.AccountingResponse) self.assertEqual(reply.id, self.packet.id) self.assertEqual(reply.secret, self.packet.secret) @@ -409,4 +487,4 @@ def testRequestPacket(self): def testRequestPacketSetsId(self): self.packet.id = None self.packet.RequestPacket() - self.failUnless(self.packet.id is not None) + self.assertTrue(self.packet.id is not None) diff --git a/pyrad/tests/testProxy.py b/pyrad/tests/testProxy.py index 946cd4f..6a4eb77 100644 --- a/pyrad/tests/testProxy.py +++ b/pyrad/tests/testProxy.py @@ -33,7 +33,7 @@ def testProxyFd(self): self.failUnless(isinstance(self.proxy._proxyfd, MockSocket)) self.assertEqual(list(self.proxy._fdmap.keys()), [1]) self.assertEqual(self.proxy._poll.registry, - [(1, select.POLLIN | select.POLLPRI | select.POLLERR)]) + {1: select.POLLIN | select.POLLPRI | select.POLLERR}) class ProxyPacketHandlingTests(unittest.TestCase): diff --git a/pyrad/tests/testServer.py b/pyrad/tests/testServer.py index e1c3c68..d4558c6 100644 --- a/pyrad/tests/testServer.py +++ b/pyrad/tests/testServer.py @@ -91,6 +91,16 @@ def testBind(self): self.assertEqual(self.server.acctfds[0].address, ('192.168.13.13', 1813)) + def testBindv6(self): + self.server.BindToAddress('2001:db8:123::1') + self.assertEqual(len(self.server.authfds), 1) + self.assertEqual(self.server.authfds[0].address, + ('2001:db8:123::1', 1812)) + + self.assertEqual(len(self.server.acctfds), 1) + self.assertEqual(self.server.acctfds[0].address, + ('2001:db8:123::1', 1813)) + def testGrabPacket(self): def gen(data): res = TrivialObject() @@ -109,7 +119,7 @@ def testPrepareSocketNoFds(self): self.server._poll = MockPoll() self.server._PrepareSockets() - self.assertEqual(self.server._poll.registry, []) + self.assertEqual(self.server._poll.registry, {}) self.assertEqual(self.server._realauthfds, []) self.assertEqual(self.server._realacctfds, []) @@ -121,8 +131,8 @@ def testPrepareSocketAuthFds(self): self.assertEqual(list(self.server._fdmap.keys()), [12, 14]) self.assertEqual(self.server._poll.registry, - [(12, select.POLLIN | select.POLLPRI | select.POLLERR), - (14, select.POLLIN | select.POLLPRI | select.POLLERR)]) + {12: select.POLLIN | select.POLLPRI | select.POLLERR, + 14: select.POLLIN | select.POLLPRI | select.POLLERR}) def testPrepareSocketAcctFds(self): self.server._poll = MockPoll() @@ -132,8 +142,8 @@ def testPrepareSocketAcctFds(self): self.assertEqual(list(self.server._fdmap.keys()), [12, 14]) self.assertEqual(self.server._poll.registry, - [(12, select.POLLIN | select.POLLPRI | select.POLLERR), - (14, select.POLLIN | select.POLLPRI | select.POLLERR)]) + {12: select.POLLIN | select.POLLPRI | select.POLLERR, + 14: select.POLLIN | select.POLLPRI | select.POLLERR}) class AuthPacketHandlingTests(unittest.TestCase): diff --git a/pyrad/tests/testTools.py b/pyrad/tests/testTools.py index 581ef53..8dfd55f 100644 --- a/pyrad/tests/testTools.py +++ b/pyrad/tests/testTools.py @@ -28,6 +28,11 @@ def testInvalidAddressEncodingRaisesTypeError(self): def testIntegerEncoding(self): self.assertEqual(tools.EncodeInteger(0x01020304), six.b('\x01\x02\x03\x04')) + def testInteger64Encoding(self): + self.assertEqual( + tools.EncodeInteger64(0xFFFFFFFFFFFFFFFF), six.b('\xff' * 8) + ) + def testUnsignedIntegerEncoding(self): self.assertEqual(tools.EncodeInteger(0xFFFFFFFF), six.b('\xff\xff\xff\xff')) @@ -60,6 +65,11 @@ def testIntegerDecoding(self): tools.DecodeInteger(six.b('\x01\x02\x03\x04')), 0x01020304) + def testInteger64Decoding(self): + self.assertEqual( + tools.DecodeInteger64(six.b('\xff' * 8)), 0xFFFFFFFFFFFFFFFF + ) + def testDateDecoding(self): self.assertEqual( tools.DecodeDate(six.b('\x01\x02\x03\x04')), @@ -87,6 +97,9 @@ def testEncodeFunction(self): self.assertEqual( tools.EncodeAttr('date', 0x01020304), six.b('\x01\x02\x03\x04')) + self.assertEqual( + tools.EncodeAttr('integer64', 0xFFFFFFFFFFFFFFFF), + six.b('\xff'*8)) def testDecodeFunction(self): self.assertEqual( @@ -101,6 +114,9 @@ def testDecodeFunction(self): self.assertEqual( tools.DecodeAttr('integer', six.b('\x01\x02\x03\x04')), 0x01020304) + self.assertEqual( + tools.DecodeAttr('integer64', six.b('\xff'*8)), + 0xFFFFFFFFFFFFFFFF) self.assertEqual( tools.DecodeAttr('date', six.b('\x01\x02\x03\x04')), 0x01020304) diff --git a/pyrad/tools.py b/pyrad/tools.py index e3ec5a8..9b330b1 100644 --- a/pyrad/tools.py +++ b/pyrad/tools.py @@ -125,6 +125,12 @@ def EncodeInteger(num, format='!I'): raise TypeError('Can not encode non-integer as integer') return struct.pack(format, num) +def EncodeInteger64(num, format='!Q'): + try: + num = int(num) + except: + raise TypeError('Can not encode non-integer as integer64') + return struct.pack(format, num) def EncodeDate(num): if not isinstance(num, int): @@ -149,13 +155,13 @@ def DecodeAddress(addr): def DecodeIPv6Prefix(addr): addr = addr + b'\x00' * (18-len(addr)) - _, length, prefix = ':'.join(map('{:x}'.format, struct.unpack('!BB'+'H'*8, addr))).split(":", 2) + _, length, prefix = ':'.join(map('{0:x}'.format, struct.unpack('!BB'+'H'*8, addr))).split(":", 2) return str(IPNetwork("%s/%s" % (prefix, int(length, 16)))) def DecodeIPv6Address(addr): addr = addr + b'\x00' * (16-len(addr)) - prefix = ':'.join(map('{:x}'.format, struct.unpack('!'+'H'*8, addr))) + prefix = ':'.join(map('{0:x}'.format, struct.unpack('!'+'H'*8, addr))) return str(IPAddress(prefix)) @@ -166,6 +172,8 @@ def DecodeAscendBinary(str): def DecodeInteger(num, format='!I'): return (struct.unpack(format, num))[0] +def DecodeInteger64(num, format='!Q'): + return (struct.unpack(format, num))[0] def DecodeDate(num): return (struct.unpack('!I', num))[0] @@ -194,6 +202,8 @@ def EncodeAttr(datatype, value): return EncodeInteger(value, '!B') elif datatype == 'date': return EncodeDate(value) + elif datatype == 'integer64': + return EncodeInteger64(value) else: raise ValueError('Unknown attribute type %s' % datatype) @@ -221,5 +231,7 @@ def DecodeAttr(datatype, value): return DecodeInteger(value, '!B') elif datatype == 'date': return DecodeDate(value) + elif datatype == 'integer64': + return DecodeInteger64(value) else: raise ValueError('Unknown attribute type %s' % datatype) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b5007e2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +six +netaddr diff --git a/setup.py b/setup.py index d40412d..5a9ea4e 100755 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages -version = '2.0' +version = '2.1' setup(name='pyrad', @@ -18,9 +18,9 @@ 'Development Status :: 6 - Mature', 'Intended Audience :: Developers', 'License :: OSI Approved :: BSD License', - 'Programming Language :: Python :: 2.6', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3.2', + 'Programming Language :: Python :: 3.6', 'Topic :: Software Development :: Libraries :: Python Modules', 'Topic :: System :: Systems Administration :: Authentication/Directory', ],