bird_list_ip/net_tree.py
2024-09-28 01:00:13 +10:00

214 lines
7.7 KiB
Python

BIG_MASK = (1 << 32) - 1
def getMaskByMaskSize(mask_size):
return BIG_MASK ^ ((1 << (32 - mask_size)) - 1)
def getIpVolumeByMaskSize(mask_size):
return 1 << (32 - mask_size)
class Net:
__slots__ = ['mask_size', 'net', 'mask', 'ip_volume']
def __init__(self, net: int, mask_size: int):
self.mask_size = mask_size
self.net = net & getMaskByMaskSize(mask_size)
self.mask = getMaskByMaskSize(self.mask_size)
self.ip_volume = getIpVolumeByMaskSize(mask_size)
def hasSubnet(self, Net: 'Net'):
if Net.mask_size <= self.mask_size: return 0
return self.net == Net.net & self.mask
def isSameNet(self, Net: 'Net'):
return (Net.mask_size == self.mask_size) and (Net.net == self.net)
def getCommonNet(self, OtherNet: 'Net', min_mask_size: int):
if self.mask_size <= min_mask_size: return 0
if OtherNet.mask_size <= min_mask_size: return 0
for mask_size in range(min(self.mask_size, OtherNet.mask_size) - 1, min_mask_size - 1, -1):
mask = getMaskByMaskSize(mask_size)
if (self.net & mask) == (OtherNet.net & mask):
return Net(self.net, mask_size)
return 0
def getAsString(self, fmt='{addr}/{masklen}'):
net = self.net
mask = self.mask
addrbytes = []
maskbytes = []
for i in range(4):
addrbytes.append(str(net % 256))
maskbytes.append(str(mask % 256))
net = net >> 8
mask = mask >> 8
return fmt.format(addr='.'.join(reversed(addrbytes)), mask='.'.join(reversed(maskbytes)), masklen=self.mask_size)
class Node:
__slots__ = ['net', 'child1', 'child2', 'is_real_net', 'real_ip_volume', 'real_ip_records_count', 'weight', 'max_child_weight', 'added_fake_ip_volume']
def __init__(self, net: Net, is_real_net: int):
self.net = net
self.child1 = None
self.child2 = None
self.is_real_net = is_real_net
self.real_ip_volume = 0
self.real_ip_records_count = 0
self.weight = 0.0
self.max_child_weight = 0.0
self.added_fake_ip_volume = 0
def getNet(self):
return self.net
def addSubnet(self, NewNode: 'Node'):
if self.net.isSameNet(NewNode.net):
if not self.is_real_net and NewNode.is_real_net:
self.is_real_net = 1
self.child1 = None
self.child2 = None
return 1
if self.is_real_net and self.net.hasSubnet(NewNode.net):
return 1
if not self.net.hasSubnet(NewNode.net):
return 0
for Child in (self.child1, self.child2):
if Child and Child.addSubnet(NewNode):
return 1
if self.child1:
CommonNet = self.child1.net.getCommonNet(NewNode.net, self.net.mask_size + 1)
if CommonNet:
CommonNode = Node(CommonNet, 0)
CommonNode.addSubnet(NewNode)
CommonNode.addSubnet(self.child1)
self.child1 = CommonNode
return 1
if self.child2:
CommonNet = self.child2.net.getCommonNet(NewNode.net, self.net.mask_size + 1)
if CommonNet:
CommonNode = Node(CommonNet, 0)
CommonNode.addSubnet(NewNode)
CommonNode.addSubnet(self.child2)
self.child2 = CommonNode
return 1
if not self.child1:
self.child1 = NewNode
else:
self.child2 = NewNode
return 1
def printTree(self, level):
prefix = ''
for i in range(level):
prefix = prefix + ' '
if self.is_real_net: sign = '*'
elif self.weight == 0: sign = '.'
else: sign = ''
print(prefix + self.net.getAsString() + ' ' + str(self.real_ip_records_count))
if self.child1:
self.child1.printTree(level + 1)
if self.child2:
self.child2.printTree(level + 1)
def finishTreeFirst(self):
if self.is_real_net:
self.real_ip_volume = self.net.ip_volume
self.real_ip_records_count = 1
self.weight = 0
self.max_child_weight = 0
else:
self.real_ip_volume = 0
self.real_ip_records_count = 0
self.max_child_weight = 0
for Child in (self.child1, self.child2):
if Child:
Child.finishTreeFirst()
self.real_ip_volume += Child.real_ip_volume
self.real_ip_records_count += Child.real_ip_records_count
self.max_child_weight = max(self.max_child_weight, Child.weight, Child.max_child_weight)
self.recalcWeight()
def collapse(self, min_weight, max_net_delta):
# trying to collapse self
if self.weight >= min_weight:
self.weight = 0
self.max_child_weight = 0
delta = (self.net.ip_volume - self.real_ip_volume) - self.added_fake_ip_volume
self.added_fake_ip_volume = self.net.ip_volume - self.real_ip_volume
return self.real_ip_records_count - 1, delta
net_delta = 0
fake_ip_delta = 0
self.max_child_weight = 0
for Child in (self.child1, self.child2):
if Child:
if net_delta < max_net_delta and min_weight <= max(Child.weight, Child.max_child_weight):
child_net_delta, child_fake_ip_count = Child.collapse(min_weight, max_net_delta - net_delta)
net_delta += child_net_delta
fake_ip_delta += child_fake_ip_count
self.max_child_weight = max(self.max_child_weight, Child.weight, Child.max_child_weight)
if net_delta > 0:
self.added_fake_ip_volume += fake_ip_delta
self.real_ip_records_count -= net_delta
self.recalcWeight()
# trying to collapse self
if self.weight >= min_weight:
self.weight = 0
self.max_child_weight = 0
delta = (self.net.ip_volume - self.real_ip_volume) - (self.added_fake_ip_volume - fake_ip_delta)
self.added_fake_ip_volume = self.net.ip_volume - self.real_ip_volume
return self.real_ip_records_count - 1, delta
else:
return net_delta, fake_ip_delta
def collapseRoot(self, required_net_delta):
while required_net_delta > 0:
delta, fake_ip_volume = self.collapse(self.max_child_weight, required_net_delta)
required_net_delta -= delta
def printCollapsedTree(self, fmt='{addr}/{masklen}'):
if self.is_real_net or self.weight == 0:
print(self.net.getAsString(fmt))
else:
for Child in (self.child1, self.child2):
if Child:
Child.printCollapsedTree(fmt)
def returnCollapsedTree(self, fmt='{addr}/{masklen}'):
if self.is_real_net or self.weight == 0:
return self.net.getAsString(fmt) + "\n"
else:
res = ""
for Child in (self.child1, self.child2):
if Child:
res += Child.returnCollapsedTree(fmt)
return res
def recalcWeight(self):
fake_ip_delta = self.net.ip_volume - self.real_ip_volume - self.added_fake_ip_volume
if fake_ip_delta:
self.weight = (self.real_ip_records_count - 1) / fake_ip_delta
else:
self.weight = float('Inf')
def getNotRealIpCount(self):
if self.is_real_net: return 0
if self.weight == 0: return self.net.ip_volume - self.real_ip_volume
res = 0
for Child in (self.child1, self.child2):
if Child:
res = res + Child.getNotRealIpCount()
return res