From ee2bced55fef1182b93c038f70a9749bec452a1d Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Mon, 31 Jul 2017 08:36:43 +0100 Subject: [PATCH] Implement ratelimiting on external calls per user (and on tell) --- dave/dave.py | 31 +++++++++++++++++++++++-------- dave/module.py | 16 ++++++++++++++++ dave/modules/pollen.py | 1 + dave/modules/reddit.py | 3 +++ dave/modules/speedtest.py | 1 + dave/modules/stock.py | 1 + dave/modules/tell.py | 1 + dave/modules/title.py | 1 + dave/modules/urbandictionary.py | 1 + dave/modules/weather.py | 1 + dave/modules/wolfram.py | 1 + dave/modules/youtube.py | 1 + dave/ratelimit.py | 30 ++++++++++++++++++++++++++++++ 13 files changed, 81 insertions(+), 8 deletions(-) create mode 100644 dave/ratelimit.py diff --git a/dave/dave.py b/dave/dave.py index c0e5fe2..cb9c920 100644 --- a/dave/dave.py +++ b/dave/dave.py @@ -14,6 +14,7 @@ import re import subprocess import dave.config as config import requests +from dave.ratelimit import ratelimit class Dave(irc.IRCClient): @@ -57,7 +58,8 @@ class Dave(irc.IRCClient): def privmsg(self, user, channel, msg): """This will get called when the bot receives a message.""" nick = user.split("!", 1)[0] - log.msg("<{}> {}".format(nick, msg)) + userhost = user.split("!", 1)[1] + log.msg("<{}> {}".format(user, msg)) path = modules.__path__ prefix = "{}.".format(modules.__name__) @@ -97,16 +99,29 @@ class Dave(irc.IRCClient): if hasattr(val, "always_run"): run.append((val, match.groups())) else: - method = (priority, val, match.groups()) + method = (priority, val, match.groups(), + rule["named"]) + + ignore_dont_always_run = False if method[1] is not None: # we matched a command - deferToThread(method[1], self, method[2], nick, channel) - - if not (hasattr(method[1], "dont_always_run") and method[1].dont_always_run): - # if dont_always_run is set, the command the user sent doesn't - # want "always run" modules to run. - for m in run: + if ratelimit(method[1], userhost): + # ratelimit returned true, we can run our function! + deferToThread(method[1], self, method[2], nick, channel) + elif method[3]: + # if this was a direct command to the bot, tell them they've been r/l'd + self.reply(channel, nick, "You have been ratelimited for this command.") + else: + # if it wasn't, let the always_run functions run. + ignore_dont_always_run = True + + if method[1] is None or ignore_dont_always_run or \ + not (hasattr(method[1], "dont_always_run") and method[1].dont_always_run): + # if dont_always_run is set, the command the user sent doesn't + # want "always run" modules to run. + for m in run: + if not hasattr(m[0], "ratelimit") or ratelimit(m[0], userhost): # modules that should always be run regardless of priority deferToThread(m[0], self, m[1], nick, channel) diff --git a/dave/module.py b/dave/module.py index 52e8bc2..78a321b 100644 --- a/dave/module.py +++ b/dave/module.py @@ -5,6 +5,22 @@ from enum import Enum import dave.config as config +def ratelimit(value, per): + """Decorate a function to be ratelimited by the command processor + + Args: + value: accepted amount of requests per "per" seconds + per: amount of seconds before the ratelimit should be cleared + """ + def add_attribute(function): + function.ratelimit = { + "value": value, + "per": per + } + return function + + return add_attribute + def match(value): """Decorate a function to be called whenever a message matches the given pattern. diff --git a/dave/modules/pollen.py b/dave/modules/pollen.py index 12aff4f..0e7d5a6 100644 --- a/dave/modules/pollen.py +++ b/dave/modules/pollen.py @@ -10,6 +10,7 @@ import dave.config @dave.module.help("Syntax: pollen [first part of postcode]. Get the forecast in the specified location. Only works for UK postcodes.") @dave.module.command(["pollen"], "(([gG][iI][rR] {0,}0[aA]{2})|((([a-pr-uwyzA-PR-UWYZ][a-hk-yA-HK-Y]?[0-9][0-9]?)|(([a-pr-uwyzA-PR-UWYZ][0-9][a-hjkstuwA-HJKSTUW])|([a-pr-uwyzA-PR-UWYZ][a-hk-yA-HK-Y][0-9][abehmnprv-yABEHMNPRV-Y])))))$") @dave.module.priority(dave.module.Priority.HIGHEST) +@dave.module.ratelimit(1, 1) def pollen(bot, args, sender, source): postcode = args[0].lower() diff --git a/dave/modules/reddit.py b/dave/modules/reddit.py index 487915d..629b897 100644 --- a/dave/modules/reddit.py +++ b/dave/modules/reddit.py @@ -9,6 +9,7 @@ from humanize import naturaltime, naturaldelta, intcomma @dave.module.match(r'.*(?:https?://(?:www\.)?reddit.com)?(/r/(.+)/comments/([^\s]+)).*') @dave.module.match(r'.*https?://(?:www\.)?redd.it/([^\s]+).*') +@dave.module.ratelimit(1, 1) @dave.module.dont_always_run_if_run() def post(bot, args, sender, source): """Ran whenever a reddit post is sent""" @@ -44,6 +45,7 @@ def post(bot, args, sender, source): )) @dave.module.match(r'.*(?:https?://(?:www\.)?reddit.com)?/?r/(([^\s/]+))/?(?: |$).*') +@dave.module.ratelimit(1, 1) @dave.module.dont_always_run_if_run() def subreddit(bot, args, sender, source): """Ran whenever a subreddit is mentioned""" @@ -83,6 +85,7 @@ def subreddit(bot, args, sender, source): @dave.module.match(r'.*(?:https?://(?:www\.)?reddit.com)?/?(?:u|user)/(([^\s]+)/?)(?: |$).*') +@dave.module.ratelimit(1, 1) @dave.module.dont_always_run_if_run() def user(bot, args, sender, source): if not dave.config.redis.exists("reddit:user:{}".format(args[0])): diff --git a/dave/modules/speedtest.py b/dave/modules/speedtest.py index 8c520ee..bbb95c6 100644 --- a/dave/modules/speedtest.py +++ b/dave/modules/speedtest.py @@ -5,6 +5,7 @@ import dave.module from twisted.words.protocols.irc import assembleFormattedText, attributes as A @dave.module.match(r'.*https?://(?:www\.|beta\.)?speedtest\.net/(?:my-)?result/([0-9]+)(?:.png)?.*') +@dave.module.ratelimit(2, 2) @dave.module.dont_always_run_if_run() def speedtest(bot, args, sender, source): res = get("http://www.speedtest.net/result/{}".format(args[0]), timeout=3) diff --git a/dave/modules/stock.py b/dave/modules/stock.py index 5bafb16..56c613e 100644 --- a/dave/modules/stock.py +++ b/dave/modules/stock.py @@ -9,6 +9,7 @@ from twisted.words.protocols.irc import assembleFormattedText, attributes as A @dave.module.help("Syntax: stock [symbol].") +@dave.module.ratelimit(1, 1) @dave.module.command(["stock"], "([a-zA-Z.]+)") def stock(bot, args, sender, source): try: diff --git a/dave/modules/tell.py b/dave/modules/tell.py index afbb22a..89082d9 100644 --- a/dave/modules/tell.py +++ b/dave/modules/tell.py @@ -7,6 +7,7 @@ import pickle @dave.module.help("Syntax: tell [user] [message]. Tell a user something when we next " "see them") @dave.module.command(["tell"], "([A-Za-z_\-\[\]\\^{}|`][A-Za-z0-9_\-\[\]\\^{}|`]*) (.*)") +@dave.module.ratelimit(1, 3) def tell(bot, args, sender, source): dave.config.redis.lpush("tell:{}".format(args[0].lower()), pickle.dumps({ "sender": sender, diff --git a/dave/modules/title.py b/dave/modules/title.py index 928f285..3045294 100644 --- a/dave/modules/title.py +++ b/dave/modules/title.py @@ -12,6 +12,7 @@ parse = re.compile(r"(?:(?:https?):\/\/)(?:\S+(?::\S*)?@)?(?:(?!(?:10|127)(?:\.\ @dave.module.match(r"(.*)") @dave.module.always_run() +@dave.module.ratelimit(2, 2) def link_parse(bot, args, sender, source): matches = parse.findall(args[0]) diff --git a/dave/modules/urbandictionary.py b/dave/modules/urbandictionary.py index b6f6ec7..385ce24 100644 --- a/dave/modules/urbandictionary.py +++ b/dave/modules/urbandictionary.py @@ -12,6 +12,7 @@ from twisted.words.protocols.irc import assembleFormattedText, attributes as A @dave.module.help("Get results for an urbandictionary query. Syntax: urban [result #] (query)") @dave.module.command(["urbandictionary", "ub", "urban"], "(\d+ )?([a-zA-Z0-9 ]+)$") @dave.module.priority(dave.module.Priority.HIGHEST) +@dave.module.ratelimit(1, 1) def urbandictionary(bot, args, sender, source): result = int(args[0].strip()) - 1 if args[0] else 0 query = args[1].strip().lower() diff --git a/dave/modules/weather.py b/dave/modules/weather.py index 884eae9..b3ecc57 100644 --- a/dave/modules/weather.py +++ b/dave/modules/weather.py @@ -14,6 +14,7 @@ import socket @dave.module.help("Syntax: weather [location]. Get the forecast in the specified location.") @dave.module.command(["weather"], "?( .*)?$") @dave.module.priority(dave.module.Priority.HIGHEST) +@dave.module.ratelimit(1, 5) def weather(bot, args, sender, source): location = args[0] diff --git a/dave/modules/wolfram.py b/dave/modules/wolfram.py index c9a0070..5e0c547 100644 --- a/dave/modules/wolfram.py +++ b/dave/modules/wolfram.py @@ -9,6 +9,7 @@ from twisted.words.protocols.irc import assembleFormattedText, attributes as A @dave.module.help("Query the Wolfram API and return the result back to the user.") @dave.module.command(["wolfram", "w", "wolframalpha", "wa"], "(.+)$") @dave.module.priority(dave.module.Priority.HIGHEST) +@dave.module.ratelimit(1, 3) def wolfram(bot, args, sender, source): query = args[0].strip() diff --git a/dave/modules/youtube.py b/dave/modules/youtube.py index 073eb08..ec39881 100644 --- a/dave/modules/youtube.py +++ b/dave/modules/youtube.py @@ -13,6 +13,7 @@ BASE_URL = "https://www.googleapis.com/youtube/v3/videos?part=contentDetails,sni @dave.module.match(r'.*https?://(?:www\.)?youtu(?:be\.com/watch\?v=|\.be/)([\w\-\_]*)(&(amp;)?[\w\=]*)?.*') @dave.module.dont_always_run_if_run() +@dave.module.ratelimit(1, 1) def youtubevideo(bot, args, sender, source): """Ran whenever a YouTube video is sent""" if not dave.config.redis.exists("youtube:{}".format(args[0])): diff --git a/dave/ratelimit.py b/dave/ratelimit.py new file mode 100644 index 0000000..d577753 --- /dev/null +++ b/dave/ratelimit.py @@ -0,0 +1,30 @@ +from dave.config import redis + +def ratelimit(fun, userhost): + """ + Ratelimit a function + + :param fun: Function to ratelimit + :param userhost: Host of the user + :return: True, if this function is allowed to be executed + """ + if not hasattr(fun, "ratelimit"): + return True + + # how many requests are allowed per "per" seconds + value = fun.ratelimit["value"] + # how long before the ratelimit is reset + per = fun.ratelimit["per"] + + key = "ratelimit:{}:{}:{}".format(userhost, fun.__module__, fun.__qualname__) + + if not redis.exists(key): + # ratelimit doesn't exist, make a new one + redis.setex(key, per, 0) + elif int(redis.get(key)) >= value: + # ratelimit has been exceed + return False + + redis.incr(key) + return True + -- libgit2 1.7.2