Skip to content

Commit

Permalink
CombineSets() implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
gobwas committed Oct 9, 2020
1 parent 0548b99 commit b52c183
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 136 deletions.
137 changes: 47 additions & 90 deletions flagutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ var MergeUsage = func(name string, usage0, usage1 string) string {
return usage0 + " / " + usage1
}

// Merge merges new flagset into superset and resolves any name collisions.
// MergeInto merges new flagset into superset and resolves any name collisions.
// It calls setup function to let caller register needed flags within subset
// before they are merged into the superset.
//
Expand All @@ -440,7 +440,7 @@ var MergeUsage = func(name string, usage0, usage1 string) string {
//
// Note that default values (and initial values of where flag.Value points to)
// are kept untouched and may differ if no value is set during parsing phase.
func Merge(super *flag.FlagSet, setup func(*flag.FlagSet)) {
func MergeInto(super *flag.FlagSet, setup func(*flag.FlagSet)) {
fs := flag.NewFlagSet("", flag.ContinueOnError)
setup(fs)
fs.VisitAll(func(next *flag.Flag) {
Expand All @@ -449,47 +449,58 @@ func Merge(super *flag.FlagSet, setup func(*flag.FlagSet)) {
super.Var(next.Value, next.Name, next.Usage)
return
}
*prev = *Combine(prev, next)
*prev = *CombineFlags(prev, next)
})
}

// MergeFlags makes all given flags look like single one. That is, setting
// value of any given flag will cause value of all flags change.
func MergeFlags(fs ...*flag.Flag) {
if len(fs) < 2 {
return
}
var (
noDef bool
latest *flag.Flag
)
for i := 1; i < len(fs); i++ {
f0 := fs[i-1]
f1 := fs[i-0]
c := Combine(f0, f1)
*f0 = *c
*f1 = *c

latest = c
if c.DefValue == "" {
noDef = true
// CombineSets combines given sets into a third one.
// Every collided flags are combined into third one in a way that setting value
// to it sets value of both original flags.
func CombineSets(fs0, fs1 *flag.FlagSet) *flag.FlagSet {
// TODO: join Name().
super := flag.NewFlagSet("", flag.ContinueOnError)
fs0.VisitAll(func(f0 *flag.Flag) {
var v flag.Value
f1 := fs1.Lookup(f0.Name)
if f1 != nil {
// Same flag exists in fs1 flag set.
f0 = CombineFlags(f0, f1)
}
}
for i := 0; i < len(fs); i++ {
fs[i].Usage = latest.Usage
fs[i].Value = latest.Value
if noDef {
fs[i].DefValue = ""
v = OverrideSet(f0.Value, func(value string) (err error) {
err = fs0.Set(f0.Name, value)
if err != nil {
return
}
if f1 == nil {
return
}
err = fs1.Set(f1.Name, value)
if err != nil {
return
}
return nil
})
super.Var(v, f0.Name, f0.Usage)
})
fs1.VisitAll(func(f1 *flag.Flag) {
if super.Lookup(f1.Name) != nil {
// Already combined.
return
}
}
v := OverrideSet(f1.Value, func(value string) error {
return fs1.Set(f1.Name, value)
})
super.Var(v, f1.Name, f1.Usage)
})
return super
}

// Combine combines given flags into a third one.
// Setting value of returned flag will cause both given flags change their
// values as well.
// CombineFlags combines given flags into a third one. Setting value of
// returned flag will cause both given flags change their values as well.
// However, flag sets of both flags will not be aware that the flags were set.
//
// Description of each flag (if differ) is joined by MergeUsage().
func Combine(f0, f1 *flag.Flag) *flag.Flag {
func CombineFlags(f0, f1 *flag.Flag) *flag.Flag {
if f0.Name != f1.Name {
panic(fmt.Sprintf(
"flagutil: can't combine flags with different names: %q vs %q",
Expand All @@ -501,11 +512,8 @@ func Combine(f0, f1 *flag.Flag) *flag.Flag {
Value: valuePair{f0.Value, f1.Value},
Usage: mergeUsage(f0.Name, f0.Usage, f1.Usage),
}
// Clear default values to be printed in usage to empty string since
// they are different.
if f0.DefValue == f1.DefValue {
r.DefValue = f0.DefValue
}
// This is how flag.FlagSet() does it in its Var() method.
r.DefValue = r.Value.String()
return &r
}

Expand All @@ -521,54 +529,3 @@ func mergeUsage(name, s0, s1 string) string {
return MergeUsage(name, s0, s1)
}
}

type valuePair [2]flag.Value

func (p valuePair) Set(val string) error {
for _, v := range p {
if err := v.Set(val); err != nil {
return err
}
}
return nil
}

func (p valuePair) Get() interface{} {
var (
v0 interface{}
v1 interface{}
)
if g0, ok := p[0].(flag.Getter); ok {
v0 = g0.Get()
}
if g1, ok := p[1].(flag.Getter); ok {
v1 = g1.Get()
}
if !reflect.DeepEqual(v0, v1) {
return nil
}
return v0
}

func (p valuePair) String() string {
if p.isZero() {
return ""
}
s0 := p[0].String()
s1 := p[1].String()
if s0 != s1 {
return ""
}
return s0
}

func (p valuePair) IsBoolFlag() bool {
if isBoolValue(p[0]) && isBoolValue(p[1]) {
return true
}
return false
}

func (p valuePair) isZero() bool {
return p == valuePair{}
}
151 changes: 105 additions & 46 deletions flagutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,78 +250,143 @@ func TestMerge(t *testing.T) {
}
}

func TestMergeFlags(t *testing.T) {
func TestCombineSets(t *testing.T) {
var (
nameInBoth = "both"
nameInFirst = "first"
nameInSecond = "second"
nameUnknown = "whoa"
)
var (
fs0 = flag.NewFlagSet("FlagSet#0", flag.ContinueOnError)
fs1 = flag.NewFlagSet("FlagSet#1", flag.ContinueOnError)
)
fs0.String(nameInFirst, "first-default", "")
fs0.String(nameInBoth, "both-default-0", "")
fs1.String(nameInBoth, "both-default-1", "")
fs1.String(nameInSecond, "second-default", "")

fs := CombineSets(fs0, fs1)

mustNotBeDefined(t, fs, nameUnknown)
mustBeEqualTo(t, fs, nameInFirst, "first-default")
mustBeEqualTo(t, fs, nameInSecond, "second-default")
mustBeEqualTo(t, fs, nameInBoth, "")

mustNotSet(t, fs, nameUnknown, "want error")

mustSet(t, fs, nameInFirst, "first")
mustBeEqualTo(t, fs, nameInFirst, "first")
mustBeEqualTo(t, fs0, nameInFirst, "first")

mustSet(t, fs, nameInSecond, "second")
mustBeEqualTo(t, fs, nameInSecond, "second")
mustBeEqualTo(t, fs1, nameInSecond, "second")

mustSet(t, fs, nameInBoth, "both")
mustBeEqualTo(t, fs, nameInBoth, "both")
mustBeEqualTo(t, fs0, nameInBoth, "both")
mustBeEqualTo(t, fs1, nameInBoth, "both")
}

func mustNotSet(t *testing.T, fs *flag.FlagSet, name, value string) {
if err := fs.Set(name, value); err == nil {
t.Fatalf(
"want error on setting flag %q value to %q: %v",
name, value, err,
)
}
}

func mustSet(t *testing.T, fs *flag.FlagSet, name, value string) {
if err := fs.Set(name, value); err != nil {
t.Fatalf("can't set flag %q value to %q: %v", name, value, err)
}
}

func mustBeEqualTo(t *testing.T, fs *flag.FlagSet, name, value string) {
mustBeDefined(t, fs, name)
if act, exp := fs.Lookup(name).Value.String(), value; act != exp {
t.Fatalf("flag %q value is %q; want %q", name, act, exp)
}
}

func mustNotBeDefined(t *testing.T, fs *flag.FlagSet, name string) {
if fs.Lookup(name) != nil {
t.Fatalf("want flag %q to not be present in set", name)
}
}

func mustBeDefined(t *testing.T, fs *flag.FlagSet, name string) {
if fs.Lookup(name) == nil {
t.Fatalf("want flag %q to be present in set", name)
}
}

func TestCombineFlags(t *testing.T) {
for _, test := range []struct {
name string
flags []flag.Flag
exp []flag.Flag
flags [2]flag.Flag
exp flag.Flag
panic bool
}{
{
name: "different names",
flags: []flag.Flag{
flags: [2]flag.Flag{
stringFlag("foo", "def", "desc#0"),
stringFlag("bar", "def", "desc#1"),
},
panic: true,
},
{
name: "different default values",
flags: []flag.Flag{
flags: [2]flag.Flag{
stringFlag("foo", "def#0", "desc#0"),
stringFlag("foo", "def#1", "desc#1"),
},
exp: []flag.Flag{
stringFlag("foo", "", "desc#0 / desc#1"),
stringFlag("foo", "", "desc#0 / desc#1"),
},
exp: stringFlag("foo", "", "desc#0 / desc#1"),
},
{
name: "basic",
flags: []flag.Flag{
flags: [2]flag.Flag{
stringFlag("foo", "def", "desc#0"),
stringFlag("foo", "def", "desc#1"),
},
exp: []flag.Flag{
stringFlag("foo", "def", "desc#0 / desc#1"),
stringFlag("foo", "def", "desc#0 / desc#1"),
},
exp: stringFlag("foo", "def", "desc#0 / desc#1"),
},
{
name: "basic",
flags: []flag.Flag{
flags: [2]flag.Flag{
stringFlag("foo", "def", "desc#0"),
stringFlag("foo", "def", "desc#1"),
stringFlag("foo", "", "desc#2"),
},
exp: []flag.Flag{
stringFlag("foo", "", "desc#0 / desc#1 / desc#2"),
stringFlag("foo", "", "desc#0 / desc#1 / desc#2"),
stringFlag("foo", "", "desc#0 / desc#1 / desc#2"),
stringFlag("foo", "", "desc#1"),
},
exp: stringFlag("foo", "", "desc#0 / desc#1"),
},
} {
t.Run(test.name, func(t *testing.T) {
if !test.panic && len(test.flags) != len(test.exp) {
t.Skip("malformed test")
type flagOrPanic struct {
flag *flag.Flag
panic interface{}
}
ptrs := make([]*flag.Flag, len(test.flags))
for i := range test.flags {
ptrs[i] = &test.flags[i]
}
done := make(chan interface{})
done := make(chan flagOrPanic)
go func() {
defer func() {
done <- recover()
if p := recover(); p != nil {
done <- flagOrPanic{
panic: p,
}
}
}()
MergeFlags(ptrs...)
done <- flagOrPanic{
flag: CombineFlags(&test.flags[0], &test.flags[1]),
}
}()
p := <-done
if !test.panic && p != nil {
t.Fatalf("panic() recovered: %s", p)
x := <-done
if !test.panic && x.panic != nil {
t.Fatalf("panic() recovered: %s", x.panic)
}
if test.panic {
if p == nil {
if x.panic == nil {
t.Fatalf("want panic; got nothing")
}
return
Expand All @@ -331,22 +396,16 @@ func TestMergeFlags(t *testing.T) {
return v.String()
}),
}
for i, exp := range test.exp {
act := test.flags[i]
if !cmp.Equal(act, exp, opts...) {
t.Errorf("unexpected #%d flag:\n%s", i, cmp.Diff(exp, act, opts...))
}
if act, exp := x.flag, &test.exp; !cmp.Equal(act, exp, opts...) {
t.Errorf("unexpected flag:\n%s", cmp.Diff(exp, act, opts...))
}
exp := fmt.Sprintf("%x", rand.Int63())
if err := test.flags[0].Value.Set(exp); err != nil {
if err := x.flag.Value.Set(exp); err != nil {
t.Fatalf("unexpected Set() error: %v", err)
}
for i, f := range test.flags {
for _, f := range test.flags {
if act := f.Value.String(); act != exp {
t.Errorf(
"unexpected #%d flag value: %q; want %q",
i, act, exp,
)
t.Errorf("unexpected flag value: %s; want %s", act, exp)
}
}
})
Expand Down
Loading

0 comments on commit b52c183

Please sign in to comment.