// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build goexperiment.rangefunc package iter import ( "fmt" "runtime" "testing" ) func count(n int) Seq[int] { return func(yield func(int) bool) { for i := range n { if !yield(i) { break } } } } func squares(n int) Seq2[int, int64] { return func(yield func(int, int64) bool) { for i := range n { if !yield(i, int64(i)*int64(i)) { break } } } } func TestPull(t *testing.T) { for end := 0; end <= 3; end++ { t.Run(fmt.Sprint(end), func(t *testing.T) { ng := runtime.NumGoroutine() wantNG := func(want int) { if xg := runtime.NumGoroutine() - ng; xg != want { t.Helper() t.Errorf("have %d extra goroutines, want %d", xg, want) } } wantNG(0) next, stop := Pull(count(3)) wantNG(1) for i := range end { v, ok := next() if v != i || ok != true { t.Fatalf("next() = %d, %v, want %d, %v", v, ok, i, true) } wantNG(1) } wantNG(1) if end < 3 { stop() wantNG(0) } for range 2 { v, ok := next() if v != 0 || ok != false { t.Fatalf("next() = %d, %v, want %d, %v", v, ok, 0, false) } wantNG(0) } wantNG(0) stop() stop() stop() wantNG(0) }) } } func TestPull2(t *testing.T) { for end := 0; end <= 3; end++ { t.Run(fmt.Sprint(end), func(t *testing.T) { ng := runtime.NumGoroutine() wantNG := func(want int) { if xg := runtime.NumGoroutine() - ng; xg != want { t.Helper() t.Errorf("have %d extra goroutines, want %d", xg, want) } } wantNG(0) next, stop := Pull2(squares(3)) wantNG(1) for i := range end { k, v, ok := next() if k != i || v != int64(i*i) || ok != true { t.Fatalf("next() = %d, %d, %v, want %d, %d, %v", k, v, ok, i, i*i, true) } wantNG(1) } wantNG(1) if end < 3 { stop() wantNG(0) } for range 2 { k, v, ok := next() if v != 0 || ok != false { t.Fatalf("next() = %d, %d, %v, want %d, %d, %v", k, v, ok, 0, 0, false) } wantNG(0) } wantNG(0) stop() stop() stop() wantNG(0) }) } }