diff --git a/src/limits/MemoryLimitListener.cc b/src/limits/MemoryLimitListener.cc index feb5bcc..06f3b0a 100644 --- a/src/limits/MemoryLimitListener.cc +++ b/src/limits/MemoryLimitListener.cc @@ -8,9 +8,16 @@ #include "seccomp/action/ActionTrace.h" #include "seccomp/filter/LibSeccompFilter.h" +#include #include #include +#ifndef MREMAP_DONTUNMAP +// Due to this flag being introduced in Linux 5.7 there are many system which do +// not define it +#define MREMAP_DONTUNMAP 4 +#endif + #include #include #include @@ -40,33 +47,57 @@ MemoryLimitListener::MemoryLimitListener(uint64_t memoryLimitKb) if (!vmPeakValid_) { return tracer::TraceAction::CONTINUE; } - - uint64_t memoryUsage = getMemoryUsageKb() + - tracee.getSyscallArgument(1) / 1024; - memoryPeakKb_ = std::max(memoryPeakKb_, memoryUsage); - outputBuilder_->setMemoryPeak(memoryPeakKb_); - logger::debug( - "Memory usage after mmap ", - VAR(memoryUsage), - ", ", - VAR(memoryPeakKb_)); - - if (memoryUsage > memoryLimitKb_) { - outputBuilder_->setKillReason( - printer::OutputBuilder::KillReason::MLE, - "memory limit exceeded"); - logger::debug( - "Limit ", - VAR(memoryLimitKb_), - " exceeded, killing tracee"); - return tracer::TraceAction::KILL; - } - return tracer::TraceAction::CONTINUE; + return handleMemoryAllocation( + tracee.getSyscallArgument(1) / 1024); }), Arg(0) == 0 && Arg(1) > MEMORY_LIMIT_MARGIN / 2)); } + syscallRules_.emplace_back(seccomp::SeccompRule( + "mremap", + seccomp::action::ActionTrace([this](tracer::Tracee& tracee) { + TRACE(); + uint64_t old_size = tracee.getSyscallArgument(1); + uint64_t new_size = tracee.getSyscallArgument(2); + uint64_t flags = tracee.getSyscallArgument(3); + if (!vmPeakValid_) { + return tracer::TraceAction::CONTINUE; + } + bool doesntUnmap = + (flags & MREMAP_DONTUNMAP) == MREMAP_DONTUNMAP; + // Allow user to shrink its memory + if (!doesntUnmap && old_size >= new_size) { + return tracer::TraceAction::CONTINUE; + } + uint64_t newMemoryAllocated = new_size; + // Do not count already allocated memory + if (!doesntUnmap) + newMemoryAllocated -= old_size; + newMemoryAllocated /= 1024; + return handleMemoryAllocation(newMemoryAllocated); + }), + Arg(2) > MEMORY_LIMIT_MARGIN / 2)); } +tracer::TraceAction MemoryLimitListener::handleMemoryAllocation( + uint64_t allocatedMemoryKb) { + uint64_t memoryUsage = getMemoryUsageKb() + allocatedMemoryKb; + memoryPeakKb_ = std::max(memoryPeakKb_, memoryUsage); + outputBuilder_->setMemoryPeak(memoryPeakKb_); + logger::debug( + "Memory usage after allocation ", + VAR(memoryUsage), + ", ", + VAR(memoryPeakKb_)); + if (memoryLimitKb_ > 0 && memoryUsage > memoryLimitKb_) { + outputBuilder_->setKillReason( + printer::OutputBuilder::KillReason::MLE, + "memory limit exceeded"); + logger::debug( + "Limit ", VAR(memoryLimitKb_), " exceeded, killing tracee"); + return tracer::TraceAction::KILL; + } + return tracer::TraceAction::CONTINUE; +} void MemoryLimitListener::onPostForkChild() { TRACE(); diff --git a/src/limits/MemoryLimitListener.h b/src/limits/MemoryLimitListener.h index 7430fa1..fb7bf50 100644 --- a/src/limits/MemoryLimitListener.h +++ b/src/limits/MemoryLimitListener.h @@ -40,6 +40,7 @@ class MemoryLimitListener pid_t childPid_; std::vector syscallRules_; + tracer::TraceAction handleMemoryAllocation(uint64_t allocatedMemoryKb); }; } // namespace limits