diff --git a/spec/lucky/mime_type_spec.cr b/spec/lucky/mime_type_spec.cr index 17b98561f..779e1358d 100644 --- a/spec/lucky/mime_type_spec.cr +++ b/spec/lucky/mime_type_spec.cr @@ -42,6 +42,148 @@ describe Lucky::MimeType do format = determine_format(default_format: :csv) format.should eq(:csv) end + + describe "when the 'Accept' header accepts all images" do + before_each do + Lucky::MimeType.register "image/png", :png + Lucky::MimeType.register "image/x-icon", :ico + end + + after_each do + Lucky::MimeType.deregister "image/png" + Lucky::MimeType.deregister "image/x-icon" + end + + it "returns the default accepted mime type that matches the prefix" do + any_image = "image/*;q=0.8" + format = determine_format(default_format: :ico, headers: {"accept": any_image}, accepted_formats: [:png, :ico]) + format.should eq(:png) + end + end + + describe "when the 'Accept' header accepts anything with a lower quality factor" do + # Test for https://github.com/luckyframework/lucky/issues/1766 + it "returns an accepted format" do + accept = "*/*; q=0.5, application/xml" + format = determine_format(default_format: :html, headers: {"accept": accept}, accepted_formats: [:json]) + format.should eq(:json) + end + end + end + + describe Lucky::MimeType::MediaRange do + it "accepts valid values" do + [ + {"*/*", Lucky::MimeType::MediaRange.new("*", "*", 1000)}, + {"image/*", Lucky::MimeType::MediaRange.new("image", "*", 1000)}, + {"text/plain", Lucky::MimeType::MediaRange.new("text", "plain", 1000)}, + ].each do |test| + Lucky::MimeType::MediaRange.parse(test[0]).should eq(test[1]) + end + end + + it "rejects invalid values" do + [ + {"*/image", "*/image is not a valid media range"}, + {"asdf", "asdf is not a valid media range"}, + {"text/plain; q=1.9", "qvalue 1.9 is not within 0 to 1.0"}, + {"text/plain; q=1.2.3", "1.2.3 is not a valid qvalue"}, + ].each do |range, message| + expect_raises(Lucky::MimeType::InvalidMediaRange, message) do + Lucky::MimeType::MediaRange.parse(range) + end + end + end + + it "accepts parameters" do + expected = Lucky::MimeType::MediaRange.new("text", "plain", 1000) + [ + "text/plain;format=flowed", + "text/plain\t; format=flowed", + "text/plain;format=fixed", + "text/plain; format=fixed", + "text/plain \t; \tformat=fixed", + "text/plain;format=fixed;charset=UTF-8", + ].each do |input| + Lucky::MimeType::MediaRange.parse(input).should eq(expected) + end + end + + it "ignores case" do + expected = Lucky::MimeType::MediaRange.new("text", "html", 1000) + [ + "text/html;charset=utf-8", + "Text/HTML;Charset=\"utf-8\"", + "text/html; charset=\"utf-8\"", + "text/html;charset=UTF-8", + ].each do |input| + Lucky::MimeType::MediaRange.parse(input).should eq(expected) + end + end + + it "parses the qvalue" do + [ + {"*/*; q=0", Lucky::MimeType::MediaRange.new("*", "*", 0)}, + {"*/*; q=1", Lucky::MimeType::MediaRange.new("*", "*", 1000)}, + {"*/*; q=0.1", Lucky::MimeType::MediaRange.new("*", "*", 100)}, + {"image/*; q=0.12", Lucky::MimeType::MediaRange.new("image", "*", 120)}, + {"text/plain; q=0.123", Lucky::MimeType::MediaRange.new("text", "plain", 123)}, + {"text/plain;format=fixed;q=0.4", Lucky::MimeType::MediaRange.new("text", "plain", 400)}, + # qvalue must be last so is ignored if not + {"text/plain;q=0.4;format=fixed", Lucky::MimeType::MediaRange.new("text", "plain", 1000)}, + ].each do |test| + Lucky::MimeType::MediaRange.parse(test[0]).should eq(test[1]) + end + end + end + + describe Lucky::MimeType::AcceptList do + it "is empty when the Accept value is nil" do + Lucky::MimeType::AcceptList.new(nil).list.should be_empty + end + + it "accepts single values" do + expected = [Lucky::MimeType::MediaRange.new("text", "html", 1000)] + Lucky::MimeType::AcceptList.new("text/html").list.should eq(expected) + end + + it "accepts multiple values" do + expected = [ + Lucky::MimeType::MediaRange.new("audio", "basic", 1000), + Lucky::MimeType::MediaRange.new("audio", "*", 200), + ] + Lucky::MimeType::AcceptList.new("audio/*; q=0.2, audio/basic").list.should eq(expected) + end + + it "sorts multiple values by qvalue" do + expected = [ + Lucky::MimeType::MediaRange.new("text", "html", 1000), + Lucky::MimeType::MediaRange.new("text", "x-c", 1000), + Lucky::MimeType::MediaRange.new("text", "x-dvi", 800), + Lucky::MimeType::MediaRange.new("text", "plain", 500), + ] + Lucky::MimeType::AcceptList.new("text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c").list.should eq(expected) + end + + it "parses a default browser Accept value" do + expected = [ + Lucky::MimeType::MediaRange.new("text", "html", 1000), + Lucky::MimeType::MediaRange.new("application", "xhtml+xml", 1000), + Lucky::MimeType::MediaRange.new("image", "avif", 1000), + Lucky::MimeType::MediaRange.new("image", "webp", 1000), + Lucky::MimeType::MediaRange.new("application", "xml", 900), + Lucky::MimeType::MediaRange.new("*", "*", 800), + ] + # Value is from Firefox requesting a web page + Lucky::MimeType::AcceptList.new("text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,*/*;q=0.8").list.should eq(expected) + end + + it "skips invalid media ranges" do + expected = [ + Lucky::MimeType::MediaRange.new("audio", "basic", 1000), + ] + Lucky::MimeType::AcceptList.new("*/invalid; q=0.2, audio/basic").list.should eq(expected) + end end end diff --git a/spec/lucky/request_type_helper_spec.cr b/spec/lucky/request_type_helper_spec.cr index 6896acd5a..8b33be871 100644 --- a/spec/lucky/request_type_helper_spec.cr +++ b/spec/lucky/request_type_helper_spec.cr @@ -13,7 +13,7 @@ end describe Lucky::RequestTypeHelpers do it "determines the format from 'Accept' header correctly" do Lucky::MimeType.accept_header_formats.each do |header, format| - override_accept_header header do |action| + override_accept_header header.to_s do |action| action.accepts?(format).should be_true end end diff --git a/src/lucky/mime_type.cr b/src/lucky/mime_type.cr index 1cc00b057..ee07af266 100644 --- a/src/lucky/mime_type.cr +++ b/src/lucky/mime_type.cr @@ -2,7 +2,18 @@ class Lucky::MimeType alias Format = Symbol alias AcceptHeaderSubstring = String - class_getter accept_header_formats = {} of AcceptHeaderSubstring => Format + class_getter accept_header_formats = {} of MediaType => Format + + struct MediaType + property type, subtype + + def initialize(@type : String, @subtype : String) + end + + def to_s + "#{type}/#{subtype}" + end + end register "text/html", :html register "application/json", :json @@ -24,7 +35,7 @@ class Lucky::MimeType register "application/x-www-form-urlencoded", :url_encoded_form def self.known_accept_headers : Array(String) - accept_header_formats.keys + accept_header_formats.keys.map(&.to_s) end def self.known_formats : Array(Symbol) @@ -36,7 +47,17 @@ class Lucky::MimeType end def self.register(accept_header_substring : AcceptHeaderSubstring, format : Format) : Nil - accept_header_formats[accept_header_substring] = format + type, subtype = accept_header_substring.split("/", 2) + if type && subtype + accept_header_formats[MediaType.new(type, subtype)] = format + else + raise "#{accept_header_substring} is not a valid media type" + end + end + + def self.deregister(accept_header_substring : AcceptHeaderSubstring) : Nil + type, subtype = accept_header_substring.split("/", 2) + accept_header_formats.delete({type, subtype}) end # :nodoc: @@ -44,6 +65,9 @@ class Lucky::MimeType DetermineClientsDesiredFormat.new(request, default_format, accepted_formats).call end + class InvalidMediaRange < Exception + end + private class DetermineClientsDesiredFormat private getter request, default_format, accepted_formats @@ -51,37 +75,24 @@ class Lucky::MimeType end def call : Symbol? - accept = accept_header - - if usable_accept_header? && accept + if accept = accept_header from_accept_header(accept) - elsif accepts_html? && default_accept_header_that_browsers_send? - :html else default_format end end - private def accepts_html? : Bool - @accepted_formats.includes? :html - end - private def from_accept_header(accept : String) : Symbol? # If the request accepts anything with no particular preference, return # the default format if accept == "*/*" default_format else - Lucky::MimeType.accept_header_formats.find do |accept_header_substring, _format| - accept.includes?(accept_header_substring) - end.try(&.[1]) + accept_list = AcceptList.new(accept_header) + accept_list.find_match(Lucky::MimeType.accept_header_formats, accepted_formats, default_format) end end - private def usable_accept_header? : Bool - !!(accept_header && !default_accept_header_that_browsers_send?) - end - private def accept_header : String? accept = request.headers["Accept"]? @@ -89,15 +100,138 @@ class Lucky::MimeType accept end end + end + + class AcceptList + getter list + + ACCEPT_SEP = /[ \t]*,[ \t]*/ + + # Parses the value of an Accept header and returns an array of MediaRanges sorted by + # quality value. + def self.parse(accept : String) : Array(MediaRange) + list = accept.split(ACCEPT_SEP).compact_map do |range| + begin + MediaRange.parse(range) + rescue ex : InvalidMediaRange + Log.debug { "invalid media range in Accept: #{accept} - #{ex}" } + nil + end + end + list.unstable_sort_by! { |range| -range.qvalue.to_i32 } + end + + def initialize(accept : String?) + if accept && !accept.empty? + @list = AcceptList.parse(accept) + else + @list = [] of MediaRange + end + end + + # Find a matching accepted format by accept list priority + def find_match(known_formats : Hash(MediaType, Format), accepted_formats : Array(Symbol), default_format : Symbol) : Symbol? + # If we find a match in the things we accept then pick one of those + formats_in_common = known_formats.select { |_media, format| accepted_formats.includes?(format) } + unless formats_in_common.empty? + self.list.each do |media_range| + if match = formats_in_common.find { |media, _format| media_range.matches?(media) } + return match[1] + end + end + end + + # Otherwise if the client doesn't just accept anything then try to find something they + # do accept in the list of known formats + unless includes_catch_all? + self.list.each do |media_range| + if match = known_formats.find { |media, _format| media_range.matches?(media) } + return match[1] + end + end + + # No known formats match the ones requested + return nil + end + + # Finally the client accepts anything so use the default format + default_format + end + + def includes_catch_all? + @list.any? &.catch_all? + end + end + + class MediaRange + TOKEN = /[!#$%&'*+.^_`|~0-9A-Za-z-]+/ + MEDIA_TYPE = /^(#{TOKEN})\/(#{TOKEN})$/ + PARAM_SEP = /[ \t]*;[ \t]*/ + QVALUE_RE = /^[qQ]=([01][0-9.]*)$/ + + getter type, subtype, qvalue + + def initialize(type : String, @subtype : String, qvalue : UInt16) + if type == "*" && @subtype != "*" + raise InvalidMediaRange.new("#{type}/#{@subtype} is not a valid media range") + end + unless (0..1000).includes?(qvalue) + raise InvalidMediaRange.new("qvalue #{qvalue.to_f32 / 1000f32} is not within 0 to 1.0") + end + + @type = type + @qvalue = qvalue + end + + # Parse a single media range with optional parameters + # https://httpwg.org/specs/rfc9110.html#field.accept + def self.parse(input : String) + parameters = input.split(PARAM_SEP) + media = parameters.shift + + # For now we're only interested in the weight, which must be the last parameter + qvalue = MediaRange.parse_qvalue(parameters.last?) + + if media =~ MEDIA_TYPE + type = $1 + subtype = $2 + MediaRange.new(type.downcase, subtype.downcase, qvalue) + else + raise InvalidMediaRange.new("#{input} is not a valid media range") + end + end + + def self.parse_qvalue(parameter : String?) : UInt16 + if parameter && parameter =~ QVALUE_RE + # qvalues start with 0 or 1 and can have up to three digits after the + # decimal point. To avoid needing to deal with floats, the value is + # muliplied by 1000 and then handled as an integer. + begin + ($1.to_f32 * 1000).round.to_u16 + rescue ArgumentError | OverflowError + raise InvalidMediaRange.new("#{parameter} is not a valid qvalue") + end + else + 1000u16 + end + end + + def ==(other) + @type == other.type && + @subtype == other.subtype && + @qvalue == other.qvalue + end - # This checks if the "Accept" header is from a browser. Browsers typically - # include "*/*" along with other characters in the request's "Accept" header. - # This method handles those intricacies and determines if the header is from - # a browser. - private def default_accept_header_that_browsers_send? : Bool - accept = accept_header + def matches?(media : MediaType) : Bool + @type == "*" || (@type == media.type && self.class.match_type?(@subtype, media.subtype)) + end + + def catch_all? + @type == "*" && @subtype == "*" + end - !!accept && !!(accept =~ /,\s*\*\/\*|\*\/\*\s*,/) + protected def self.match_type?(pattern, value) + pattern == "*" || pattern == value end end end