aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoralandonovan <adonovan@google.com>2017-10-19 10:18:36 -0400
committerGitHub <noreply@github.com>2017-10-19 10:18:36 -0400
commitf3709075f99e3c2168378dd5f537f39b146b3eff (patch)
tree2567f1409f1725a4a3784747d4f209bacdef5be2
parent5cb1ab577184d6f26040b58b71de23f1ff20035c (diff)
downloadstarlark-go-f3709075f99e3c2168378dd5f537f39b146b3eff.tar.gz
set: in x|y, require that y is a set if x is a set (#29)
The old functionality can still be accessed using x.union(y).
-rw-r--r--doc/spec.md7
-rw-r--r--eval.go3
-rw-r--r--testdata/set.sky27
3 files changed, 22 insertions, 15 deletions
diff --git a/doc/spec.md b/doc/spec.md
index 08d3815..e8216da 100644
--- a/doc/spec.md
+++ b/doc/spec.md
@@ -1813,7 +1813,7 @@ String interpolation
Sets
int | int # bitwise union (OR)
- set | iterable # set union
+ set | set # set union
int & int # bitwise intersection (AND)
set & set # set intersection
```
@@ -1864,9 +1864,7 @@ elements of the operand sets, preserving the element order of the left
operand.
The `|` operator likewise computes bitwise or set unions.
-However, if the left operand of `|` is a set, the right operand may be
-any iterable, not necessarily another set.
-The result of `set | iterable` is a new set whose elements are the
+The result of `set | set` is a new set whose elements are the
union of the operands, preserving the order of the elements of the
operands, left before right.
@@ -1876,7 +1874,6 @@ operands, left before right.
set([1, 2]) & set([2, 3]) # set([2])
set([1, 2]) | set([2, 3]) # set([1, 2, 3])
-set([1, 2]) | [2,3] # set([1, 2, 3])
```
<b>Implementation note:</b>
diff --git a/eval.go b/eval.go
index d50a129..fa94569 100644
--- a/eval.go
+++ b/eval.go
@@ -1178,7 +1178,8 @@ func Binary(op syntax.Token, x, y Value) (Value, error) {
return x.Or(y), nil
}
case *Set: // union
- if iter := Iterate(y); iter != nil {
+ if y, ok := y.(*Set); ok {
+ iter := Iterate(y)
defer iter.Done()
return x.Union(iter)
}
diff --git a/testdata/set.sky b/testdata/set.sky
index 7d96a69..ed67010 100644
--- a/testdata/set.sky
+++ b/testdata/set.sky
@@ -31,8 +31,10 @@ assert.eq(type(set([1, 3, 2, 3])), "set")
assert.eq(list(set([1, 3, 2, 3])), [1, 3, 2])
assert.eq(type(set("hello".split_bytes())), "set")
assert.eq(list(set("hello".split_bytes())), ["h", "e", "l", "o"])
+assert.eq(list(set(range(3))), [0, 1, 2])
assert.fails(lambda: set(1), "got int, want iterable")
assert.fails(lambda: set(1, 2, 3), "got 3 arguments")
+assert.fails(lambda: set([1, 2, {}]), "unhashable type: dict")
# truth
assert.true(not set())
@@ -45,31 +47,38 @@ y = set([3, 4, 5])
# set + any is not defined
assert.fails(lambda: x + y, "unknown.*: set \+ set")
-# union, set | iterable
+# set | set
assert.eq(list(set("a".split_bytes()) | set("b".split_bytes())), ["a", "b"])
assert.eq(list(set("ab".split_bytes()) | set("bc".split_bytes())), ["a", "b", "c"])
-assert.eq(list(set("ab".split_bytes()) | "bc".split_bytes()), ["a", "b", "c"])
+assert.fails(lambda: set() | [], "unknown binary op: set | list")
assert.eq(type(x | y), "set")
assert.eq(list(x | y), [1, 2, 3, 4, 5])
-assert.eq(list(x | [5, 1]), [1, 2, 3, 5])
-assert.eq(list(x | (6, 5, 4)), [1, 2, 3, 6, 5, 4])
-assert.fails(lambda: x | [1, 2, {}], "unhashable type: dict")
+assert.eq(list(x | set([5, 1])), [1, 2, 3, 5])
+assert.eq(list(x | set((6, 5, 4))), [1, 2, 3, 6, 5, 4])
+
+# set.union (allows any iterable for right operand)
+assert.eq(list(set("a".split_bytes()).union("b".split_bytes())), ["a", "b"])
+assert.eq(list(set("ab".split_bytes()).union("bc".split_bytes())), ["a", "b", "c"])
+assert.eq(set().union([]), set())
+assert.eq(type(x.union(y)), "set")
+assert.eq(list(x.union(y)), [1, 2, 3, 4, 5])
+assert.eq(list(x.union([5, 1])), [1, 2, 3, 5])
+assert.eq(list(x.union((6, 5, 4))), [1, 2, 3, 6, 5, 4])
+assert.fails(lambda: x.union([1, 2, {}]), "unhashable type: dict")
# intersection, set & set
assert.eq(list(set("a".split_bytes()) & set("b".split_bytes())), [])
assert.eq(list(set("ab".split_bytes()) & set("bc".split_bytes())), ["b"])
-# set.union
-assert.eq(list(x.union(y)), [1, 2, 3, 4, 5])
-
# len
assert.eq(len(x), 3)
assert.eq(len(y), 3)
assert.eq(len(x | y), 5)
# str
-# TODO(adonovan): make output deterministic when len > 1?
assert.eq(str(set([1])), "set([1])")
+assert.eq(str(set([2, 3])), "set([2, 3])")
+assert.eq(str(set([3, 2])), "set([3, 2])")
# comparison
assert.eq(x, x)