Jade Dungeon

Scala函数式编程

函数式数据结构

函数式编程的特点是数据结构不可变,函数的操作每次都生成新的值作为返回值, 而不会去修改传入的实参的值。

单向链表

列表类型抽象为特质,内容类型为泛型A

sealed trait List[+A]

空列表,作为列表类子类的单例实现:

case object Nil extends List[Nothing]

非空列表类型,构造器由表头第一个元素和其他元素列表两部分组成:

case class Cons[+A](head: A, tail: List[A]) extends List[A]

列表的常用方法:

/* 返回新列表,删除列表的第一个元素 */
def tail[A](l: List[A]): List[A] = l match {
	case Nil => sys.error("tail of empty list")
	case Cons(_, t) => t
}

/* 返回新列表,替换列表的第一个元素 */
def setHead[A](l: List[A], h: A): List[A] = l match {
	case Nil => sys.error("setHead on empty list")
	case Cons(_, t) => Cons(h, t)
}

/* 返回新列表,删除列表前n个元素 */
def drop[A](l: List[A], n: Int): List[A] = if (n <= 0) l else l match {
	case Nil => Nil
	case Cons(_, t) => drop(t, n - 1)
}

/* 返回新列表,删除列表前缀所有符合条件的元素 */
def dropWhile1[A](l: List[A], f: A => Boolean): List[A] = l match {
	case Cons(h, t) if f(h) => dropWhile1(t, f)
	case _ => l
}

/* 返回新列表,构建列表,以可变长的多个参数拼接成列表 */
def apply[A](as: A*): List[A] = if (as.isEmpty) Nil else {
	Cons(as.head, apply(as.tail: _*))
}

/* 返回新列表,拼接列表,把a1的元素都加到a2里 */
def append[A](a1: List[A], a2: List[A]): List[A] = a1 match {
	case Nil => a2
	case Cons(h, t) => Cons(h, append(t, a2))
}

通过柯里化,在使用用这个函数时,第二个参数的类型可以直接类型推导出来,不用注明类型:

def dropWhile[A](l: List[A])(f: A => Boolean): List[A] = l match {
	case Cons(h, t) if f(h) => dropWhile(t)(f)
	case _ => l
}

val ld = List(1, 2, 3, 4)
val ls = dropWhile(ld)(x => x < 4)

折叠操作

def foldRight2[A, B](as: List[A], z: B)(f: (A, B) => B): B = as match {
	case Nil => z
	case Cons(x, xs) => f(x, foldRight2(xs, z)(f)) // 无法尾递归优化
}

foldRight2函数替换为它的定义,来演示折叠过程:

foldRight2(Cons(1, Cons(2, Cons(3, Nil))), Nil: List[Int])(Cons(_, _))
Cons(1, foldRight2(Cons(2, Cons(3, Nil)), Nil: List[Int])(Cons(_, _)))
Cons(1, Cons(2, foldRight2(Cons(3, Nil), Nil: List[Int])(Cons(_, _))))
Cons(1, Cons(2, Cons(3, foldRight2(Nil, Nil: List[Int])(Cons(_, _)))))
Cons(1, Cons(2, Cons(3, Nil)))

flodRight不是尾递归,改进为可以尾递归优化的foldLeft:

@annotation.tailrec
def foldLeft[A, B](l: List[A], z: B)(f: (B, A) => B): B = l match {
	case Nil => z
	case Cons(h, t) => foldLeft(t, f(z, h))(f)
}

通过foldLeft来实现foldRight

def foldRightViaFoldLeft_1[A, B](l: List[A], z: B)(f: (A, B) => B): B =
	foldLeft(reverse(l), z)((b, a) => f(a, b))

另一种实现的代码:

def foldRightViaFoldLeft_2[A, B](l: List[A], z: B)(f: (A, B) => B): B = {

	def identity(p1: B): B = { p1 }   // 把值包装为函数`B=>B`

	// 匹配到的类型应该为:
	// foldLeft(
	//   l: List[A], fun: B=>B
	// )(
	//   (g: B=>B, a:A) => (B=>B)
	// )
	def func: B => B = foldLeft(
		l, identity(_)    // 把函数`identity: B => B`作为不动点
	)(
		(g, a) => {       // 实参`identity: B=>B`代入形参`g: B=>B`
			b => { g(f(a, b)) }
		}
	)

	func(z)             // `z`作为一开始代入`identity`
}

简写为:

