Skip to content

Commit 8a91774

Browse files
committed
Fix partial functions with non trivial selector
E.g. ``` class Foo(val field: Option[Int]) val p: PartialFunction[Foo, Int] = foo => foo.field match { case Some(x) => x } ```
1 parent d03b226 commit 8a91774

File tree

2 files changed

+27
-22
lines changed

2 files changed

+27
-22
lines changed

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

+19-19
Original file line numberDiff line numberDiff line change
@@ -74,38 +74,38 @@ class ExpandSAMs extends MiniPhase {
7474
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
7575
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
7676

77+
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree) = {
78+
val selector = tree.selector
79+
val selectorTpe = selector.tpe.widen
80+
val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD, Synthetic, selectorTpe)
81+
val defaultCase =
82+
CaseDef(
83+
Bind(defaultSym, Underscore(selectorTpe)),
84+
EmptyTree,
85+
defaultValue)
86+
val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType)))
87+
cpy.Match(tree)(unchecked, cases :+ defaultCase)
88+
.subst(param.symbol :: Nil, pfParam :: Nil)
89+
// Needed because a partial function can be written as:
90+
// param => param match { case "foo" if foo(param) => param }
91+
// And we need to update all references to 'param'
92+
}
93+
7794
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
7895
val tru = Literal(Constant(true))
7996
def translateCase(cdef: CaseDef) =
8097
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
8198
val paramRef = paramRefss.head.head
8299
val defaultValue = Literal(Constant(false))
83-
translateMatch(pf, paramRef, pf.cases.map(translateCase), defaultValue)
100+
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
84101
}
85102

86103
def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
87104
val List(paramRef, defaultRef) = paramRefss.head
88105
def translateCase(cdef: CaseDef) =
89106
cdef.changeOwner(anonSym, applyOrElseFn)
90107
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
91-
translateMatch(pf, paramRef, pf.cases.map(translateCase), defaultValue)
92-
}
93-
94-
def translateMatch(tree: Match, selector: Tree, cases: List[CaseDef], defaultValue: Tree) = {
95-
assert(tree.selector.symbol == param.symbol)
96-
val selectorTpe = selector.tpe.widen
97-
val defaultSym = ctx.newSymbol(selector.symbol.owner, nme.WILDCARD, Synthetic, selectorTpe)
98-
val defaultCase =
99-
CaseDef(
100-
Bind(defaultSym, Underscore(selectorTpe)),
101-
EmptyTree,
102-
defaultValue)
103-
val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType)))
104-
cpy.Match(tree)(unchecked, cases :+ defaultCase)
105-
.subst(param.symbol :: Nil, selector.symbol :: Nil)
106-
// Needed because a partial function can be written as:
107-
// param => param match { case "foo" if foo(param) => param }
108-
// And we need to update all references to 'param'
108+
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
109109
}
110110

111111
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))

tests/run/partialFunctions.scala

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
object Test {
22

3-
def takesPartialFunction(a: PartialFunction[Int, Int]) = a(1)
3+
def takesPartialFunction(a: PartialFunction[Int, Int]) = a(1)
4+
class Foo(val field: Option[Int])
45

56
def main(args: Array[String]): Unit = {
6-
val partialFunction: PartialFunction[Int, Int] = {case a: Int => a}
7+
val p1: PartialFunction[Int, Int] = { case a: Int => a }
8+
assert(takesPartialFunction(p1) == 1)
79

8-
assert(takesPartialFunction(partialFunction) == 1)
10+
val p2: PartialFunction[Foo, Int] =
11+
foo => foo.field match { case Some(x) => x }
12+
assert(p2.isDefinedAt(new Foo(Some(1))))
13+
assert(!p2.isDefinedAt(new Foo(None)))
914
}
1015
}

0 commit comments

Comments
 (0)