Scala函数式编程
函数式数据结构
函数式编程的特点是数据结构不可变,函数的操作每次都生成新的值作为返回值, 而不会去修改传入的实参的值。
单向链表
列表类型抽象为特质,内容类型为泛型A
:
sealed trait List[+A]
空列表,作为列表类子类的单例实现:
case object Nil extends List[Nothing]
非空列表类型,构造器由表头第一个元素和其他元素列表两部分组成:
case class Cons[+A](head: A, tail: List[A]) extends List[A]
列表的常用方法:
/* 返回新列表,删除列表的第一个元素 */ def tail[A](l: List[A]): List[A] = l match { case Nil => sys.error("tail of empty list") case Cons(_, t) => t } /* 返回新列表,替换列表的第一个元素 */ def setHead[A](l: List[A], h: A): List[A] = l match { case Nil => sys.error("setHead on empty list") case Cons(_, t) => Cons(h, t) } /* 返回新列表,删除列表前n个元素 */ def drop[A](l: List[A], n: Int): List[A] = if (n <= 0) l else l match { case Nil => Nil case Cons(_, t) => drop(t, n - 1) } /* 返回新列表,删除列表前缀所有符合条件的元素 */ def dropWhile1[A](l: List[A], f: A => Boolean): List[A] = l match { case Cons(h, t) if f(h) => dropWhile1(t, f) case _ => l } /* 返回新列表,构建列表,以可变长的多个参数拼接成列表 */ def apply[A](as: A*): List[A] = if (as.isEmpty) Nil else { Cons(as.head, apply(as.tail: _*)) } /* 返回新列表,拼接列表,把a1的元素都加到a2里 */ def append[A](a1: List[A], a2: List[A]): List[A] = a1 match { case Nil => a2 case Cons(h, t) => Cons(h, append(t, a2)) }
通过柯里化,在使用用这个函数时,第二个参数的类型可以直接类型推导出来,不用注明类型:
def dropWhile[A](l: List[A])(f: A => Boolean): List[A] = l match { case Cons(h, t) if f(h) => dropWhile(t)(f) case _ => l } val ld = List(1, 2, 3, 4) val ls = dropWhile(ld)(x => x < 4)
折叠操作
def foldRight2[A, B](as: List[A], z: B)(f: (A, B) => B): B = as match { case Nil => z case Cons(x, xs) => f(x, foldRight2(xs, z)(f)) // 无法尾递归优化 }
foldRight2
函数替换为它的定义,来演示折叠过程:
foldRight2(Cons(1, Cons(2, Cons(3, Nil))), Nil: List[Int])(Cons(_, _)) Cons(1, foldRight2(Cons(2, Cons(3, Nil)), Nil: List[Int])(Cons(_, _))) Cons(1, Cons(2, foldRight2(Cons(3, Nil), Nil: List[Int])(Cons(_, _)))) Cons(1, Cons(2, Cons(3, foldRight2(Nil, Nil: List[Int])(Cons(_, _))))) Cons(1, Cons(2, Cons(3, Nil)))
flodRight不是尾递归,改进为可以尾递归优化的foldLeft:
@annotation.tailrec def foldLeft[A, B](l: List[A], z: B)(f: (B, A) => B): B = l match { case Nil => z case Cons(h, t) => foldLeft(t, f(z, h))(f) }
通过foldLeft
来实现foldRight
:
def foldRightViaFoldLeft_1[A, B](l: List[A], z: B)(f: (A, B) => B): B = foldLeft(reverse(l), z)((b, a) => f(a, b))
另一种实现的代码:
def foldRightViaFoldLeft_2[A, B](l: List[A], z: B)(f: (A, B) => B): B = { def identity(p1: B): B = { p1 } // 把值包装为函数`B=>B` // 匹配到的类型应该为: // foldLeft( // l: List[A], fun: B=>B // )( // (g: B=>B, a:A) => (B=>B) // ) def func: B => B = foldLeft( l, identity(_) // 把函数`identity: B => B`作为不动点 )( (g, a) => { // 实参`identity: B=>B`代入形参`g: B=>B` b => { g(f(a, b)) } } ) func(z) // `z`作为一开始代入`identity` }
简写为:
def foldRightViaFoldLeft[A, B](l: List[A], z: B)(f: (A, B) => B): B = foldLeft(l, (b: B) => b)((g, a) => b => g(f(a, b)))(z)
通过foldRight
来实现foldLeft
:
def foldLeft2[A, B](l: List[A], z: B)(f: (B, A) => B): B = foldRight(l, (b: B) => b)((a, g) => b => g(f(b, a)))(z)
通过foldRight
来实现append
:
def append2[A](l: List[A], r: List[A]): List[A] = foldRight(l, r)(Cons(_, _))
常用的折叠操作有三种,主要的区别是fold函数操作遍历问题集合的顺序:
- foldLeft是从左开始计算。
- foldRight是从右开始算。
- fold遍历没有特殊的次序,所以对fold的初始化参数和返回值都有限制。
以Scala自带的源代码来说明:
def fold[A1 >: A](z: A1)(op: (A1, A1) => A1): A1 = foldLeft(z)(op) def foldLeft[B](z: B)(op: (B, A) => B): B = { var result = z this.seq foreach (x => result = op(result, x)) result } def foldRight[B](z: B)(op: (A, B) => B): B = reversed.foldLeft(z)((x, y) => op(y, x))
由于fold函数遍历没有特殊的次序,所以对fold的初始化参数和返回值都有限制。 在这三个函数中,初始化参数和返回值的参数类型必须相同。
-
第一个限制是初始值的类型必须是list中元素类型的超类。在我们的例子中,我们的对
List[Int]
进行fold计算,而初始值是Int类型的,它是List[Int]
的超类。 - 第二个限制是初始值必须是中立的(neutral)。也就是它不能改变结果。比如对「数字」 这个范围与「加法」这个操作组成的「范畴」来说,中立的值是0(在范畴论中被称为 幺元),因为任何数加上0都等于它本身;而对于数字与乘法组成的范畴来说, 中立值则是1。
val lst = ch03.List(1, 2, 3) ch03.List.foldLeft(lst, ch03.Nil: ch03.List[Int])((acc, o) => ch03.Cons(o, acc)) //> res0: fpinscala.ch03.datastructure.List[Int] = Cons(3,Cons(2,Cons(1,Nil))) ch03.List.foldRight(lst, ch03.Nil: ch03.List[Int])((o, acc) => ch03.Cons(o, acc)) //> res1: fpinscala.ch03.datastructure.List[Int] = Cons(1,Cons(2,Cons(3,Nil)))
/* 加法 */ def sum(l: List[Int]) = foldLeft(l, 0)(_ + _) /* 乘法 */ def product(l: List[Double]) = foldLeft(l, 1.0)(_ * _) /* 计算长度 */ def length[A](l: List[A]): Int = foldLeft(l, 0)((acc, h) => acc + 1) /* 反转列表 */ def reverse[A](l: List[A]): List[A] = foldLeft(l, List[A]())((acc, h) => Cons(h, acc)) // 拼接多个列表为一个列表 def concat[A](l: List[List[A]]): List[A] = foldRight(l, Nil: List[A])(append)
映射操作
// 列表中的每个元素值加1 def add1(l: List[Int]): List[Int] = foldRight(l, Nil: List[Int])((h, t) => Cons(h + 1, t)) // 列表中double转为字符串 def doubleToString(l: List[Double]): List[String] = foldRight(l, Nil: List[String])((h, t) => Cons(h.toString, t)) def map[A, B](l: List[A])(f: A => B): List[B] = foldRight(l, Nil: List[B])((h, t) => Cons(f(h), t)) def map_2[A, B](l: List[A])(f: A => B): List[B] = { val buf = new collection.mutable.ListBuffer[B] def go(l: List[A]): Unit = l match { case Nil => () case Cons(h, t) => buf += f(h); go(t) } go(l) List(buf.toList: _*) // 从Scala内部的List转为我们自己实现的List }
过滤操作
def filter[A](l: List[A])(f: A => Boolean): List[A] = foldRight(l, Nil: List[A])((h, t) => if (f(h)) Cons(h, t) else t) def filter_1[A](l: List[A])(f: A => Boolean): List[A] = foldRightViaFoldLeft(l, Nil: List[A])((h, t) => if (f(h)) Cons(h, t) else t) def filter_2[A](l: List[A])(f: A => Boolean): List[A] = { val buf = new collection.mutable.ListBuffer[A] def go(l: List[A]): Unit = l match { case Nil => () case Cons(h, t) => if (f(h)) buf += h; go(t) } go(l) List(buf.toList: _*) // 从Scala内部的List转为我们自己实现的List }
flatMap
flatMap
与映射很像,区别是函数f
返回的是列表而不是单个结果
def flatMap[A, B](l: List[A])(f: A => List[B]): List[B] = concat(map(l)(f))
调用:
flatMap(List(1, 2, 3))(i => List(i, i)) // > List(1, 1, 2, 2, 3, 3)
用flatMap实现Filter
def filterViaFlatMap[A](l: List[A])(f: A => Boolean): List[A] = flatMap(l)(a => if (f(a)) List(a) else Nil)
zip操作
把两个列表按索引相同的值加加起来,形成一个新的列表:
def addPairwise(a: List[Int], b: List[Int]): List[Int] = (a, b) match { case (Nil, _) => Nil case (_, Nil) => Nil case (Cons(h1, t1), Cons(h2, t2)) => Cons(h1 + h2, addPairwise(t1, t2)) } addPairwise(List(1, 2, 3), List(4, 5, 6)) // > List(5, 7, 9)
抽像成更通用的方法zipWith
:
def zipWith[A, B, C](a: List[A], b: List[B])(f: (A, B) => C): List[C] = { (a, b) match { case (Nil, _) => Nil case (_, Nil) => Nil case (Cons(h1, t1), Cons(h2, t2)) => { Cons(f(h1, h2), zipWith(t1, t2)(f)) } } }
二叉树
用二元组分别指向左右子树:
sealed trait Tree[+A] case class Leaf[A](value: A) extends Tree[A] case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A] // 递归遍历统计节点数 def size[A](t: Tree[A]): Int = t match { case Leaf(_) => 1 case Branch(l, r) => 1 + size(l) + size(r) } // 找到最大值 def maximum(t: Tree[Int]): Int = t match { case Leaf(n) => n case Branch(l, r) => maximum(l) max maximum(r) } // 层数 def depth[A](t: Tree[A]): Int = t match { case Leaf(_) => 0 case Branch(l, r) => 1 + (depth(l) max depth(r)) } // 映射 def map[A, B](t: Tree[A])(f: A => B): Tree[B] = t match { case Leaf(a) => Leaf(f(a)) case Branch(l, r) => Branch(map(l)(f), map(r)(f)) }
树的折叠
和列表一样,fold
函数通过递归处理Tree
类型构造器的参数来折叠树形结构:
def fold[A, B](t: Tree[A])(f: A => B)(g: (B, B) => B): B = t match { case Leaf(a) => f(a) case Branch(l, r) => g(fold(l)(f)(g), fold(r)(f)(g)) } def sizeViaFold[A](t: Tree[A]): Int = fold(t)(a => 1)(1 + _ + _) def maximumViaFold(t: Tree[Int]): Int = fold(t)(a => a)(_ max _) def depthViaFold[A](t: Tree[A]): Int = fold(t)(a => 0)((d1, d2) => 1 + (d1 max d2)) def mapViaFold[A, B](t: Tree[A])(f: A => B): Tree[B] = fold(t)(a => Leaf(f(a)): Tree[B])(Branch(_, _))
对于像是Leaf(f(a))
这样的表达式要注明类型,不然Scala的类型推导会出错:
type mismatch; found : fpinscala.datastructures.Branch[B] required: fpinscala.datastructures.Leaf[B] fold(t)(a => Leaf(f(a)))(Branch(_,_)) ^
这是Scala使用一个类的子类应用到代数数据类型
(subtyping to encode algebraic data types)时引发的错误。在不注明的情况下fold的
返回值被推导为Leaf[B]
。在这个基础上假定fold
的第二个函数的返回值类型也是
Leaf[B]
(但实际上应该是Branch[B]
)。
从期望上讲,如果Scala的类型推导出Tree[B]
是最好的情况,因为这样可以适用到Tree
的所有子类。当在Scala中使用代数数据类型时,常常会定义一些辅助函数直接调用恰当的
构造函数,同时让返回值的类型是更加通用的类型:
def leaf[A](a: A): Tree[A] = Leaf(a) def branch[A](l: Tree[A], r: Tree[A]): Tree[A] = Branch(l, r)