def foldRightViaFoldLeft[A, B](l: List[A], z: B)(f: (A, B) => B): B =
	foldLeft(l, (b: B) => b)((g, a) => b => g(f(a, b)))(z)

通过foldRight来实现foldLeft

def foldLeft2[A, B](l: List[A], z: B)(f: (B, A) => B): B =
	foldRight(l, (b: B) => b)((a, g) => b => g(f(b, a)))(z)

通过foldRight来实现append

def append2[A](l: List[A], r: List[A]): List[A] =
	foldRight(l, r)(Cons(_, _))

常用的折叠操作有三种,主要的区别是fold函数操作遍历问题集合的顺序:

  • foldLeft是从左开始计算。
  • foldRight是从右开始算。
  • fold遍历没有特殊的次序,所以对fold的初始化参数和返回值都有限制。

以Scala自带的源代码来说明:

def fold[A1 >: A](z: A1)(op: (A1, A1) => A1): A1 = foldLeft(z)(op)
 
def foldLeft[B](z: B)(op: (B, A) => B): B = {
  var result = z
  this.seq foreach (x => result = op(result, x))
  result
}
 
def foldRight[B](z: B)(op: (A, B) => B): B =
  reversed.foldLeft(z)((x, y) => op(y, x))

由于fold函数遍历没有特殊的次序,所以对fold的初始化参数和返回值都有限制。 在这三个函数中,初始化参数和返回值的参数类型必须相同。

  • 第一个限制是初始值的类型必须是list中元素类型的超类。在我们的例子中,我们的对 List[Int]进行fold计算,而初始值是Int类型的,它是List[Int]的超类。
  • 第二个限制是初始值必须是中立的(neutral)。也就是它不能改变结果。比如对「数字」 这个范围与「加法」这个操作组成的「范畴」来说,中立的值是0(在范畴论中被称为 幺元),因为任何数加上0都等于它本身;而对于数字与乘法组成的范畴来说, 中立值则是1。
val lst = ch03.List(1, 2, 3)

ch03.List.foldLeft(lst, ch03.Nil: ch03.List[Int])((acc, o) => ch03.Cons(o, acc))
//> res0: fpinscala.ch03.datastructure.List[Int] = Cons(3,Cons(2,Cons(1,Nil)))

ch03.List.foldRight(lst, ch03.Nil: ch03.List[Int])((o, acc) => ch03.Cons(o, acc))
//> res1: fpinscala.ch03.datastructure.List[Int] = Cons(1,Cons(2,Cons(3,Nil)))
/* 加法 */
def sum(l: List[Int]) = foldLeft(l, 0)(_ + _)

/* 乘法 */
def product(l: List[Double]) = foldLeft(l, 1.0)(_ * _)

/* 计算长度 */
def length[A](l: List[A]): Int = foldLeft(l, 0)((acc, h) => acc + 1)

/* 反转列表 */
def reverse[A](l: List[A]): List[A] = 
                        foldLeft(l, List[A]())((acc, h) => Cons(h, acc))
	
// 拼接多个列表为一个列表
def concat[A](l: List[List[A]]): List[A] = foldRight(l, Nil: List[A])(append)

映射操作

// 列表中的每个元素值加1
def add1(l: List[Int]): List[Int] =
	foldRight(l, Nil: List[Int])((h, t) => Cons(h + 1, t))

// 列表中double转为字符串
def doubleToString(l: List[Double]): List[String] =
	foldRight(l, Nil: List[String])((h, t) => Cons(h.toString, t))

def map[A, B](l: List[A])(f: A => B): List[B] =
	foldRight(l, Nil: List[B])((h, t) => Cons(f(h), t))

def map_2[A, B](l: List[A])(f: A => B): List[B] = {
	val buf = new collection.mutable.ListBuffer[B]
	def go(l: List[A]): Unit = l match {
		case Nil => ()
		case Cons(h, t) => buf += f(h); go(t)
	}
	go(l)
	List(buf.toList: _*) // 从Scala内部的List转为我们自己实现的List
}

过滤操作

def filter[A](l: List[A])(f: A => Boolean): List[A] =
	foldRight(l, Nil: List[A])((h, t) => if (f(h)) Cons(h, t) else t)

def filter_1[A](l: List[A])(f: A => Boolean): List[A] =
	foldRightViaFoldLeft(l, Nil: List[A])((h, t) => if (f(h)) Cons(h, t) else t)

