Skip to content

Commit

Permalink
[2.5] Support explicit message authentication (NVIDIA#3096)
Browse files Browse the repository at this point in the history
* add client auth to 2.5

* fix test case

* remove unused func

* support explicit client auth

* fix unused imports

* add docstring

* fix insecure processing

* fix simulator
  • Loading branch information
yanchengnv authored Dec 9, 2024
1 parent 245f0f5 commit 397983b
Show file tree
Hide file tree
Showing 21 changed files with 855 additions and 137 deletions.
6 changes: 6 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ class FLContextKey(object):
FILTER_DIRECTION = "__filter_dir__"
ROOT_URL = "__root_url__" # the URL for accessing the FL Server
NOT_READY_TO_END_RUN = "not_ready_to_end_run__" # component sets this to indicate it's not ready to end run yet
CLIENT_CONFIG = "__client_config__"
SERVER_CONFIG = "__server_config__"
SERVER_HOST_NAME = "__server_host_name__"


class ReservedTopic(object):
Expand Down Expand Up @@ -480,6 +483,9 @@ class ConfigVarName:
# client and server: max amount of time to wait for communication cell to be created
CELL_WAIT_TIMEOUT = "cell_wait_timeout"

# these vars are set in Server's startup config (fed_server.json)
MAX_REG_DURATION = "max_reg_duration"


class SystemVarName:
"""
Expand Down
19 changes: 19 additions & 0 deletions nvflare/fuel/f3/cellnet/core_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def __init__(
self.communicator.register_message_receiver(app_id=self.APP_ID, receiver=self)
self.communicator.register_monitor(monitor=self)
self.req_reg = Registry()
self.in_filter_reg = Registry() # for any incoming messages
self.in_req_filter_reg = Registry() # for request received
self.out_reply_filter_reg = Registry() # for reply going out
self.out_req_filter_reg = Registry() # for request sent
Expand Down Expand Up @@ -991,6 +992,11 @@ def decrypt_payload(self, message: Message):
if len(message.payload) != payload_len:
raise RuntimeError(f"Payload size changed after decryption {len(message.payload)} <> {payload_len}")

def add_incoming_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified incoming_filter {type(cb)} is not callable")
self.in_filter_reg.append(channel, topic, Callback(cb, args, kwargs))

def add_incoming_request_filter(self, channel: str, topic: str, cb, *args, **kwargs):
if not callable(cb):
raise ValueError(f"specified incoming_request_filter {type(cb)} is not callable")
Expand Down Expand Up @@ -1856,6 +1862,19 @@ def _process_received_msg(self, endpoint: Endpoint, connection: Connection, mess
category=self._stats_category(message), counter_name=_CounterName.RECEIVED
)

# invoke incoming filters
channel = message.get_header(MessageHeaderKey.CHANNEL, "")
topic = message.get_header(MessageHeaderKey.TOPIC, "")
in_filters = self.in_filter_reg.find(channel, topic)
if in_filters:
self.logger.debug(f"{self.my_info.fqcn}: invoking incoming filters")
assert isinstance(in_filters, list)
for f in in_filters:
assert isinstance(f, Callback)
reply = self._try_cb(message, f.cb, *f.args, **f.kwargs)
if reply:
return reply

if msg_type == MessageType.REQ and self.message_interceptor is not None:
reply = self._try_cb(
message, self.message_interceptor, *self.message_interceptor_args, **self.message_interceptor_kwargs
Expand Down
10 changes: 9 additions & 1 deletion nvflare/fuel/f3/cellnet/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class CellChannel:


class CellChannelTopic:

Challenge = "challenge"
Register = "register"
Quit = "quit"
GET_TASK = "get_task"
Expand All @@ -171,3 +171,11 @@ class CellChannelTopic:
REPORT_JOB_FAILURE = "report_job_failure"

SIMULATOR_WORKER_INIT = "simulator_worker_init"


class IdentityChallengeKey:

NONCE = "nonce"
CERT = "cert"
SIGNATURE = "signature"
COMMON_NAME = "cn"
155 changes: 100 additions & 55 deletions nvflare/lighter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,94 @@
from base64 import b64decode, b64encode

import yaml
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding

from nvflare.lighter.impl.cert import load_crt
from nvflare.lighter.tool_consts import NVFLARE_SIG_FILE, NVFLARE_SUBMITTER_CRT_FILE


def serialize_pri_key(pri_key):
return pri_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)


def serialize_cert(cert):
return cert.public_bytes(serialization.Encoding.PEM)


def load_crt(path):
with open(path, "rb") as f:
return load_crt_bytes(f.read())


def load_crt_bytes(data: bytes):
return x509.load_pem_x509_certificate(data, default_backend())


def generate_password(passlen=16):
s = "abcdefghijklmnopqrstuvwxyz01234567890ABCDEFGHIJKLMNOPQRSTUVWXYZ"
p = "".join(random.sample(s, passlen))
return p


def sign_one(content, signing_pri_key):
def sign_content(content, signing_pri_key, return_str=True):
if isinstance(content, str):
content = content.encode("utf-8") # to bytes
signature = signing_pri_key.sign(
data=content,
padding=padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
algorithm=hashes.SHA256(),
padding=_content_padding(),
algorithm=_content_hash_algo(),
)

# signature is bytes
if return_str:
return b64encode(signature).decode("utf-8")
else:
return signature


def _content_padding():
return padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH)


def _content_hash_algo():
return hashes.SHA256()


def verify_content(content, signature, public_key):
if isinstance(content, str):
content = content.encode("utf-8") # to bytes
if isinstance(signature, str):
signature = b64decode(signature.encode("utf-8")) # decode to bytes
public_key.verify(
signature=signature,
data=content,
padding=_content_padding(),
algorithm=_content_hash_algo(),
)


def verify_cert(cert_to_be_verified, root_ca_public_key):
root_ca_public_key.verify(
cert_to_be_verified.signature,
cert_to_be_verified.tbs_certificate_bytes,
padding.PKCS1v15(),
cert_to_be_verified.signature_hash_algorithm,
)
return b64encode(signature).decode("utf-8")


def load_private_key(data: str):
return serialization.load_pem_private_key(data.encode("ascii"), password=None, backend=default_backend())


def load_private_key_file(file_path):
with open(file_path, "rt") as f:
pri_key = serialization.load_pem_private_key(f.read().encode("ascii"), password=None, backend=default_backend())
return pri_key
return load_private_key(f.read())


def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999):
Expand All @@ -59,27 +117,19 @@ def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999):
for file in files:
if file == NVFLARE_SIG_FILE or file == NVFLARE_SUBMITTER_CRT_FILE:
continue
signature = signing_pri_key.sign(
data=open(os.path.join(root, file), "rb").read(),
padding=padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
algorithm=hashes.SHA256(),
)
signatures[file] = b64encode(signature).decode("utf-8")
with open(os.path.join(root, file), "rb") as f:
signatures[file] = sign_content(
content=f.read(),
signing_pri_key=signing_pri_key,
)
for folder in folders:
signature = signing_pri_key.sign(
data=folder.encode("utf-8"),
padding=padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
algorithm=hashes.SHA256(),
signatures[folder] = sign_content(
content=folder,
signing_pri_key=signing_pri_key,
)
signatures[folder] = b64encode(signature).decode("utf-8")

json.dump(signatures, open(os.path.join(root, NVFLARE_SIG_FILE), "wt"))
with open(os.path.join(root, NVFLARE_SIG_FILE), "wt") as f:
json.dump(signatures, f)
shutil.copyfile(crt_path, os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE))
if depth >= max_depth:
break
Expand All @@ -91,35 +141,32 @@ def verify_folder_signature(src_folder, root_ca_path):
root_ca_public_key = root_ca_cert.public_key()
for root, folders, files in os.walk(src_folder):
try:
signatures = json.load(open(os.path.join(root, NVFLARE_SIG_FILE), "rt"))
with open(os.path.join(root, NVFLARE_SIG_FILE), "rt") as f:
signatures = json.load(f)
cert = load_crt(os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE))
public_key = cert.public_key()
except:
continue # TODO: shall return False
root_ca_public_key.verify(
cert.signature, cert.tbs_certificate_bytes, padding.PKCS1v15(), cert.signature_hash_algorithm
)
for k in signatures:
signatures[k] = b64decode(signatures[k].encode("utf-8"))

verify_cert(cert_to_be_verified=cert, root_ca_public_key=root_ca_public_key)
for file in files:
if file == NVFLARE_SIG_FILE or file == NVFLARE_SUBMITTER_CRT_FILE:
continue
signature = signatures.get(file)
if signature:
public_key.verify(
signature=signature,
data=open(os.path.join(root, file), "rb").read(),
padding=padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
algorithm=hashes.SHA256(),
)
with open(os.path.join(root, file), "rb") as f:
verify_content(
content=f.read(),
signature=signature,
public_key=public_key,
)
for folder in folders:
signature = signatures.get(folder)
if signature:
public_key.verify(
verify_content(
content=folder,
signature=signature,
data=folder.encode("utf-8"),
padding=padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH),
algorithm=hashes.SHA256(),
public_key=public_key,
)
return True
except Exception as e:
Expand All @@ -131,21 +178,18 @@ def sign_all(content_folder, signing_pri_key):
for f in os.listdir(content_folder):
path = os.path.join(content_folder, f)
if os.path.isfile(path):
signature = signing_pri_key.sign(
data=open(path, "rb").read(),
padding=padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
algorithm=hashes.SHA256(),
)
signatures[f] = b64encode(signature).decode("utf-8")
with open(path, "rb") as file:
signatures[f] = sign_content(
content=file.read(),
signing_pri_key=signing_pri_key,
)
return signatures


def load_yaml(file):
if isinstance(file, str):
return yaml.safe_load(open(file, "r"))
with open(file, "r") as f:
return yaml.safe_load(f)
elif isinstance(file, bytes):
return yaml.safe_load(file)
else:
Expand Down Expand Up @@ -181,7 +225,8 @@ def update_participant_server_name(project_config, old_server_name, new_server_n
for p in participants:
if p["type"] == "server" and p["name"] == old_server_name:
p["name"] = new_server_name
return
break
return project_config


def update_project_server_name(project_file: str, old_server_name, server_name):
Expand Down
24 changes: 21 additions & 3 deletions nvflare/private/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import uuid

# this import is to let existing scripts import from nvflare.private.defs
from nvflare.fuel.f3.cellnet.defs import CellChannel, CellChannelTopic, SSLConstants # noqa: F401
Expand All @@ -31,8 +33,8 @@ class TaskConstant(object):
class EngineConstant(object):

FEDERATE_CLIENT = "federate_client"
FL_TOKEN = "fl_token"
CLIENT_TOKEN_FILE = "client_token.txt"
AUTH_TOKEN = "auth_token"
AUTH_TOKEN_SIGNATURE = "auth_token_signature"
ENGINE_TASK_NAME = "engine_task_name"


Expand Down Expand Up @@ -138,7 +140,8 @@ class CellMessageHeaderKeys:
CLIENT_NAME = "client_name"
CLIENT_IP = "client_ip"
PROJECT_NAME = "project_name"
TOKEN = "token"
TOKEN = "__token__"
TOKEN_SIGNATURE = "__token_signature__"
SSID = "ssid"
UNAUTHENTICATED = "unauthenticated"
JOB_ID = "job_id"
Expand All @@ -147,13 +150,28 @@ class CellMessageHeaderKeys:
ABORT_JOBS = "abort_jobs"


AUTH_CLIENT_NAME_FOR_SJ = "server_job"


class JobFailureMsgKey:

JOB_ID = "job_id"
CODE = "code"
REASON = "reason"


class InternalFLContextKey:

CLIENT_REG_SESSION = "client_reg_session"


class ClientRegSession:
def __init__(self, client_name: str):
self.client_name = client_name
self.nonce = str(uuid.uuid4())
self.reg_start_time = time.time()


def new_cell_message(headers: dict, payload=None):
msg_headers = {}
if headers:
Expand Down
19 changes: 19 additions & 0 deletions nvflare/private/fed/app/client/client_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,25 @@ def main(args):
time.sleep(1.0)

with client_engine.new_context() as fl_ctx:
fl_ctx.set_prop(
key=FLContextKey.CLIENT_CONFIG,
value=deployer.client_config,
private=True,
sticky=True,
)
fl_ctx.set_prop(
key=FLContextKey.SERVER_CONFIG,
value=deployer.server_config,
private=True,
sticky=True,
)
fl_ctx.set_prop(
key=FLContextKey.SECURE_MODE,
value=deployer.secure_train,
private=True,
sticky=True,
)

fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True)
client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx)

Expand Down
Loading

0 comments on commit 397983b

Please sign in to comment.