class Set
include Enumerable
def self.[](*ary)
new(ary)
end
def initialize(enum = nil, &block) @hash ||= Hash.new
enum.nil? and return
if block
enum.each { |o| add(block[o]) }
else
merge(enum)
end
end
def initialize_copy(orig)
@hash = orig.instance_eval{@hash}.dup
end
def size
@hash.size
end
alias length size
def empty?
@hash.empty?
end
def clear
@hash.clear
self
end
def replace(enum)
if enum.class == self.class
@hash.replace(enum.instance_eval { @hash })
else
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
clear
enum.each { |o| add(o) }
end
self
end
def to_a
@hash.keys
end
def flatten_merge(set, seen = Set.new)
set.each { |e|
if e.is_a?(Set)
if seen.include?(e_id = e.object_id)
raise ArgumentError, "tried to flatten recursive Set"
end
seen.add(e_id)
flatten_merge(e, seen)
seen.delete(e_id)
else
add(e)
end
}
self
end
protected :flatten_merge
def flatten
self.class.new.flatten_merge(self)
end
def flatten!
if detect { |e| e.is_a?(Set) }
replace(flatten())
else
nil
end
end
def include?(o)
@hash.include?(o)
end
alias member? include?
def superset?(set)
set.is_a?(Set) or raise ArgumentError, "value must be a set"
return false if size < set.size
set.all? { |o| include?(o) }
end
def proper_superset?(set)
set.is_a?(Set) or raise ArgumentError, "value must be a set"
return false if size <= set.size
set.all? { |o| include?(o) }
end
def subset?(set)
set.is_a?(Set) or raise ArgumentError, "value must be a set"
return false if set.size < size
all? { |o| set.include?(o) }
end
def proper_subset?(set)
set.is_a?(Set) or raise ArgumentError, "value must be a set"
return false if set.size <= size
all? { |o| set.include?(o) }
end
def each
block_given? or return enum_for(__method__)
@hash.each_key { |o| yield(o) }
self
end
def add(o)
@hash[o] = true
self
end
alias << add
def add?(o)
if include?(o)
nil
else
add(o)
end
end
def delete(o)
@hash.delete(o)
self
end
def delete?(o)
if include?(o)
delete(o)
else
nil
end
end
def delete_if
to_a.each { |o| @hash.delete(o) if yield(o) }
self
end
def collect!
set = self.class.new
each { |o| set << yield(o) }
replace(set)
end
alias map! collect!
def reject!
n = size
delete_if { |o| yield(o) }
size == n ? nil : self
end
def merge(enum)
if enum.is_a?(Set)
@hash.update(enum.instance_eval { @hash })
else
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
enum.each { |o| add(o) }
end
self
end
def subtract(enum)
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
enum.each { |o| delete(o) }
self
end
def |(enum)
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
dup.merge(enum)
end
alias + | alias union |
def -(enum)
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
dup.subtract(enum)
end
alias difference -
def &(enum)
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
n = self.class.new
enum.each { |o| n.add(o) if include?(o) }
n
end
alias intersection &
def ^(enum)
enum.is_a?(Enumerable) or raise ArgumentError, "value must be enumerable"
n = Set.new(enum)
each { |o| if n.include?(o) then n.delete(o) else n.add(o) end }
n
end
def ==(set)
equal?(set) and return true
set.is_a?(Set) && size == set.size or return false
hash = @hash.dup
set.all? { |o| hash.include?(o) }
end
def hash # :nodoc:
@hash.hash
end
def eql?(o) return false unless o.is_a?(Set)
@hash.eql?(o.instance_eval{@hash})
end
def classify # :yields: o
h = {}
each { |i|
x = yield(i)
(h[x] ||= self.class.new).add(i)
}
h
end
def divide(&func)
if func.arity == 2
require 'tsort'
class << dig = {} include TSort
alias tsort_each_node each_key
def tsort_each_child(node, &block)
fetch(node).each(&block)
end
end
each { |u|
dig[u] = a = []
each{ |v| func.call(u, v) and a << v }
}
set = Set.new()
dig.each_strongly_connected_component { |css|
set.add(self.class.new(css))
}
set
else
Set.new(classify(&func).values)
end
end
InspectKey = :__inspect_key__
def inspect
ids = (Thread.current[InspectKey] ||= [])
if ids.include?(object_id)
return sprintf('#<%s: {...}>', self.class.name)
end
begin
ids << object_id
return sprintf('#<%s: {%s}>', self.class, to_a.inspect[1..-2])
ensure
ids.pop
end
end
def pretty_print(pp) pp.text sprintf('#<%s: {', self.class.name)
pp.nest(1) {
pp.seplist(self) { |o|
pp.pp o
}
}
pp.text "}>"
end
def pretty_print_cycle(pp) pp.text sprintf('#<%s: {%s}>', self.class.name, empty? ? '' : '...')
end
end
class SortedSet < Set
@@setup = false
class << self
def [](*ary) new(ary)
end
def setup # :nodoc:
@@setup and return
module_eval {
alias old_init initialize
remove_method :old_init
}
begin
require 'rbtree'
module_eval %{
def initialize(*args, &block)
@hash = RBTree.new
super
end
}
rescue LoadError
module_eval %{
def initialize(*args, &block)
@keys = nil
super
end
def clear
@keys = nil
super
end
def replace(enum)
@keys = nil
super
end
def add(o)
@keys = nil
@hash[o] = true
self
end
alias << add
def delete(o)
@keys = nil
@hash.delete(o)
self
end
def delete_if
n = @hash.size
super
@keys = nil if @hash.size != n
self
end
def merge(enum)
@keys = nil
super
end
def each
block_given? or return enum_for(__method__)
to_a.each { |o| yield(o) }
self
end
def to_a
(@keys = @hash.keys).sort! unless @keys
@keys
end
}
end
@@setup = true
end
end
def initialize(*args, &block) SortedSet.setup
initialize(*args, &block)
end
end
module Enumerable
def to_set(klass = Set, *args, &block)
klass.new(self, *args, &block)
end
end
if $0 == __FILE__
eval DATA.read, nil, $0, __LINE__+4
end
__END__
require 'test/unit'
class TC_Set < Test::Unit::TestCase
def test_aref
assert_nothing_raised {
Set[]
Set[nil]
Set[1,2,3]
}
assert_equal(0, Set[].size)
assert_equal(1, Set[nil].size)
assert_equal(1, Set[[]].size)
assert_equal(1, Set[[nil]].size)
set = Set[2,4,6,4]
assert_equal(Set.new([2,4,6]), set)
end
def test_s_new
assert_nothing_raised {
Set.new()
Set.new(nil)
Set.new([])
Set.new([1,2])
Set.new('a'..'c')
Set.new('XYZ')
}
assert_raises(ArgumentError) {
Set.new(false)
}
assert_raises(ArgumentError) {
Set.new(1)
}
assert_raises(ArgumentError) {
Set.new(1,2)
}
assert_equal(0, Set.new().size)
assert_equal(0, Set.new(nil).size)
assert_equal(0, Set.new([]).size)
assert_equal(1, Set.new([nil]).size)
ary = [2,4,6,4]
set = Set.new(ary)
ary.clear
assert_equal(false, set.empty?)
assert_equal(3, set.size)
ary = [1,2,3]
s = Set.new(ary) { |o| o * 2 }
assert_equal([2,4,6], s.sort)
end
def test_clone
set1 = Set.new
set2 = set1.clone
set1 << 'abc'
assert_equal(Set.new, set2)
end
def test_dup
set1 = Set[1,2]
set2 = set1.dup
assert_not_same(set1, set2)
assert_equal(set1, set2)
set1.add(3)
assert_not_equal(set1, set2)
end
def test_size
assert_equal(0, Set[].size)
assert_equal(2, Set[1,2].size)
assert_equal(2, Set[1,2,1].size)
end
def test_empty?
assert_equal(true, Set[].empty?)
assert_equal(false, Set[1, 2].empty?)
end
def test_clear
set = Set[1,2]
ret = set.clear
assert_same(set, ret)
assert_equal(true, set.empty?)
end
def test_replace
set = Set[1,2]
ret = set.replace('a'..'c')
assert_same(set, ret)
assert_equal(Set['a','b','c'], set)
end
def test_to_a
set = Set[1,2,3,2]
ary = set.to_a
assert_equal([1,2,3], ary.sort)
end
def test_flatten
set1 = Set[
1,
Set[
5,
Set[7,
Set[0]
],
Set[6,2],
1
],
3,
Set[3,4]
]
set2 = set1.flatten
set3 = Set.new(0..7)
assert_not_same(set2, set1)
assert_equal(set3, set2)
orig_set1 = set1
set1.flatten!
assert_same(orig_set1, set1)
assert_equal(set3, set1)
set1 = Set[1, 2]
set2 = Set[set1, Set[set1, 4], 3]
assert_nothing_raised {
set2.flatten!
}
assert_equal(Set.new(1..4), set2)
set2 = Set[]
set1 = Set[1, set2]
set2.add(set1)
assert_raises(ArgumentError) {
set1.flatten!
}
empty = Set[]
set = Set[Set[empty, "a"],Set[empty, "b"]]
assert_nothing_raised {
set.flatten
}
set1 = empty.merge(Set["no_more", set])
assert_nil(Set.new(0..31).flatten!)
x = Set[Set[],Set[1,2]].flatten!
y = Set[1,2]
assert_equal(x, y)
end
def test_include?
set = Set[1,2,3]
assert_equal(true, set.include?(1))
assert_equal(true, set.include?(2))
assert_equal(true, set.include?(3))
assert_equal(false, set.include?(0))
assert_equal(false, set.include?(nil))
set = Set["1",nil,"2",nil,"0","1",false]
assert_equal(true, set.include?(nil))
assert_equal(true, set.include?(false))
assert_equal(true, set.include?("1"))
assert_equal(false, set.include?(0))
assert_equal(false, set.include?(true))
end
def test_superset?
set = Set[1,2,3]
assert_raises(ArgumentError) {
set.superset?()
}
assert_raises(ArgumentError) {
set.superset?(2)
}
assert_raises(ArgumentError) {
set.superset?([2])
}
assert_equal(true, set.superset?(Set[]))
assert_equal(true, set.superset?(Set[1,2]))
assert_equal(true, set.superset?(Set[1,2,3]))
assert_equal(false, set.superset?(Set[1,2,3,4]))
assert_equal(false, set.superset?(Set[1,4]))
assert_equal(true, Set[].superset?(Set[]))
end
def test_proper_superset?
set = Set[1,2,3]
assert_raises(ArgumentError) {
set.proper_superset?()
}
assert_raises(ArgumentError) {
set.proper_superset?(2)
}
assert_raises(ArgumentError) {
set.proper_superset?([2])
}
assert_equal(true, set.proper_superset?(Set[]))
assert_equal(true, set.proper_superset?(Set[1,2]))
assert_equal(false, set.proper_superset?(Set[1,2,3]))
assert_equal(false, set.proper_superset?(Set[1,2,3,4]))
assert_equal(false, set.proper_superset?(Set[1,4]))
assert_equal(false, Set[].proper_superset?(Set[]))
end
def test_subset?
set = Set[1,2,3]
assert_raises(ArgumentError) {
set.subset?()
}
assert_raises(ArgumentError) {
set.subset?(2)
}
assert_raises(ArgumentError) {
set.subset?([2])
}
assert_equal(true, set.subset?(Set[1,2,3,4]))
assert_equal(true, set.subset?(Set[1,2,3]))
assert_equal(false, set.subset?(Set[1,2]))
assert_equal(false, set.subset?(Set[]))
assert_equal(true, Set[].subset?(Set[1]))
assert_equal(true, Set[].subset?(Set[]))
end
def test_proper_subset?
set = Set[1,2,3]
assert_raises(ArgumentError) {
set.proper_subset?()
}
assert_raises(ArgumentError) {
set.proper_subset?(2)
}
assert_raises(ArgumentError) {
set.proper_subset?([2])
}
assert_equal(true, set.proper_subset?(Set[1,2,3,4]))
assert_equal(false, set.proper_subset?(Set[1,2,3]))
assert_equal(false, set.proper_subset?(Set[1,2]))
assert_equal(false, set.proper_subset?(Set[]))
assert_equal(false, Set[].proper_subset?(Set[]))
end
def test_each
ary = [1,3,5,7,10,20]
set = Set.new(ary)
ret = set.each { |o| }
assert_same(set, ret)
e = set.each
assert_instance_of(Enumerable::Enumerator, e)
assert_nothing_raised {
set.each { |o|
ary.delete(o) or raise "unexpected element: #{o}"
}
ary.empty? or raise "forgotten elements: #{ary.join(', ')}"
}
end
def test_add
set = Set[1,2,3]
ret = set.add(2)
assert_same(set, ret)
assert_equal(Set[1,2,3], set)
ret = set.add?(2)
assert_nil(ret)
assert_equal(Set[1,2,3], set)
ret = set.add(4)
assert_same(set, ret)
assert_equal(Set[1,2,3,4], set)
ret = set.add?(5)
assert_same(set, ret)
assert_equal(Set[1,2,3,4,5], set)
end
def test_delete
set = Set[1,2,3]
ret = set.delete(4)
assert_same(set, ret)
assert_equal(Set[1,2,3], set)
ret = set.delete?(4)
assert_nil(ret)
assert_equal(Set[1,2,3], set)
ret = set.delete(2)
assert_equal(set, ret)
assert_equal(Set[1,3], set)
ret = set.delete?(1)
assert_equal(set, ret)
assert_equal(Set[3], set)
end
def test_delete_if
set = Set.new(1..10)
ret = set.delete_if { |i| i > 10 }
assert_same(set, ret)
assert_equal(Set.new(1..10), set)
set = Set.new(1..10)
ret = set.delete_if { |i| i % 3 == 0 }
assert_same(set, ret)
assert_equal(Set[1,2,4,5,7,8,10], set)
end
def test_collect!
set = Set[1,2,3,'a','b','c',-1..1,2..4]
ret = set.collect! { |i|
case i
when Numeric
i * 2
when String
i.upcase
else
nil
end
}
assert_same(set, ret)
assert_equal(Set[2,4,6,'A','B','C',nil], set)
end
def test_reject!
set = Set.new(1..10)
ret = set.reject! { |i| i > 10 }
assert_nil(ret)
assert_equal(Set.new(1..10), set)
ret = set.reject! { |i| i % 3 == 0 }
assert_same(set, ret)
assert_equal(Set[1,2,4,5,7,8,10], set)
end
def test_merge
set = Set[1,2,3]
ret = set.merge([2,4,6])
assert_same(set, ret)
assert_equal(Set[1,2,3,4,6], set)
end
def test_subtract
set = Set[1,2,3]
ret = set.subtract([2,4,6])
assert_same(set, ret)
assert_equal(Set[1,3], set)
end
def test_plus
set = Set[1,2,3]
ret = set + [2,4,6]
assert_not_same(set, ret)
assert_equal(Set[1,2,3,4,6], ret)
end
def test_minus
set = Set[1,2,3]
ret = set - [2,4,6]
assert_not_same(set, ret)
assert_equal(Set[1,3], ret)
end
def test_and
set = Set[1,2,3,4]
ret = set & [2,4,6]
assert_not_same(set, ret)
assert_equal(Set[2,4], ret)
end
def test_xor
set = Set[1,2,3,4]
ret = set ^ [2,4,5,5]
assert_not_same(set, ret)
assert_equal(Set[1,3,5], ret)
end
def test_eq
set1 = Set[2,3,1]
set2 = Set[1,2,3]
assert_equal(set1, set1)
assert_equal(set1, set2)
assert_not_equal(Set[1], [1])
set1 = Class.new(Set)["a", "b"]
set2 = Set["a", "b", set1]
set1 = set1.add(set1.clone)
assert_equal(set2, set2.clone)
assert_equal(set1.clone, set1)
assert_not_equal(Set[Exception.new,nil], Set[Exception.new,Exception.new], "[ruby-dev:26127]")
end
def test_classify
set = Set.new(1..10)
ret = set.classify { |i| i % 3 }
assert_equal(3, ret.size)
assert_instance_of(Hash, ret)
ret.each_value { |value| assert_instance_of(Set, value) }
assert_equal(Set[3,6,9], ret[0])
assert_equal(Set[1,4,7,10], ret[1])
assert_equal(Set[2,5,8], ret[2])
end
def test_divide
set = Set.new(1..10)
ret = set.divide { |i| i % 3 }
assert_equal(3, ret.size)
n = 0
ret.each { |s| n += s.size }
assert_equal(set.size, n)
assert_equal(set, ret.flatten)
set = Set[7,10,5,11,1,3,4,9,0]
ret = set.divide { |a,b| (a - b).abs == 1 }
assert_equal(4, ret.size)
n = 0
ret.each { |s| n += s.size }
assert_equal(set.size, n)
assert_equal(set, ret.flatten)
ret.each { |s|
if s.include?(0)
assert_equal(Set[0,1], s)
elsif s.include?(3)
assert_equal(Set[3,4,5], s)
elsif s.include?(7)
assert_equal(Set[7], s)
elsif s.include?(9)
assert_equal(Set[9,10,11], s)
else
raise "unexpected group: #{s.inspect}"
end
}
end
def test_inspect
set1 = Set[1]
assert_equal('#<Set: {1}>', set1.inspect)
set2 = Set[Set[0], 1, 2, set1]
assert_equal(false, set2.inspect.include?('#<Set: {...}>'))
set1.add(set2)
assert_equal(true, set1.inspect.include?('#<Set: {...}>'))
end
end
class TC_SortedSet < Test::Unit::TestCase
def test_sortedset
s = SortedSet[4,5,3,1,2]
assert_equal([1,2,3,4,5], s.to_a)
prev = nil
s.each { |o| assert(prev < o) if prev; prev = o }
assert_not_nil(prev)
s.map! { |o| -2 * o }
assert_equal([-10,-8,-6,-4,-2], s.to_a)
prev = nil
ret = s.each { |o| assert(prev < o) if prev; prev = o }
assert_not_nil(prev)
assert_same(s, ret)
s = SortedSet.new([2,1,3]) { |o| o * -2 }
assert_equal([-6,-4,-2], s.to_a)
s = SortedSet.new(['one', 'two', 'three', 'four'])
a = []
ret = s.delete_if { |o| a << o; o.start_with?('t') }
assert_same(s, ret)
assert_equal(['four', 'one'], s.to_a)
assert_equal(['four', 'one', 'three', 'two'], a)
s = SortedSet.new(['one', 'two', 'three', 'four'])
a = []
ret = s.reject! { |o| a << o; o.start_with?('t') }
assert_same(s, ret)
assert_equal(['four', 'one'], s.to_a)
assert_equal(['four', 'one', 'three', 'two'], a)
s = SortedSet.new(['one', 'two', 'three', 'four'])
a = []
ret = s.reject! { |o| a << o; false }
assert_same(nil, ret)
assert_equal(['four', 'one', 'three', 'two'], s.to_a)
assert_equal(['four', 'one', 'three', 'two'], a)
end
end
class TC_Enumerable < Test::Unit::TestCase
def test_to_set
ary = [2,5,4,3,2,1,3]
set = ary.to_set
assert_instance_of(Set, set)
assert_equal([1,2,3,4,5], set.sort)
set = ary.to_set { |o| o * -2 }
assert_instance_of(Set, set)
assert_equal([-10,-8,-6,-4,-2], set.sort)
set = ary.to_set(SortedSet)
assert_instance_of(SortedSet, set)
assert_equal([1,2,3,4,5], set.to_a)
set = ary.to_set(SortedSet) { |o| o * -2 }
assert_instance_of(SortedSet, set)
assert_equal([-10,-8,-6,-4,-2], set.sort)
end
end