def filter_2[A](l: List[A])(f: A => Boolean): List[A] = {
	val buf = new collection.mutable.ListBuffer[A]
	def go(l: List[A]): Unit = l match {
		case Nil => ()
		case Cons(h, t) => if (f(h)) buf += h; go(t)
	}
	go(l)
	List(buf.toList: _*) // 从Scala内部的List转为我们自己实现的List
}

flatMap

flatMap与映射很像,区别是函数f返回的是列表而不是单个结果

def flatMap[A, B](l: List[A])(f: A => List[B]): List[B] =
	concat(map(l)(f))

调用:

flatMap(List(1, 2, 3))(i => List(i, i))
// > List(1, 1, 2, 2, 3, 3)

用flatMap实现Filter

def filterViaFlatMap[A](l: List[A])(f: A => Boolean): List[A] =
	flatMap(l)(a => if (f(a)) List(a) else Nil)

zip操作

把两个列表按索引相同的值加加起来,形成一个新的列表:

def addPairwise(a: List[Int], b: List[Int]): List[Int] = (a, b) match {
	case (Nil, _) => Nil
	case (_, Nil) => Nil
	case (Cons(h1, t1), Cons(h2, t2)) => Cons(h1 + h2, addPairwise(t1, t2))
}

addPairwise(List(1, 2, 3), List(4, 5, 6))
// > List(5, 7, 9)

抽像成更通用的方法zipWith

def zipWith[A, B, C](a: List[A], b: List[B])(f: (A, B) => C): List[C] = 
{
	(a, b) match {
		case (Nil, _) => Nil
		case (_, Nil) => Nil
		case (Cons(h1, t1), Cons(h2, t2)) => {
			Cons(f(h1, h2), zipWith(t1, t2)(f))
		}
	}
}

二叉树

用二元组分别指向左右子树:

sealed trait Tree[+A]

case class Leaf[A](value: A) extends Tree[A]

case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

// 递归遍历统计节点数
def size[A](t: Tree[A]): Int = t match {
	case Leaf(_) => 1
	case Branch(l, r) => 1 + size(l) + size(r)
}

// 找到最大值
def maximum(t: Tree[Int]): Int = t match {
	case Leaf(n) => n
	case Branch(l, r) => maximum(l) max maximum(r)
}

// 层数
def depth[A](t: Tree[A]): Int = t match {
	case Leaf(_) => 0
	case Branch(l, r) => 1 + (depth(l) max depth(r))
}

// 映射
def map[A, B](t: Tree[A])(f: A => B): Tree[B] = t match {
	case Leaf(a) => Leaf(f(a))
	case Branch(l, r) => Branch(map(l)(f), map(r)(f))
}

树的折叠

和列表一样,fold函数通过递归处理Tree类型构造器的参数来折叠树形结构:

def fold[A, B](t: Tree[A])(f: A => B)(g: (B, B) => B): B = t match {
	case Leaf(a) => f(a)
	case Branch(l, r) => g(fold(l)(f)(g), fold(r)(f)(g))
}

def sizeViaFold[A](t: Tree[A]): Int =
	fold(t)(a => 1)(1 + _ + _)

def maximumViaFold(t: Tree[Int]): Int =
	fold(t)(a => a)(_ max _)

def depthViaFold[A](t: Tree[A]): Int =
	fold(t)(a => 0)((d1, d2) => 1 + (d1 max d2))

def mapViaFold[A, B](t: Tree[A])(f: A => B): Tree[B] =
	fold(t)(a => Leaf(f(a)): Tree[B])(Branch(_, _))

对于像是Leaf(f(a))这样的表达式要注明类型,不然Scala的类型推导会出错:

  type mismatch;
    found   : fpinscala.datastructures.Branch[B]
    required: fpinscala.datastructures.Leaf[B]
       fold(t)(a => Leaf(f(a)))(Branch(_,_))
                                      ^

这是Scala使用一个类的子类应用到代数数据类型 (subtyping to encode algebraic data types)时引发的错误。在不注明的情况下fold的 返回值被推导为Leaf[B]。在这个基础上假定fold的第二个函数的返回值类型也是 Leaf[B](但实际上应该是Branch[B])。

从期望上讲,如果Scala的类型推导出Tree[B]是最好的情况,因为这样可以适用到Tree 的所有子类。当在Scala中使用代数数据类型时,常常会定义一些辅助函数直接调用恰当的 构造函数,同时让返回值的类型是更加通用的类型:

def leaf[A](a: A): Tree[A] = Leaf(a)

def branch[A](l: Tree[A], r: Tree[A]): Tree[A] = Branch(l, r)