| Paste number 93940: | Compile-time type-checking multimethods |
| Pasted by: | stuartsierra |
| When: | 6 months, 2 days ago |
| Share: | Tweet this! | http://paste.lisp.org/+20HG |
| Channel: | #clojure |
| Paste contents: |
;;; COMPILE-TIME TYPE-CHECKING FOR MULTIMETHODS
;; These macros create a multimethod with compile-time type checking.
;; If the multimethod is called with literal arguments, the correct
;; method expansion will selected at compile-time and inlined.
;;
;; If the argument types cannot be determined at compile time, it will
;; fall back to a normal multimethod with run-time dispatch.
;;
;; This could be integrated with the Clojure compiler's type-hinting
;; for fast multiple-argument dispatch.
(defn- impl-map-name [sym]
(symbol (str (name sym) "**")))
(defn- multimethod-name [sym]
(symbol (str (name sym) "*")))
(defn- expand-multi-typed [sym argv args]
(let [types (vec (map type args))
expansion (get (var-get (resolve (impl-map-name sym)))
types)]
(if expansion
`(let ~(vec (interleave argv args))
~@expansion)
(list* (multimethod-name sym) args))))
(defmacro defmulti-typed [sym argv]
`(do (defmulti ~(multimethod-name sym)
(fn ~argv ~(vec (map (fn [arg] `(type ~arg))
argv))))
(def ~(impl-map-name sym) {})
(definline ~sym ~argv
(expand-multi-typed '~sym '~argv ~argv))))
(defmacro defmethod-typed [sym types argv & body]
`(do (defmethod ~(multimethod-name sym)
~types ~argv
(println "Multimethod called")
~@body)
(alter-var-root (var ~(impl-map-name sym))
assoc ~(vec (map resolve types))
'~body)))
;;; NORMAL MULTIMETHOD
(defmulti add (fn [x y] [(type x) (type y)]))
(defmethod add [Integer Integer] [x y]
(+ x y))
(defmethod add [Double Double] [x y]
(+ x y))
(defmethod add [Double Integer] [x y]
(+ x y))
(defmethod add [Integer Double] [x y]
(+ x y))
;;; COMPILE-TIME TYPED MULTIMETHOD
(defmulti-typed add-typed [x y])
(defmethod-typed add-typed [Integer Integer] [x y]
(+ x y))
(defmethod-typed add-typed [Double Double] [x y]
(+ x y))
(defmethod-typed add-typed [Double Integer] [x y]
(+ x y))
(defmethod-typed add-typed [Integer Double] [x y]
(+ x y))
;;; MICROBENCHMARK
(defn benchmark []
(println "Normal multimethod")
(dotimes [i 5]
(time (dotimes [j 100000]
(add 1 2)
(add 3.0 4.0)
(add 1.0 2)
(add 2 1.0))))
(println "Compile-time typed multimethod")
(dotimes [i 5]
(time (dotimes [j 100000]
(add-typed 1 2)
(add-typed 3.0 4.0)
(add-typed 1.0 2)
(add-typed 2 1.0)))))
;; user> (benchmark)
;; Normal multimethod
;; "Elapsed time: 790.124804 msecs"
;; "Elapsed time: 477.639847 msecs"
;; "Elapsed time: 189.584692 msecs"
;; "Elapsed time: 186.345104 msecs"
;; "Elapsed time: 184.11397 msecs"
;; Compile-time typed multimethod
;; "Elapsed time: 27.185498 msecs"
;; "Elapsed time: 15.220522 msecs"
;; "Elapsed time: 12.525659 msecs"
;; "Elapsed time: 12.378429 msecs"
;; "Elapsed time: 13.43024 msecs"
This paste has no annotations.